187 lines
6.5 KiB
Rust
187 lines
6.5 KiB
Rust
use axum::{
|
|
extract::ws::{Message, WebSocket},
|
|
http::HeaderMap,
|
|
};
|
|
use futures_util::{SinkExt, StreamExt};
|
|
use tokio_tungstenite::{
|
|
connect_async,
|
|
tungstenite::{Message as TungsteniteMessage, client::IntoClientRequest}
|
|
};
|
|
use tracing::{debug, error, warn};
|
|
use url::Url;
|
|
|
|
pub async fn proxy_websocket(
|
|
socket: WebSocket,
|
|
target_url: &str,
|
|
path: &str,
|
|
query: Option<&str>,
|
|
headers: &HeaderMap,
|
|
) {
|
|
// Build the target WebSocket URL
|
|
let ws_url = if target_url.starts_with("http://") {
|
|
target_url.replacen("http://", "ws://", 1)
|
|
} else if target_url.starts_with("https://") {
|
|
target_url.replacen("https://", "wss://", 1)
|
|
} else {
|
|
format!("ws://{}", target_url)
|
|
};
|
|
|
|
let mut full_url = format!("{}{}", ws_url, path);
|
|
if let Some(q) = query {
|
|
full_url = format!("{}?{}", full_url, q);
|
|
}
|
|
|
|
debug!("Proxying WebSocket connection to: {}", full_url);
|
|
|
|
// Parse the URL
|
|
let url = match Url::parse(&full_url) {
|
|
Ok(url) => url,
|
|
Err(e) => {
|
|
error!("Invalid WebSocket URL {}: {}", full_url, e);
|
|
let _ = socket.close().await;
|
|
return;
|
|
}
|
|
};
|
|
|
|
// Create request with headers
|
|
let mut request = url.into_client_request().unwrap();
|
|
|
|
// Copy relevant headers from the original request
|
|
for (name, value) in headers {
|
|
// Skip headers that shouldn't be forwarded
|
|
let header_name = name.as_str().to_lowercase();
|
|
if header_name == "host"
|
|
|| header_name == "connection"
|
|
|| header_name == "upgrade"
|
|
|| header_name == "sec-websocket-key"
|
|
|| header_name == "sec-websocket-version"
|
|
|| header_name == "sec-websocket-protocol"
|
|
|| header_name == "sec-websocket-extensions" {
|
|
continue;
|
|
}
|
|
|
|
if let Ok(_header_value) = value.to_str() {
|
|
request.headers_mut().insert(name, value.clone());
|
|
}
|
|
}
|
|
|
|
// Connect to the target WebSocket
|
|
let (target_ws, _) = match connect_async(request).await {
|
|
Ok(result) => result,
|
|
Err(e) => {
|
|
error!("Failed to connect to target WebSocket: {}", e);
|
|
let _ = socket.close().await;
|
|
return;
|
|
}
|
|
};
|
|
|
|
debug!("Connected to target WebSocket");
|
|
|
|
let (target_sink, target_stream) = target_ws.split();
|
|
|
|
// Spawn task to forward messages from client to target
|
|
let (client_sink, client_stream) = socket.split();
|
|
|
|
let client_to_target = async move {
|
|
let mut target_sink = target_sink;
|
|
let mut client_stream = client_stream;
|
|
|
|
while let Some(msg) = client_stream.next().await {
|
|
match msg {
|
|
Ok(Message::Text(text)) => {
|
|
if let Err(e) = target_sink.send(TungsteniteMessage::Text(text)).await {
|
|
error!("Failed to send text message to target: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
Ok(Message::Binary(data)) => {
|
|
if let Err(e) = target_sink.send(TungsteniteMessage::Binary(data)).await {
|
|
error!("Failed to send binary message to target: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
Ok(Message::Ping(data)) => {
|
|
if let Err(e) = target_sink.send(TungsteniteMessage::Ping(data)).await {
|
|
error!("Failed to send ping to target: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
Ok(Message::Pong(data)) => {
|
|
if let Err(e) = target_sink.send(TungsteniteMessage::Pong(data)).await {
|
|
error!("Failed to send pong to target: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
Ok(Message::Close(_)) => {
|
|
debug!("Client closed WebSocket connection");
|
|
let _ = target_sink.send(TungsteniteMessage::Close(None)).await;
|
|
break;
|
|
}
|
|
Err(e) => {
|
|
warn!("WebSocket error from client: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
let target_to_client = async move {
|
|
let mut client_sink = client_sink;
|
|
let mut target_stream = target_stream;
|
|
|
|
while let Some(msg) = target_stream.next().await {
|
|
match msg {
|
|
Ok(TungsteniteMessage::Text(text)) => {
|
|
if let Err(e) = client_sink.send(Message::Text(text)).await {
|
|
error!("Failed to send text message to client: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
Ok(TungsteniteMessage::Binary(data)) => {
|
|
if let Err(e) = client_sink.send(Message::Binary(data)).await {
|
|
error!("Failed to send binary message to client: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
Ok(TungsteniteMessage::Ping(data)) => {
|
|
if let Err(e) = client_sink.send(Message::Ping(data)).await {
|
|
error!("Failed to send ping to client: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
Ok(TungsteniteMessage::Pong(data)) => {
|
|
if let Err(e) = client_sink.send(Message::Pong(data)).await {
|
|
error!("Failed to send pong to client: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
Ok(TungsteniteMessage::Close(_)) => {
|
|
debug!("Target closed WebSocket connection");
|
|
let _ = client_sink.send(Message::Close(None)).await;
|
|
break;
|
|
}
|
|
Ok(TungsteniteMessage::Frame(_)) => {
|
|
// Frame messages are low-level and should be handled automatically
|
|
debug!("Received frame message, ignoring");
|
|
}
|
|
Err(e) => {
|
|
warn!("WebSocket error from target: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
// Run both forwarding tasks concurrently
|
|
tokio::select! {
|
|
_ = client_to_target => {
|
|
debug!("Client to target forwarding finished");
|
|
}
|
|
_ = target_to_client => {
|
|
debug!("Target to client forwarding finished");
|
|
}
|
|
}
|
|
|
|
debug!("WebSocket proxy connection closed");
|
|
}
|