AI-generate first working version of the app

This commit is contained in:
2025-08-11 00:20:14 +02:00
parent 90e7e26f79
commit b5556a78ac
13 changed files with 5232 additions and 3 deletions

14
.env.example Normal file
View File

@@ -0,0 +1,14 @@
# SFTP Configuration for Jellyfin Database Access
SFTP_HOST=your.jellyfin.server.com
SFTP_PORT=22
SFTP_USER=your_username
SFTP_PASSWORD=your_password
SFTP_PATH=/var/lib/jellyfin/data/jellyfin.db
# Jellyfin Server Configuration
JELLYFIN_URL=http://your.jellyfin.server.com:8096
JELLYFIN_API_KEY=your_api_key_here
# Power Management Commands
JELLYFIN_POWER_ON_COMMAND=wakeonlan aa:bb:cc:dd:ee:ff
JELLYFIN_HIBERNATE_COMMAND=ssh user@server "sudo systemctl hibernate"

2485
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,30 @@
[package]
name = "AutoJellyProxy"
version = "0.1.0"
edition = "2024"
edition = "2021"
[dependencies]
tokio = { version = "1.0", features = ["full"] }
axum = { version = "0.7", features = ["ws"] }
tower = "0.4"
tower-http = { version = "0.5", features = ["fs", "cors"] }
tokio-tungstenite = "0.21"
futures-util = "0.3"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
reqwest = { version = "0.11", features = ["json"] }
rusqlite = { version = "0.31", features = ["bundled"] }
ssh2 = "0.9"
sha2 = "0.10"
hex = "0.4"
pbkdf2 = { version = "0.12", features = ["simple"] }
hmac = "0.12"
dotenv = "0.15"
tracing = "0.1"
tracing-subscriber = "0.3"
anyhow = "1.0"
chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1.0", features = ["v4"] }
base64 = "0.22"
url = "2.5"
urlencoding = "2.1"

664
login.html Normal file
View File

@@ -0,0 +1,664 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Jellyfin</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Helvetica Neue', Arial, sans-serif;
background: #0d0d0d;
color: white;
height: 100vh;
overflow: hidden;
}
.hero-background {
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
background-image:
linear-gradient(to right, rgba(0,0,0,0.8) 30%, rgba(0,0,0,0.4) 60%, rgba(0,0,0,0.8) 100%),
linear-gradient(to bottom, rgba(0,0,0,0.2) 0%, rgba(0,0,0,0.6) 100%),
url('https://cdn.pixabay.com/photo/2021/12/12/20/00/popcorn-6865976_1280.jpg');
background-size: cover;
background-position: center;
animation: slowZoom 20s ease-in-out infinite alternate;
}
@keyframes slowZoom {
0% { transform: scale(1); }
100% { transform: scale(1.05); }
}
.content-overlay {
position: relative;
z-index: 10;
height: 100vh;
display: flex;
align-items: center;
justify-content: space-between;
padding: 0 4%;
}
.branding-section {
flex: 1;
max-width: 600px;
padding-right: 60px;
}
.logo-container {
margin-bottom: 40px;
}
.logo {
display: flex;
align-items: center;
margin-bottom: 20px;
}
.logo-icon {
width: 50px;
height: 50px;
background: #e50914;
border-radius: 8px;
display: flex;
align-items: center;
justify-content: center;
margin-right: 15px;
}
.logo-icon svg {
width: 28px;
height: 28px;
fill: white;
}
.logo-text {
font-size: 36px;
font-weight: 700;
color: white;
letter-spacing: -1px;
}
.hero-title {
font-size: 48px;
font-weight: 700;
line-height: 1.1;
margin-bottom: 20px;
color: white;
}
.hero-subtitle {
font-size: 24px;
color: #cccccc;
margin-bottom: 30px;
font-weight: 400;
}
.feature-list {
list-style: none;
margin-bottom: 40px;
}
.feature-list li {
padding: 8px 0;
font-size: 16px;
color: #b3b3b3;
position: relative;
padding-left: 30px;
}
.feature-list li::before {
content: '✓';
position: absolute;
left: 0;
color: #00D4FF;
font-weight: bold;
font-size: 18px;
}
.login-section {
width: 450px;
background: rgba(0, 0, 0, 0.75);
border-radius: 4px;
padding: 60px 68px 40px;
}
.login-title {
font-size: 32px;
font-weight: 700;
margin-bottom: 28px;
color: white;
}
.form-group {
margin-bottom: 16px;
position: relative;
}
.form-input {
width: 100%;
height: 50px;
background: #333333;
border: none;
border-radius: 4px;
padding: 16px 20px;
font-size: 16px;
color: white;
transition: all 0.2s ease;
}
.form-input::placeholder {
color: #8c8c8c;
}
.form-input:focus {
outline: none;
background: #454545;
}
.form-input.error {
background: #e87c03;
}
.login-button {
width: 100%;
height: 48px;
background: #e50914;
border: none;
border-radius: 4px;
color: white;
font-size: 16px;
font-weight: 700;
cursor: pointer;
margin-top: 24px;
margin-bottom: 12px;
transition: background-color 0.2s ease;
position: relative;
overflow: hidden;
}
.login-button:hover {
background: #f40612;
}
.login-button:active {
background: #d40812;
}
.login-button.loading {
background: #e50914;
cursor: not-allowed;
}
.login-button.loading::after {
content: '';
position: absolute;
width: 20px;
height: 20px;
margin: auto;
border: 2px solid transparent;
border-top-color: white;
border-radius: 50%;
animation: spin 1s ease infinite;
top: 0;
left: 0;
bottom: 0;
right: 0;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.error-message {
background: #e87c03;
color: white;
padding: 10px 16px;
border-radius: 4px;
margin-bottom: 16px;
font-size: 14px;
opacity: 0;
transform: translateY(-10px);
transition: all 0.3s ease;
}
.error-message.show {
opacity: 1;
transform: translateY(0);
}
.success-message {
background: rgba(46, 125, 50, 0.1);
border: 1px solid #4caf50;
color: #81c784;
padding: 16px;
border-radius: 4px;
margin-bottom: 16px;
font-size: 14px;
opacity: 0;
transform: translateY(-10px);
transition: all 0.3s ease;
}
.success-message.show {
opacity: 1;
transform: translateY(0);
}
.help-text {
color: #737373;
font-size: 13px;
margin-top: 16px;
}
.help-text a {
color: white;
text-decoration: none;
}
.help-text a:hover {
text-decoration: underline;
}
.floating-particles {
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
pointer-events: none;
z-index: 1;
}
.particle {
position: absolute;
width: 2px;
height: 2px;
background: rgba(255, 255, 255, 0.3);
border-radius: 50%;
animation: float 8s infinite linear;
}
@keyframes float {
0% {
transform: translateY(100vh) rotate(0deg);
opacity: 0;
}
10% {
opacity: 1;
}
90% {
opacity: 1;
}
100% {
transform: translateY(-10vh) rotate(360deg);
opacity: 0;
}
}
@media (max-width: 950px) {
.content-overlay {
flex-direction: column;
justify-content: center;
padding: 20px;
}
.branding-section {
display: none;
}
.login-section {
width: 100%;
max-width: 450px;
padding: 48px 32px;
}
}
@media (max-width: 480px) {
.login-section {
padding: 40px 24px;
}
.hero-title {
font-size: 36px;
}
}
</style>
</head>
<body>
<div class="hero-background"></div>
<div class="floating-particles" id="particles"></div>
<div class="content-overlay">
<div class="branding-section">
<div class="logo-container">
<div class="logo">
<div class="logo-icon">
<svg viewBox="0 0 24 24">
<path d="M14,3V5H17.59L7.76,14.83L9.17,16.24L19,6.41V10H21V3M19,19H5V5H12V3H5C3.89,3 3,3.9 3,5V19A2,2 0 0,0 5,21H19A2,2 0 0,0 21,19V12H19V19Z"/>
</svg>
</div>
<div class="logo-text">Jellyfin</div>
</div>
</div>
<h1 class="hero-title">Your personal media universe</h1>
<p class="hero-subtitle" id="serverStatusText">Server is offline - will start automatically after login</p>
</div>
<div class="login-section">
<h2 class="login-title">Sign In</h2>
<form id="loginForm">
<div class="error-message" id="errorMessage"></div>
<div class="success-message" id="successMessage"></div>
<div class="form-group">
<input type="text" class="form-input" id="username" placeholder="Username" required>
</div>
<div class="form-group">
<input type="password" class="form-input" id="password" placeholder="Password" required>
</div>
<button type="submit" class="login-button" id="loginButton">
<span id="buttonText">Sign In</span>
</button>
</form>
<div class="help-text">
Secure access to your personal media server
</div>
</div>
</div>
<script>
// Create floating particles
function createParticles() {
const container = document.getElementById('particles');
const particleCount = 30;
for (let i = 0; i < particleCount; i++) {
const particle = document.createElement('div');
particle.className = 'particle';
particle.style.left = Math.random() * 100 + '%';
particle.style.animationDelay = Math.random() * 8 + 's';
particle.style.animationDuration = (Math.random() * 4 + 6) + 's';
container.appendChild(particle);
}
}
// Server status management
let serverStatus = 'offline';
const serverStatusText = document.getElementById('serverStatusText');
function updateServerStatus(status) {
serverStatus = status;
switch(status) {
case 'offline':
serverStatusText.textContent = 'Server is offline - will start automatically after login';
break;
case 'starting':
serverStatusText.textContent = 'Starting media server - please wait...';
break;
case 'running':
serverStatusText.textContent = 'Server is running - redirecting to your media library';
break;
}
}
// Generate device ID based on browser fingerprint
function generateDeviceId() {
const userAgent = navigator.userAgent;
const timestamp = Date.now().toString();
const combined = userAgent + '|' + timestamp;
// Base64 encode the combined string
return btoa(combined);
}
// Get browser/device information
function getBrowserInfo() {
const userAgent = navigator.userAgent;
let browser = 'Unknown';
let device = 'Unknown';
if (userAgent.includes('Chrome')) {
browser = 'Jellyfin Web';
device = 'Chrome';
} else if (userAgent.includes('Firefox')) {
browser = 'Jellyfin Web';
device = 'Firefox';
} else if (userAgent.includes('Safari')) {
browser = 'Jellyfin Web';
device = 'Safari';
} else if (userAgent.includes('Edge')) {
browser = 'Jellyfin Web';
device = 'Edge';
}
return { browser, device };
}
// Form handling
document.getElementById('loginForm').addEventListener('submit', async function(e) {
e.preventDefault();
const username = document.getElementById('username').value;
const password = document.getElementById('password').value;
const loginButton = document.getElementById('loginButton');
const buttonText = document.getElementById('buttonText');
const errorMessage = document.getElementById('errorMessage');
const successMessage = document.getElementById('successMessage');
const usernameInput = document.getElementById('username');
const passwordInput = document.getElementById('password');
// Clear previous states
errorMessage.classList.remove('show');
successMessage.classList.remove('show');
usernameInput.classList.remove('error');
passwordInput.classList.remove('error');
// Validate inputs
if (!username || !password) {
errorMessage.textContent = 'Please enter both username and password.';
errorMessage.classList.add('show');
if (!username) usernameInput.classList.add('error');
if (!password) passwordInput.classList.add('error');
return;
}
// Show loading state
loginButton.classList.add('loading');
buttonText.style.opacity = '0';
try {
// Generate authorization header
const { browser, device } = getBrowserInfo();
const deviceId = generateDeviceId();
const version = "10.10.5";
const authHeader = `MediaBrowser Client="${browser}", Device="${device}", DeviceId="${deviceId}", Version="${version}"`;
// Authenticate with backend
const response = await fetch('/Users/AuthenticateByName', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': authHeader
},
body: JSON.stringify({
Username: username,
Pw: password
})
});
if (!response.ok) {
throw new Error('Authentication failed. Please check your credentials.');
}
const authData = await response.json();
// Store credentials in localStorage in Jellyfin Web Client format
// Use the same device info that was sent in the request
const jellyfinCredentials = {
"Servers": [{
"ManualAddress": "http://localhost:8096",
"manualAddressOnly": true,
"AccessToken": authData.AccessToken,
"UserId": authData.User.Id,
"DeviceId": deviceId,
"Client": browser,
"Device": device,
"Version": version
}]
};
localStorage.setItem('jellyfin_credentials', JSON.stringify(jellyfinCredentials));
// Show success
successMessage.textContent = 'Authentication successful! Initializing media server...';
successMessage.classList.add('show');
// Update server status
updateServerStatus('starting');
// Wait a moment then check if server is ready
setTimeout(() => {
updateServerStatus('running');
successMessage.textContent = 'Welcome to your media library! Redirecting...';
// Redirect to root which should now proxy to Jellyfin
setTimeout(() => {
window.location.href = '/';
}, 1500);
}, 2000);
} catch (error) {
// Show error
errorMessage.textContent = error.message || 'Authentication failed. Please check your credentials.';
errorMessage.classList.add('show');
usernameInput.classList.add('error');
passwordInput.classList.add('error');
// Reset button
loginButton.classList.remove('loading');
buttonText.style.opacity = '1';
}
});
// Check for existing authentication
async function checkExistingAuth() {
const credentials = localStorage.getItem('jellyfin_credentials');
if (credentials) {
try {
const parsedCreds = JSON.parse(credentials);
if (parsedCreds.Servers && parsedCreds.Servers.length > 0) {
const server = parsedCreds.Servers[0];
if (server.AccessToken && server.DeviceId) {
// Show that we're checking authentication
successMessage.textContent = 'Found existing credentials, validating...';
successMessage.classList.add('show');
updateServerStatus('starting');
// Use stored device information
const authHeaderWithToken = `MediaBrowser Client="${server.Client || 'Jellyfin Web'}", Device="${server.Device || 'Chrome'}", DeviceId="${server.DeviceId}", Version="${server.Version || '10.10.5'}", Token="${server.AccessToken}"`;
// Test if the token is still valid by making any authenticated request
// This will trigger server startup if needed and validate the token
const testResponse = await fetch('/', {
headers: {
'Authorization': authHeaderWithToken
}
});
if (testResponse.ok || testResponse.status === 503) {
// Token is valid (or server is starting up), wait for full startup
successMessage.textContent = 'Authentication successful! Starting media server...';
updateServerStatus('starting');
// Wait for server to fully start up and then redirect
setTimeout(async () => {
// Check if we can access the main page now
try {
const finalCheck = await fetch('/', {
headers: {
'Authorization': authHeaderWithToken
}
});
if (finalCheck.ok) {
successMessage.textContent = 'Media server ready! Redirecting...';
updateServerStatus('running');
setTimeout(() => {
window.location.href = '/';
}, 1000);
} else {
// Server might still be starting, try one more time
setTimeout(() => {
window.location.href = '/';
}, 2000);
}
} catch (e) {
// Network error, just try to redirect anyway
setTimeout(() => {
window.location.href = '/';
}, 1000);
}
}, 3000);
} else if (testResponse.status === 401) {
// Token is invalid, clear it and show login form
console.log('Token invalid, clearing credentials');
localStorage.removeItem('jellyfin_credentials');
successMessage.classList.remove('show');
updateServerStatus('offline');
} else {
// Other error, maybe server startup failed, but keep credentials for now
console.log('Server error, status:', testResponse.status);
successMessage.textContent = 'Server error occurred. Please try again.';
successMessage.classList.remove('show');
updateServerStatus('offline');
}
}
}
} catch (e) {
// Invalid credentials, clear them
console.log('Error parsing credentials:', e);
localStorage.removeItem('jellyfin_credentials');
successMessage.classList.remove('show');
updateServerStatus('offline');
}
}
}
// Initialize
createParticles();
updateServerStatus('offline');
// Check if user is already authenticated
checkExistingAuth();
// Add input focus effects
document.querySelectorAll('.form-input').forEach(input => {
input.addEventListener('focus', function() {
this.classList.remove('error');
});
});
</script>
</body>
</html>

