Create utility functions

This commit is contained in:
2025-09-08 21:17:03 +02:00
parent f31d10b6e1
commit f03a6935d5
6 changed files with 545 additions and 0 deletions

70
server/src/utils/auth.rs Normal file
View File

@@ -0,0 +1,70 @@
use crate::controllers::auth::AuthController;
use crate::utils::{models::*, DbPool};
use axum::{
async_trait,
extract::FromRequestParts,
http::{header::AUTHORIZATION, request::Parts, StatusCode},
};
#[derive(Clone)]
pub struct AuthUser {
pub user: User,
}
#[async_trait]
impl<S> FromRequestParts<S> for AuthUser
where
S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let auth_header = parts
.headers
.get(AUTHORIZATION)
.and_then(|header| header.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
if !auth_header.starts_with("Bearer ") {
return Err(StatusCode::UNAUTHORIZED);
}
let token = &auth_header[7..];
let pool = parts
.extensions
.get::<DbPool>()
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
let user = AuthController::authenticate_user(pool, token)
.await
.map_err(|_| StatusCode::UNAUTHORIZED)?;
Ok(AuthUser { user })
}
}
#[derive(Clone)]
pub struct AdminUser {
#[allow(dead_code)]
pub user: User,
}
#[async_trait]
impl<S> FromRequestParts<S> for AdminUser
where
S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let auth_user = AuthUser::from_request_parts(parts, _state).await?;
if auth_user.user.role != UserRole::Admin {
return Err(StatusCode::FORBIDDEN);
}
Ok(AdminUser {
user: auth_user.user,
})
}
}

View File

@@ -0,0 +1,111 @@
use crate::utils::{ensure_data_directories, get_database_path, AppResult};
use sqlx::{sqlite::SqlitePool, Pool, Row, Sqlite};
use std::path::Path;
pub type DbPool = Pool<Sqlite>;
pub async fn init_database() -> AppResult<DbPool> {
ensure_data_directories()?;
let db_path = get_database_path()?;
if !Path::new(&db_path).exists() {
std::fs::File::create(&db_path)?;
}
let database_url = format!("sqlite://{}", db_path);
let pool = SqlitePool::connect(&database_url).await?;
run_migrations(&pool).await?;
Ok(pool)
}
async fn run_migrations(pool: &DbPool) -> AppResult<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
role TEXT CHECK(role IN ('admin','user')) NOT NULL,
storage_limit_gb INTEGER NOT NULL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
token TEXT UNIQUE NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
expires_at DATETIME NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS machines (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
uuid TEXT UNIQUE NOT NULL,
name TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS provisioning_codes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
code TEXT UNIQUE NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
expires_at DATETIME NOT NULL,
used BOOLEAN DEFAULT 0,
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
machine_id INTEGER NOT NULL,
snapshot_hash TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(machine_id) REFERENCES machines(id) ON DELETE CASCADE
)
"#,
)
.execute(pool)
.await?;
Ok(())
}
pub async fn check_first_user_exists(pool: &DbPool) -> AppResult<bool> {
let row = sqlx::query("SELECT COUNT(*) as count FROM users")
.fetch_one(pool)
.await?;
let count: i64 = row.get("count");
Ok(count > 0)
}

View File

@@ -0,0 +1,76 @@
use crate::utils::error::{internal_error, AppResult};
use std::fs;
pub fn get_database_path() -> AppResult<String> {
let db_dir = "data/db";
let db_path = format!("{}/arkendro.db", db_dir);
if let Err(e) = fs::create_dir_all(db_dir) {
return Err(internal_error(&format!(
"Failed to create database directory: {}",
e
)));
}
Ok(db_path)
}
pub fn ensure_data_directories() -> AppResult<()> {
let directories = ["data", "data/db", "data/backups", "data/logs"];
for dir in directories.iter() {
if let Err(e) = fs::create_dir_all(dir) {
return Err(internal_error(&format!(
"Failed to create directory '{}': {}",
dir, e
)));
}
}
Ok(())
}
pub fn get_data_path(filename: &str) -> String {
format!("data/{}", filename)
}
pub fn get_backup_path(filename: &str) -> String {
format!("data/backups/{}", filename)
}
pub fn get_log_path(filename: &str) -> String {
format!("data/logs/{}", filename)
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::path::Path;
#[test]
fn test_database_path_creation() {
let _ = fs::remove_dir_all("data");
let db_path = get_database_path().expect("Should create database path");
assert_eq!(db_path, "data/db/arkendro.db");
assert!(Path::new("data/db").exists());
let _ = fs::remove_dir_all("data");
}
#[test]
fn test_ensure_data_directories() {
let _ = fs::remove_dir_all("data");
ensure_data_directories().expect("Should create all directories");
assert!(Path::new("data").exists());
assert!(Path::new("data/db").exists());
assert!(Path::new("data/backups").exists());
assert!(Path::new("data/logs").exists());
let _ = fs::remove_dir_all("data");
}
}

142
server/src/utils/error.rs Normal file
View File

