use axum::{ extract::ws::{Message, WebSocket}, http::HeaderMap, }; use futures_util::{SinkExt, StreamExt}; use tokio_tungstenite::{ connect_async, tungstenite::{Message as TungsteniteMessage, client::IntoClientRequest} }; use tracing::{debug, error, warn}; use url::Url; pub async fn proxy_websocket( socket: WebSocket, target_url: &str, path: &str, query: Option<&str>, headers: &HeaderMap, ) { // Build the target WebSocket URL let ws_url = if target_url.starts_with("http://") { target_url.replacen("http://", "ws://", 1) } else if target_url.starts_with("https://") { target_url.replacen("https://", "wss://", 1) } else { format!("ws://{}", target_url) }; let mut full_url = format!("{}{}", ws_url, path); if let Some(q) = query { full_url = format!("{}?{}", full_url, q); } debug!("Proxying WebSocket connection to: {}", full_url); // Parse the URL let url = match Url::parse(&full_url) { Ok(url) => url, Err(e) => { error!("Invalid WebSocket URL {}: {}", full_url, e); let _ = socket.close().await; return; } }; // Create request with headers let mut request = url.into_client_request().unwrap(); // Copy relevant headers from the original request for (name, value) in headers { // Skip headers that shouldn't be forwarded let header_name = name.as_str().to_lowercase(); if header_name == "host" || header_name == "connection" || header_name == "upgrade" || header_name == "sec-websocket-key" || header_name == "sec-websocket-version" || header_name == "sec-websocket-protocol" || header_name == "sec-websocket-extensions" { continue; } if let Ok(_header_value) = value.to_str() { request.headers_mut().insert(name, value.clone()); } } // Connect to the target WebSocket let (target_ws, _) = match connect_async(request).await { Ok(result) => result, Err(e) => { error!("Failed to connect to target WebSocket: {}", e); let _ = socket.close().await; return; } }; debug!("Connected to target WebSocket"); let (target_sink, target_stream) = target_ws.split(); // Spawn task to forward messages from client to target let (client_sink, client_stream) = socket.split(); let client_to_target = async move { let mut target_sink = target_sink; let mut client_stream = client_stream; while let Some(msg) = client_stream.next().await { match msg { Ok(Message::Text(text)) => { if let Err(e) = target_sink.send(TungsteniteMessage::Text(text)).await { error!("Failed to send text message to target: {}", e); break; } } Ok(Message::Binary(data)) => { if let Err(e) = target_sink.send(TungsteniteMessage::Binary(data)).await { error!("Failed to send binary message to target: {}", e); break; } } Ok(Message::Ping(data)) => { if let Err(e) = target_sink.send(TungsteniteMessage::Ping(data)).await { error!("Failed to send ping to target: {}", e); break; } } Ok(Message::Pong(data)) => { if let Err(e) = target_sink.send(TungsteniteMessage::Pong(data)).await { error!("Failed to send pong to target: {}", e); break; } } Ok(Message::Close(_)) => { debug!("Client closed WebSocket connection"); let _ = target_sink.send(TungsteniteMessage::Close(None)).await; break; } Err(e) => { warn!("WebSocket error from client: {}", e); break; } } } }; let target_to_client = async move { let mut client_sink = client_sink; let mut target_stream = target_stream; while let Some(msg) = target_stream.next().await { match msg { Ok(TungsteniteMessage::Text(text)) => { if let Err(e) = client_sink.send(Message::Text(text)).await { error!("Failed to send text message to client: {}", e); break; } } Ok(TungsteniteMessage::Binary(data)) => { if let Err(e) = client_sink.send(Message::Binary(data)).await { error!("Failed to send binary message to client: {}", e); break; } } Ok(TungsteniteMessage::Ping(data)) => { if let Err(e) = client_sink.send(Message::Ping(data)).await { error!("Failed to send ping to client: {}", e); break; } } Ok(TungsteniteMessage::Pong(data)) => { if let Err(e) = client_sink.send(Message::Pong(data)).await { error!("Failed to send pong to client: {}", e); break; } } Ok(TungsteniteMessage::Close(_)) => { debug!("Target closed WebSocket connection"); let _ = client_sink.send(Message::Close(None)).await; break; } Ok(TungsteniteMessage::Frame(_)) => { // Frame messages are low-level and should be handled automatically debug!("Received frame message, ignoring"); } Err(e) => { warn!("WebSocket error from target: {}", e); break; } } } }; // Run both forwarding tasks concurrently tokio::select! { _ = client_to_target => { debug!("Client to target forwarding finished"); } _ = target_to_client => { debug!("Target to client forwarding finished"); } } debug!("WebSocket proxy connection closed"); }