390
src/auth.rs Normal file
View File

@@ -0,0 +1,390 @@
use anyhow::{anyhow, Result};
use pbkdf2::{
password_hash::{PasswordHash, PasswordVerifier},
Pbkdf2,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize)]
pub struct AuthRequest {
#[serde(rename = "Username")]
pub username: String,
#[serde(rename = "Pw")]
pub pw: String,
}
#[derive(Debug, Serialize)]
pub struct AuthResponse {
#[serde(rename = "User")]
pub user: AuthUser,
#[serde(rename = "SessionInfo")]
pub session_info: SessionInfo,
#[serde(rename = "AccessToken")]
pub access_token: String,
#[serde(rename = "ServerId")]
pub server_id: String,
}
#[derive(Debug, Serialize)]
pub struct AuthUser {
#[serde(rename = "Name")]
pub name: String,
#[serde(rename = "ServerId")]
pub server_id: String,
#[serde(rename = "Id")]
pub id: String,
#[serde(rename = "HasPassword")]
pub has_password: bool,
#[serde(rename = "HasConfiguredPassword")]
pub has_configured_password: bool,
#[serde(rename = "HasConfiguredEasyPassword")]
pub has_configured_easy_password: bool,
#[serde(rename = "EnableAutoLogin")]
pub enable_auto_login: bool,
#[serde(rename = "LastLoginDate")]
pub last_login_date: Option<String>,
#[serde(rename = "LastActivityDate")]
pub last_activity_date: Option<String>,
#[serde(rename = "Configuration")]
pub configuration: UserConfiguration,
#[serde(rename = "Policy")]
pub policy: UserPolicy,
}
#[derive(Debug, Serialize)]
pub struct UserConfiguration {
#[serde(rename = "PlayDefaultAudioTrack")]
pub play_default_audio_track: bool,
#[serde(rename = "SubtitleLanguagePreference")]
pub subtitle_language_preference: String,
#[serde(rename = "DisplayMissingEpisodes")]
pub display_missing_episodes: bool,
#[serde(rename = "GroupedFolders")]
pub grouped_folders: Vec<String>,
#[serde(rename = "SubtitleMode")]
pub subtitle_mode: String,
#[serde(rename = "DisplayCollectionsView")]
pub display_collections_view: bool,
#[serde(rename = "EnableLocalPassword")]
pub enable_local_password: bool,
#[serde(rename = "OrderedViews")]
pub ordered_views: Vec<String>,
#[serde(rename = "LatestItemsExcludes")]
pub latest_items_excludes: Vec<String>,
#[serde(rename = "MyMediaExcludes")]
pub my_media_excludes: Vec<String>,
#[serde(rename = "HidePlayedInLatest")]
pub hide_played_in_latest: bool,
#[serde(rename = "RememberAudioSelections")]
pub remember_audio_selections: bool,
#[serde(rename = "RememberSubtitleSelections")]
pub remember_subtitle_selections: bool,
#[serde(rename = "EnableNextEpisodeAutoPlay")]
pub enable_next_episode_auto_play: bool,
#[serde(rename = "CastReceiverId")]
pub cast_receiver_id: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct UserPolicy {
#[serde(rename = "IsAdministrator")]
pub is_administrator: bool,
#[serde(rename = "IsHidden")]
pub is_hidden: bool,
#[serde(rename = "EnableCollectionManagement")]
pub enable_collection_management: bool,
#[serde(rename = "EnableSubtitleManagement")]
pub enable_subtitle_management: bool,
#[serde(rename = "EnableLyricManagement")]
pub enable_lyric_management: bool,
#[serde(rename = "IsDisabled")]
pub is_disabled: bool,
#[serde(rename = "MaxParentalRating")]
pub max_parental_rating: Option<i32>,
#[serde(rename = "BlockedTags")]
pub blocked_tags: Vec<String>,
#[serde(rename = "AllowedTags")]
pub allowed_tags: Vec<String>,
#[serde(rename = "EnableUserPreferenceAccess")]
pub enable_user_preference_access: bool,
#[serde(rename = "AccessSchedules")]
pub access_schedules: Vec<String>,
#[serde(rename = "BlockUnratedItems")]
pub block_unrated_items: Vec<String>,
#[serde(rename = "EnableRemoteControlOfOtherUsers")]
pub enable_remote_control_of_other_users: bool,
#[serde(rename = "EnableSharedDeviceControl")]
pub enable_shared_device_control: bool,
#[serde(rename = "EnableRemoteAccess")]
pub enable_remote_access: bool,
#[serde(rename = "EnableLiveTvManagement")]
pub enable_live_tv_management: bool,
#[serde(rename = "EnableLiveTvAccess")]
pub enable_live_tv_access: bool,
#[serde(rename = "EnableMediaPlayback")]
pub enable_media_playback: bool,
#[serde(rename = "EnableAudioPlaybackTranscoding")]
pub enable_audio_playback_transcoding: bool,
#[serde(rename = "EnableVideoPlaybackTranscoding")]
pub enable_video_playback_transcoding: bool,
#[serde(rename = "EnablePlaybackRemuxing")]
pub enable_playback_remuxing: bool,
#[serde(rename = "ForceRemoteSourceTranscoding")]
pub force_remote_source_transcoding: bool,
#[serde(rename = "EnableContentDeletion")]
pub enable_content_deletion: bool,
#[serde(rename = "EnableContentDeletionFromFolders")]
pub enable_content_deletion_from_folders: Vec<String>,
#[serde(rename = "EnableContentDownloading")]
pub enable_content_downloading: bool,
#[serde(rename = "EnableSyncTranscoding")]
pub enable_sync_transcoding: bool,
#[serde(rename = "EnableMediaConversion")]
pub enable_media_conversion: bool,
#[serde(rename = "EnabledDevices")]
pub enabled_devices: Vec<String>,
#[serde(rename = "EnableAllDevices")]
pub enable_all_devices: bool,
#[serde(rename = "EnabledChannels")]
pub enabled_channels: Vec<String>,
#[serde(rename = "EnableAllChannels")]
pub enable_all_channels: bool,
#[serde(rename = "EnabledFolders")]
pub enabled_folders: Vec<String>,
#[serde(rename = "EnableAllFolders")]
pub enable_all_folders: bool,
#[serde(rename = "InvalidLoginAttemptCount")]
pub invalid_login_attempt_count: i32,
#[serde(rename = "LoginAttemptsBeforeLockout")]
pub login_attempts_before_lockout: i32,
#[serde(rename = "MaxActiveSessions")]
pub max_active_sessions: i32,
#[serde(rename = "EnablePublicSharing")]
pub enable_public_sharing: bool,
#[serde(rename = "BlockedMediaFolders")]
pub blocked_media_folders: Vec<String>,
#[serde(rename = "BlockedChannels")]
pub blocked_channels: Vec<String>,
#[serde(rename = "RemoteClientBitrateLimit")]
pub remote_client_bitrate_limit: i32,
#[serde(rename = "AuthenticationProviderId")]
pub authentication_provider_id: String,
#[serde(rename = "PasswordResetProviderId")]
pub password_reset_provider_id: String,
#[serde(rename = "SyncPlayAccess")]
pub sync_play_access: String,
}
#[derive(Debug, Serialize)]
pub struct SessionInfo {
#[serde(rename = "PlayState")]
pub play_state: PlayState,
#[serde(rename = "AdditionalUsers")]
pub additional_users: Vec<String>,
#[serde(rename = "Capabilities")]
pub capabilities: Capabilities,
#[serde(rename = "RemoteEndPoint")]
pub remote_end_point: String,
#[serde(rename = "Id")]
pub id: String,
#[serde(rename = "UserId")]
pub user_id: String,
#[serde(rename = "UserName")]
pub user_name: String,
#[serde(rename = "Client")]
pub client: String,
#[serde(rename = "LastActivityDate")]
pub last_activity_date: String,
#[serde(rename = "LastPlaybackCheckIn")]
pub last_playback_check_in: String,
#[serde(rename = "DeviceName")]
pub device_name: String,
#[serde(rename = "DeviceType")]
pub device_type: String,
#[serde(rename = "NowPlayingItem")]
pub now_playing_item: Option<String>,
#[serde(rename = "DeviceId")]
pub device_id: String,
#[serde(rename = "ApplicationVersion")]
pub application_version: String,
#[serde(rename = "IsActive")]
pub is_active: bool,
#[serde(rename = "SupportsMediaControl")]
pub supports_media_control: bool,
#[serde(rename = "SupportsRemoteControl")]
pub supports_remote_control: bool,
#[serde(rename = "HasCustomDeviceName")]
pub has_custom_device_name: bool,
#[serde(rename = "ServerId")]
pub server_id: String,
#[serde(rename = "SupportedCommands")]
pub supported_commands: Vec<String>,
}
#[derive(Debug, Serialize)]
pub struct PlayState {
#[serde(rename = "CanSeek")]
pub can_seek: bool,
#[serde(rename = "IsPaused")]
pub is_paused: bool,
#[serde(rename = "IsMuted")]
pub is_muted: bool,
#[serde(rename = "RepeatMode")]
pub repeat_mode: String,
#[serde(rename = "ShuffleMode")]
pub shuffle_mode: String,
}
#[derive(Debug, Serialize)]
pub struct Capabilities {
#[serde(rename = "PlayableMediaTypes")]
pub playable_media_types: Vec<String>,
#[serde(rename = "SupportedCommands")]
pub supported_commands: Vec<String>,
#[serde(rename = "SupportsMediaControl")]
pub supports_media_control: bool,
#[serde(rename = "SupportsContentUploading")]
pub supports_content_uploading: bool,
#[serde(rename = "SupportsPersistentIdentifier")]
pub supports_persistent_identifier: bool,
#[serde(rename = "SupportsSync")]
pub supports_sync: bool,
}
pub fn verify_password(password: &str, stored_hash: &str) -> Result<bool> {
// Handle PBKDF2-SHA512 format: $PBKDF2-SHA512$iterations=210000$salt$hash
if stored_hash.starts_with("$PBKDF2-SHA512$") {
let parts: Vec<&str> = stored_hash.split('$').collect();
if parts.len() != 5 {
return Err(anyhow!("Invalid password hash format"));
}
let iterations_part = parts[2];
let salt_part = parts[3];
let hash_part = parts[4];
let iterations: u32 = iterations_part
.strip_prefix("iterations=")
.ok_or_else(|| anyhow!("Invalid iterations format"))?
.parse()?;
let salt = hex::decode(salt_part)?;
let expected_hash = hex::decode(hash_part)?;
let mut result = vec![0u8; expected_hash.len()];
pbkdf2::pbkdf2_hmac::<sha2::Sha512>(password.as_bytes(), &salt, iterations, &mut result);
Ok(result == expected_hash)
} else {
// Fallback for other hash formats
match PasswordHash::new(stored_hash) {
Ok(parsed_hash) => Ok(Pbkdf2.verify_password(password.as_bytes(), &parsed_hash).is_ok()),
Err(_) => Ok(false),
}
}
}
pub fn parse_authorization_header(auth_header: &str) -> Option<(String, String, String, String)> {
// Parse MediaBrowser authorization header
// Format: MediaBrowser Client="...", Version="...", DeviceId="...", Device="...", Token="..."
if !auth_header.starts_with("MediaBrowser ") {
return None;
}
let params_part = &auth_header[12..]; // Remove "MediaBrowser "
let mut client = String::new();
let mut version = String::new();
let mut device_id = String::new();
let mut device = String::new();
for param in params_part.split(", ") {
if let Some((key, value)) = param.split_once('=') {
let value = value.trim_matches('"');
match key {
"Client" => client = value.replace('+', " "),
"Version" => version = value.to_string(),
"DeviceId" => device_id = value.to_string(),
"Device" => device = value.replace('+', " "),
_ => {}
}
}
}
if !client.is_empty() && !version.is_empty() && !device_id.is_empty() && !device.is_empty() {
Some((client, version, device_id, device))
} else {
None
}
}
impl Default for UserConfiguration {
fn default() -> Self {
Self {
play_default_audio_track: true,
subtitle_language_preference: String::new(),
display_missing_episodes: false,
grouped_folders: Vec::new(),
subtitle_mode: "Default".to_string(),
display_collections_view: false,
enable_local_password: false,
ordered_views: Vec::new(),
latest_items_excludes: Vec::new(),
my_media_excludes: Vec::new(),
hide_played_in_latest: true,
remember_audio_selections: true,
remember_subtitle_selections: true,
enable_next_episode_auto_play: true,
cast_receiver_id: None,
}
}
}
impl Default for UserPolicy {
fn default() -> Self {
Self {
is_administrator: true,
is_hidden: false,
enable_collection_management: true,
enable_subtitle_management: true,
enable_lyric_management: true,
is_disabled: false,
max_parental_rating: None,
blocked_tags: Vec::new(),
allowed_tags: Vec::new(),
enable_user_preference_access: true,
access_schedules: Vec::new(),
block_unrated_items: Vec::new(),
enable_remote_control_of_other_users: true,
enable_shared_device_control: true,
enable_remote_access: true,
enable_live_tv_management: true,
enable_live_tv_access: true,
enable_media_playback: true,
enable_audio_playback_transcoding: true,
enable_video_playback_transcoding: true,
enable_playback_remuxing: true,
force_remote_source_transcoding: false,
enable_content_deletion: true,
enable_content_deletion_from_folders: Vec::new(),
enable_content_downloading: true,
enable_sync_transcoding: true,
enable_media_conversion: true,
enabled_devices: Vec::new(),
enable_all_devices: true,
enabled_channels: Vec::new(),
enable_all_channels: true,
enabled_folders: Vec::new(),
enable_all_folders: true,
invalid_login_attempt_count: 0,
login_attempts_before_lockout: -1,
max_active_sessions: 0,
enable_public_sharing: true,
blocked_media_folders: Vec::new(),
blocked_channels: Vec::new(),
remote_client_bitrate_limit: 0,
authentication_provider_id: "Jellyfin.Server.Implementations.Users.DefaultAuthenticationProvider".to_string(),
password_reset_provider_id: "Jellyfin.Server.Implementations.Users.DefaultPasswordResetProvider".to_string(),
sync_play_access: "CreateAndJoinGroups".to_string(),
}
}
}

45
src/config.rs Normal file
View File

@@ -0,0 +1,45 @@
use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::env;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub sftp_host: String,
pub sftp_port: u16,
pub sftp_user: String,
pub sftp_password: String,
pub sftp_path: String,
pub jellyfin_url: String,
pub jellyfin_api_key: String,
pub jellyfin_power_on_command: String,
pub jellyfin_hibernate_command: String,
}
impl Config {
pub fn load() -> Result<Self> {
dotenv::dotenv().ok();
Ok(Config {
sftp_host: env::var("SFTP_HOST")
.map_err(|_| anyhow!("SFTP_HOST environment variable not set"))?,
sftp_port: env::var("SFTP_PORT")
.unwrap_or_else(|_| "22".to_string())
.parse()
.map_err(|_| anyhow!("Invalid SFTP_PORT"))?,
sftp_user: env::var("SFTP_USER")
.map_err(|_| anyhow!("SFTP_USER environment variable not set"))?,
sftp_password: env::var("SFTP_PASSWORD")
.map_err(|_| anyhow!("SFTP_PASSWORD environment variable not set"))?,
sftp_path: env::var("SFTP_PATH")
.unwrap_or_else(|_| "/var/lib/jellyfin/data/jellyfin.db".to_string()),
jellyfin_url: env::var("JELLYFIN_URL")
.map_err(|_| anyhow!("JELLYFIN_URL environment variable not set"))?,
jellyfin_api_key: env::var("JELLYFIN_API_KEY")
.map_err(|_| anyhow!("JELLYFIN_API_KEY environment variable not set"))?,
jellyfin_power_on_command: env::var("JELLYFIN_POWER_ON_COMMAND")
.map_err(|_| anyhow!("JELLYFIN_POWER_ON_COMMAND environment variable not set"))?,
jellyfin_hibernate_command: env::var("JELLYFIN_HIBERNATE_COMMAND")
.map_err(|_| anyhow!("JELLYFIN_HIBERNATE_COMMAND environment variable not set"))?,
})
}
}