@@ -0,0 +1,142 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Json, Response},
};
use serde_json::json;
use std::fmt;
#[derive(Debug)]
pub enum AppError {
DatabaseError(String),
ValidationError(String),
AuthenticationError(String),
AuthorizationError(String),
NotFoundError(String),
ConflictError(String),
InternalError(String),
}
impl fmt::Display for AppError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AppError::DatabaseError(msg) => write!(f, "Database error: {}", msg),
AppError::ValidationError(msg) => write!(f, "Validation error: {}", msg),
AppError::AuthenticationError(msg) => write!(f, "Authentication error: {}", msg),
AppError::AuthorizationError(msg) => write!(f, "Authorization error: {}", msg),
AppError::NotFoundError(msg) => write!(f, "Not found: {}", msg),
AppError::ConflictError(msg) => write!(f, "Conflict: {}", msg),
AppError::InternalError(msg) => write!(f, "Internal error: {}", msg),
}
}
}
impl std::error::Error for AppError {}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let (status, error_message) = match self {
AppError::DatabaseError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Database error"),
AppError::ValidationError(ref msg) => (StatusCode::BAD_REQUEST, msg.as_str()),
AppError::AuthenticationError(ref msg) => (StatusCode::UNAUTHORIZED, msg.as_str()),
AppError::AuthorizationError(ref msg) => (StatusCode::FORBIDDEN, msg.as_str()),
AppError::NotFoundError(ref msg) => (StatusCode::NOT_FOUND, msg.as_str()),
AppError::ConflictError(ref msg) => (StatusCode::CONFLICT, msg.as_str()),
AppError::InternalError(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error")
}
};
let body = Json(json!({
"error": error_message
}));
(status, body).into_response()
}
}
impl From<anyhow::Error> for AppError {
fn from(err: anyhow::Error) -> Self {
if let Some(sqlx_err) = err.downcast_ref::<sqlx::Error>() {
match sqlx_err {
sqlx::Error::RowNotFound => {
AppError::NotFoundError("Resource not found".to_string())
}
sqlx::Error::Database(db_err) => {
if db_err.message().contains("UNIQUE constraint failed") {
AppError::ConflictError("Resource already exists".to_string())
} else {
AppError::DatabaseError(db_err.message().to_string())
}
}
_ => AppError::DatabaseError("Database operation failed".to_string()),
}
} else {
AppError::InternalError(err.to_string())
}
}
}
impl From<bcrypt::BcryptError> for AppError {
fn from(_: bcrypt::BcryptError) -> Self {
AppError::InternalError("Password hashing error".to_string())
}
}
impl From<sqlx::Error> for AppError {
fn from(err: sqlx::Error) -> Self {
match err {
sqlx::Error::RowNotFound => AppError::NotFoundError("Resource not found".to_string()),
sqlx::Error::Database(db_err) => {
if db_err.message().contains("UNIQUE constraint failed") {
AppError::ConflictError("Resource already exists".to_string())
} else {
AppError::DatabaseError(db_err.message().to_string())
}
}
_ => AppError::DatabaseError("Database operation failed".to_string()),
}
}
}
impl From<std::io::Error> for AppError {
fn from(err: std::io::Error) -> Self {
AppError::InternalError(format!("IO error: {}", err))
}
}
pub type AppResult<T> = Result<T, AppError>;
pub fn validation_error(msg: &str) -> AppError {
AppError::ValidationError(msg.to_string())
}
pub fn auth_error(msg: &str) -> AppError {
AppError::AuthenticationError(msg.to_string())
}
pub fn forbidden_error(msg: &str) -> AppError {
AppError::AuthorizationError(msg.to_string())
}
pub fn not_found_error(msg: &str) -> AppError {
AppError::NotFoundError(msg.to_string())
}
pub fn conflict_error(msg: &str) -> AppError {
AppError::ConflictError(msg.to_string())
}
pub fn internal_error(msg: &str) -> AppError {
AppError::InternalError(msg.to_string())
}
pub fn success_response<T>(data: T) -> Json<T>
where
T: serde::Serialize,
{
Json(data)
}
pub fn success_message(msg: &str) -> Json<serde_json::Value> {
Json(json!({ "message": msg }))
}

9
server/src/utils/mod.rs Normal file
View File

@@ -0,0 +1,9 @@
pub mod auth;
pub mod database;
pub mod db_path;
pub mod error;
pub mod models;
pub use database::*;
pub use db_path::*;
pub use error::*;

137
server/src/utils/models.rs Normal file
View File

@@ -0,0 +1,137 @@
use serde::{Deserialize, Serialize};
use chrono::{DateTime, Utc};
use uuid::Uuid;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct User {
pub id: i64,
pub username: String,
#[serde(skip_serializing)]
pub password_hash: String,
pub role: UserRole,
pub storage_limit_gb: i64,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum UserRole {
Admin,
User,
}
impl std::fmt::Display for UserRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
UserRole::Admin => write!(f, "admin"),
UserRole::User => write!(f, "user"),
}
}
}
impl std::str::FromStr for UserRole {
type Err = &'static str;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"admin" => Ok(UserRole::Admin),
"user" => Ok(UserRole::User),
_ => Err("Invalid role"),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CreateUserRequest {
pub username: String,
pub password: String,
pub role: Option<UserRole>,
pub storage_limit_gb: Option<i64>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UpdateUserRequest {
pub username: Option<String>,
pub password: Option<String>,
pub role: Option<UserRole>,
pub storage_limit_gb: Option<i64>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Session {
pub id: i64,
pub user_id: i64,
pub token: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginRequest {
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginResponse {
pub token: String,
pub role: UserRole,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Machine {
pub id: i64,
pub user_id: i64,
pub uuid: Uuid,
pub name: String,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RegisterMachineRequest {
pub code: String,
pub uuid: Uuid,
pub name: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ProvisioningCode {
pub id: i64,
pub user_id: i64,
pub code: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub used: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Snapshot {
pub id: i64,
pub machine_id: i64,
pub snapshot_hash: String,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SetupStatusResponse {
pub first_user_exists: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct InitSetupRequest {
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ErrorResponse {
pub error: String,
}
impl ErrorResponse {
pub fn new(message: &str) -> Self {
Self {
error: message.to_string(),
}
}
}