From f03a6935d56e2d5fed268b4b0c5eafd9cd6fadc4 Mon Sep 17 00:00:00 2001 From: Mathias Wagner Date: Mon, 8 Sep 2025 21:17:03 +0200 Subject: [PATCH] Create utility functions --- server/src/utils/auth.rs | 70 +++++++++++++++++ server/src/utils/database.rs | 111 +++++++++++++++++++++++++++ server/src/utils/db_path.rs | 76 +++++++++++++++++++ server/src/utils/error.rs | 142 +++++++++++++++++++++++++++++++++++ server/src/utils/mod.rs | 9 +++ server/src/utils/models.rs | 137 +++++++++++++++++++++++++++++++++ 6 files changed, 545 insertions(+) create mode 100644 server/src/utils/auth.rs create mode 100644 server/src/utils/database.rs create mode 100644 server/src/utils/db_path.rs create mode 100644 server/src/utils/error.rs create mode 100644 server/src/utils/mod.rs create mode 100644 server/src/utils/models.rs diff --git a/server/src/utils/auth.rs b/server/src/utils/auth.rs new file mode 100644 index 0000000..70baace --- /dev/null +++ b/server/src/utils/auth.rs @@ -0,0 +1,70 @@ +use crate::controllers::auth::AuthController; +use crate::utils::{models::*, DbPool}; +use axum::{ + async_trait, + extract::FromRequestParts, + http::{header::AUTHORIZATION, request::Parts, StatusCode}, +}; + +#[derive(Clone)] +pub struct AuthUser { + pub user: User, +} + +#[async_trait] +impl FromRequestParts for AuthUser +where + S: Send + Sync, +{ + type Rejection = StatusCode; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let auth_header = parts + .headers + .get(AUTHORIZATION) + .and_then(|header| header.to_str().ok()) + .ok_or(StatusCode::UNAUTHORIZED)?; + + if !auth_header.starts_with("Bearer ") { + return Err(StatusCode::UNAUTHORIZED); + } + + let token = &auth_header[7..]; + let pool = parts + .extensions + .get::() + .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; + + let user = AuthController::authenticate_user(pool, token) + .await + .map_err(|_| StatusCode::UNAUTHORIZED)?; + + Ok(AuthUser { user }) + } +} + +#[derive(Clone)] +pub struct AdminUser { + #[allow(dead_code)] + pub user: User, +} + +#[async_trait] +impl FromRequestParts for AdminUser +where + S: Send + Sync, +{ + type Rejection = StatusCode; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let auth_user = AuthUser::from_request_parts(parts, _state).await?; + + if auth_user.user.role != UserRole::Admin { + return Err(StatusCode::FORBIDDEN); + } + + Ok(AdminUser { + user: auth_user.user, + }) + } +} diff --git a/server/src/utils/database.rs b/server/src/utils/database.rs new file mode 100644 index 0000000..5e348e0 --- /dev/null +++ b/server/src/utils/database.rs @@ -0,0 +1,111 @@ +use crate::utils::{ensure_data_directories, get_database_path, AppResult}; +use sqlx::{sqlite::SqlitePool, Pool, Row, Sqlite}; +use std::path::Path; + +pub type DbPool = Pool; + +pub async fn init_database() -> AppResult { + ensure_data_directories()?; + + let db_path = get_database_path()?; + + if !Path::new(&db_path).exists() { + std::fs::File::create(&db_path)?; + } + + let database_url = format!("sqlite://{}", db_path); + + let pool = SqlitePool::connect(&database_url).await?; + + run_migrations(&pool).await?; + + Ok(pool) +} + +async fn run_migrations(pool: &DbPool) -> AppResult<()> { + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL, + role TEXT CHECK(role IN ('admin','user')) NOT NULL, + storage_limit_gb INTEGER NOT NULL DEFAULT 0, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + "#, + ) + .execute(pool) + .await?; + + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + token TEXT UNIQUE NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + expires_at DATETIME NOT NULL, + FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE + ) + "#, + ) + .execute(pool) + .await?; + + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS machines ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + uuid TEXT UNIQUE NOT NULL, + name TEXT NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE + ) + "#, + ) + .execute(pool) + .await?; + + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS provisioning_codes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + code TEXT UNIQUE NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + expires_at DATETIME NOT NULL, + used BOOLEAN DEFAULT 0, + FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE + ) + "#, + ) + .execute(pool) + .await?; + + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS snapshots ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + machine_id INTEGER NOT NULL, + snapshot_hash TEXT NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY(machine_id) REFERENCES machines(id) ON DELETE CASCADE + ) + "#, + ) + .execute(pool) + .await?; + + Ok(()) +} + +pub async fn check_first_user_exists(pool: &DbPool) -> AppResult { + let row = sqlx::query("SELECT COUNT(*) as count FROM users") + .fetch_one(pool) + .await?; + + let count: i64 = row.get("count"); + Ok(count > 0) +} diff --git a/server/src/utils/db_path.rs b/server/src/utils/db_path.rs new file mode 100644 index 0000000..26d620f --- /dev/null +++ b/server/src/utils/db_path.rs @@ -0,0 +1,76 @@ +use crate::utils::error::{internal_error, AppResult}; +use std::fs; + +pub fn get_database_path() -> AppResult { + let db_dir = "data/db"; + let db_path = format!("{}/arkendro.db", db_dir); + + if let Err(e) = fs::create_dir_all(db_dir) { + return Err(internal_error(&format!( + "Failed to create database directory: {}", + e + ))); + } + + Ok(db_path) +} + +pub fn ensure_data_directories() -> AppResult<()> { + let directories = ["data", "data/db", "data/backups", "data/logs"]; + + for dir in directories.iter() { + if let Err(e) = fs::create_dir_all(dir) { + return Err(internal_error(&format!( + "Failed to create directory '{}': {}", + dir, e + ))); + } + } + + Ok(()) +} + +pub fn get_data_path(filename: &str) -> String { + format!("data/{}", filename) +} + +pub fn get_backup_path(filename: &str) -> String { + format!("data/backups/{}", filename) +} + +pub fn get_log_path(filename: &str) -> String { + format!("data/logs/{}", filename) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use std::path::Path; + + #[test] + fn test_database_path_creation() { + let _ = fs::remove_dir_all("data"); + + let db_path = get_database_path().expect("Should create database path"); + assert_eq!(db_path, "data/db/arkendro.db"); + + assert!(Path::new("data/db").exists()); + + let _ = fs::remove_dir_all("data"); + } + + #[test] + fn test_ensure_data_directories() { + let _ = fs::remove_dir_all("data"); + + ensure_data_directories().expect("Should create all directories"); + + assert!(Path::new("data").exists()); + assert!(Path::new("data/db").exists()); + assert!(Path::new("data/backups").exists()); + assert!(Path::new("data/logs").exists()); + + let _ = fs::remove_dir_all("data"); + } +} diff --git a/server/src/utils/error.rs b/server/src/utils/error.rs new file mode 100644 index 0000000..26442b5 --- /dev/null +++ b/server/src/utils/error.rs @@ -0,0 +1,142 @@ +use axum::{ + http::StatusCode, + response::{IntoResponse, Json, Response}, +}; +use serde_json::json; +use std::fmt; + +#[derive(Debug)] +pub enum AppError { + DatabaseError(String), + ValidationError(String), + AuthenticationError(String), + AuthorizationError(String), + NotFoundError(String), + ConflictError(String), + InternalError(String), +} + +impl fmt::Display for AppError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AppError::DatabaseError(msg) => write!(f, "Database error: {}", msg), + AppError::ValidationError(msg) => write!(f, "Validation error: {}", msg), + AppError::AuthenticationError(msg) => write!(f, "Authentication error: {}", msg), + AppError::AuthorizationError(msg) => write!(f, "Authorization error: {}", msg), + AppError::NotFoundError(msg) => write!(f, "Not found: {}", msg), + AppError::ConflictError(msg) => write!(f, "Conflict: {}", msg), + AppError::InternalError(msg) => write!(f, "Internal error: {}", msg), + } + } +} + +impl std::error::Error for AppError {} + +impl IntoResponse for AppError { + fn into_response(self) -> Response { + let (status, error_message) = match self { + AppError::DatabaseError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Database error"), + AppError::ValidationError(ref msg) => (StatusCode::BAD_REQUEST, msg.as_str()), + AppError::AuthenticationError(ref msg) => (StatusCode::UNAUTHORIZED, msg.as_str()), + AppError::AuthorizationError(ref msg) => (StatusCode::FORBIDDEN, msg.as_str()), + AppError::NotFoundError(ref msg) => (StatusCode::NOT_FOUND, msg.as_str()), + AppError::ConflictError(ref msg) => (StatusCode::CONFLICT, msg.as_str()), + AppError::InternalError(_) => { + (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error") + } + }; + + let body = Json(json!({ + "error": error_message + })); + + (status, body).into_response() + } +} + +impl From for AppError { + fn from(err: anyhow::Error) -> Self { + if let Some(sqlx_err) = err.downcast_ref::() { + match sqlx_err { + sqlx::Error::RowNotFound => { + AppError::NotFoundError("Resource not found".to_string()) + } + sqlx::Error::Database(db_err) => { + if db_err.message().contains("UNIQUE constraint failed") { + AppError::ConflictError("Resource already exists".to_string()) + } else { + AppError::DatabaseError(db_err.message().to_string()) + } + } + _ => AppError::DatabaseError("Database operation failed".to_string()), + } + } else { + AppError::InternalError(err.to_string()) + } + } +} + +impl From for AppError { + fn from(_: bcrypt::BcryptError) -> Self { + AppError::InternalError("Password hashing error".to_string()) + } +} + +impl From for AppError { + fn from(err: sqlx::Error) -> Self { + match err { + sqlx::Error::RowNotFound => AppError::NotFoundError("Resource not found".to_string()), + sqlx::Error::Database(db_err) => { + if db_err.message().contains("UNIQUE constraint failed") { + AppError::ConflictError("Resource already exists".to_string()) + } else { + AppError::DatabaseError(db_err.message().to_string()) + } + } + _ => AppError::DatabaseError("Database operation failed".to_string()), + } + } +} + +impl From for AppError { + fn from(err: std::io::Error) -> Self { + AppError::InternalError(format!("IO error: {}", err)) + } +} + +pub type AppResult = Result; + +pub fn validation_error(msg: &str) -> AppError { + AppError::ValidationError(msg.to_string()) +} + +pub fn auth_error(msg: &str) -> AppError { + AppError::AuthenticationError(msg.to_string()) +} + +pub fn forbidden_error(msg: &str) -> AppError { + AppError::AuthorizationError(msg.to_string()) +} + +pub fn not_found_error(msg: &str) -> AppError { + AppError::NotFoundError(msg.to_string()) +} + +pub fn conflict_error(msg: &str) -> AppError { + AppError::ConflictError(msg.to_string()) +} + +pub fn internal_error(msg: &str) -> AppError { + AppError::InternalError(msg.to_string()) +} + +pub fn success_response(data: T) -> Json +where + T: serde::Serialize, +{ + Json(data) +} + +pub fn success_message(msg: &str) -> Json { + Json(json!({ "message": msg })) +} diff --git a/server/src/utils/mod.rs b/server/src/utils/mod.rs new file mode 100644 index 0000000..9664c9c --- /dev/null +++ b/server/src/utils/mod.rs @@ -0,0 +1,9 @@ +pub mod auth; +pub mod database; +pub mod db_path; +pub mod error; +pub mod models; + +pub use database::*; +pub use db_path::*; +pub use error::*; diff --git a/server/src/utils/models.rs b/server/src/utils/models.rs new file mode 100644 index 0000000..bbc2276 --- /dev/null +++ b/server/src/utils/models.rs @@ -0,0 +1,137 @@ +use serde::{Deserialize, Serialize}; +use chrono::{DateTime, Utc}; +use uuid::Uuid; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct User { + pub id: i64, + pub username: String, + #[serde(skip_serializing)] + pub password_hash: String, + pub role: UserRole, + pub storage_limit_gb: i64, + pub created_at: DateTime, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum UserRole { + Admin, + User, +} + +impl std::fmt::Display for UserRole { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + UserRole::Admin => write!(f, "admin"), + UserRole::User => write!(f, "user"), + } + } +} + +impl std::str::FromStr for UserRole { + type Err = &'static str; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "admin" => Ok(UserRole::Admin), + "user" => Ok(UserRole::User), + _ => Err("Invalid role"), + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CreateUserRequest { + pub username: String, + pub password: String, + pub role: Option, + pub storage_limit_gb: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct UpdateUserRequest { + pub username: Option, + pub password: Option, + pub role: Option, + pub storage_limit_gb: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Session { + pub id: i64, + pub user_id: i64, + pub token: String, + pub created_at: DateTime, + pub expires_at: DateTime, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct LoginRequest { + pub username: String, + pub password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct LoginResponse { + pub token: String, + pub role: UserRole, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Machine { + pub id: i64, + pub user_id: i64, + pub uuid: Uuid, + pub name: String, + pub created_at: DateTime, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RegisterMachineRequest { + pub code: String, + pub uuid: Uuid, + pub name: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ProvisioningCode { + pub id: i64, + pub user_id: i64, + pub code: String, + pub created_at: DateTime, + pub expires_at: DateTime, + pub used: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Snapshot { + pub id: i64, + pub machine_id: i64, + pub snapshot_hash: String, + pub created_at: DateTime, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SetupStatusResponse { + pub first_user_exists: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct InitSetupRequest { + pub username: String, + pub password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ErrorResponse { + pub error: String, +} + +impl ErrorResponse { + pub fn new(message: &str) -> Self { + Self { + error: message.to_string(), + } + } +}