113
src/database.rs Normal file
View File

@@ -0,0 +1,113 @@
use anyhow::Result;
use rusqlite::Connection;
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub id: String,
pub username: String,
pub password: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Device {
pub id: i64,
pub user_id: String,
pub access_token: String,
pub app_name: String,
pub app_version: String,
pub device_name: String,
pub device_id: String,
pub is_active: bool,
pub date_created: String,
pub date_modified: String,
pub date_last_activity: String,
}
pub struct Database {
conn: Connection,
}
impl Database {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
let conn = Connection::open(path)?;
Ok(Database { conn })
}
pub fn get_user_by_username(&self, username: &str) -> Result<Option<User>> {
let mut stmt = self.conn.prepare(
"SELECT Id, Username, Password FROM Users WHERE Username = ?1 COLLATE NOCASE"
)?;
let user_result = stmt.query_row([username], |row| {
Ok(User {
id: row.get(0)?,
username: row.get(1)?,
password: row.get(2)?,
})
});
match user_result {
Ok(user) => Ok(Some(user)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e.into()),
}
}
pub fn get_device_by_access_token(&self, access_token: &str) -> Result<Option<Device>> {
let mut stmt = self.conn.prepare(
"SELECT Id, UserId, AccessToken, AppName, AppVersion, DeviceName, DeviceId, IsActive, DateCreated, DateModified, DateLastActivity FROM Devices WHERE AccessToken = ?1"
)?;
let device_result = stmt.query_row([access_token], |row| {
Ok(Device {
id: row.get(0)?,
user_id: row.get(1)?,
access_token: row.get(2)?,
app_name: row.get(3)?,
app_version: row.get(4)?,
device_name: row.get(5)?,
device_id: row.get(6)?,
is_active: row.get::<_, i64>(7)? != 0,
date_created: row.get(8)?,
date_modified: row.get(9)?,
date_last_activity: row.get(10)?,
})
});
match device_result {
Ok(device) => Ok(Some(device)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e.into()),
}
}
pub fn get_device_by_device_id(&self, device_id: &str) -> Result<Option<Device>> {
let mut stmt = self.conn.prepare(
"SELECT Id, UserId, AccessToken, AppName, AppVersion, DeviceName, DeviceId, IsActive, DateCreated, DateModified, DateLastActivity FROM Devices WHERE DeviceId = ?1"
)?;
let device_result = stmt.query_row([device_id], |row| {
Ok(Device {
id: row.get(0)?,
user_id: row.get(1)?,
access_token: row.get(2)?,
app_name: row.get(3)?,
app_version: row.get(4)?,
device_name: row.get(5)?,
device_id: row.get(6)?,
is_active: row.get::<_, i64>(7)? != 0,
date_created: row.get(8)?,
date_modified: row.get(9)?,
date_last_activity: row.get(10)?,
})
});
match device_result {
Ok(device) => Ok(Some(device)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e.into()),
}
}
}

