Files
AutoJellyProxy/src/websocket.rs

187 lines
6.5 KiB
Rust

use axum::{
extract::ws::{Message, WebSocket},
http::HeaderMap,
};
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::{
connect_async,
tungstenite::{Message as TungsteniteMessage, client::IntoClientRequest}
};
use tracing::{debug, error, warn};
use url::Url;
pub async fn proxy_websocket(
socket: WebSocket,
target_url: &str,
path: &str,
query: Option<&str>,
headers: &HeaderMap,
) {
// Build the target WebSocket URL
let ws_url = if target_url.starts_with("http://") {
target_url.replacen("http://", "ws://", 1)
} else if target_url.starts_with("https://") {
target_url.replacen("https://", "wss://", 1)
} else {
format!("ws://{}", target_url)
};
let mut full_url = format!("{}{}", ws_url, path);
if let Some(q) = query {
full_url = format!("{}?{}", full_url, q);
}
debug!("Proxying WebSocket connection to: {}", full_url);
// Parse the URL
let url = match Url::parse(&full_url) {
Ok(url) => url,
Err(e) => {
error!("Invalid WebSocket URL {}: {}", full_url, e);
let _ = socket.close().await;
return;
}
};
// Create request with headers
let mut request = url.into_client_request().unwrap();
// Copy relevant headers from the original request
for (name, value) in headers {
// Skip headers that shouldn't be forwarded
let header_name = name.as_str().to_lowercase();
if header_name == "host"
|| header_name == "connection"
|| header_name == "upgrade"
|| header_name == "sec-websocket-key"
|| header_name == "sec-websocket-version"
|| header_name == "sec-websocket-protocol"
|| header_name == "sec-websocket-extensions" {
continue;
}
if let Ok(_header_value) = value.to_str() {
request.headers_mut().insert(name, value.clone());
}
}
// Connect to the target WebSocket
let (target_ws, _) = match connect_async(request).await {
Ok(result) => result,
Err(e) => {
error!("Failed to connect to target WebSocket: {}", e);
let _ = socket.close().await;
return;
}
};
debug!("Connected to target WebSocket");
let (target_sink, target_stream) = target_ws.split();
// Spawn task to forward messages from client to target
let (client_sink, client_stream) = socket.split();
let client_to_target = async move {
let mut target_sink = target_sink;
let mut client_stream = client_stream;
while let Some(msg) = client_stream.next().await {
match msg {
Ok(Message::Text(text)) => {
if let Err(e) = target_sink.send(TungsteniteMessage::Text(text)).await {
error!("Failed to send text message to target: {}", e);
break;
}
}
Ok(Message::Binary(data)) => {
if let Err(e) = target_sink.send(TungsteniteMessage::Binary(data)).await {
error!("Failed to send binary message to target: {}", e);
break;
}
}
Ok(Message::Ping(data)) => {
if let Err(e) = target_sink.send(TungsteniteMessage::Ping(data)).await {
error!("Failed to send ping to target: {}", e);
break;
}
}
Ok(Message::Pong(data)) => {
if let Err(e) = target_sink.send(TungsteniteMessage::Pong(data)).await {
error!("Failed to send pong to target: {}", e);
break;
}
}
Ok(Message::Close(_)) => {
debug!("Client closed WebSocket connection");
let _ = target_sink.send(TungsteniteMessage::Close(None)).await;
break;
}
Err(e) => {
warn!("WebSocket error from client: {}", e);
break;
}
}
}
};
let target_to_client = async move {
let mut client_sink = client_sink;
let mut target_stream = target_stream;
while let Some(msg) = target_stream.next().await {
match msg {
Ok(TungsteniteMessage::Text(text)) => {
if let Err(e) = client_sink.send(Message::Text(text)).await {
error!("Failed to send text message to client: {}", e);
break;
}
}
Ok(TungsteniteMessage::Binary(data)) => {
if let Err(e) = client_sink.send(Message::Binary(data)).await {
error!("Failed to send binary message to client: {}", e);
break;
}
}
Ok(TungsteniteMessage::Ping(data)) => {
if let Err(e) = client_sink.send(Message::Ping(data)).await {
error!("Failed to send ping to client: {}", e);
break;
}
}
Ok(TungsteniteMessage::Pong(data)) => {
if let Err(e) = client_sink.send(Message::Pong(data)).await {
error!("Failed to send pong to client: {}", e);
break;
}
}
Ok(TungsteniteMessage::Close(_)) => {
debug!("Target closed WebSocket connection");
let _ = client_sink.send(Message::Close(None)).await;
break;
}
Ok(TungsteniteMessage::Frame(_)) => {
// Frame messages are low-level and should be handled automatically
debug!("Received frame message, ignoring");
}
Err(e) => {
warn!("WebSocket error from target: {}", e);
break;
}
}
}
};
// Run both forwarding tasks concurrently
tokio::select! {
_ = client_to_target => {
debug!("Client to target forwarding finished");
}
_ = target_to_client => {
debug!("Target to client forwarding finished");
}
}
debug!("WebSocket proxy connection closed");
}