Add provisioning system to server

This commit is contained in:
2025-09-09 19:06:16 +02:00
parent 88e5f3d694
commit 7b3ae6bb6e
5 changed files with 91 additions and 61 deletions

View File

@@ -1,48 +1,65 @@
use crate::utils::{error::*, models::*, DbPool}; use crate::utils::{base62::Base62, config::ConfigManager, error::*, models::*, DbPool};
use chrono::Utc; use chrono::{Duration, Utc};
use rand::{distributions::Alphanumeric, Rng};
use sqlx::Row; use sqlx::Row;
use uuid::Uuid; use uuid::Uuid;
pub struct MachinesController; pub struct MachinesController;
impl MachinesController { impl MachinesController {
pub async fn register_machine( pub async fn register_machine(pool: &DbPool, user: &User, name: &str) -> AppResult<Machine> {
pool: &DbPool,
code: &str,
uuid: &Uuid,
name: &str,
) -> AppResult<Machine> {
Self::validate_machine_input(name)?; Self::validate_machine_input(name)?;
let provisioning_code = Self::get_provisioning_code(pool, code) let machine_uuid = Uuid::new_v4();
.await?
.ok_or_else(|| validation_error("Invalid provisioning code"))?;
if provisioning_code.used { let machine = Self::create_machine(pool, user.id, &machine_uuid, name).await?;
return Err(validation_error("Provisioning code already used"));
}
if provisioning_code.expires_at < Utc::now() {
return Err(validation_error("Provisioning code expired"));
}
if Self::machine_exists_by_uuid(pool, uuid).await? {
return Err(conflict_error("Machine with this UUID already exists"));
}
let machine = Self::create_machine(pool, provisioning_code.user_id, uuid, name).await?;
Self::mark_provisioning_code_used(pool, code).await?;
Ok(machine) Ok(machine)
} }
pub async fn get_machines_for_user(pool: &DbPool, user: &User) -> AppResult<Vec<Machine>> { pub async fn create_provisioning_code(
if user.role == UserRole::Admin { pool: &DbPool,
Self::get_all_machines(pool).await machine_id: i64,
} else { user: &User,
Self::get_machines_by_user_id(pool, user.id).await ) -> AppResult<ProvisioningCodeResponse> {
let machine = Self::get_machine_by_id(pool, machine_id).await?;
if user.role != UserRole::Admin && machine.user_id != user.id {
return Err(forbidden_error("Access denied"));
} }
let code: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(5)
.map(char::from)
.collect();
let external_url = ConfigManager::get_external_url(pool).await?;
let provisioning_string = format!("52?#{}/{}", external_url, code);
let encoded_code = Base62::encode(&provisioning_string);
let expires_at = Utc::now() + Duration::hours(1);
sqlx::query(
r#"
INSERT INTO provisioning_codes (machine_id, code, expires_at)
VALUES (?, ?, ?)
"#,
)
.bind(machine_id)
.bind(&code)
.bind(expires_at)
.execute(pool)
.await?;
Ok(ProvisioningCodeResponse {
code: encoded_code,
raw_code: code,
expires_at,
})
}
pub async fn get_machines_for_user(pool: &DbPool, user: &User) -> AppResult<Vec<Machine>> {
Self::get_machines_by_user_id(pool, user.id).await
} }
pub async fn delete_machine(pool: &DbPool, machine_id: i64, user: &User) -> AppResult<()> { pub async fn delete_machine(pool: &DbPool, machine_id: i64, user: &User) -> AppResult<()> {
@@ -75,30 +92,6 @@ impl MachinesController {
}) })
} }
async fn get_all_machines(pool: &DbPool) -> AppResult<Vec<Machine>> {
let rows = sqlx::query(
r#"
SELECT id, user_id, uuid, name, created_at
FROM machines ORDER BY created_at DESC
"#,
)
.fetch_all(pool)
.await?;
let mut machines = Vec::new();
for row in rows {
machines.push(Machine {
id: row.get("id"),
user_id: row.get("user_id"),
uuid: Uuid::parse_str(&row.get::<String, _>("uuid")).unwrap(),
name: row.get("name"),
created_at: row.get("created_at"),
});
}
Ok(machines)
}
async fn get_machines_by_user_id(pool: &DbPool, user_id: i64) -> AppResult<Vec<Machine>> { async fn get_machines_by_user_id(pool: &DbPool, user_id: i64) -> AppResult<Vec<Machine>> {
let rows = sqlx::query( let rows = sqlx::query(
r#" r#"
@@ -169,7 +162,7 @@ impl MachinesController {
) -> AppResult<Option<ProvisioningCode>> { ) -> AppResult<Option<ProvisioningCode>> {
let row = sqlx::query( let row = sqlx::query(
r#" r#"
SELECT id, user_id, code, created_at, expires_at, used SELECT id, machine_id, code, created_at, expires_at, used
FROM provisioning_codes WHERE code = ? FROM provisioning_codes WHERE code = ?
"#, "#,
) )
@@ -180,7 +173,7 @@ impl MachinesController {
if let Some(row) = row { if let Some(row) = row {
Ok(Some(ProvisioningCode { Ok(Some(ProvisioningCode {
id: row.get("id"), id: row.get("id"),
user_id: row.get("user_id"), machine_id: row.get("machine_id"),
code: row.get("code"), code: row.get("code"),
created_at: row.get("created_at"), created_at: row.get("created_at"),
expires_at: row.get("expires_at"), expires_at: row.get("expires_at"),

View File

@@ -7,7 +7,7 @@ use axum::{
routing::{delete, get, post, put}, routing::{delete, get, post, put},
Router, Router,
}; };
use routes::{accounts, admin, auth as auth_routes, machines, setup}; use routes::{accounts, admin, auth as auth_routes, config, machines, setup};
use std::path::Path; use std::path::Path;
use tokio::signal; use tokio::signal;
use tower_http::{ use tower_http::{
@@ -29,7 +29,11 @@ async fn main() -> Result<()> {
.route("/admin/users", post(admin::create_user_handler)) .route("/admin/users", post(admin::create_user_handler))
.route("/admin/users/{id}", put(admin::update_user_handler)) .route("/admin/users/{id}", put(admin::update_user_handler))
.route("/admin/users/{id}", delete(admin::delete_user_handler)) .route("/admin/users/{id}", delete(admin::delete_user_handler))
.route("/admin/config", get(config::get_all_configs))
.route("/admin/config", post(config::set_config))
.route("/admin/config/{key}", get(config::get_config))
.route("/machines/register", post(machines::register_machine)) .route("/machines/register", post(machines::register_machine))
.route("/machines/provisioning-code", post(machines::create_provisioning_code))
.route("/machines", get(machines::get_machines)) .route("/machines", get(machines::get_machines))
.route("/machines/{id}", delete(machines::delete_machine)) .route("/machines/{id}", delete(machines::delete_machine))
.layer(CorsLayer::permissive()) .layer(CorsLayer::permissive())

View File

@@ -6,13 +6,13 @@ use axum::{
}; };
pub async fn register_machine( pub async fn register_machine(
auth_user: AuthUser,
State(pool): State<DbPool>, State(pool): State<DbPool>,
Json(request): Json<RegisterMachineRequest>, Json(request): Json<RegisterMachineRequest>,
) -> Result<Json<Machine>, AppError> { ) -> Result<Json<Machine>, AppError> {
let machine = MachinesController::register_machine( let machine = MachinesController::register_machine(
&pool, &pool,
&request.code, &auth_user.user,
&request.uuid,
&request.name, &request.name,
) )
.await?; .await?;
@@ -20,6 +20,21 @@ pub async fn register_machine(
Ok(success_response(machine)) Ok(success_response(machine))
} }
pub async fn create_provisioning_code(
auth_user: AuthUser,
State(pool): State<DbPool>,
Json(request): Json<CreateProvisioningCodeRequest>,
) -> Result<Json<ProvisioningCodeResponse>, AppError> {
let response = MachinesController::create_provisioning_code(
&pool,
request.machine_id,
&auth_user.user,
)
.await?;
Ok(success_response(response))
}
pub async fn get_machines( pub async fn get_machines(
auth_user: AuthUser, auth_user: AuthUser,
State(pool): State<DbPool>, State(pool): State<DbPool>,

View File

@@ -1,5 +1,6 @@
pub mod admin; pub mod admin;
pub mod auth; pub mod auth;
pub mod config;
pub mod machines; pub mod machines;
pub mod setup; pub mod setup;
pub mod accounts; pub mod accounts;

View File

@@ -89,15 +89,32 @@ pub struct Machine {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct RegisterMachineRequest { pub struct RegisterMachineRequest {
pub name: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UseProvisioningCodeRequest {
pub code: String, pub code: String,
pub uuid: Uuid, pub uuid: Uuid,
pub name: String, pub name: String,
} }
#[derive(Debug, Serialize, Deserialize)]
pub struct CreateProvisioningCodeRequest {
pub machine_id: i64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ProvisioningCodeResponse {
pub code: String,
pub raw_code: String,
pub expires_at: DateTime<Utc>,
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct ProvisioningCode { pub struct ProvisioningCode {
pub id: i64, pub id: i64,
pub user_id: i64, pub machine_id: i64,
pub code: String, pub code: String,
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>, pub expires_at: DateTime<Utc>,