106
src/jellyfin.rs Normal file
View File

@@ -0,0 +1,106 @@
use anyhow::Result;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::process::Command;
use tracing::{error, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SystemInfo {
#[serde(rename = "LocalAddress")]
pub local_address: String,
#[serde(rename = "ServerName")]
pub server_name: String,
#[serde(rename = "Version")]
pub version: String,
#[serde(rename = "ProductName")]
pub product_name: String,
#[serde(rename = "OperatingSystem")]
pub operating_system: String,
#[serde(rename = "Id")]
pub id: String,
#[serde(rename = "StartupWizardCompleted")]
pub startup_wizard_completed: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BrandingConfig {
#[serde(rename = "LoginDisclaimer")]
pub login_disclaimer: String,
#[serde(rename = "CustomCss")]
pub custom_css: String,
#[serde(rename = "SplashscreenEnabled")]
pub splashscreen_enabled: bool,
}
pub struct JellyfinClient {
client: Client,
base_url: String,
}
impl JellyfinClient {
pub fn new(base_url: String, _api_key: String) -> Self {
Self {
client: Client::new(),
base_url,
}
}
pub fn get_base_url(&self) -> &str {
&self.base_url
}
pub async fn is_online(&self) -> bool {
let url = format!("{}/System/Info/Public", self.base_url);
match self.client.get(&url).send().await {
Ok(response) => {
let is_ok = response.status().is_success();
if is_ok {
info!("Jellyfin server is online");
} else {
warn!("Jellyfin server responded with status: {}", response.status());
}
is_ok
}
Err(e) => {
warn!("Failed to connect to Jellyfin server: {}", e);
false
}
}
}
pub async fn get_system_info(&self) -> Result<SystemInfo> {
let url = format!("{}/System/Info/Public", self.base_url);
let response = self.client.get(&url).send().await?;
let system_info: SystemInfo = response.json().await?;
Ok(system_info)
}
pub async fn get_branding_config(&self) -> Result<BrandingConfig> {
let url = format!("{}/Branding/Configuration", self.base_url);
let response = self.client.get(&url).send().await?;
let branding: BrandingConfig = response.json().await?;
Ok(branding)
}
}
pub async fn power_on_server(command: &str) -> Result<()> {
info!("Powering on server with command: {}", command);
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.is_empty() {
return Err(anyhow::anyhow!("Empty power on command"));
}
let output = Command::new(parts[0])
.args(&parts[1..])
.output()?;
if output.status.success() {
info!("Power on command executed successfully");
Ok(())
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
error!("Power on command failed: {}", stderr);
Err(anyhow::anyhow!("Power on command failed: {}", stderr))
}
}

View File

@@ -1,3 +1,62 @@
fn main() {
println!("Hello, world!");
mod auth;
mod config;
mod database;
mod jellyfin;
mod proxy;
mod server;
mod sftp;
mod websocket;
use anyhow::Result;
use config::Config;
use server::create_app;
use std::sync::Arc;
use tokio::time::{interval, Duration};
use tracing::{error, info};
#[tokio::main]
async fn main() -> Result<()> {
// Initialize tracing
tracing_subscriber::fmt::init();
// Load configuration
let config = Config::load()?;
info!("Loaded configuration");
// Initialize shared state
let app_state = server::AppState::new(config).await?;
info!("Initialized application state");
// Start background tasks
let state_clone = app_state.clone();
tokio::spawn(async move {
background_tasks(state_clone).await;
});
// Create and start the server
let app = create_app(app_state);
let listener = tokio::net::TcpListener::bind("0.0.0.0:8096").await?;
info!("Starting server on port 8096");
axum::serve(listener, app).await?;
Ok(())
}
async fn background_tasks(state: Arc<server::AppState>) {
let mut interval = interval(Duration::from_secs(15));
loop {
interval.tick().await;
// Check if Jellyfin server is online
if let Err(e) = state.update_jellyfin_status().await {
error!("Failed to update Jellyfin status: {}", e);
}
// Check for database updates
if let Err(e) = state.check_database_updates().await {
error!("Failed to check database updates: {}", e);
}
}
}

172
src/proxy.rs Normal file
View File

@@ -0,0 +1,172 @@
use crate::jellyfin::JellyfinClient;
use anyhow::Result;
use axum::{
body::Body,
http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode},
response::Response,
};
use tokio::time::{sleep, Duration};
use tracing::{debug, error, info, warn};
pub async fn proxy_to_jellyfin_with_retry<F>(
method: Method,
path: &str,
query: Option<String>,
headers: HeaderMap,
body: Vec<u8>,
jellyfin_client: &JellyfinClient,
status_updater: F,
) -> Result<Response<Body>, StatusCode>
where
F: Fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>> + Send + Sync,
{
let max_retries = 60; // 5 minutes with 5-second intervals
let retry_interval = Duration::from_secs(5);
for attempt in 0..max_retries {
match proxy_to_jellyfin_once(method.clone(), path, query.clone(), headers.clone(), body.clone(), jellyfin_client).await {
Ok(response) => return Ok(response),
Err(StatusCode::BAD_GATEWAY) => {
// Check if this is a connection error - update server status
let is_online = status_updater().await;
if !is_online {
if attempt == 0 {
info!("Jellyfin server is offline, waiting for it to come back online...");
}
if attempt < max_retries - 1 {
debug!("Attempt {}/{} - server still offline, retrying in {} seconds",
attempt + 1, max_retries, retry_interval.as_secs());
sleep(retry_interval).await;
continue;
} else {
error!("Server failed to come online after {} attempts", max_retries);
return Err(StatusCode::SERVICE_UNAVAILABLE);
}
} else {
// Server is reported as online but request failed, try once more
warn!("Server reported as online but proxy failed, retrying once more...");
sleep(Duration::from_secs(1)).await;
return proxy_to_jellyfin_once(method, path, query, headers, body, jellyfin_client).await;
}
}
Err(other_error) => {
// For other errors, don't retry
return Err(other_error);
}
}
}
Err(StatusCode::SERVICE_UNAVAILABLE)
}
async fn proxy_to_jellyfin_once(
method: Method,
path: &str,
query: Option<String>,
headers: HeaderMap,
body: Vec<u8>,
jellyfin_client: &JellyfinClient,
) -> Result<Response<Body>, StatusCode> {
let query_str = query.as_deref();
debug!("Proxying {} {} to Jellyfin", method.as_str(), path);
// Convert axum headers to reqwest headers
let mut reqwest_headers = reqwest::header::HeaderMap::new();
for (name, value) in &headers {
// Skip certain headers that should be handled by the proxy
if name.as_str().to_lowercase() == "host" {
continue;
}
if let Ok(header_name) = reqwest::header::HeaderName::from_bytes(name.as_str().as_bytes()) {
if let Ok(header_value) = reqwest::header::HeaderValue::from_bytes(value.as_bytes()) {
reqwest_headers.insert(header_name, header_value);
}
}
}
// Create a client that doesn't follow redirects automatically
let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let mut url = format!("{}{}", jellyfin_client.get_base_url(), path);
if let Some(q) = query_str {
url = format!("{}?{}", url, q);
}
let mut request = match method.as_str() {
"GET" => client.get(&url),
"POST" => client.post(&url),
"PUT" => client.put(&url),
"DELETE" => client.delete(&url),
"PATCH" => client.patch(&url),
"HEAD" => client.head(&url),
_ => return Err(StatusCode::METHOD_NOT_ALLOWED),
};
// Copy headers
for (name, value) in &reqwest_headers {
request = request.header(name, value);
}
// Add body if present
if !body.is_empty() {
request = request.body(body);
}
match request.send().await {
Ok(response) => {
let status_code = response.status().as_u16();
let mut response_headers = HeaderMap::new();
// Copy response headers
for (name, value) in response.headers() {
// Skip certain headers that might cause issues
let header_name_str = name.as_str().to_lowercase();
if header_name_str == "transfer-encoding" ||
header_name_str == "connection" ||
header_name_str == "upgrade" {
continue;
}
if let Ok(header_name) = HeaderName::try_from(name.as_str()) {
if let Ok(header_value) = HeaderValue::from_bytes(value.as_bytes()) {
response_headers.insert(header_name, header_value);
}
}
}
let body_bytes = match response.bytes().await {
Ok(bytes) => bytes,
Err(e) => {
error!("Failed to read response body: {}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
};
let mut response_builder = Response::builder()
.status(StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR));
// Set headers
if let Some(headers_mut) = response_builder.headers_mut() {
*headers_mut = response_headers;
}
match response_builder.body(Body::from(body_bytes)) {
Ok(response) => Ok(response),
Err(e) => {
error!("Failed to build response: {}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
}
Err(e) => {
error!("Failed to proxy request to Jellyfin: {}", e);
Err(StatusCode::BAD_GATEWAY)
}
}
}

898
src/server.rs Normal file
View File

@@ -0,0 +1,898 @@
use crate::{
auth::{parse_authorization_header, verify_password, AuthRequest},
config::Config,
database::{Database, Device},
jellyfin::{power_on_server, BrandingConfig, JellyfinClient, SystemInfo},
proxy::proxy_to_jellyfin_with_retry,
sftp::{calculate_local_file_hash, SftpClient},
websocket::proxy_websocket,
};
use anyhow::{anyhow, Result};
use axum::{
body::Body,
extract::{ws::WebSocketUpgrade, State},
http::{HeaderMap, Method, StatusCode, Uri},
response::{IntoResponse, Response},
routing::any,
Router,
};
use std::{
collections::HashMap,
path::Path as StdPath,
sync::{Arc, RwLock},
time::{Duration, Instant},
};
use tokio::time::sleep;
use tracing::{debug, error, info, warn};
const LOCAL_DB_PATH: &str = "./jellyfin.db";
const SYSTEM_INFO_PATH: &str = "./system_info.json";
// Embedded login HTML content at build time
const LOGIN_HTML: &str = include_str!("../login.html");
pub struct AppState {
config: Config,
jellyfin_client: JellyfinClient,
sftp_client: SftpClient,
cached_system_info: RwLock<Option<SystemInfo>>,
is_jellyfin_online: RwLock<bool>,
last_db_hash: RwLock<Option<String>>,
last_activity: RwLock<Option<Instant>>,
is_powering_on: RwLock<bool>,
power_on_start_time: RwLock<Option<Instant>>,
}
impl AppState {
pub async fn new(config: Config) -> Result<Arc<Self>> {
let jellyfin_client = JellyfinClient::new(config.jellyfin_url.clone(), config.jellyfin_api_key.clone());
let sftp_client = SftpClient::new(
config.sftp_host.clone(),
config.sftp_port,
config.sftp_user.clone(),
config.sftp_password.clone(),
);
// Try to load cached system info
let cached_system_info = if StdPath::new(SYSTEM_INFO_PATH).exists() {
match std::fs::read_to_string(SYSTEM_INFO_PATH) {
Ok(content) => serde_json::from_str(&content).ok(),
Err(_) => None,
}
} else {
None
};
// Initial database download
if let Err(e) = sftp_client.download_file(&config.sftp_path, LOCAL_DB_PATH).await {
warn!("Failed to download initial database: {}", e);
}
let app_state = Self {
config,
jellyfin_client,
sftp_client,
cached_system_info: RwLock::new(cached_system_info),
is_jellyfin_online: RwLock::new(false),
last_db_hash: RwLock::new(None),
last_activity: RwLock::new(None),
is_powering_on: RwLock::new(false),
power_on_start_time: RwLock::new(None),
};
// Initial status check
app_state.update_jellyfin_status().await?;
// Start background database update checker
let state_clone = Arc::new(app_state);
let checker_state = state_clone.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30)); // Check every 30 seconds
loop {
interval.tick().await;
if let Err(e) = checker_state.check_database_updates().await {
warn!("Failed to check database updates: {}", e);
}
}
});
Ok(state_clone)
}
pub async fn update_jellyfin_status(&self) -> Result<()> {
let is_online = self.jellyfin_client.is_online().await;
*self.is_jellyfin_online.write().unwrap() = is_online;
if is_online {
// Update system info cache
if let Ok(system_info) = self.jellyfin_client.get_system_info().await {
*self.cached_system_info.write().unwrap() = Some(system_info.clone());
// Save to file
if let Ok(json_str) = serde_json::to_string_pretty(&system_info) {
let _ = std::fs::write(SYSTEM_INFO_PATH, json_str);
}
}
}
Ok(())
}
pub async fn check_database_updates(&self) -> Result<()> {
match self.sftp_client.get_file_hash(&self.config.sftp_path).await {
Ok(remote_hash) => {
let current_hash = if StdPath::new(LOCAL_DB_PATH).exists() {
calculate_local_file_hash(LOCAL_DB_PATH).ok()
} else {
None
};
let last_hash = self.last_db_hash.read().unwrap().clone();
if Some(&remote_hash) != last_hash.as_ref() || current_hash.as_ref() != Some(&remote_hash) {
info!("Database hash changed, downloading new version");
if let Err(e) = self.sftp_client.download_file(&self.config.sftp_path, LOCAL_DB_PATH).await {
error!("Failed to download updated database: {}", e);
} else {
*self.last_db_hash.write().unwrap() = Some(remote_hash);
info!("Database updated successfully");
}
}
}
Err(e) => {
warn!("Failed to check remote database hash: {}", e);
}
}
Ok(())
}
pub fn is_online(&self) -> bool {
*self.is_jellyfin_online.read().unwrap()
}
pub fn is_powering_on(&self) -> bool {
*self.is_powering_on.read().unwrap()
}
pub fn is_power_on_timeout(&self) -> bool {
if let Some(start_time) = *self.power_on_start_time.read().unwrap() {
start_time.elapsed() > Duration::from_secs(300) // 5 minutes
} else {
false
}
}
pub async fn update_and_check_status(&self) -> bool {
let is_online = self.jellyfin_client.is_online().await;
*self.is_jellyfin_online.write().unwrap() = is_online;
if is_online {
// Reset power-on state if server came online
*self.is_powering_on.write().unwrap() = false;
*self.power_on_start_time.write().unwrap() = None;
}
is_online
}
pub fn update_activity(&self) {
*self.last_activity.write().unwrap() = Some(Instant::now());
}
// Check for database updates on every authentication-related operation
async fn get_database_with_update_check(&self) -> Result<Database> {
// Check for updates first
if let Err(e) = self.check_database_updates().await {
warn!("Failed to check database updates during access: {}", e);
}
if !StdPath::new(LOCAL_DB_PATH).exists() {
return Err(anyhow!("Local database not found"));
}
Database::new(LOCAL_DB_PATH)
}
async fn authenticate_user(&self, username: &str, password: &str) -> Result<Option<crate::database::User>> {
let db = self.get_database_with_update_check().await?;
if let Some(user) = db.get_user_by_username(username)? {
if verify_password(password, &user.password)? {
return Ok(Some(user));
}
}
Ok(None)
}
async fn validate_token(&self, token: &str) -> Result<Option<Device>> {
let db = self.get_database_with_update_check().await?;
db.get_device_by_access_token(token)
}
async fn validate_device_id(&self, device_id: &str) -> Result<Option<Device>> {
let db = self.get_database_with_update_check().await?;
db.get_device_by_device_id(device_id)
}
}
pub fn create_app(state: Arc<AppState>) -> Router {
Router::new()
.route("/", any(handle_root_request))
.route("/web", any(handle_web_request))
.route("/web/", any(handle_web_request))
.route("/web/*path", any(handle_web_request))
.route("/Users/AuthenticateByName", any(handle_auth_request))
.route("/System/Info/Public", any(handle_system_info_request))
.route("/Branding/Configuration", any(handle_branding_request))
.fallback(handle_fallback_request)
.with_state(state)
}
async fn handle_root_request(
State(state): State<Arc<AppState>>,
method: Method,
uri: Uri,
headers: HeaderMap,
body: Body,
) -> Result<Response<Body>, StatusCode> {
state.update_activity();
// Check authentication first
let is_authenticated = check_authentication(&state, &headers, &uri.query().map(|q| q.to_string())).await;
// If server is offline but user is authenticated, power it on and wait
if !state.is_online() && is_authenticated {
ensure_server_online_for_authenticated_request(&state).await?;
}
// If server is online (either was online or just came online), proxy the request to Jellyfin
if state.is_online() {
let path = uri.path();
let query = uri.query().map(|q| q.to_string());
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes.to_vec(),
Err(e) => {
error!("Failed to read request body: {}", e);
return Err(StatusCode::BAD_REQUEST);
}
};
return proxy_to_jellyfin_with_retry(
method,
path,
query,
headers,
body_bytes,
&state.jellyfin_client,
{
let state_clone = state.clone();
move || {
let state_clone = state_clone.clone();
Box::pin(async move {
state_clone.update_and_check_status().await
})
}
}
).await;
}
// If we reach here, server is offline and user is not authenticated
// Serve login page only for GET requests
if method == Method::GET {
Ok(Response::builder()
.status(StatusCode::OK)
.header("content-type", "text/html")
.body(Body::from(LOGIN_HTML))
.unwrap())
} else {
Err(StatusCode::UNAUTHORIZED)
}
}
async fn handle_web_request(
State(state): State<Arc<AppState>>,
method: Method,
uri: Uri,
headers: HeaderMap,
body: Body,
) -> Result<Response<Body>, StatusCode> {
state.update_activity();
// Check authentication first
let is_authenticated = check_authentication(&state, &headers, &uri.query().map(|q| q.to_string())).await;
// If server is offline but user is authenticated, power it on and wait
if !state.is_online() && is_authenticated {
ensure_server_online_for_authenticated_request(&state).await?;
}
// If server is online (either was online or just came online), proxy the request to Jellyfin
if state.is_online() {
let path = uri.path();
let query = uri.query().map(|q| q.to_string());
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes.to_vec(),
Err(e) => {
error!("Failed to read request body: {}", e);
return Err(StatusCode::BAD_REQUEST);
}
};
return proxy_to_jellyfin_with_retry(
method,
path,
query,
headers,
body_bytes,
&state.jellyfin_client,
{
let state_clone = state.clone();
move || {
let state_clone = state_clone.clone();
Box::pin(async move {
state_clone.update_and_check_status().await
})
}
}
).await;
}
// If we reach here, server is offline and user is not authenticated
// Serve login page only for GET requests
if method == Method::GET {
Ok(Response::builder()
.status(StatusCode::OK)
.header("content-type", "text/html")
.body(Body::from(LOGIN_HTML))
.unwrap())
} else {
Err(StatusCode::UNAUTHORIZED)
}
}
async fn handle_auth_request(
State(state): State<Arc<AppState>>,
method: Method,
uri: Uri,
headers: HeaderMap,
body: Body,
) -> Result<Response<Body>, StatusCode> {
state.update_activity();
// Check authentication first (for existing session tokens)
let is_authenticated = check_authentication(&state, &headers, &uri.query().map(|q| q.to_string())).await;
// If server is offline but user is authenticated, power it on and wait
if !state.is_online() && is_authenticated {
ensure_server_online_for_authenticated_request(&state).await?;
}
// If server is online, proxy the request to Jellyfin
if state.is_online() {
let path = uri.path();
let query = uri.query().map(|q| q.to_string());
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes.to_vec(),
Err(e) => {
error!("Failed to read request body: {}", e);
return Err(StatusCode::BAD_REQUEST);
}
};
return proxy_to_jellyfin_with_retry(
method,
path,
query,
headers,
body_bytes,
&state.jellyfin_client,
{
let state_clone = state.clone();
move || {
let state_clone = state_clone.clone();
Box::pin(async move {
state_clone.update_and_check_status().await
})
}
}
).await;
}
// If server is offline, handle authentication locally
if method == Method::POST {
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes.to_vec(),
Err(e) => {
error!("Failed to read request body: {}", e);
return Err(StatusCode::BAD_REQUEST);
}
};
let auth_request: AuthRequest = match serde_json::from_slice(&body_bytes) {
Ok(req) => req,
Err(e) => {
error!("Failed to parse auth request: {}", e);
return Err(StatusCode::BAD_REQUEST);
}
};
// Validate credentials locally first - don't start server if invalid
match state.authenticate_user(&auth_request.username, &auth_request.pw).await {
Ok(Some(_user)) => {
info!("Credentials validated locally, starting server and proxying to Jellyfin");
// User is valid, power on server and wait for it to come online
ensure_server_online_for_authenticated_request(&state).await?;
// Once server is online, proxy the original request to Jellyfin for real auth response
if state.is_online() {
let path = uri.path();
let query = uri.query().map(|q| q.to_string());
return proxy_to_jellyfin_with_retry(
method,
path,
query,
headers,
body_bytes,
&state.jellyfin_client,
{
let state_clone = state.clone();
move || {
let state_clone = state_clone.clone();
Box::pin(async move {
state_clone.update_and_check_status().await
})
}
}
).await;
} else {
error!("Server failed to come online after authentication");
Err(StatusCode::SERVICE_UNAVAILABLE)
}
}
Ok(None) => {
warn!("Authentication failed for user: {} - not starting server", auth_request.username);
Err(StatusCode::UNAUTHORIZED)
}
Err(e) => {
error!("Database error during authentication: {}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
} else {
Err(StatusCode::METHOD_NOT_ALLOWED)
}
}
async fn handle_system_info_request(
State(state): State<Arc<AppState>>,
method: Method,
uri: Uri,
headers: HeaderMap,
body: Body,
) -> Result<Response<Body>, StatusCode> {
state.update_activity();
// Check authentication first
let is_authenticated = check_authentication(&state, &headers, &uri.query().map(|q| q.to_string())).await;
// If server is offline but user is authenticated, power it on
if !state.is_online() && is_authenticated {
ensure_server_online_for_authenticated_request(&state).await?;
}
// If server is online, proxy the request to Jellyfin
if state.is_online() {
let path = uri.path();
let query = uri.query().map(|q| q.to_string());
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes.to_vec(),
Err(e) => {
error!("Failed to read request body: {}", e);
return Err(StatusCode::BAD_REQUEST);
}
};
return proxy_to_jellyfin_with_retry(
method,
path,
query,
headers,
body_bytes,
&state.jellyfin_client,
{
let state_clone = state.clone();
move || {
let state_clone = state_clone.clone();
Box::pin(async move {
state_clone.update_and_check_status().await
})
}
}
).await;
}
// If server is offline, return cached system info (this endpoint is usually public)
let system_info = get_system_info_impl(state).await;
let json_body = serde_json::to_string(&system_info).unwrap();
Ok(Response::builder()
.status(StatusCode::OK)
.header("content-type", "application/json")
.body(Body::from(json_body))
.unwrap())
}
async fn handle_branding_request(
State(state): State<Arc<AppState>>,
method: Method,
uri: Uri,
headers: HeaderMap,
body: Body,
) -> Result<Response<Body>, StatusCode> {
state.update_activity();
// Check authentication first
let is_authenticated = check_authentication(&state, &headers, &uri.query().map(|q| q.to_string())).await;
// If server is offline but user is authenticated, power it on
if !state.is_online() && is_authenticated {
ensure_server_online_for_authenticated_request(&state).await?;
}
// If server is online, proxy the request to Jellyfin
if state.is_online() {
let path = uri.path();
let query = uri.query().map(|q| q.to_string());
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes.to_vec(),
Err(e) => {
error!("Failed to read request body: {}", e);
return Err(StatusCode::BAD_REQUEST);
}
};
return proxy_to_jellyfin_with_retry(
method,
path,
query,
headers,
body_bytes,
&state.jellyfin_client,
{
let state_clone = state.clone();
move || {
let state_clone = state_clone.clone();
Box::pin(async move {
state_clone.update_and_check_status().await
})
}
}
).await;
}
// If server is offline, return offline branding config (this endpoint is usually public)
let branding_config = get_branding_config_impl(state).await;
let json_body = serde_json::to_string(&branding_config).unwrap();
Ok(Response::builder()
.status(StatusCode::OK)
.header("content-type", "application/json")
.body(Body::from(json_body))
.unwrap())
}
async fn get_system_info_impl(state: Arc<AppState>) -> SystemInfo {
if state.is_online() {
// Try to get fresh info from Jellyfin
if let Ok(info) = state.jellyfin_client.get_system_info().await {
return info;
}
}
// Return cached info or default
let cached_info = state.cached_system_info.read().unwrap();
if let Some(mut info) = cached_info.clone() {
if !state.is_online() {
info.server_name = format!("{} (Offline)", info.server_name);
}
info
} else {
SystemInfo {
local_address: "http://localhost:8096".to_string(),
server_name: "Jellyfin Server (Offline)".to_string(),
version: "10.10.6".to_string(),
product_name: "Jellyfin Server".to_string(),
operating_system: "".to_string(),
id: "unknown".to_string(),
startup_wizard_completed: true,
}
}
}
async fn get_branding_config_impl(state: Arc<AppState>) -> BrandingConfig {
if state.is_online() {
if let Ok(config) = state.jellyfin_client.get_branding_config().await {
return config;
}
}
BrandingConfig {
login_disclaimer: "This server is currently offline. Log-in to start the server.".to_string(),
custom_css: "".to_string(),
splashscreen_enabled: true,
}
}
async fn handle_fallback_request(
ws: Option<WebSocketUpgrade>,
State(state): State<Arc<AppState>>,
method: Method,
uri: Uri,
headers: HeaderMap,
body: Body,
) -> Result<Response<Body>, StatusCode> {
// Check if this is a WebSocket upgrade request
if let Some(ws_upgrade) = ws {
return handle_websocket_request(ws_upgrade, state, uri, headers).await;
}
// Handle as regular HTTP request
handle_proxy_request(None, State(state), method, uri, headers, body).await
}
async fn handle_websocket_request(
ws_upgrade: WebSocketUpgrade,
state: Arc<AppState>,
uri: Uri,
headers: HeaderMap,
) -> Result<Response<Body>, StatusCode> {
state.update_activity();
let path = uri.path().to_string();
let query = uri.query().map(|q| q.to_string());
// Check authentication for WebSocket connections
let is_authenticated = check_authentication(&state, &headers, &query).await;
// If server is offline but user is authenticated, power it on
if !state.is_online() && is_authenticated {
ensure_server_online_for_authenticated_request(&state).await?;
}
// Check if server is online for WebSocket connections
if !state.is_online() {
// For WebSocket connections when offline, we need authentication
if !is_authenticated {
return Err(StatusCode::UNAUTHORIZED);
}
}
// Handle WebSocket upgrade
let jellyfin_url = state.jellyfin_client.get_base_url().to_string();
let query_str = query.clone();
let headers_clone = headers.clone();
Ok(ws_upgrade.on_upgrade(move |socket| async move {
proxy_websocket(socket, &jellyfin_url, &path, query_str.as_deref(), &headers_clone).await;
}).into_response())
}
async fn handle_proxy_request(
_ws: Option<WebSocketUpgrade>,
State(state): State<Arc<AppState>>,
method: Method,
uri: Uri,
headers: HeaderMap,
body: Body,
) -> Result<Response<Body>, StatusCode> {
state.update_activity();
let path = uri.path();
let query = uri.query().map(|q| q.to_string());
// Check authentication for all requests
let is_authenticated = check_authentication(&state, &headers, &query).await;
// If server is offline but user is authenticated, power it on
if !state.is_online() && is_authenticated {
ensure_server_online_for_authenticated_request(&state).await?;
}
// Handle regular HTTP requests
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes.to_vec(),
Err(e) => {
error!("Failed to read request body: {}", e);
return Err(StatusCode::BAD_REQUEST);
}
};
if !state.is_online() && !is_authenticated {
return Err(StatusCode::UNAUTHORIZED);
}
// Proxy to Jellyfin
proxy_to_jellyfin_with_retry(
method,
path,
query,
headers,
body_bytes,
&state.jellyfin_client,
{
let state_clone = state.clone();
move || {
let state_clone = state_clone.clone();
Box::pin(async move {
state_clone.update_and_check_status().await
})
}
}
).await
}
async fn check_authentication(
state: &Arc<AppState>,
headers: &HeaderMap,
query: &Option<String>,
) -> bool {
// Check for API key in query parameters
if let Some(query_str) = query {
let params: HashMap<_, _> = url::form_urlencoded::parse(query_str.as_bytes()).collect();
if let Some(api_key) = params.get("api_key") {
if let Ok(Some(_)) = state.validate_token(api_key).await {
debug!("Valid API key found in query parameters");
return true;
}
}
if let Some(device_id) = params.get("deviceId").or_else(|| params.get("DeviceId")) {
if let Ok(Some(_)) = state.validate_device_id(device_id).await {
debug!("Valid device ID found in query parameters");
return true;
}
}
}
// Check authorization header
if let Some(auth_header) = headers.get("authorization") {
if let Ok(header_str) = auth_header.to_str() {
debug!("Checking authorization header: {}", header_str);
// First, try to extract token from MediaBrowser header
if header_str.starts_with("MediaBrowser ") && header_str.contains("Token=") {
if let Some(token_start) = header_str.find("Token=\"") {
let token_start = token_start + 7; // Skip 'Token="'
if let Some(token_end) = header_str[token_start..].find('"') {
let token = &header_str[token_start..token_start + token_end];
debug!("Extracted token from header: {}", token);
if let Ok(Some(_)) = state.validate_token(token).await {
debug!("Valid token found in authorization header");
return true;
} else {
debug!("Token validation failed for: {}", token);
}
}
} else if let Some(token_start) = header_str.find("Token=") {
// Handle case without quotes around token value
let token_start = token_start + 6; // Skip 'Token='
let token_end = header_str[token_start..].find(',').or_else(||
header_str[token_start..].find(' ')).unwrap_or(header_str.len() - token_start);
let token = &header_str[token_start..token_start + token_end].trim_matches('"');
debug!("Extracted token from header (no quotes): {}", token);
if let Ok(Some(_)) = state.validate_token(token).await {
debug!("Valid token found in authorization header (no quotes)");
return true;
} else {
debug!("Token validation failed for: {}", token);
}
}
}
// Also try to parse device ID from header and validate it
if let Some((_, _, device_id, _)) = parse_authorization_header(header_str) {
// URL decode the device ID since it might be encoded
if let Ok(decoded_device_id) = urlencoding::decode(&device_id) {
debug!("Checking device ID: {}", decoded_device_id);
if let Ok(Some(_)) = state.validate_device_id(&decoded_device_id).await {
debug!("Valid device ID found in authorization header");
return true;
}
}
// Also try the non-decoded version
if let Ok(Some(_)) = state.validate_device_id(&device_id).await {
debug!("Valid device ID found in authorization header (non-decoded)");
return true;
}
}
}
}
debug!("No valid authentication found");
false
}
async fn ensure_server_online_for_authenticated_request(state: &Arc<AppState>) -> Result<(), StatusCode> {
// If server is already online, nothing to do
if state.is_online() {
return Ok(());
}
// Check if we're already powering on
if state.is_powering_on() {
// Check if power-on has timed out
if state.is_power_on_timeout() {
error!("Server power-on timed out, resetting power-on state");
*state.is_powering_on.write().unwrap() = false;
*state.power_on_start_time.write().unwrap() = None;
return Err(StatusCode::SERVICE_UNAVAILABLE);
}
// Wait for server to come online or timeout
let max_wait = Duration::from_secs(300); // 5 minutes
let start_check = Instant::now();
info!("Server is being powered on, waiting for it to come online...");
while start_check.elapsed() < max_wait {
// Check if server came online
if state.jellyfin_client.is_online().await {
*state.is_jellyfin_online.write().unwrap() = true;
*state.is_powering_on.write().unwrap() = false;
*state.power_on_start_time.write().unwrap() = None;
info!("Server came online successfully");
return Ok(());
}
// Check if power-on process timed out
if state.is_power_on_timeout() {
error!("Server power-on timed out while waiting");
*state.is_powering_on.write().unwrap() = false;
*state.power_on_start_time.write().unwrap() = None;
return Err(StatusCode::SERVICE_UNAVAILABLE);
}
sleep(tokio::time::Duration::from_secs(5)).await;
}
error!("Timed out waiting for server to come online");
return Err(StatusCode::SERVICE_UNAVAILABLE);
}
// Start power-on process
info!("Server is offline but user is authenticated, powering on");
*state.is_powering_on.write().unwrap() = true;
*state.power_on_start_time.write().unwrap() = Some(Instant::now());
if let Err(e) = power_on_server(&state.config.jellyfin_power_on_command).await {
error!("Failed to power on server: {}", e);
*state.is_powering_on.write().unwrap() = false;
*state.power_on_start_time.write().unwrap() = None;
return Err(StatusCode::SERVICE_UNAVAILABLE);
}
// Wait for server to come online
let max_wait = Duration::from_secs(300); // 5 minutes
let start_time = Instant::now();
while start_time.elapsed() < max_wait {
if state.jellyfin_client.is_online().await {
*state.is_jellyfin_online.write().unwrap() = true;
*state.is_powering_on.write().unwrap() = false;
*state.power_on_start_time.write().unwrap() = None;
info!("Server came online successfully");
return Ok(());
}
sleep(tokio::time::Duration::from_secs(5)).await;
}
error!("Server failed to come online within timeout");
*state.is_powering_on.write().unwrap() = false;
*state.power_on_start_time.write().unwrap() = None;
Err(StatusCode::SERVICE_UNAVAILABLE)
}

