From 4e38b13faafd4e2d1828699ed6c1a0aa1fc32f7c Mon Sep 17 00:00:00 2001 From: Mathias Wagner Date: Tue, 9 Sep 2025 21:02:37 +0200 Subject: [PATCH] Add sync test using AI --- server/Cargo.lock | 93 +++++ server/Cargo.toml | 9 +- server/src/main.rs | 17 +- server/src/sync/meta.rs | 582 +++++++++++++++++++++++++++++++ server/src/sync/mod.rs | 8 + server/src/sync/protocol.rs | 620 ++++++++++++++++++++++++++++++++++ server/src/sync/server.rs | 463 +++++++++++++++++++++++++ server/src/sync/session.rs | 344 +++++++++++++++++++ server/src/sync/storage.rs | 399 ++++++++++++++++++++++ server/src/sync/validation.rs | 233 +++++++++++++ 10 files changed, 2766 insertions(+), 2 deletions(-) create mode 100644 server/src/sync/meta.rs create mode 100644 server/src/sync/mod.rs create mode 100644 server/src/sync/protocol.rs create mode 100644 server/src/sync/server.rs create mode 100644 server/src/sync/session.rs create mode 100644 server/src/sync/storage.rs create mode 100644 server/src/sync/validation.rs diff --git a/server/Cargo.lock b/server/Cargo.lock index b8dd143..c075466 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -38,6 +38,18 @@ version = "1.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100" +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + [[package]] name = "atoi" version = "2.0.0" @@ -153,6 +165,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bitflags" version = "2.9.4" @@ -162,6 +183,19 @@ dependencies = [ "serde", ] +[[package]] +name = "blake3" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -254,6 +288,12 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -364,6 +404,16 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "etcetera" version = "0.8.0" @@ -386,6 +436,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "find-msvc-tools" version = "0.1.1" @@ -903,6 +959,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + [[package]] name = "litemap" version = "0.8.0" @@ -1249,6 +1311,19 @@ version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" +[[package]] +name = "rustix" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.59.0", +] + [[package]] name = "rustls" version = "0.23.31" @@ -1362,11 +1437,16 @@ dependencies = [ "anyhow", "axum", "bcrypt", + "bincode", + "blake3", + "bytes", "chrono", + "hex", "rand", "serde", "serde_json", "sqlx", + "tempfile", "tokio", "tower-http", "uuid", @@ -1712,6 +1792,19 @@ dependencies = [ "syn", ] +[[package]] +name = "tempfile" +version = "3.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84fa4d11fadde498443cca10fd3ac23c951f0dc59e080e9f4b93d4df4e4eea53" +dependencies = [ + "fastrand", + "getrandom 0.3.3", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "thiserror" version = "2.0.16" diff --git a/server/Cargo.toml b/server/Cargo.toml index 6cabc78..3715a4b 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -14,4 +14,11 @@ uuid = { version = "1.0", features = ["v4", "serde"] } chrono = { version = "0.4", features = ["serde"] } tower-http = { version = "0.6.6", features = ["cors", "fs"] } anyhow = "1.0" -rand = "0.8" \ No newline at end of file +rand = "0.8" +blake3 = "1.5" +bytes = "1.0" +bincode = "1.3" +hex = "0.4" + +[dev-dependencies] +tempfile = "3.0" \ No newline at end of file diff --git a/server/src/main.rs b/server/src/main.rs index aef6044..e80de09 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,6 +1,7 @@ mod controllers; mod routes; mod utils; +mod sync; use anyhow::Result; use axum::{ @@ -15,10 +16,14 @@ use tower_http::{ services::{ServeDir, ServeFile}, }; use utils::init_database; +use sync::{SyncServer, server::SyncServerConfig}; #[tokio::main] async fn main() -> Result<()> { let pool = init_database().await?; + + let sync_pool = pool.clone(); + let api_routes = Router::new() .route("/setup/status", get(setup::get_setup_status)) .route("/setup/init", post(setup::init_setup)) @@ -51,8 +56,18 @@ async fn main() -> Result<()> { println!("Warning: dist directory not found at {}", dist_path); } + let sync_config = SyncServerConfig::default(); + let sync_server = SyncServer::new(sync_config.clone(), sync_pool); + + tokio::spawn(async move { + if let Err(e) = sync_server.start().await { + eprintln!("Sync server error: {}", e); + } + }); + let listener = tokio::net::TcpListener::bind("0.0.0.0:8379").await?; - println!("Server running on http://0.0.0.0:8379"); + println!("HTTP server running on http://0.0.0.0:8379"); + println!("Sync server running on {}:{}", sync_config.bind_address, sync_config.port); axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal()) diff --git a/server/src/sync/meta.rs b/server/src/sync/meta.rs new file mode 100644 index 0000000..bdf2252 --- /dev/null +++ b/server/src/sync/meta.rs @@ -0,0 +1,582 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::io::{Error, ErrorKind, Result}; +use crate::sync::protocol::{Hash, MetaType}; + +/// Filesystem type codes +#[repr(u32)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FsType { + Ext = 1, + Ntfs = 2, + Fat32 = 3, + Unknown = 0, +} + +impl From for FsType { + fn from(value: u32) -> Self { + match value { + 1 => FsType::Ext, + 2 => FsType::Ntfs, + 3 => FsType::Fat32, + _ => FsType::Unknown, + } + } +} + +/// Directory entry types +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EntryType { + File = 0, + Dir = 1, + Symlink = 2, +} + +impl TryFrom for EntryType { + type Error = Error; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(EntryType::File), + 1 => Ok(EntryType::Dir), + 2 => Ok(EntryType::Symlink), + _ => Err(Error::new(ErrorKind::InvalidData, "Unknown entry type")), + } + } +} + +/// File metadata object +#[derive(Debug, Clone)] +pub struct FileObj { + pub version: u8, + pub fs_type_code: FsType, + pub size: u64, + pub mode: u32, + pub uid: u32, + pub gid: u32, + pub mtime_unixsec: u64, + pub chunk_hashes: Vec, +} + +impl FileObj { + pub fn new( + fs_type_code: FsType, + size: u64, + mode: u32, + uid: u32, + gid: u32, + mtime_unixsec: u64, + chunk_hashes: Vec, + ) -> Self { + Self { + version: 1, + fs_type_code, + size, + mode, + uid, + gid, + mtime_unixsec, + chunk_hashes, + } + } + + pub fn serialize(&self) -> Result { + let mut buf = BytesMut::new(); + + buf.put_u8(self.version); + buf.put_u32_le(self.fs_type_code as u32); + buf.put_u64_le(self.size); + buf.put_u32_le(self.mode); + buf.put_u32_le(self.uid); + buf.put_u32_le(self.gid); + buf.put_u64_le(self.mtime_unixsec); + buf.put_u32_le(self.chunk_hashes.len() as u32); + + for hash in &self.chunk_hashes { + buf.put_slice(hash); + } + + Ok(buf.freeze()) + } + + pub fn deserialize(mut data: Bytes) -> Result { + if data.remaining() < 41 { + return Err(Error::new(ErrorKind::UnexpectedEof, "FileObj data too short")); + } + + let version = data.get_u8(); + if version != 1 { + return Err(Error::new(ErrorKind::InvalidData, "Unsupported FileObj version")); + } + + let fs_type_code = FsType::from(data.get_u32_le()); + let size = data.get_u64_le(); + let mode = data.get_u32_le(); + let uid = data.get_u32_le(); + let gid = data.get_u32_le(); + let mtime_unixsec = data.get_u64_le(); + let chunk_count = data.get_u32_le() as usize; + + if data.remaining() < chunk_count * 32 { + return Err(Error::new(ErrorKind::UnexpectedEof, "FileObj chunk hashes too short")); + } + + let mut chunk_hashes = Vec::with_capacity(chunk_count); + for _ in 0..chunk_count { + let mut hash = [0u8; 32]; + data.copy_to_slice(&mut hash); + chunk_hashes.push(hash); + } + + Ok(Self { + version, + fs_type_code, + size, + mode, + uid, + gid, + mtime_unixsec, + chunk_hashes, + }) + } + + pub fn compute_hash(&self) -> Result { + let serialized = self.serialize()?; + Ok(blake3::hash(&serialized).into()) + } +} + +/// Directory entry +#[derive(Debug, Clone)] +pub struct DirEntry { + pub entry_type: EntryType, + pub name: String, + pub target_meta_hash: Hash, +} + +/// Directory metadata object +#[derive(Debug, Clone)] +pub struct DirObj { + pub version: u8, + pub entries: Vec, +} + +impl DirObj { + pub fn new(entries: Vec) -> Self { + Self { + version: 1, + entries, + } + } + + pub fn serialize(&self) -> Result { + let mut buf = BytesMut::new(); + + buf.put_u8(self.version); + buf.put_u32_le(self.entries.len() as u32); + + for entry in &self.entries { + buf.put_u8(entry.entry_type as u8); + let name_bytes = entry.name.as_bytes(); + buf.put_u16_le(name_bytes.len() as u16); + buf.put_slice(name_bytes); + buf.put_slice(&entry.target_meta_hash); + } + + Ok(buf.freeze()) + } + + pub fn deserialize(mut data: Bytes) -> Result { + if data.remaining() < 5 { + return Err(Error::new(ErrorKind::UnexpectedEof, "DirObj data too short")); + } + + let version = data.get_u8(); + if version != 1 { + return Err(Error::new(ErrorKind::InvalidData, "Unsupported DirObj version")); + } + + let entry_count = data.get_u32_le() as usize; + let mut entries = Vec::with_capacity(entry_count); + + for _ in 0..entry_count { + if data.remaining() < 35 { + return Err(Error::new(ErrorKind::UnexpectedEof, "DirObj entry too short")); + } + + let entry_type = EntryType::try_from(data.get_u8())?; + let name_len = data.get_u16_le() as usize; + + if data.remaining() < name_len + 32 { + return Err(Error::new(ErrorKind::UnexpectedEof, "DirObj entry name/hash too short")); + } + + let name = String::from_utf8(data.copy_to_bytes(name_len).to_vec()) + .map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in entry name"))?; + + let mut target_meta_hash = [0u8; 32]; + data.copy_to_slice(&mut target_meta_hash); + + entries.push(DirEntry { + entry_type, + name, + target_meta_hash, + }); + } + + Ok(Self { + version, + entries, + }) + } + + pub fn compute_hash(&self) -> Result { + let serialized = self.serialize()?; + Ok(blake3::hash(&serialized).into()) + } +} + +/// Partition metadata object +#[derive(Debug, Clone)] +pub struct PartitionObj { + pub version: u8, + pub fs_type_code: FsType, + pub root_dir_hash: Hash, + pub start_lba: u64, + pub end_lba: u64, + pub type_guid: [u8; 16], +} + +impl PartitionObj { + pub fn new( + fs_type_code: FsType, + root_dir_hash: Hash, + start_lba: u64, + end_lba: u64, + type_guid: [u8; 16], + ) -> Self { + Self { + version: 1, + fs_type_code, + root_dir_hash, + start_lba, + end_lba, + type_guid, + } + } + + pub fn serialize(&self) -> Result { + let mut buf = BytesMut::new(); + + buf.put_u8(self.version); + buf.put_u32_le(self.fs_type_code as u32); + buf.put_slice(&self.root_dir_hash); + buf.put_u64_le(self.start_lba); + buf.put_u64_le(self.end_lba); + buf.put_slice(&self.type_guid); + + Ok(buf.freeze()) + } + + pub fn deserialize(mut data: Bytes) -> Result { + if data.remaining() < 69 { + return Err(Error::new(ErrorKind::UnexpectedEof, "PartitionObj data too short")); + } + + let version = data.get_u8(); + if version != 1 { + return Err(Error::new(ErrorKind::InvalidData, "Unsupported PartitionObj version")); + } + + let fs_type_code = FsType::from(data.get_u32_le()); + + let mut root_dir_hash = [0u8; 32]; + data.copy_to_slice(&mut root_dir_hash); + + let start_lba = data.get_u64_le(); + let end_lba = data.get_u64_le(); + + let mut type_guid = [0u8; 16]; + data.copy_to_slice(&mut type_guid); + + Ok(Self { + version, + fs_type_code, + root_dir_hash, + start_lba, + end_lba, + type_guid, + }) + } + + pub fn compute_hash(&self) -> Result { + let serialized = self.serialize()?; + Ok(blake3::hash(&serialized).into()) + } +} + +/// Disk metadata object +#[derive(Debug, Clone)] +pub struct DiskObj { + pub version: u8, + pub partition_hashes: Vec, + pub disk_size_bytes: u64, + pub serial: String, +} + +impl DiskObj { + pub fn new(partition_hashes: Vec, disk_size_bytes: u64, serial: String) -> Self { + Self { + version: 1, + partition_hashes, + disk_size_bytes, + serial, + } + } + + pub fn serialize(&self) -> Result { + let mut buf = BytesMut::new(); + + buf.put_u8(self.version); + buf.put_u32_le(self.partition_hashes.len() as u32); + + for hash in &self.partition_hashes { + buf.put_slice(hash); + } + + buf.put_u64_le(self.disk_size_bytes); + + let serial_bytes = self.serial.as_bytes(); + buf.put_u16_le(serial_bytes.len() as u16); + buf.put_slice(serial_bytes); + + Ok(buf.freeze()) + } + + pub fn deserialize(mut data: Bytes) -> Result { + if data.remaining() < 15 { + return Err(Error::new(ErrorKind::UnexpectedEof, "DiskObj data too short")); + } + + let version = data.get_u8(); + if version != 1 { + return Err(Error::new(ErrorKind::InvalidData, "Unsupported DiskObj version")); + } + + let partition_count = data.get_u32_le() as usize; + + if data.remaining() < partition_count * 32 + 10 { + return Err(Error::new(ErrorKind::UnexpectedEof, "DiskObj partitions too short")); + } + + let mut partition_hashes = Vec::with_capacity(partition_count); + for _ in 0..partition_count { + let mut hash = [0u8; 32]; + data.copy_to_slice(&mut hash); + partition_hashes.push(hash); + } + + let disk_size_bytes = data.get_u64_le(); + let serial_len = data.get_u16_le() as usize; + + if data.remaining() < serial_len { + return Err(Error::new(ErrorKind::UnexpectedEof, "DiskObj serial too short")); + } + + let serial = String::from_utf8(data.copy_to_bytes(serial_len).to_vec()) + .map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in serial"))?; + + Ok(Self { + version, + partition_hashes, + disk_size_bytes, + serial, + }) + } + + pub fn compute_hash(&self) -> Result { + let serialized = self.serialize()?; + Ok(blake3::hash(&serialized).into()) + } +} + +/// Snapshot metadata object +#[derive(Debug, Clone)] +pub struct SnapshotObj { + pub version: u8, + pub created_at_unixsec: u64, + pub disk_hashes: Vec, +} + +impl SnapshotObj { + pub fn new(created_at_unixsec: u64, disk_hashes: Vec) -> Self { + Self { + version: 1, + created_at_unixsec, + disk_hashes, + } + } + + pub fn serialize(&self) -> Result { + let mut buf = BytesMut::new(); + + buf.put_u8(self.version); + buf.put_u64_le(self.created_at_unixsec); + buf.put_u32_le(self.disk_hashes.len() as u32); + + for hash in &self.disk_hashes { + buf.put_slice(hash); + } + + Ok(buf.freeze()) + } + + pub fn deserialize(mut data: Bytes) -> Result { + if data.remaining() < 13 { + return Err(Error::new(ErrorKind::UnexpectedEof, "SnapshotObj data too short")); + } + + let version = data.get_u8(); + if version != 1 { + return Err(Error::new(ErrorKind::InvalidData, "Unsupported SnapshotObj version")); + } + + let created_at_unixsec = data.get_u64_le(); + let disk_count = data.get_u32_le() as usize; + + if data.remaining() < disk_count * 32 { + return Err(Error::new(ErrorKind::UnexpectedEof, "SnapshotObj disk hashes too short")); + } + + let mut disk_hashes = Vec::with_capacity(disk_count); + for _ in 0..disk_count { + let mut hash = [0u8; 32]; + data.copy_to_slice(&mut hash); + disk_hashes.push(hash); + } + + Ok(Self { + version, + created_at_unixsec, + disk_hashes, + }) + } + + pub fn compute_hash(&self) -> Result { + let serialized = self.serialize()?; + Ok(blake3::hash(&serialized).into()) + } +} + +/// Meta object wrapper +#[derive(Debug, Clone)] +pub enum MetaObj { + File(FileObj), + Dir(DirObj), + Partition(PartitionObj), + Disk(DiskObj), + Snapshot(SnapshotObj), +} + +impl MetaObj { + pub fn meta_type(&self) -> MetaType { + match self { + MetaObj::File(_) => MetaType::File, + MetaObj::Dir(_) => MetaType::Dir, + MetaObj::Partition(_) => MetaType::Partition, + MetaObj::Disk(_) => MetaType::Disk, + MetaObj::Snapshot(_) => MetaType::Snapshot, + } + } + + pub fn serialize(&self) -> Result { + match self { + MetaObj::File(obj) => obj.serialize(), + MetaObj::Dir(obj) => obj.serialize(), + MetaObj::Partition(obj) => obj.serialize(), + MetaObj::Disk(obj) => obj.serialize(), + MetaObj::Snapshot(obj) => obj.serialize(), + } + } + + pub fn deserialize(meta_type: MetaType, data: Bytes) -> Result { + match meta_type { + MetaType::File => Ok(MetaObj::File(FileObj::deserialize(data)?)), + MetaType::Dir => Ok(MetaObj::Dir(DirObj::deserialize(data)?)), + MetaType::Partition => Ok(MetaObj::Partition(PartitionObj::deserialize(data)?)), + MetaType::Disk => Ok(MetaObj::Disk(DiskObj::deserialize(data)?)), + MetaType::Snapshot => Ok(MetaObj::Snapshot(SnapshotObj::deserialize(data)?)), + } + } + + pub fn compute_hash(&self) -> Result { + match self { + MetaObj::File(obj) => obj.compute_hash(), + MetaObj::Dir(obj) => obj.compute_hash(), + MetaObj::Partition(obj) => obj.compute_hash(), + MetaObj::Disk(obj) => obj.compute_hash(), + MetaObj::Snapshot(obj) => obj.compute_hash(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_file_obj_serialization() { + let obj = FileObj::new( + FsType::Ext, + 1024, + 0o644, + 1000, + 1000, + 1234567890, + vec![[1; 32], [2; 32]], + ); + + let serialized = obj.serialize().unwrap(); + let deserialized = FileObj::deserialize(serialized).unwrap(); + + assert_eq!(obj.fs_type_code, deserialized.fs_type_code); + assert_eq!(obj.size, deserialized.size); + assert_eq!(obj.chunk_hashes, deserialized.chunk_hashes); + } + + #[test] + fn test_dir_obj_serialization() { + let entries = vec![ + DirEntry { + entry_type: EntryType::File, + name: "test.txt".to_string(), + target_meta_hash: [1; 32], + }, + DirEntry { + entry_type: EntryType::Dir, + name: "subdir".to_string(), + target_meta_hash: [2; 32], + }, + ]; + + let obj = DirObj::new(entries); + let serialized = obj.serialize().unwrap(); + let deserialized = DirObj::deserialize(serialized).unwrap(); + + assert_eq!(obj.entries.len(), deserialized.entries.len()); + assert_eq!(obj.entries[0].name, deserialized.entries[0].name); + assert_eq!(obj.entries[1].entry_type, deserialized.entries[1].entry_type); + } + + #[test] + fn test_hash_computation() { + let obj = FileObj::new(FsType::Ext, 1024, 0o644, 1000, 1000, 1234567890, vec![]); + let hash1 = obj.compute_hash().unwrap(); + let hash2 = obj.compute_hash().unwrap(); + assert_eq!(hash1, hash2); + + let obj2 = FileObj::new(FsType::Ext, 1025, 0o644, 1000, 1000, 1234567890, vec![]); + let hash3 = obj2.compute_hash().unwrap(); + assert_ne!(hash1, hash3); + } +} diff --git a/server/src/sync/mod.rs b/server/src/sync/mod.rs new file mode 100644 index 0000000..cc3927d --- /dev/null +++ b/server/src/sync/mod.rs @@ -0,0 +1,8 @@ +pub mod protocol; +pub mod server; +pub mod storage; +pub mod session; +pub mod meta; +pub mod validation; + +pub use server::SyncServer; diff --git a/server/src/sync/protocol.rs b/server/src/sync/protocol.rs new file mode 100644 index 0000000..b0fdeaf --- /dev/null +++ b/server/src/sync/protocol.rs @@ -0,0 +1,620 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::io::{Error, ErrorKind, Result}; + +/// Command codes for the sync protocol +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Command { + Hello = 0x01, + HelloOk = 0x02, + AuthUserPass = 0x10, + AuthCode = 0x11, + AuthOk = 0x12, + AuthFail = 0x13, + BatchCheckChunk = 0x20, + CheckChunkResp = 0x21, + SendChunk = 0x22, + ChunkOk = 0x23, + ChunkFail = 0x24, + BatchCheckMeta = 0x30, + CheckMetaResp = 0x31, + SendMeta = 0x32, + MetaOk = 0x33, + MetaFail = 0x34, + SendSnapshot = 0x40, + SnapshotOk = 0x41, + SnapshotFail = 0x42, + Close = 0xFF, +} + +impl TryFrom for Command { + type Error = Error; + + fn try_from(value: u8) -> Result { + match value { + 0x01 => Ok(Command::Hello), + 0x02 => Ok(Command::HelloOk), + 0x10 => Ok(Command::AuthUserPass), + 0x11 => Ok(Command::AuthCode), + 0x12 => Ok(Command::AuthOk), + 0x13 => Ok(Command::AuthFail), + 0x20 => Ok(Command::BatchCheckChunk), + 0x21 => Ok(Command::CheckChunkResp), + 0x22 => Ok(Command::SendChunk), + 0x23 => Ok(Command::ChunkOk), + 0x24 => Ok(Command::ChunkFail), + 0x30 => Ok(Command::BatchCheckMeta), + 0x31 => Ok(Command::CheckMetaResp), + 0x32 => Ok(Command::SendMeta), + 0x33 => Ok(Command::MetaOk), + 0x34 => Ok(Command::MetaFail), + 0x40 => Ok(Command::SendSnapshot), + 0x41 => Ok(Command::SnapshotOk), + 0x42 => Ok(Command::SnapshotFail), + 0xFF => Ok(Command::Close), + _ => Err(Error::new(ErrorKind::InvalidData, "Unknown command code")), + } + } +} + +/// Message header structure (24 bytes fixed) +#[derive(Debug, Clone)] +pub struct MessageHeader { + pub cmd: Command, + pub flags: u8, + pub reserved: [u8; 2], + pub session_id: [u8; 16], + pub payload_len: u32, +} + +impl MessageHeader { + pub const SIZE: usize = 24; + + pub fn new(cmd: Command, session_id: [u8; 16], payload_len: u32) -> Self { + Self { + cmd, + flags: 0, + reserved: [0; 2], + session_id, + payload_len, + } + } + + pub fn serialize(&self) -> [u8; Self::SIZE] { + let mut buf = [0u8; Self::SIZE]; + buf[0] = self.cmd as u8; + buf[1] = self.flags; + buf[2..4].copy_from_slice(&self.reserved); + buf[4..20].copy_from_slice(&self.session_id); + buf[20..24].copy_from_slice(&self.payload_len.to_le_bytes()); + buf + } + + pub fn deserialize(buf: &[u8]) -> Result { + if buf.len() < Self::SIZE { + return Err(Error::new(ErrorKind::UnexpectedEof, "Header too short")); + } + + let cmd = Command::try_from(buf[0])?; + let flags = buf[1]; + let reserved = [buf[2], buf[3]]; + let mut session_id = [0u8; 16]; + session_id.copy_from_slice(&buf[4..20]); + let payload_len = u32::from_le_bytes([buf[20], buf[21], buf[22], buf[23]]); + + Ok(Self { + cmd, + flags, + reserved, + session_id, + payload_len, + }) + } +} + +/// A 32-byte BLAKE3 hash +pub type Hash = [u8; 32]; + +/// Meta object types +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum MetaType { + File = 1, + Dir = 2, + Partition = 3, + Disk = 4, + Snapshot = 5, +} + +impl TryFrom for MetaType { + type Error = Error; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(MetaType::File), + 2 => Ok(MetaType::Dir), + 3 => Ok(MetaType::Partition), + 4 => Ok(MetaType::Disk), + 5 => Ok(MetaType::Snapshot), + _ => Err(Error::new(ErrorKind::InvalidData, "Unknown meta type")), + } + } +} + +/// Protocol message types +#[derive(Debug, Clone)] +pub enum Message { + Hello { + client_type: u8, + auth_type: u8, + }, + HelloOk, + AuthUserPass { + username: String, + password: String, + machine_id: i64, + }, + AuthCode { + code: String, + }, + AuthOk { + session_id: [u8; 16], + }, + AuthFail { + reason: String, + }, + BatchCheckChunk { + hashes: Vec, + }, + CheckChunkResp { + missing_hashes: Vec, + }, + SendChunk { + hash: Hash, + data: Bytes, + }, + ChunkOk, + ChunkFail { + reason: String, + }, + BatchCheckMeta { + items: Vec<(MetaType, Hash)>, + }, + CheckMetaResp { + missing_items: Vec<(MetaType, Hash)>, + }, + SendMeta { + meta_type: MetaType, + meta_hash: Hash, + body: Bytes, + }, + MetaOk, + MetaFail { + reason: String, + }, + SendSnapshot { + snapshot_hash: Hash, + body: Bytes, + }, + SnapshotOk { + snapshot_id: String, + }, + SnapshotFail { + missing_chunks: Vec, + missing_metas: Vec<(MetaType, Hash)>, + }, + Close, +} + +impl Message { + /// Serialize message payload to bytes + pub fn serialize_payload(&self) -> Result { + let mut buf = BytesMut::new(); + + match self { + Message::Hello { client_type, auth_type } => { + buf.put_u8(*client_type); + buf.put_u8(*auth_type); + } + Message::HelloOk => { + // No payload + } + Message::AuthUserPass { username, password, machine_id } => { + let username_bytes = username.as_bytes(); + let password_bytes = password.as_bytes(); + buf.put_u16_le(username_bytes.len() as u16); + buf.put_slice(username_bytes); + buf.put_u16_le(password_bytes.len() as u16); + buf.put_slice(password_bytes); + buf.put_i64_le(*machine_id); + } + Message::AuthCode { code } => { + let code_bytes = code.as_bytes(); + buf.put_u16_le(code_bytes.len() as u16); + buf.put_slice(code_bytes); + } + Message::AuthOk { session_id } => { + buf.put_slice(session_id); + } + Message::AuthFail { reason } => { + let reason_bytes = reason.as_bytes(); + buf.put_u16_le(reason_bytes.len() as u16); + buf.put_slice(reason_bytes); + } + Message::BatchCheckChunk { hashes } => { + buf.put_u32_le(hashes.len() as u32); + for hash in hashes { + buf.put_slice(hash); + } + } + Message::CheckChunkResp { missing_hashes } => { + buf.put_u32_le(missing_hashes.len() as u32); + for hash in missing_hashes { + buf.put_slice(hash); + } + } + Message::SendChunk { hash, data } => { + buf.put_slice(hash); + buf.put_u32_le(data.len() as u32); + buf.put_slice(data); + } + Message::ChunkOk => { + // No payload + } + Message::ChunkFail { reason } => { + let reason_bytes = reason.as_bytes(); + buf.put_u16_le(reason_bytes.len() as u16); + buf.put_slice(reason_bytes); + } + Message::BatchCheckMeta { items } => { + buf.put_u32_le(items.len() as u32); + for (meta_type, hash) in items { + buf.put_u8(*meta_type as u8); + buf.put_slice(hash); + } + } + Message::CheckMetaResp { missing_items } => { + buf.put_u32_le(missing_items.len() as u32); + for (meta_type, hash) in missing_items { + buf.put_u8(*meta_type as u8); + buf.put_slice(hash); + } + } + Message::SendMeta { meta_type, meta_hash, body } => { + buf.put_u8(*meta_type as u8); + buf.put_slice(meta_hash); + buf.put_u32_le(body.len() as u32); + buf.put_slice(body); + } + Message::MetaOk => { + // No payload + } + Message::MetaFail { reason } => { + let reason_bytes = reason.as_bytes(); + buf.put_u16_le(reason_bytes.len() as u16); + buf.put_slice(reason_bytes); + } + Message::SendSnapshot { snapshot_hash, body } => { + buf.put_slice(snapshot_hash); + buf.put_u32_le(body.len() as u32); + buf.put_slice(body); + } + Message::SnapshotOk { snapshot_id } => { + let id_bytes = snapshot_id.as_bytes(); + buf.put_u16_le(id_bytes.len() as u16); + buf.put_slice(id_bytes); + } + Message::SnapshotFail { missing_chunks, missing_metas } => { + buf.put_u32_le(missing_chunks.len() as u32); + for hash in missing_chunks { + buf.put_slice(hash); + } + buf.put_u32_le(missing_metas.len() as u32); + for (meta_type, hash) in missing_metas { + buf.put_u8(*meta_type as u8); + buf.put_slice(hash); + } + } + Message::Close => { + // No payload + } + } + + Ok(buf.freeze()) + } + + /// Deserialize message payload from bytes + pub fn deserialize_payload(cmd: Command, mut payload: Bytes) -> Result { + match cmd { + Command::Hello => { + if payload.remaining() < 2 { + return Err(Error::new(ErrorKind::UnexpectedEof, "Hello payload too short")); + } + let client_type = payload.get_u8(); + let auth_type = payload.get_u8(); + Ok(Message::Hello { client_type, auth_type }) + } + Command::HelloOk => Ok(Message::HelloOk), + Command::AuthUserPass => { + if payload.remaining() < 12 { // 4 bytes for lengths + at least 8 bytes for machine_id + return Err(Error::new(ErrorKind::UnexpectedEof, "AuthUserPass payload too short")); + } + let username_len = payload.get_u16_le() as usize; + if payload.remaining() < username_len + 10 { // 2 bytes for password len + 8 bytes for machine_id + return Err(Error::new(ErrorKind::UnexpectedEof, "AuthUserPass username too short")); + } + let username = String::from_utf8(payload.copy_to_bytes(username_len).to_vec()) + .map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in username"))?; + let password_len = payload.get_u16_le() as usize; + if payload.remaining() < password_len + 8 { // 8 bytes for machine_id + return Err(Error::new(ErrorKind::UnexpectedEof, "AuthUserPass password too short")); + } + let password = String::from_utf8(payload.copy_to_bytes(password_len).to_vec()) + .map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in password"))?; + let machine_id = payload.get_i64_le(); + Ok(Message::AuthUserPass { username, password, machine_id }) + } + Command::AuthCode => { + if payload.remaining() < 2 { + return Err(Error::new(ErrorKind::UnexpectedEof, "AuthCode payload too short")); + } + let code_len = payload.get_u16_le() as usize; + if payload.remaining() < code_len { + return Err(Error::new(ErrorKind::UnexpectedEof, "AuthCode code too short")); + } + let code = String::from_utf8(payload.copy_to_bytes(code_len).to_vec()) + .map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in code"))?; + Ok(Message::AuthCode { code }) + } + Command::AuthOk => { + if payload.remaining() < 16 { + return Err(Error::new(ErrorKind::UnexpectedEof, "AuthOk payload too short")); + } + let mut session_id = [0u8; 16]; + payload.copy_to_slice(&mut session_id); + Ok(Message::AuthOk { session_id }) + } + Command::AuthFail => { + if payload.remaining() < 2 { + return Err(Error::new(ErrorKind::UnexpectedEof, "AuthFail payload too short")); + } + let reason_len = payload.get_u16_le() as usize; + if payload.remaining() < reason_len { + return Err(Error::new(ErrorKind::UnexpectedEof, "AuthFail reason too short")); + } + let reason = String::from_utf8(payload.copy_to_bytes(reason_len).to_vec()) + .map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in reason"))?; + Ok(Message::AuthFail { reason }) + } + Command::BatchCheckChunk => { + if payload.remaining() < 4 { + return Err(Error::new(ErrorKind::UnexpectedEof, "BatchCheckChunk payload too short")); + } + let count = payload.get_u32_le() as usize; + if payload.remaining() < count * 32 { + return Err(Error::new(ErrorKind::UnexpectedEof, "BatchCheckChunk hashes too short")); + } + let mut hashes = Vec::with_capacity(count); + for _ in 0..count { + let mut hash = [0u8; 32]; + payload.copy_to_slice(&mut hash); + hashes.push(hash); + } + Ok(Message::BatchCheckChunk { hashes }) + } + Command::CheckChunkResp => { + if payload.remaining() < 4 { + return Err(Error::new(ErrorKind::UnexpectedEof, "CheckChunkResp payload too short")); + } + let count = payload.get_u32_le() as usize; + if payload.remaining() < count * 32 { + return Err(Error::new(ErrorKind::UnexpectedEof, "CheckChunkResp hashes too short")); + } + let mut missing_hashes = Vec::with_capacity(count); + for _ in 0..count { + let mut hash = [0u8; 32]; + payload.copy_to_slice(&mut hash); + missing_hashes.push(hash); + } + Ok(Message::CheckChunkResp { missing_hashes }) + } + Command::SendChunk => { + if payload.remaining() < 36 { + return Err(Error::new(ErrorKind::UnexpectedEof, "SendChunk payload too short")); + } + let mut hash = [0u8; 32]; + payload.copy_to_slice(&mut hash); + let size = payload.get_u32_le() as usize; + if payload.remaining() < size { + return Err(Error::new(ErrorKind::UnexpectedEof, "SendChunk data too short")); + } + let data = payload.copy_to_bytes(size); + Ok(Message::SendChunk { hash, data }) + } + Command::ChunkOk => Ok(Message::ChunkOk), + Command::ChunkFail => { + if payload.remaining() < 2 { + return Err(Error::new(ErrorKind::UnexpectedEof, "ChunkFail payload too short")); + } + let reason_len = payload.get_u16_le() as usize; + if payload.remaining() < reason_len { + return Err(Error::new(ErrorKind::UnexpectedEof, "ChunkFail reason too short")); + } + let reason = String::from_utf8(payload.copy_to_bytes(reason_len).to_vec()) + .map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in reason"))?; + Ok(Message::ChunkFail { reason }) + } + Command::BatchCheckMeta => { + if payload.remaining() < 4 { + return Err(Error::new(ErrorKind::UnexpectedEof, "BatchCheckMeta payload too short")); + } + let count = payload.get_u32_le() as usize; + if payload.remaining() < count * 33 { + return Err(Error::new(ErrorKind::UnexpectedEof, "BatchCheckMeta items too short")); + } + let mut items = Vec::with_capacity(count); + for _ in 0..count { + let meta_type = MetaType::try_from(payload.get_u8())?; + let mut hash = [0u8; 32]; + payload.copy_to_slice(&mut hash); + items.push((meta_type, hash)); + } + Ok(Message::BatchCheckMeta { items }) + } + Command::CheckMetaResp => { + if payload.remaining() < 4 { + return Err(Error::new(ErrorKind::UnexpectedEof, "CheckMetaResp payload too short")); + } + let count = payload.get_u32_le() as usize; + if payload.remaining() < count * 33 { + return Err(Error::new(ErrorKind::UnexpectedEof, "CheckMetaResp items too short")); + } + let mut missing_items = Vec::with_capacity(count); + for _ in 0..count { + let meta_type = MetaType::try_from(payload.get_u8())?; + let mut hash = [0u8; 32]; + payload.copy_to_slice(&mut hash); + missing_items.push((meta_type, hash)); + } + Ok(Message::CheckMetaResp { missing_items }) + } + Command::SendMeta => { + if payload.remaining() < 37 { + return Err(Error::new(ErrorKind::UnexpectedEof, "SendMeta payload too short")); + } + let meta_type = MetaType::try_from(payload.get_u8())?; + let mut meta_hash = [0u8; 32]; + payload.copy_to_slice(&mut meta_hash); + let body_len = payload.get_u32_le() as usize; + if payload.remaining() < body_len { + return Err(Error::new(ErrorKind::UnexpectedEof, "SendMeta body too short")); + } + let body = payload.copy_to_bytes(body_len); + Ok(Message::SendMeta { meta_type, meta_hash, body }) + } + Command::MetaOk => Ok(Message::MetaOk), + Command::MetaFail => { + if payload.remaining() < 2 { + return Err(Error::new(ErrorKind::UnexpectedEof, "MetaFail payload too short")); + } + let reason_len = payload.get_u16_le() as usize; + if payload.remaining() < reason_len { + return Err(Error::new(ErrorKind::UnexpectedEof, "MetaFail reason too short")); + } + let reason = String::from_utf8(payload.copy_to_bytes(reason_len).to_vec()) + .map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in reason"))?; + Ok(Message::MetaFail { reason }) + } + Command::SendSnapshot => { + if payload.remaining() < 36 { + return Err(Error::new(ErrorKind::UnexpectedEof, "SendSnapshot payload too short")); + } + let mut snapshot_hash = [0u8; 32]; + payload.copy_to_slice(&mut snapshot_hash); + let body_len = payload.get_u32_le() as usize; + if payload.remaining() < body_len { + return Err(Error::new(ErrorKind::UnexpectedEof, "SendSnapshot body too short")); + } + let body = payload.copy_to_bytes(body_len); + Ok(Message::SendSnapshot { snapshot_hash, body }) + } + Command::SnapshotOk => { + if payload.remaining() < 2 { + return Err(Error::new(ErrorKind::UnexpectedEof, "SnapshotOk payload too short")); + } + let id_len = payload.get_u16_le() as usize; + if payload.remaining() < id_len { + return Err(Error::new(ErrorKind::UnexpectedEof, "SnapshotOk id too short")); + } + let snapshot_id = String::from_utf8(payload.copy_to_bytes(id_len).to_vec()) + .map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in snapshot_id"))?; + Ok(Message::SnapshotOk { snapshot_id }) + } + Command::SnapshotFail => { + if payload.remaining() < 8 { + return Err(Error::new(ErrorKind::UnexpectedEof, "SnapshotFail payload too short")); + } + let chunk_count = payload.get_u32_le() as usize; + if payload.remaining() < chunk_count * 32 + 4 { + return Err(Error::new(ErrorKind::UnexpectedEof, "SnapshotFail chunks too short")); + } + let mut missing_chunks = Vec::with_capacity(chunk_count); + for _ in 0..chunk_count { + let mut hash = [0u8; 32]; + payload.copy_to_slice(&mut hash); + missing_chunks.push(hash); + } + let meta_count = payload.get_u32_le() as usize; + if payload.remaining() < meta_count * 33 { + return Err(Error::new(ErrorKind::UnexpectedEof, "SnapshotFail metas too short")); + } + let mut missing_metas = Vec::with_capacity(meta_count); + for _ in 0..meta_count { + let meta_type = MetaType::try_from(payload.get_u8())?; + let mut hash = [0u8; 32]; + payload.copy_to_slice(&mut hash); + missing_metas.push((meta_type, hash)); + } + Ok(Message::SnapshotFail { missing_chunks, missing_metas }) + } + Command::Close => Ok(Message::Close), + } + } + + /// Get the command for this message + pub fn command(&self) -> Command { + match self { + Message::Hello { .. } => Command::Hello, + Message::HelloOk => Command::HelloOk, + Message::AuthUserPass { .. } => Command::AuthUserPass, + Message::AuthCode { .. } => Command::AuthCode, + Message::AuthOk { .. } => Command::AuthOk, + Message::AuthFail { .. } => Command::AuthFail, + Message::BatchCheckChunk { .. } => Command::BatchCheckChunk, + Message::CheckChunkResp { .. } => Command::CheckChunkResp, + Message::SendChunk { .. } => Command::SendChunk, + Message::ChunkOk => Command::ChunkOk, + Message::ChunkFail { .. } => Command::ChunkFail, + Message::BatchCheckMeta { .. } => Command::BatchCheckMeta, + Message::CheckMetaResp { .. } => Command::CheckMetaResp, + Message::SendMeta { .. } => Command::SendMeta, + Message::MetaOk => Command::MetaOk, + Message::MetaFail { .. } => Command::MetaFail, + Message::SendSnapshot { .. } => Command::SendSnapshot, + Message::SnapshotOk { .. } => Command::SnapshotOk, + Message::SnapshotFail { .. } => Command::SnapshotFail, + Message::Close => Command::Close, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_header_serialization() { + let header = MessageHeader::new(Command::Hello, [1; 16], 42); + let serialized = header.serialize(); + let deserialized = MessageHeader::deserialize(&serialized).unwrap(); + + assert_eq!(deserialized.cmd, Command::Hello); + assert_eq!(deserialized.session_id, [1; 16]); + assert_eq!(deserialized.payload_len, 42); + } + + #[test] + fn test_hello_message() { + let msg = Message::Hello { client_type: 1, auth_type: 2 }; + let payload = msg.serialize_payload().unwrap(); + let deserialized = Message::deserialize_payload(Command::Hello, payload).unwrap(); + + match deserialized { + Message::Hello { client_type, auth_type } => { + assert_eq!(client_type, 1); + assert_eq!(auth_type, 2); + } + _ => panic!("Wrong message type"), + } + } +} diff --git a/server/src/sync/server.rs b/server/src/sync/server.rs new file mode 100644 index 0000000..0b9c43d --- /dev/null +++ b/server/src/sync/server.rs @@ -0,0 +1,463 @@ +use anyhow::{Context, Result}; +use bytes::Bytes; +use sqlx::SqlitePool; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use uuid::Uuid; + +use crate::sync::protocol::{Command, Message, MessageHeader, MetaType}; +use crate::sync::session::{SessionManager, session_cleanup_task}; +use crate::sync::storage::Storage; +use crate::sync::validation::SnapshotValidator; + +/// Configuration for the sync server +#[derive(Debug, Clone)] +pub struct SyncServerConfig { + pub bind_address: String, + pub port: u16, + pub data_dir: String, + pub max_connections: usize, + pub chunk_size_limit: usize, + pub meta_size_limit: usize, + pub batch_limit: usize, +} + +impl Default for SyncServerConfig { + fn default() -> Self { + Self { + bind_address: "0.0.0.0".to_string(), + port: 8380, + data_dir: "./data".to_string(), + max_connections: 100, + chunk_size_limit: 4 * 1024 * 1024, // 4 MiB + meta_size_limit: 1024 * 1024, // 1 MiB + batch_limit: 1000, + } + } +} + +/// Main sync server +pub struct SyncServer { + config: SyncServerConfig, + storage: Storage, + session_manager: Arc, + validator: SnapshotValidator, +} + +impl SyncServer { + pub fn new(config: SyncServerConfig, db_pool: SqlitePool) -> Self { + let storage = Storage::new(&config.data_dir); + let session_manager = Arc::new(SessionManager::new(db_pool)); + let validator = SnapshotValidator::new(storage.clone()); + + Self { + config, + storage, + session_manager, + validator, + } + } + + /// Start the sync server + pub async fn start(&self) -> Result<()> { + // Initialize storage + self.storage.init().await + .context("Failed to initialize storage")?; + + let bind_addr = format!("{}:{}", self.config.bind_address, self.config.port); + let listener = TcpListener::bind(&bind_addr).await + .with_context(|| format!("Failed to bind to {}", bind_addr))?; + + println!("Sync server listening on {}", bind_addr); + + // Start session cleanup task + let session_manager_clone = Arc::clone(&self.session_manager); + tokio::spawn(async move { + session_cleanup_task(session_manager_clone).await; + }); + + // Accept connections + loop { + match listener.accept().await { + Ok((stream, addr)) => { + println!("New sync connection from {}", addr); + + let handler = ConnectionHandler::new( + stream, + self.storage.clone(), + Arc::clone(&self.session_manager), + self.validator.clone(), + self.config.clone(), + ); + + tokio::spawn(async move { + if let Err(e) = handler.handle().await { + eprintln!("Connection error from {}: {}", addr, e); + } + }); + } + Err(e) => { + eprintln!("Failed to accept connection: {}", e); + } + } + } + } +} + +/// Connection handler for individual sync clients +struct ConnectionHandler { + stream: TcpStream, + storage: Storage, + session_manager: Arc, + validator: SnapshotValidator, + config: SyncServerConfig, + session_id: Option<[u8; 16]>, + machine_id: Option, +} + +impl ConnectionHandler { + fn new( + stream: TcpStream, + storage: Storage, + session_manager: Arc, + validator: SnapshotValidator, + config: SyncServerConfig, + ) -> Self { + Self { + stream, + storage, + session_manager, + validator, + config, + session_id: None, + machine_id: None, + } + } + + /// Handle the connection + async fn handle(mut self) -> Result<()> { + loop { + // Read message header + let header = self.read_header().await?; + + // Read payload + let payload = if header.payload_len > 0 { + self.read_payload(header.payload_len).await? + } else { + Bytes::new() + }; + + // Parse message + let message = Message::deserialize_payload(header.cmd, payload) + .context("Failed to deserialize message")?; + + // Handle message + let response = self.handle_message(message).await?; + + // Send response + if let Some(response_msg) = response { + self.send_message(response_msg).await?; + } + + // Close connection if requested + if header.cmd == Command::Close { + break; + } + } + + // Clean up session + if let Some(session_id) = self.session_id { + self.session_manager.remove_session(&session_id).await; + } + + Ok(()) + } + + /// Read message header + async fn read_header(&mut self) -> Result { + let mut header_buf = [0u8; MessageHeader::SIZE]; + self.stream.read_exact(&mut header_buf).await + .context("Failed to read message header")?; + + MessageHeader::deserialize(&header_buf) + .context("Failed to parse message header") + } + + /// Read message payload + async fn read_payload(&mut self, len: u32) -> Result { + if len as usize > self.config.meta_size_limit { + return Err(anyhow::anyhow!("Payload too large: {} bytes", len)); + } + + let mut payload_buf = vec![0u8; len as usize]; + self.stream.read_exact(&mut payload_buf).await + .context("Failed to read message payload")?; + + Ok(Bytes::from(payload_buf)) + } + + /// Send a message + async fn send_message(&mut self, message: Message) -> Result<()> { + let session_id = self.session_id.unwrap_or([0u8; 16]); + let payload = message.serialize_payload()?; + + let header = MessageHeader::new(message.command(), session_id, payload.len() as u32); + let header_bytes = header.serialize(); + + self.stream.write_all(&header_bytes).await + .context("Failed to write message header")?; + + if !payload.is_empty() { + self.stream.write_all(&payload).await + .context("Failed to write message payload")?; + } + + self.stream.flush().await + .context("Failed to flush stream")?; + + Ok(()) + } + + /// Handle a received message + async fn handle_message(&mut self, message: Message) -> Result> { + match message { + Message::Hello { client_type: _, auth_type: _ } => { + Ok(Some(Message::HelloOk)) + } + + Message::AuthUserPass { username, password, machine_id } => { + match self.session_manager.authenticate_userpass(&username, &password, machine_id).await { + Ok(session) => { + self.session_id = Some(session.session_id); + self.machine_id = Some(session.machine_id); + Ok(Some(Message::AuthOk { session_id: session.session_id })) + } + Err(e) => { + Ok(Some(Message::AuthFail { reason: e.to_string() })) + } + } + } + + Message::AuthCode { code } => { + match self.session_manager.authenticate_code(&code).await { + Ok(session) => { + self.session_id = Some(session.session_id); + self.machine_id = Some(session.machine_id); + Ok(Some(Message::AuthOk { session_id: session.session_id })) + } + Err(e) => { + Ok(Some(Message::AuthFail { reason: e.to_string() })) + } + } + } + + Message::BatchCheckChunk { hashes } => { + self.require_auth()?; + + if hashes.len() > self.config.batch_limit { + return Err(anyhow::anyhow!("Batch size exceeds limit: {}", hashes.len())); + } + + let missing_hashes = self.validator.validate_chunk_batch(&hashes).await?; + Ok(Some(Message::CheckChunkResp { missing_hashes })) + } + + Message::SendChunk { hash, data } => { + self.require_auth()?; + + if data.len() > self.config.chunk_size_limit { + return Ok(Some(Message::ChunkFail { + reason: format!("Chunk too large: {} bytes", data.len()) + })); + } + + match self.storage.store_chunk(&hash, &data).await { + Ok(()) => Ok(Some(Message::ChunkOk)), + Err(e) => Ok(Some(Message::ChunkFail { reason: e.to_string() })), + } + } + + Message::BatchCheckMeta { items } => { + self.require_auth()?; + + if items.len() > self.config.batch_limit { + return Err(anyhow::anyhow!("Batch size exceeds limit: {}", items.len())); + } + + let missing_items = self.validator.validate_meta_batch(&items).await?; + Ok(Some(Message::CheckMetaResp { missing_items })) + } + + Message::SendMeta { meta_type, meta_hash, body } => { + self.require_auth()?; + + if body.len() > self.config.meta_size_limit { + return Ok(Some(Message::MetaFail { + reason: format!("Meta object too large: {} bytes", body.len()) + })); + } + + match self.storage.store_meta(meta_type, &meta_hash, &body).await { + Ok(()) => Ok(Some(Message::MetaOk)), + Err(e) => Ok(Some(Message::MetaFail { reason: e.to_string() })), + } + } + + Message::SendSnapshot { snapshot_hash, body } => { + self.require_auth()?; + + if body.len() > self.config.meta_size_limit { + return Ok(Some(Message::SnapshotFail { + missing_chunks: vec![], + missing_metas: vec![], + })); + } + + // Validate snapshot + match self.validator.validate_snapshot(&snapshot_hash, &body).await { + Ok(validation_result) => { + if validation_result.is_valid { + // Store snapshot meta + if let Err(_e) = self.storage.store_meta(MetaType::Snapshot, &snapshot_hash, &body).await { + return Ok(Some(Message::SnapshotFail { + missing_chunks: vec![], + missing_metas: vec![], + })); + } + + // Create snapshot reference + let snapshot_id = Uuid::new_v4().to_string(); + let machine_id = self.machine_id.as_ref().unwrap(); + let created_at = chrono::Utc::now().timestamp() as u64; + + if let Err(_e) = self.storage.store_snapshot_ref( + machine_id, + &snapshot_id, + &snapshot_hash, + created_at + ).await { + return Ok(Some(Message::SnapshotFail { + missing_chunks: vec![], + missing_metas: vec![], + })); + } + + // Store snapshot in database + let machine_id_num: i64 = machine_id.parse().unwrap_or(0); + let snapshot_hash_hex = hex::encode(snapshot_hash); + if let Err(_e) = sqlx::query!( + "INSERT INTO snapshots (machine_id, snapshot_hash) VALUES (?, ?)", + machine_id_num, + snapshot_hash_hex + ) + .execute(self.session_manager.get_db_pool()) + .await { + return Ok(Some(Message::SnapshotFail { + missing_chunks: vec![], + missing_metas: vec![], + })); + } + + Ok(Some(Message::SnapshotOk { snapshot_id })) + } else { + Ok(Some(Message::SnapshotFail { + missing_chunks: validation_result.missing_chunks, + missing_metas: validation_result.missing_metas, + })) + } + } + Err(_e) => { + Ok(Some(Message::SnapshotFail { + missing_chunks: vec![], + missing_metas: vec![], + })) + } + } + } + + Message::Close => { + Ok(None) // No response needed + } + + // These are response messages that shouldn't be received by the server + Message::HelloOk | Message::AuthOk { .. } | Message::AuthFail { .. } | + Message::CheckChunkResp { .. } | Message::ChunkOk | Message::ChunkFail { .. } | + Message::CheckMetaResp { .. } | Message::MetaOk | Message::MetaFail { .. } | + Message::SnapshotOk { .. } | Message::SnapshotFail { .. } => { + Err(anyhow::anyhow!("Unexpected response message from client")) + } + } + } + + /// Require authentication for protected operations + fn require_auth(&self) -> Result<()> { + if self.session_id.is_none() { + return Err(anyhow::anyhow!("Authentication required")); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + use sqlx::sqlite::SqlitePoolOptions; + + async fn setup_test_server() -> (SyncServer, TempDir) { + let temp_dir = TempDir::new().unwrap(); + + let pool = SqlitePoolOptions::new() + .connect(":memory:") + .await + .unwrap(); + + // Create required tables + sqlx::query!( + r#" + CREATE TABLE users ( + id INTEGER PRIMARY KEY, + username TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL, + active INTEGER DEFAULT 1 + ) + "# + ) + .execute(&pool) + .await + .unwrap(); + + sqlx::query!( + r#" + CREATE TABLE provisioning_codes ( + id INTEGER PRIMARY KEY, + code TEXT UNIQUE NOT NULL, + created_by INTEGER NOT NULL, + expires_at TEXT NOT NULL, + used INTEGER DEFAULT 0, + used_at TEXT, + FOREIGN KEY (created_by) REFERENCES users (id) + ) + "# + ) + .execute(&pool) + .await + .unwrap(); + + let config = SyncServerConfig { + data_dir: temp_dir.path().to_string_lossy().to_string(), + ..Default::default() + }; + + (SyncServer::new(config, pool), temp_dir) + } + + #[tokio::test] + async fn test_server_creation() { + let (server, _temp_dir) = setup_test_server().await; + + // Initialize storage to verify everything works + server.storage.init().await.unwrap(); + } +} diff --git a/server/src/sync/session.rs b/server/src/sync/session.rs new file mode 100644 index 0000000..abfa9a9 --- /dev/null +++ b/server/src/sync/session.rs @@ -0,0 +1,344 @@ +use anyhow::{Context, Result}; +use rand::RngCore; +use sqlx::SqlitePool; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use uuid::Uuid; + +/// Session information +#[derive(Debug, Clone)] +pub struct Session { + pub session_id: [u8; 16], + pub machine_id: String, + pub user_id: i64, + pub created_at: chrono::DateTime, +} + +/// Session manager for sync connections +#[derive(Debug)] +pub struct SessionManager { + sessions: Arc>>, + db_pool: SqlitePool, +} + +impl SessionManager { + pub fn new(db_pool: SqlitePool) -> Self { + Self { + sessions: Arc::new(RwLock::new(HashMap::new())), + db_pool, + } + } + + /// Get database pool reference + pub fn get_db_pool(&self) -> &SqlitePool { + &self.db_pool + } + + /// Generate a new session ID + fn generate_session_id() -> [u8; 16] { + let mut session_id = [0u8; 16]; + rand::thread_rng().fill_bytes(&mut session_id); + session_id + } + + /// Authenticate with username/password and validate machine ownership + pub async fn authenticate_userpass(&self, username: &str, password: &str, machine_id: i64) -> Result { + // Query user from database + let user = sqlx::query!( + "SELECT id, username, password_hash FROM users WHERE username = ?", + username + ) + .fetch_optional(&self.db_pool) + .await + .context("Failed to query user")?; + + let user = user.ok_or_else(|| anyhow::anyhow!("Invalid credentials"))?; + + // Verify password + if !bcrypt::verify(password, &user.password_hash) + .context("Failed to verify password")? { + return Err(anyhow::anyhow!("Invalid credentials")); + } + + let user_id = user.id.unwrap_or(0) as i64; + + // Validate machine ownership + let machine = sqlx::query!( + "SELECT id, user_id FROM machines WHERE id = ?", + machine_id + ) + .fetch_optional(&self.db_pool) + .await + .context("Failed to query machine")?; + + let machine = machine.ok_or_else(|| anyhow::anyhow!("Machine not found"))?; + + let machine_user_id = machine.user_id; + if machine_user_id != user_id { + return Err(anyhow::anyhow!("Machine does not belong to user")); + } + + // Create session + let session_id = Self::generate_session_id(); + let machine_id_str = machine_id.to_string(); + let session = Session { + session_id, + machine_id: machine_id_str, + user_id, + created_at: chrono::Utc::now(), + }; + + // Store session + let mut sessions = self.sessions.write().await; + sessions.insert(session_id, session.clone()); + + Ok(session) + } + + /// Authenticate with provisioning code + pub async fn authenticate_code(&self, code: &str) -> Result { + // Query provisioning code from database + let provisioning_code = sqlx::query!( + r#" + SELECT pc.id, pc.code, pc.expires_at, pc.used, m.user_id, u.username + FROM provisioning_codes pc + JOIN machines m ON pc.machine_id = m.id + JOIN users u ON m.user_id = u.id + WHERE pc.code = ? AND pc.used = 0 + "#, + code + ) + .fetch_optional(&self.db_pool) + .await + .context("Failed to query provisioning code")?; + + let provisioning_code = provisioning_code + .ok_or_else(|| anyhow::anyhow!("Invalid or used provisioning code"))?; + + // Check if code is expired + let expires_at: chrono::DateTime = chrono::DateTime::from_naive_utc_and_offset( + provisioning_code.expires_at, + chrono::Utc + ); + + if chrono::Utc::now() > expires_at { + return Err(anyhow::anyhow!("Provisioning code expired")); + } + + // Mark code as used + sqlx::query!( + "UPDATE provisioning_codes SET used = 1 WHERE id = ?", + provisioning_code.id + ) + .execute(&self.db_pool) + .await + .context("Failed to mark provisioning code as used")?; + + // Create session + let session_id = Self::generate_session_id(); + let machine_id = format!("machine-{}", Uuid::new_v4()); + let session = Session { + session_id, + machine_id, + user_id: provisioning_code.user_id as i64, + created_at: chrono::Utc::now(), + }; + + // Store session + let mut sessions = self.sessions.write().await; + sessions.insert(session_id, session.clone()); + + Ok(session) + } + + /// Get session by session ID + pub async fn get_session(&self, session_id: &[u8; 16]) -> Option { + let sessions = self.sessions.read().await; + sessions.get(session_id).cloned() + } + + /// Validate session and return associated machine ID + pub async fn validate_session(&self, session_id: &[u8; 16]) -> Result { + let session = self.get_session(session_id).await + .ok_or_else(|| anyhow::anyhow!("Invalid session"))?; + + // Check if session is too old (24 hours) + let session_age = chrono::Utc::now() - session.created_at; + if session_age > chrono::Duration::hours(24) { + // Remove expired session + let mut sessions = self.sessions.write().await; + sessions.remove(session_id); + return Err(anyhow::anyhow!("Session expired")); + } + + Ok(session.machine_id) + } + + /// Remove session + pub async fn remove_session(&self, session_id: &[u8; 16]) { + let mut sessions = self.sessions.write().await; + sessions.remove(session_id); + } + + /// Clean up expired sessions + pub async fn cleanup_expired_sessions(&self) { + let mut sessions = self.sessions.write().await; + let now = chrono::Utc::now(); + + sessions.retain(|_, session| { + let age = now - session.created_at; + age <= chrono::Duration::hours(24) + }); + } + + /// Get active session count + pub async fn active_session_count(&self) -> usize { + let sessions = self.sessions.read().await; + sessions.len() + } + + /// List active sessions + pub async fn list_active_sessions(&self) -> Vec { + let sessions = self.sessions.read().await; + sessions.values().cloned().collect() + } +} + +/// Periodic cleanup task for expired sessions +pub async fn session_cleanup_task(session_manager: Arc) { + let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(3600)); // Every hour + + loop { + interval.tick().await; + session_manager.cleanup_expired_sessions().await; + println!("Cleaned up expired sync sessions. Active sessions: {}", + session_manager.active_session_count().await); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use sqlx::sqlite::SqlitePoolOptions; + + async fn setup_test_db() -> SqlitePool { + let pool = SqlitePoolOptions::new() + .connect(":memory:") + .await + .unwrap(); + + // Create tables + sqlx::query!( + r#" + CREATE TABLE users ( + id INTEGER PRIMARY KEY, + username TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL, + active INTEGER DEFAULT 1 + ) + "# + ) + .execute(&pool) + .await + .unwrap(); + + sqlx::query!( + r#" + CREATE TABLE provisioning_codes ( + id INTEGER PRIMARY KEY, + code TEXT UNIQUE NOT NULL, + created_by INTEGER NOT NULL, + expires_at TEXT NOT NULL, + used INTEGER DEFAULT 0, + used_at TEXT, + FOREIGN KEY (created_by) REFERENCES users (id) + ) + "# + ) + .execute(&pool) + .await + .unwrap(); + + // Insert test user + let password_hash = bcrypt::hash("password123", bcrypt::DEFAULT_COST).unwrap(); + sqlx::query!( + "INSERT INTO users (username, password_hash) VALUES (?, ?)", + "testuser", + password_hash + ) + .execute(&pool) + .await + .unwrap(); + + pool + } + + #[tokio::test] + async fn test_authenticate_userpass() { + let pool = setup_test_db().await; + let session_manager = SessionManager::new(pool); + + let session = session_manager + .authenticate_userpass("testuser", "password123") + .await + .unwrap(); + + assert_eq!(session.user_id, 1); + assert!(!session.machine_id.is_empty()); + } + + #[tokio::test] + async fn test_authenticate_userpass_invalid() { + let pool = setup_test_db().await; + let session_manager = SessionManager::new(pool); + + let result = session_manager + .authenticate_userpass("testuser", "wrongpassword") + .await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_session_validation() { + let pool = setup_test_db().await; + let session_manager = SessionManager::new(pool); + + let session = session_manager + .authenticate_userpass("testuser", "password123") + .await + .unwrap(); + + let machine_id = session_manager + .validate_session(&session.session_id) + .await + .unwrap(); + + assert_eq!(machine_id, session.machine_id); + } + + #[tokio::test] + async fn test_session_cleanup() { + let pool = setup_test_db().await; + let session_manager = SessionManager::new(pool); + + let session = session_manager + .authenticate_userpass("testuser", "password123") + .await + .unwrap(); + + assert_eq!(session_manager.active_session_count().await, 1); + + // Manually expire the session + { + let mut sessions = session_manager.sessions.write().await; + if let Some(mut session) = sessions.get_mut(&session.session_id) { + session.created_at = chrono::Utc::now() - chrono::Duration::hours(25); + } + } + + session_manager.cleanup_expired_sessions().await; + assert_eq!(session_manager.active_session_count().await, 0); + } +} diff --git a/server/src/sync/storage.rs b/server/src/sync/storage.rs new file mode 100644 index 0000000..bf2d664 --- /dev/null +++ b/server/src/sync/storage.rs @@ -0,0 +1,399 @@ +use anyhow::{Context, Result}; +use bytes::Bytes; +use std::collections::HashSet; +use std::path::{Path, PathBuf}; +use tokio::fs; +use crate::sync::protocol::{Hash, MetaType}; +use crate::sync::meta::MetaObj; + +/// Storage backend for chunks and meta objects +#[derive(Debug, Clone)] +pub struct Storage { + data_dir: PathBuf, +} + +impl Storage { + pub fn new>(data_dir: P) -> Self { + Self { + data_dir: data_dir.as_ref().to_path_buf(), + } + } + + /// Initialize storage directories + pub async fn init(&self) -> Result<()> { + let chunks_dir = self.data_dir.join("sync").join("chunks"); + let meta_dir = self.data_dir.join("sync").join("meta"); + let machines_dir = self.data_dir.join("sync").join("machines"); + + fs::create_dir_all(&chunks_dir).await + .context("Failed to create chunks directory")?; + + fs::create_dir_all(&meta_dir).await + .context("Failed to create meta directory")?; + + fs::create_dir_all(&machines_dir).await + .context("Failed to create machines directory")?; + + // Create subdirectories for each meta type + for meta_type in &["files", "dirs", "partitions", "disks", "snapshots"] { + fs::create_dir_all(meta_dir.join(meta_type)).await + .with_context(|| format!("Failed to create meta/{} directory", meta_type))?; + } + + Ok(()) + } + + /// Get chunk storage path for a hash + fn chunk_path(&self, hash: &Hash) -> PathBuf { + let hex = hex::encode(hash); + let ab = &hex[0..2]; + let cd = &hex[2..4]; + let filename = format!("{}.chk", hex); + + self.data_dir + .join("sync") + .join("chunks") + .join(ab) + .join(cd) + .join(filename) + } + + /// Get meta storage path for a hash and type + fn meta_path(&self, meta_type: MetaType, hash: &Hash) -> PathBuf { + let hex = hex::encode(hash); + let ab = &hex[0..2]; + let cd = &hex[2..4]; + let filename = format!("{}.meta", hex); + + let type_dir = match meta_type { + MetaType::File => "files", + MetaType::Dir => "dirs", + MetaType::Partition => "partitions", + MetaType::Disk => "disks", + MetaType::Snapshot => "snapshots", + }; + + self.data_dir + .join("sync") + .join("meta") + .join(type_dir) + .join(ab) + .join(cd) + .join(filename) + } + + /// Check if a chunk exists + pub async fn chunk_exists(&self, hash: &Hash) -> bool { + let path = self.chunk_path(hash); + path.exists() + } + + /// Check if multiple chunks exist + pub async fn chunks_exist(&self, hashes: &[Hash]) -> Result> { + let mut existing = HashSet::new(); + + for hash in hashes { + if self.chunk_exists(hash).await { + existing.insert(*hash); + } + } + + Ok(existing) + } + + /// Store a chunk + pub async fn store_chunk(&self, hash: &Hash, data: &[u8]) -> Result<()> { + // Verify hash + let computed_hash = blake3::hash(data); + if computed_hash.as_bytes() != hash { + return Err(anyhow::anyhow!("Chunk hash mismatch")); + } + + let path = self.chunk_path(hash); + + // Create parent directories + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).await + .context("Failed to create chunk directory")?; + } + + // Write to temporary file first, then rename (atomic write) + let temp_path = path.with_extension("tmp"); + fs::write(&temp_path, data).await + .context("Failed to write chunk to temporary file")?; + + fs::rename(&temp_path, &path).await + .context("Failed to rename chunk file")?; + + Ok(()) + } + + /// Load a chunk + pub async fn load_chunk(&self, hash: &Hash) -> Result> { + let path = self.chunk_path(hash); + + if !path.exists() { + return Ok(None); + } + + let data = fs::read(&path).await + .context("Failed to read chunk file")?; + + // Verify hash + let computed_hash = blake3::hash(&data); + if computed_hash.as_bytes() != hash { + return Err(anyhow::anyhow!("Stored chunk hash mismatch")); + } + + Ok(Some(Bytes::from(data))) + } + + /// Check if a meta object exists + pub async fn meta_exists(&self, meta_type: MetaType, hash: &Hash) -> bool { + let path = self.meta_path(meta_type, hash); + path.exists() + } + + /// Check if multiple meta objects exist + pub async fn metas_exist(&self, items: &[(MetaType, Hash)]) -> Result> { + let mut existing = HashSet::new(); + + for &(meta_type, hash) in items { + if self.meta_exists(meta_type, &hash).await { + existing.insert((meta_type, hash)); + } + } + + Ok(existing) + } + + /// Store a meta object + pub async fn store_meta(&self, meta_type: MetaType, hash: &Hash, body: &[u8]) -> Result<()> { + // Verify hash + let computed_hash = blake3::hash(body); + if computed_hash.as_bytes() != hash { + return Err(anyhow::anyhow!("Meta object hash mismatch")); + } + + let path = self.meta_path(meta_type, hash); + + // Create parent directories + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).await + .context("Failed to create meta directory")?; + } + + // Write to temporary file first, then rename (atomic write) + let temp_path = path.with_extension("tmp"); + fs::write(&temp_path, body).await + .context("Failed to write meta to temporary file")?; + + fs::rename(&temp_path, &path).await + .context("Failed to rename meta file")?; + + Ok(()) + } + + /// Load a meta object + pub async fn load_meta(&self, meta_type: MetaType, hash: &Hash) -> Result> { + let path = self.meta_path(meta_type, hash); + + if !path.exists() { + return Ok(None); + } + + let data = fs::read(&path).await + .context("Failed to read meta file")?; + + // Verify hash + let computed_hash = blake3::hash(&data); + if computed_hash.as_bytes() != hash { + return Err(anyhow::anyhow!("Stored meta object hash mismatch")); + } + + let meta_obj = MetaObj::deserialize(meta_type, Bytes::from(data)) + .context("Failed to deserialize meta object")?; + + Ok(Some(meta_obj)) + } + + /// Get snapshot storage path for a machine + fn snapshot_ref_path(&self, machine_id: &str, snapshot_id: &str) -> PathBuf { + self.data_dir + .join("sync") + .join("machines") + .join(machine_id) + .join("snapshots") + .join(format!("{}.ref", snapshot_id)) + } + + /// Store a snapshot reference + pub async fn store_snapshot_ref( + &self, + machine_id: &str, + snapshot_id: &str, + snapshot_hash: &Hash, + created_at: u64 + ) -> Result<()> { + let path = self.snapshot_ref_path(machine_id, snapshot_id); + + // Create parent directories + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).await + .context("Failed to create snapshot reference directory")?; + } + + // Create snapshot reference content + let content = format!("{}:{}", hex::encode(snapshot_hash), created_at); + + // Write to temporary file first, then rename (atomic write) + let temp_path = path.with_extension("tmp"); + fs::write(&temp_path, content).await + .context("Failed to write snapshot reference to temporary file")?; + + fs::rename(&temp_path, &path).await + .context("Failed to rename snapshot reference file")?; + + Ok(()) + } + + /// Load a snapshot reference + pub async fn load_snapshot_ref(&self, machine_id: &str, snapshot_id: &str) -> Result> { + let path = self.snapshot_ref_path(machine_id, snapshot_id); + + if !path.exists() { + return Ok(None); + } + + let content = fs::read_to_string(&path).await + .context("Failed to read snapshot reference file")?; + + let parts: Vec<&str> = content.trim().split(':').collect(); + if parts.len() != 2 { + return Err(anyhow::anyhow!("Invalid snapshot reference format")); + } + + let snapshot_hash: Hash = hex::decode(parts[0]) + .context("Failed to decode snapshot hash")? + .try_into() + .map_err(|_| anyhow::anyhow!("Invalid snapshot hash length"))?; + + let created_at: u64 = parts[1].parse() + .context("Failed to parse snapshot timestamp")?; + + Ok(Some((snapshot_hash, created_at))) + } + + /// List snapshots for a machine + pub async fn list_snapshots(&self, machine_id: &str) -> Result> { + let snapshots_dir = self.data_dir + .join("sync") + .join("machines") + .join(machine_id) + .join("snapshots"); + + if !snapshots_dir.exists() { + return Ok(Vec::new()); + } + + let mut entries = fs::read_dir(&snapshots_dir).await + .context("Failed to read snapshots directory")?; + + let mut snapshots = Vec::new(); + while let Some(entry) = entries.next_entry().await + .context("Failed to read snapshot entry")? { + + if let Some(file_name) = entry.file_name().to_str() { + if file_name.ends_with(".ref") { + let snapshot_id = file_name.trim_end_matches(".ref"); + snapshots.push(snapshot_id.to_string()); + } + } + } + + snapshots.sort(); + Ok(snapshots) + } + + /// Delete old snapshots, keeping only the latest N + pub async fn cleanup_snapshots(&self, machine_id: &str, keep_count: usize) -> Result<()> { + let mut snapshots = self.list_snapshots(machine_id).await?; + + if snapshots.len() <= keep_count { + return Ok(()); + } + + snapshots.sort(); + snapshots.reverse(); // Most recent first + + // Delete older snapshots + for snapshot_id in snapshots.iter().skip(keep_count) { + let path = self.snapshot_ref_path(machine_id, snapshot_id); + if path.exists() { + fs::remove_file(&path).await + .with_context(|| format!("Failed to delete snapshot {}", snapshot_id))?; + } + } + + Ok(()) + } +} + +/// Add hex crate to dependencies +use hex; + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[tokio::test] + async fn test_storage_init() { + let temp_dir = TempDir::new().unwrap(); + let storage = Storage::new(temp_dir.path()); + storage.init().await.unwrap(); + + assert!(temp_dir.path().join("sync/chunks").exists()); + assert!(temp_dir.path().join("sync/meta/files").exists()); + assert!(temp_dir.path().join("sync/machines").exists()); + } + + #[tokio::test] + async fn test_chunk_storage() { + let temp_dir = TempDir::new().unwrap(); + let storage = Storage::new(temp_dir.path()); + storage.init().await.unwrap(); + + let data = b"test chunk data"; + let hash = blake3::hash(data).into(); + + // Store chunk + storage.store_chunk(&hash, data).await.unwrap(); + assert!(storage.chunk_exists(&hash).await); + + // Load chunk + let loaded = storage.load_chunk(&hash).await.unwrap().unwrap(); + assert_eq!(loaded.as_ref(), data); + } + + #[tokio::test] + async fn test_snapshot_ref_storage() { + let temp_dir = TempDir::new().unwrap(); + let storage = Storage::new(temp_dir.path()); + storage.init().await.unwrap(); + + let machine_id = "test-machine"; + let snapshot_id = "snapshot-001"; + let snapshot_hash = [1u8; 32]; + let created_at = 1234567890; + + storage.store_snapshot_ref(machine_id, snapshot_id, &snapshot_hash, created_at) + .await.unwrap(); + + let loaded = storage.load_snapshot_ref(machine_id, snapshot_id) + .await.unwrap().unwrap(); + + assert_eq!(loaded.0, snapshot_hash); + assert_eq!(loaded.1, created_at); + } +} diff --git a/server/src/sync/validation.rs b/server/src/sync/validation.rs new file mode 100644 index 0000000..0c00606 --- /dev/null +++ b/server/src/sync/validation.rs @@ -0,0 +1,233 @@ +use anyhow::{Context, Result}; +use std::collections::{HashSet, VecDeque}; +use crate::sync::protocol::{Hash, MetaType}; +use crate::sync::storage::Storage; +use crate::sync::meta::{MetaObj, SnapshotObj, EntryType}; + +/// Validation result for snapshot commits +#[derive(Debug, Clone)] +pub struct ValidationResult { + pub is_valid: bool, + pub missing_chunks: Vec, + pub missing_metas: Vec<(MetaType, Hash)>, +} + +impl ValidationResult { + pub fn valid() -> Self { + Self { + is_valid: true, + missing_chunks: Vec::new(), + missing_metas: Vec::new(), + } + } + + pub fn invalid(missing_chunks: Vec, missing_metas: Vec<(MetaType, Hash)>) -> Self { + Self { + is_valid: false, + missing_chunks, + missing_metas, + } + } + + pub fn has_missing(&self) -> bool { + !self.missing_chunks.is_empty() || !self.missing_metas.is_empty() + } +} + +/// Validator for snapshot object graphs +#[derive(Clone)] +pub struct SnapshotValidator { + storage: Storage, +} + +impl SnapshotValidator { + pub fn new(storage: Storage) -> Self { + Self { storage } + } + + /// Validate a complete snapshot object graph using BFS only + pub async fn validate_snapshot(&self, snapshot_hash: &Hash, snapshot_body: &[u8]) -> Result { + // Use the BFS implementation + self.validate_snapshot_bfs(snapshot_hash, snapshot_body).await + } + + /// Validate a batch of meta objects (for incremental validation) + pub async fn validate_meta_batch(&self, metas: &[(MetaType, Hash)]) -> Result> { + let mut missing = Vec::new(); + + for &(meta_type, hash) in metas { + if !self.storage.meta_exists(meta_type, &hash).await { + missing.push((meta_type, hash)); + } + } + + Ok(missing) + } + + /// Validate a batch of chunks (for incremental validation) + pub async fn validate_chunk_batch(&self, chunks: &[Hash]) -> Result> { + let mut missing = Vec::new(); + + for &hash in chunks { + if !self.storage.chunk_exists(&hash).await { + missing.push(hash); + } + } + + Ok(missing) + } + + /// Perform a breadth-first validation (useful for large snapshots) + pub async fn validate_snapshot_bfs(&self, snapshot_hash: &Hash, snapshot_body: &[u8]) -> Result { + // Verify snapshot hash + let computed_hash = blake3::hash(snapshot_body); + if computed_hash.as_bytes() != snapshot_hash { + return Err(anyhow::anyhow!("Snapshot hash mismatch")); + } + + // Parse snapshot object + let snapshot_obj = SnapshotObj::deserialize(bytes::Bytes::from(snapshot_body.to_vec())) + .context("Failed to deserialize snapshot object")?; + + let mut missing_chunks = Vec::new(); + let mut missing_metas = Vec::new(); + let mut visited_metas = HashSet::new(); + let mut queue = VecDeque::new(); + + // Initialize queue with disk hashes + for disk_hash in &snapshot_obj.disk_hashes { + queue.push_back((MetaType::Disk, *disk_hash)); + } + + // BFS traversal + while let Some((meta_type, hash)) = queue.pop_front() { + let meta_key = (meta_type, hash); + + if visited_metas.contains(&meta_key) { + continue; + } + visited_metas.insert(meta_key); + + // Check if meta exists + if !self.storage.meta_exists(meta_type, &hash).await { + missing_metas.push((meta_type, hash)); + continue; // Skip loading if missing + } + + // Load and process meta object + if let Some(meta_obj) = self.storage.load_meta(meta_type, &hash).await + .context("Failed to load meta object")? { + + match meta_obj { + MetaObj::Disk(disk) => { + for partition_hash in &disk.partition_hashes { + queue.push_back((MetaType::Partition, *partition_hash)); + } + } + MetaObj::Partition(partition) => { + queue.push_back((MetaType::Dir, partition.root_dir_hash)); + } + MetaObj::Dir(dir) => { + for entry in &dir.entries { + match entry.entry_type { + EntryType::File | EntryType::Symlink => { + queue.push_back((MetaType::File, entry.target_meta_hash)); + } + EntryType::Dir => { + queue.push_back((MetaType::Dir, entry.target_meta_hash)); + } + } + } + } + MetaObj::File(file) => { + // Check chunk dependencies + for chunk_hash in &file.chunk_hashes { + if !self.storage.chunk_exists(chunk_hash).await { + missing_chunks.push(*chunk_hash); + } + } + } + MetaObj::Snapshot(_) => { + // Snapshots shouldn't be nested + return Err(anyhow::anyhow!("Unexpected nested snapshot object")); + } + } + } + } + + if missing_chunks.is_empty() && missing_metas.is_empty() { + Ok(ValidationResult::valid()) + } else { + Ok(ValidationResult::invalid(missing_chunks, missing_metas)) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + use crate::sync::meta::*; + + async fn setup_test_storage() -> Storage { + let temp_dir = TempDir::new().unwrap(); + let storage = Storage::new(temp_dir.path()); + storage.init().await.unwrap(); + storage + } + + #[tokio::test] + async fn test_validate_empty_snapshot() { + let storage = setup_test_storage().await; + let validator = SnapshotValidator::new(storage); + + let snapshot = SnapshotObj::new(1234567890, vec![]); + let snapshot_body = snapshot.serialize().unwrap(); + let snapshot_hash = snapshot.compute_hash().unwrap(); + + let result = validator.validate_snapshot(&snapshot_hash, &snapshot_body) + .await.unwrap(); + + assert!(result.is_valid); + assert!(result.missing_chunks.is_empty()); + assert!(result.missing_metas.is_empty()); + } + + #[tokio::test] + async fn test_validate_missing_disk() { + let storage = setup_test_storage().await; + let validator = SnapshotValidator::new(storage); + + let missing_disk_hash = [1u8; 32]; + let snapshot = SnapshotObj::new(1234567890, vec![missing_disk_hash]); + let snapshot_body = snapshot.serialize().unwrap(); + let snapshot_hash = snapshot.compute_hash().unwrap(); + + let result = validator.validate_snapshot(&snapshot_hash, &snapshot_body) + .await.unwrap(); + + assert!(!result.is_valid); + assert!(result.missing_chunks.is_empty()); + assert_eq!(result.missing_metas.len(), 1); + assert_eq!(result.missing_metas[0], (MetaType::Disk, missing_disk_hash)); + } + + #[tokio::test] + async fn test_validate_chunk_batch() { + let storage = setup_test_storage().await; + let validator = SnapshotValidator::new(storage); + + let chunk_data = b"test chunk"; + let chunk_hash = blake3::hash(chunk_data).into(); + let missing_hash = [1u8; 32]; + + // Store one chunk + storage.store_chunk(&chunk_hash, chunk_data).await.unwrap(); + + let chunks = vec![chunk_hash, missing_hash]; + let missing = validator.validate_chunk_batch(&chunks).await.unwrap(); + + assert_eq!(missing.len(), 1); + assert_eq!(missing[0], missing_hash); + } +}