73
src/sftp.rs Normal file
View File

@@ -0,0 +1,73 @@
use anyhow::Result;
use sha2::{Digest, Sha256};
use ssh2::Session;
use std::io::Read;
use std::net::TcpStream;
use std::path::Path;
use tracing::{debug, info};
pub struct SftpClient {
host: String,
port: u16,
username: String,
password: String,
}
impl SftpClient {
pub fn new(host: String, port: u16, username: String, password: String) -> Self {
Self {
host,
port,
username,
password,
}
}
pub async fn download_file(&self, remote_path: &str, local_path: &str) -> Result<()> {
let host_port = format!("{}:{}", self.host, self.port);
let tcp = TcpStream::connect(&host_port)?;
let mut sess = Session::new()?;
sess.set_tcp_stream(tcp);
sess.handshake()?;
sess.userauth_password(&self.username, &self.password)?;
let sftp = sess.sftp()?;
let mut remote_file = sftp.open(Path::new(remote_path))?;
let mut contents = Vec::new();
remote_file.read_to_end(&mut contents)?;
std::fs::write(local_path, contents)?;
info!("Downloaded {} to {}", remote_path, local_path);
Ok(())
}
pub async fn get_file_hash(&self, remote_path: &str) -> Result<String> {
let host_port = format!("{}:{}", self.host, self.port);
let tcp = TcpStream::connect(&host_port)?;
let mut sess = Session::new()?;
sess.set_tcp_stream(tcp);
sess.handshake()?;
sess.userauth_password(&self.username, &self.password)?;
let sftp = sess.sftp()?;
let mut remote_file = sftp.open(Path::new(remote_path))?;
let mut contents = Vec::new();
remote_file.read_to_end(&mut contents)?;
let mut hasher = Sha256::new();
hasher.update(&contents);
let hash = format!("{:x}", hasher.finalize());
debug!("File {} hash: {}", remote_path, hash);
Ok(hash)
}
}
pub fn calculate_local_file_hash(file_path: &str) -> Result<String> {
let contents = std::fs::read(file_path)?;
let mut hasher = Sha256::new();
hasher.update(&contents);
let hash = format!("{:x}", hasher.finalize());
Ok(hash)
}

186
src/websocket.rs Normal file
View File

@@ -0,0 +1,186 @@
use axum::{
extract::ws::{Message, WebSocket},
http::HeaderMap,
};
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::{
connect_async,
tungstenite::{Message as TungsteniteMessage, client::IntoClientRequest}
};
use tracing::{debug, error, warn};
use url::Url;
pub async fn proxy_websocket(
socket: WebSocket,
target_url: &str,
path: &str,
query: Option<&str>,
headers: &HeaderMap,
) {
// Build the target WebSocket URL
let ws_url = if target_url.starts_with("http://") {
target_url.replacen("http://", "ws://", 1)
} else if target_url.starts_with("https://") {
target_url.replacen("https://", "wss://", 1)
} else {
format!("ws://{}", target_url)
};
let mut full_url = format!("{}{}", ws_url, path);
if let Some(q) = query {
full_url = format!("{}?{}", full_url, q);
}
debug!("Proxying WebSocket connection to: {}", full_url);
// Parse the URL
let url = match Url::parse(&full_url) {
Ok(url) => url,
Err(e) => {
error!("Invalid WebSocket URL {}: {}", full_url, e);
let _ = socket.close().await;
return;
}
};
// Create request with headers
let mut request = url.into_client_request().unwrap();
// Copy relevant headers from the original request
for (name, value) in headers {
// Skip headers that shouldn't be forwarded
let header_name = name.as_str().to_lowercase();
if header_name == "host"
|| header_name == "connection"
|| header_name == "upgrade"
|| header_name == "sec-websocket-key"
|| header_name == "sec-websocket-version"
|| header_name == "sec-websocket-protocol"
|| header_name == "sec-websocket-extensions" {
continue;
}
if let Ok(_header_value) = value.to_str() {
request.headers_mut().insert(name, value.clone());
}
}
// Connect to the target WebSocket
let (target_ws, _) = match connect_async(request).await {
Ok(result) => result,
Err(e) => {
error!("Failed to connect to target WebSocket: {}", e);
let _ = socket.close().await;
return;
}
};
debug!("Connected to target WebSocket");
let (target_sink, target_stream) = target_ws.split();
// Spawn task to forward messages from client to target
let (client_sink, client_stream) = socket.split();
let client_to_target = async move {
let mut target_sink = target_sink;
let mut client_stream = client_stream;
while let Some(msg) = client_stream.next().await {
match msg {
Ok(Message::Text(text)) => {
if let Err(e) = target_sink.send(TungsteniteMessage::Text(text)).await {
error!("Failed to send text message to target: {}", e);
break;
}
}
Ok(Message::Binary(data)) => {
if let Err(e) = target_sink.send(TungsteniteMessage::Binary(data)).await {
error!("Failed to send binary message to target: {}", e);
break;
}
}
Ok(Message::Ping(data)) => {
if let Err(e) = target_sink.send(TungsteniteMessage::Ping(data)).await {
error!("Failed to send ping to target: {}", e);
break;
}
}
Ok(Message::Pong(data)) => {
if let Err(e) = target_sink.send(TungsteniteMessage::Pong(data)).await {
error!("Failed to send pong to target: {}", e);
break;
}
}
Ok(Message::Close(_)) => {
debug!("Client closed WebSocket connection");
let _ = target_sink.send(TungsteniteMessage::Close(None)).await;
break;
}
Err(e) => {
warn!("WebSocket error from client: {}", e);
break;
}
}
}
};
let target_to_client = async move {
let mut client_sink = client_sink;
let mut target_stream = target_stream;
while let Some(msg) = target_stream.next().await {
match msg {
Ok(TungsteniteMessage::Text(text)) => {
if let Err(e) = client_sink.send(Message::Text(text)).await {
error!("Failed to send text message to client: {}", e);
break;
}
}
Ok(TungsteniteMessage::Binary(data)) => {
if let Err(e) = client_sink.send(Message::Binary(data)).await {
error!("Failed to send binary message to client: {}", e);
break;
}
}
Ok(TungsteniteMessage::Ping(data)) => {
if let Err(e) = client_sink.send(Message::Ping(data)).await {
error!("Failed to send ping to client: {}", e);
break;
}
}
Ok(TungsteniteMessage::Pong(data)) => {
if let Err(e) = client_sink.send(Message::Pong(data)).await {
error!("Failed to send pong to client: {}", e);
break;
}
}
Ok(TungsteniteMessage::Close(_)) => {
debug!("Target closed WebSocket connection");
let _ = client_sink.send(Message::Close(None)).await;
break;
}
Ok(TungsteniteMessage::Frame(_)) => {
// Frame messages are low-level and should be handled automatically
debug!("Received frame message, ignoring");
}
Err(e) => {
warn!("WebSocket error from target: {}", e);
break;
}
}
}
};
// Run both forwarding tasks concurrently
tokio::select! {
_ = client_to_target => {
debug!("Client to target forwarding finished");
}
_ = target_to_client => {
debug!("Target to client forwarding finished");
}
}
debug!("WebSocket proxy connection closed");
}