Add sync test using AI
This commit is contained in:
93
server/Cargo.lock
generated
93
server/Cargo.lock
generated
@@ -38,6 +38,18 @@ version = "1.0.99"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100"
|
||||
|
||||
[[package]]
|
||||
name = "arrayref"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb"
|
||||
|
||||
[[package]]
|
||||
name = "arrayvec"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
||||
|
||||
[[package]]
|
||||
name = "atoi"
|
||||
version = "2.0.0"
|
||||
@@ -153,6 +165,15 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bincode"
|
||||
version = "1.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "2.9.4"
|
||||
@@ -162,6 +183,19 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blake3"
|
||||
version = "1.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0"
|
||||
dependencies = [
|
||||
"arrayref",
|
||||
"arrayvec",
|
||||
"cc",
|
||||
"cfg-if",
|
||||
"constant_time_eq",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "block-buffer"
|
||||
version = "0.10.4"
|
||||
@@ -254,6 +288,12 @@ version = "0.9.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8"
|
||||
|
||||
[[package]]
|
||||
name = "constant_time_eq"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6"
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation-sys"
|
||||
version = "0.8.7"
|
||||
@@ -364,6 +404,16 @@ version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "etcetera"
|
||||
version = "0.8.0"
|
||||
@@ -386,6 +436,12 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||
|
||||
[[package]]
|
||||
name = "find-msvc-tools"
|
||||
version = "0.1.1"
|
||||
@@ -903,6 +959,12 @@ dependencies = [
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
|
||||
|
||||
[[package]]
|
||||
name = "litemap"
|
||||
version = "0.8.0"
|
||||
@@ -1249,6 +1311,19 @@ version = "0.1.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace"
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.23.31"
|
||||
@@ -1362,11 +1437,16 @@ dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"bcrypt",
|
||||
"bincode",
|
||||
"blake3",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"hex",
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sqlx",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tower-http",
|
||||
"uuid",
|
||||
@@ -1712,6 +1792,19 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tempfile"
|
||||
version = "3.22.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "84fa4d11fadde498443cca10fd3ac23c951f0dc59e080e9f4b93d4df4e4eea53"
|
||||
dependencies = [
|
||||
"fastrand",
|
||||
"getrandom 0.3.3",
|
||||
"once_cell",
|
||||
"rustix",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "2.0.16"
|
||||
|
@@ -14,4 +14,11 @@ uuid = { version = "1.0", features = ["v4", "serde"] }
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
tower-http = { version = "0.6.6", features = ["cors", "fs"] }
|
||||
anyhow = "1.0"
|
||||
rand = "0.8"
|
||||
rand = "0.8"
|
||||
blake3 = "1.5"
|
||||
bytes = "1.0"
|
||||
bincode = "1.3"
|
||||
hex = "0.4"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.0"
|
@@ -1,6 +1,7 @@
|
||||
mod controllers;
|
||||
mod routes;
|
||||
mod utils;
|
||||
mod sync;
|
||||
|
||||
use anyhow::Result;
|
||||
use axum::{
|
||||
@@ -15,10 +16,14 @@ use tower_http::{
|
||||
services::{ServeDir, ServeFile},
|
||||
};
|
||||
use utils::init_database;
|
||||
use sync::{SyncServer, server::SyncServerConfig};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let pool = init_database().await?;
|
||||
|
||||
let sync_pool = pool.clone();
|
||||
|
||||
let api_routes = Router::new()
|
||||
.route("/setup/status", get(setup::get_setup_status))
|
||||
.route("/setup/init", post(setup::init_setup))
|
||||
@@ -51,8 +56,18 @@ async fn main() -> Result<()> {
|
||||
println!("Warning: dist directory not found at {}", dist_path);
|
||||
}
|
||||
|
||||
let sync_config = SyncServerConfig::default();
|
||||
let sync_server = SyncServer::new(sync_config.clone(), sync_pool);
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = sync_server.start().await {
|
||||
eprintln!("Sync server error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("0.0.0.0:8379").await?;
|
||||
println!("Server running on http://0.0.0.0:8379");
|
||||
println!("HTTP server running on http://0.0.0.0:8379");
|
||||
println!("Sync server running on {}:{}", sync_config.bind_address, sync_config.port);
|
||||
|
||||
axum::serve(listener, app)
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
|
582
server/src/sync/meta.rs
Normal file
582
server/src/sync/meta.rs
Normal file
@@ -0,0 +1,582 @@
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use std::io::{Error, ErrorKind, Result};
|
||||
use crate::sync::protocol::{Hash, MetaType};
|
||||
|
||||
/// Filesystem type codes
|
||||
#[repr(u32)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum FsType {
|
||||
Ext = 1,
|
||||
Ntfs = 2,
|
||||
Fat32 = 3,
|
||||
Unknown = 0,
|
||||
}
|
||||
|
||||
impl From<u32> for FsType {
|
||||
fn from(value: u32) -> Self {
|
||||
match value {
|
||||
1 => FsType::Ext,
|
||||
2 => FsType::Ntfs,
|
||||
3 => FsType::Fat32,
|
||||
_ => FsType::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Directory entry types
|
||||
#[repr(u8)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum EntryType {
|
||||
File = 0,
|
||||
Dir = 1,
|
||||
Symlink = 2,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for EntryType {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: u8) -> Result<Self> {
|
||||
match value {
|
||||
0 => Ok(EntryType::File),
|
||||
1 => Ok(EntryType::Dir),
|
||||
2 => Ok(EntryType::Symlink),
|
||||
_ => Err(Error::new(ErrorKind::InvalidData, "Unknown entry type")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// File metadata object
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileObj {
|
||||
pub version: u8,
|
||||
pub fs_type_code: FsType,
|
||||
pub size: u64,
|
||||
pub mode: u32,
|
||||
pub uid: u32,
|
||||
pub gid: u32,
|
||||
pub mtime_unixsec: u64,
|
||||
pub chunk_hashes: Vec<Hash>,
|
||||
}
|
||||
|
||||
impl FileObj {
|
||||
pub fn new(
|
||||
fs_type_code: FsType,
|
||||
size: u64,
|
||||
mode: u32,
|
||||
uid: u32,
|
||||
gid: u32,
|
||||
mtime_unixsec: u64,
|
||||
chunk_hashes: Vec<Hash>,
|
||||
) -> Self {
|
||||
Self {
|
||||
version: 1,
|
||||
fs_type_code,
|
||||
size,
|
||||
mode,
|
||||
uid,
|
||||
gid,
|
||||
mtime_unixsec,
|
||||
chunk_hashes,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize(&self) -> Result<Bytes> {
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
buf.put_u8(self.version);
|
||||
buf.put_u32_le(self.fs_type_code as u32);
|
||||
buf.put_u64_le(self.size);
|
||||
buf.put_u32_le(self.mode);
|
||||
buf.put_u32_le(self.uid);
|
||||
buf.put_u32_le(self.gid);
|
||||
buf.put_u64_le(self.mtime_unixsec);
|
||||
buf.put_u32_le(self.chunk_hashes.len() as u32);
|
||||
|
||||
for hash in &self.chunk_hashes {
|
||||
buf.put_slice(hash);
|
||||
}
|
||||
|
||||
Ok(buf.freeze())
|
||||
}
|
||||
|
||||
pub fn deserialize(mut data: Bytes) -> Result<Self> {
|
||||
if data.remaining() < 41 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "FileObj data too short"));
|
||||
}
|
||||
|
||||
let version = data.get_u8();
|
||||
if version != 1 {
|
||||
return Err(Error::new(ErrorKind::InvalidData, "Unsupported FileObj version"));
|
||||
}
|
||||
|
||||
let fs_type_code = FsType::from(data.get_u32_le());
|
||||
let size = data.get_u64_le();
|
||||
let mode = data.get_u32_le();
|
||||
let uid = data.get_u32_le();
|
||||
let gid = data.get_u32_le();
|
||||
let mtime_unixsec = data.get_u64_le();
|
||||
let chunk_count = data.get_u32_le() as usize;
|
||||
|
||||
if data.remaining() < chunk_count * 32 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "FileObj chunk hashes too short"));
|
||||
}
|
||||
|
||||
let mut chunk_hashes = Vec::with_capacity(chunk_count);
|
||||
for _ in 0..chunk_count {
|
||||
let mut hash = [0u8; 32];
|
||||
data.copy_to_slice(&mut hash);
|
||||
chunk_hashes.push(hash);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
version,
|
||||
fs_type_code,
|
||||
size,
|
||||
mode,
|
||||
uid,
|
||||
gid,
|
||||
mtime_unixsec,
|
||||
chunk_hashes,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn compute_hash(&self) -> Result<Hash> {
|
||||
let serialized = self.serialize()?;
|
||||
Ok(blake3::hash(&serialized).into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Directory entry
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DirEntry {
|
||||
pub entry_type: EntryType,
|
||||
pub name: String,
|
||||
pub target_meta_hash: Hash,
|
||||
}
|
||||
|
||||
/// Directory metadata object
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DirObj {
|
||||
pub version: u8,
|
||||
pub entries: Vec<DirEntry>,
|
||||
}
|
||||
|
||||
impl DirObj {
|
||||
pub fn new(entries: Vec<DirEntry>) -> Self {
|
||||
Self {
|
||||
version: 1,
|
||||
entries,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize(&self) -> Result<Bytes> {
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
buf.put_u8(self.version);
|
||||
buf.put_u32_le(self.entries.len() as u32);
|
||||
|
||||
for entry in &self.entries {
|
||||
buf.put_u8(entry.entry_type as u8);
|
||||
let name_bytes = entry.name.as_bytes();
|
||||
buf.put_u16_le(name_bytes.len() as u16);
|
||||
buf.put_slice(name_bytes);
|
||||
buf.put_slice(&entry.target_meta_hash);
|
||||
}
|
||||
|
||||
Ok(buf.freeze())
|
||||
}
|
||||
|
||||
pub fn deserialize(mut data: Bytes) -> Result<Self> {
|
||||
if data.remaining() < 5 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "DirObj data too short"));
|
||||
}
|
||||
|
||||
let version = data.get_u8();
|
||||
if version != 1 {
|
||||
return Err(Error::new(ErrorKind::InvalidData, "Unsupported DirObj version"));
|
||||
}
|
||||
|
||||
let entry_count = data.get_u32_le() as usize;
|
||||
let mut entries = Vec::with_capacity(entry_count);
|
||||
|
||||
for _ in 0..entry_count {
|
||||
if data.remaining() < 35 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "DirObj entry too short"));
|
||||
}
|
||||
|
||||
let entry_type = EntryType::try_from(data.get_u8())?;
|
||||
let name_len = data.get_u16_le() as usize;
|
||||
|
||||
if data.remaining() < name_len + 32 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "DirObj entry name/hash too short"));
|
||||
}
|
||||
|
||||
let name = String::from_utf8(data.copy_to_bytes(name_len).to_vec())
|
||||
.map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in entry name"))?;
|
||||
|
||||
let mut target_meta_hash = [0u8; 32];
|
||||
data.copy_to_slice(&mut target_meta_hash);
|
||||
|
||||
entries.push(DirEntry {
|
||||
entry_type,
|
||||
name,
|
||||
target_meta_hash,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
version,
|
||||
entries,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn compute_hash(&self) -> Result<Hash> {
|
||||
let serialized = self.serialize()?;
|
||||
Ok(blake3::hash(&serialized).into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Partition metadata object
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PartitionObj {
|
||||
pub version: u8,
|
||||
pub fs_type_code: FsType,
|
||||
pub root_dir_hash: Hash,
|
||||
pub start_lba: u64,
|
||||
pub end_lba: u64,
|
||||
pub type_guid: [u8; 16],
|
||||
}
|
||||
|
||||
impl PartitionObj {
|
||||
pub fn new(
|
||||
fs_type_code: FsType,
|
||||
root_dir_hash: Hash,
|
||||
start_lba: u64,
|
||||
end_lba: u64,
|
||||
type_guid: [u8; 16],
|
||||
) -> Self {
|
||||
Self {
|
||||
version: 1,
|
||||
fs_type_code,
|
||||
root_dir_hash,
|
||||
start_lba,
|
||||
end_lba,
|
||||
type_guid,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize(&self) -> Result<Bytes> {
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
buf.put_u8(self.version);
|
||||
buf.put_u32_le(self.fs_type_code as u32);
|
||||
buf.put_slice(&self.root_dir_hash);
|
||||
buf.put_u64_le(self.start_lba);
|
||||
buf.put_u64_le(self.end_lba);
|
||||
buf.put_slice(&self.type_guid);
|
||||
|
||||
Ok(buf.freeze())
|
||||
}
|
||||
|
||||
pub fn deserialize(mut data: Bytes) -> Result<Self> {
|
||||
if data.remaining() < 69 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "PartitionObj data too short"));
|
||||
}
|
||||
|
||||
let version = data.get_u8();
|
||||
if version != 1 {
|
||||
return Err(Error::new(ErrorKind::InvalidData, "Unsupported PartitionObj version"));
|
||||
}
|
||||
|
||||
let fs_type_code = FsType::from(data.get_u32_le());
|
||||
|
||||
let mut root_dir_hash = [0u8; 32];
|
||||
data.copy_to_slice(&mut root_dir_hash);
|
||||
|
||||
let start_lba = data.get_u64_le();
|
||||
let end_lba = data.get_u64_le();
|
||||
|
||||
let mut type_guid = [0u8; 16];
|
||||
data.copy_to_slice(&mut type_guid);
|
||||
|
||||
Ok(Self {
|
||||
version,
|
||||
fs_type_code,
|
||||
root_dir_hash,
|
||||
start_lba,
|
||||
end_lba,
|
||||
type_guid,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn compute_hash(&self) -> Result<Hash> {
|
||||
let serialized = self.serialize()?;
|
||||
Ok(blake3::hash(&serialized).into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Disk metadata object
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DiskObj {
|
||||
pub version: u8,
|
||||
pub partition_hashes: Vec<Hash>,
|
||||
pub disk_size_bytes: u64,
|
||||
pub serial: String,
|
||||
}
|
||||
|
||||
impl DiskObj {
|
||||
pub fn new(partition_hashes: Vec<Hash>, disk_size_bytes: u64, serial: String) -> Self {
|
||||
Self {
|
||||
version: 1,
|
||||
partition_hashes,
|
||||
disk_size_bytes,
|
||||
serial,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize(&self) -> Result<Bytes> {
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
buf.put_u8(self.version);
|
||||
buf.put_u32_le(self.partition_hashes.len() as u32);
|
||||
|
||||
for hash in &self.partition_hashes {
|
||||
buf.put_slice(hash);
|
||||
}
|
||||
|
||||
buf.put_u64_le(self.disk_size_bytes);
|
||||
|
||||
let serial_bytes = self.serial.as_bytes();
|
||||
buf.put_u16_le(serial_bytes.len() as u16);
|
||||
buf.put_slice(serial_bytes);
|
||||
|
||||
Ok(buf.freeze())
|
||||
}
|
||||
|
||||
pub fn deserialize(mut data: Bytes) -> Result<Self> {
|
||||
if data.remaining() < 15 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "DiskObj data too short"));
|
||||
}
|
||||
|
||||
let version = data.get_u8();
|
||||
if version != 1 {
|
||||
return Err(Error::new(ErrorKind::InvalidData, "Unsupported DiskObj version"));
|
||||
}
|
||||
|
||||
let partition_count = data.get_u32_le() as usize;
|
||||
|
||||
if data.remaining() < partition_count * 32 + 10 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "DiskObj partitions too short"));
|
||||
}
|
||||
|
||||
let mut partition_hashes = Vec::with_capacity(partition_count);
|
||||
for _ in 0..partition_count {
|
||||
let mut hash = [0u8; 32];
|
||||
data.copy_to_slice(&mut hash);
|
||||
partition_hashes.push(hash);
|
||||
}
|
||||
|
||||
let disk_size_bytes = data.get_u64_le();
|
||||
let serial_len = data.get_u16_le() as usize;
|
||||
|
||||
if data.remaining() < serial_len {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "DiskObj serial too short"));
|
||||
}
|
||||
|
||||
let serial = String::from_utf8(data.copy_to_bytes(serial_len).to_vec())
|
||||
.map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in serial"))?;
|
||||
|
||||
Ok(Self {
|
||||
version,
|
||||
partition_hashes,
|
||||
disk_size_bytes,
|
||||
serial,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn compute_hash(&self) -> Result<Hash> {
|
||||
let serialized = self.serialize()?;
|
||||
Ok(blake3::hash(&serialized).into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Snapshot metadata object
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SnapshotObj {
|
||||
pub version: u8,
|
||||
pub created_at_unixsec: u64,
|
||||
pub disk_hashes: Vec<Hash>,
|
||||
}
|
||||
|
||||
impl SnapshotObj {
|
||||
pub fn new(created_at_unixsec: u64, disk_hashes: Vec<Hash>) -> Self {
|
||||
Self {
|
||||
version: 1,
|
||||
created_at_unixsec,
|
||||
disk_hashes,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize(&self) -> Result<Bytes> {
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
buf.put_u8(self.version);
|
||||
buf.put_u64_le(self.created_at_unixsec);
|
||||
buf.put_u32_le(self.disk_hashes.len() as u32);
|
||||
|
||||
for hash in &self.disk_hashes {
|
||||
buf.put_slice(hash);
|
||||
}
|
||||
|
||||
Ok(buf.freeze())
|
||||
}
|
||||
|
||||
pub fn deserialize(mut data: Bytes) -> Result<Self> {
|
||||
if data.remaining() < 13 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "SnapshotObj data too short"));
|
||||
}
|
||||
|
||||
let version = data.get_u8();
|
||||
if version != 1 {
|
||||
return Err(Error::new(ErrorKind::InvalidData, "Unsupported SnapshotObj version"));
|
||||
}
|
||||
|
||||
let created_at_unixsec = data.get_u64_le();
|
||||
let disk_count = data.get_u32_le() as usize;
|
||||
|
||||
if data.remaining() < disk_count * 32 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "SnapshotObj disk hashes too short"));
|
||||
}
|
||||
|
||||
let mut disk_hashes = Vec::with_capacity(disk_count);
|
||||
for _ in 0..disk_count {
|
||||
let mut hash = [0u8; 32];
|
||||
data.copy_to_slice(&mut hash);
|
||||
disk_hashes.push(hash);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
version,
|
||||
created_at_unixsec,
|
||||
disk_hashes,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn compute_hash(&self) -> Result<Hash> {
|
||||
let serialized = self.serialize()?;
|
||||
Ok(blake3::hash(&serialized).into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Meta object wrapper
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum MetaObj {
|
||||
File(FileObj),
|
||||
Dir(DirObj),
|
||||
Partition(PartitionObj),
|
||||
Disk(DiskObj),
|
||||
Snapshot(SnapshotObj),
|
||||
}
|
||||
|
||||
impl MetaObj {
|
||||
pub fn meta_type(&self) -> MetaType {
|
||||
match self {
|
||||
MetaObj::File(_) => MetaType::File,
|
||||
MetaObj::Dir(_) => MetaType::Dir,
|
||||
MetaObj::Partition(_) => MetaType::Partition,
|
||||
MetaObj::Disk(_) => MetaType::Disk,
|
||||
MetaObj::Snapshot(_) => MetaType::Snapshot,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize(&self) -> Result<Bytes> {
|
||||
match self {
|
||||
MetaObj::File(obj) => obj.serialize(),
|
||||
MetaObj::Dir(obj) => obj.serialize(),
|
||||
MetaObj::Partition(obj) => obj.serialize(),
|
||||
MetaObj::Disk(obj) => obj.serialize(),
|
||||
MetaObj::Snapshot(obj) => obj.serialize(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize(meta_type: MetaType, data: Bytes) -> Result<Self> {
|
||||
match meta_type {
|
||||
MetaType::File => Ok(MetaObj::File(FileObj::deserialize(data)?)),
|
||||
MetaType::Dir => Ok(MetaObj::Dir(DirObj::deserialize(data)?)),
|
||||
MetaType::Partition => Ok(MetaObj::Partition(PartitionObj::deserialize(data)?)),
|
||||
MetaType::Disk => Ok(MetaObj::Disk(DiskObj::deserialize(data)?)),
|
||||
MetaType::Snapshot => Ok(MetaObj::Snapshot(SnapshotObj::deserialize(data)?)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn compute_hash(&self) -> Result<Hash> {
|
||||
match self {
|
||||
MetaObj::File(obj) => obj.compute_hash(),
|
||||
MetaObj::Dir(obj) => obj.compute_hash(),
|
||||
MetaObj::Partition(obj) => obj.compute_hash(),
|
||||
MetaObj::Disk(obj) => obj.compute_hash(),
|
||||
MetaObj::Snapshot(obj) => obj.compute_hash(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_file_obj_serialization() {
|
||||
let obj = FileObj::new(
|
||||
FsType::Ext,
|
||||
1024,
|
||||
0o644,
|
||||
1000,
|
||||
1000,
|
||||
1234567890,
|
||||
vec![[1; 32], [2; 32]],
|
||||
);
|
||||
|
||||
let serialized = obj.serialize().unwrap();
|
||||
let deserialized = FileObj::deserialize(serialized).unwrap();
|
||||
|
||||
assert_eq!(obj.fs_type_code, deserialized.fs_type_code);
|
||||
assert_eq!(obj.size, deserialized.size);
|
||||
assert_eq!(obj.chunk_hashes, deserialized.chunk_hashes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dir_obj_serialization() {
|
||||
let entries = vec![
|
||||
DirEntry {
|
||||
entry_type: EntryType::File,
|
||||
name: "test.txt".to_string(),
|
||||
target_meta_hash: [1; 32],
|
||||
},
|
||||
DirEntry {
|
||||
entry_type: EntryType::Dir,
|
||||
name: "subdir".to_string(),
|
||||
target_meta_hash: [2; 32],
|
||||
},
|
||||
];
|
||||
|
||||
let obj = DirObj::new(entries);
|
||||
let serialized = obj.serialize().unwrap();
|
||||
let deserialized = DirObj::deserialize(serialized).unwrap();
|
||||
|
||||
assert_eq!(obj.entries.len(), deserialized.entries.len());
|
||||
assert_eq!(obj.entries[0].name, deserialized.entries[0].name);
|
||||
assert_eq!(obj.entries[1].entry_type, deserialized.entries[1].entry_type);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hash_computation() {
|
||||
let obj = FileObj::new(FsType::Ext, 1024, 0o644, 1000, 1000, 1234567890, vec![]);
|
||||
let hash1 = obj.compute_hash().unwrap();
|
||||
let hash2 = obj.compute_hash().unwrap();
|
||||
assert_eq!(hash1, hash2);
|
||||
|
||||
let obj2 = FileObj::new(FsType::Ext, 1025, 0o644, 1000, 1000, 1234567890, vec![]);
|
||||
let hash3 = obj2.compute_hash().unwrap();
|
||||
assert_ne!(hash1, hash3);
|
||||
}
|
||||
}
|
8
server/src/sync/mod.rs
Normal file
8
server/src/sync/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
pub mod protocol;
|
||||
pub mod server;
|
||||
pub mod storage;
|
||||
pub mod session;
|
||||
pub mod meta;
|
||||
pub mod validation;
|
||||
|
||||
pub use server::SyncServer;
|
620
server/src/sync/protocol.rs
Normal file
620
server/src/sync/protocol.rs
Normal file
@@ -0,0 +1,620 @@
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use std::io::{Error, ErrorKind, Result};
|
||||
|
||||
/// Command codes for the sync protocol
|
||||
#[repr(u8)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Command {
|
||||
Hello = 0x01,
|
||||
HelloOk = 0x02,
|
||||
AuthUserPass = 0x10,
|
||||
AuthCode = 0x11,
|
||||
AuthOk = 0x12,
|
||||
AuthFail = 0x13,
|
||||
BatchCheckChunk = 0x20,
|
||||
CheckChunkResp = 0x21,
|
||||
SendChunk = 0x22,
|
||||
ChunkOk = 0x23,
|
||||
ChunkFail = 0x24,
|
||||
BatchCheckMeta = 0x30,
|
||||
CheckMetaResp = 0x31,
|
||||
SendMeta = 0x32,
|
||||
MetaOk = 0x33,
|
||||
MetaFail = 0x34,
|
||||
SendSnapshot = 0x40,
|
||||
SnapshotOk = 0x41,
|
||||
SnapshotFail = 0x42,
|
||||
Close = 0xFF,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for Command {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: u8) -> Result<Self> {
|
||||
match value {
|
||||
0x01 => Ok(Command::Hello),
|
||||
0x02 => Ok(Command::HelloOk),
|
||||
0x10 => Ok(Command::AuthUserPass),
|
||||
0x11 => Ok(Command::AuthCode),
|
||||
0x12 => Ok(Command::AuthOk),
|
||||
0x13 => Ok(Command::AuthFail),
|
||||
0x20 => Ok(Command::BatchCheckChunk),
|
||||
0x21 => Ok(Command::CheckChunkResp),
|
||||
0x22 => Ok(Command::SendChunk),
|
||||
0x23 => Ok(Command::ChunkOk),
|
||||
0x24 => Ok(Command::ChunkFail),
|
||||
0x30 => Ok(Command::BatchCheckMeta),
|
||||
0x31 => Ok(Command::CheckMetaResp),
|
||||
0x32 => Ok(Command::SendMeta),
|
||||
0x33 => Ok(Command::MetaOk),
|
||||
0x34 => Ok(Command::MetaFail),
|
||||
0x40 => Ok(Command::SendSnapshot),
|
||||
0x41 => Ok(Command::SnapshotOk),
|
||||
0x42 => Ok(Command::SnapshotFail),
|
||||
0xFF => Ok(Command::Close),
|
||||
_ => Err(Error::new(ErrorKind::InvalidData, "Unknown command code")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Message header structure (24 bytes fixed)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MessageHeader {
|
||||
pub cmd: Command,
|
||||
pub flags: u8,
|
||||
pub reserved: [u8; 2],
|
||||
pub session_id: [u8; 16],
|
||||
pub payload_len: u32,
|
||||
}
|
||||
|
||||
impl MessageHeader {
|
||||
pub const SIZE: usize = 24;
|
||||
|
||||
pub fn new(cmd: Command, session_id: [u8; 16], payload_len: u32) -> Self {
|
||||
Self {
|
||||
cmd,
|
||||
flags: 0,
|
||||
reserved: [0; 2],
|
||||
session_id,
|
||||
payload_len,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize(&self) -> [u8; Self::SIZE] {
|
||||
let mut buf = [0u8; Self::SIZE];
|
||||
buf[0] = self.cmd as u8;
|
||||
buf[1] = self.flags;
|
||||
buf[2..4].copy_from_slice(&self.reserved);
|
||||
buf[4..20].copy_from_slice(&self.session_id);
|
||||
buf[20..24].copy_from_slice(&self.payload_len.to_le_bytes());
|
||||
buf
|
||||
}
|
||||
|
||||
pub fn deserialize(buf: &[u8]) -> Result<Self> {
|
||||
if buf.len() < Self::SIZE {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "Header too short"));
|
||||
}
|
||||
|
||||
let cmd = Command::try_from(buf[0])?;
|
||||
let flags = buf[1];
|
||||
let reserved = [buf[2], buf[3]];
|
||||
let mut session_id = [0u8; 16];
|
||||
session_id.copy_from_slice(&buf[4..20]);
|
||||
let payload_len = u32::from_le_bytes([buf[20], buf[21], buf[22], buf[23]]);
|
||||
|
||||
Ok(Self {
|
||||
cmd,
|
||||
flags,
|
||||
reserved,
|
||||
session_id,
|
||||
payload_len,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A 32-byte BLAKE3 hash
|
||||
pub type Hash = [u8; 32];
|
||||
|
||||
/// Meta object types
|
||||
#[repr(u8)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum MetaType {
|
||||
File = 1,
|
||||
Dir = 2,
|
||||
Partition = 3,
|
||||
Disk = 4,
|
||||
Snapshot = 5,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for MetaType {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: u8) -> Result<Self> {
|
||||
match value {
|
||||
1 => Ok(MetaType::File),
|
||||
2 => Ok(MetaType::Dir),
|
||||
3 => Ok(MetaType::Partition),
|
||||
4 => Ok(MetaType::Disk),
|
||||
5 => Ok(MetaType::Snapshot),
|
||||
_ => Err(Error::new(ErrorKind::InvalidData, "Unknown meta type")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Protocol message types
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Message {
|
||||
Hello {
|
||||
client_type: u8,
|
||||
auth_type: u8,
|
||||
},
|
||||
HelloOk,
|
||||
AuthUserPass {
|
||||
username: String,
|
||||
password: String,
|
||||
machine_id: i64,
|
||||
},
|
||||
AuthCode {
|
||||
code: String,
|
||||
},
|
||||
AuthOk {
|
||||
session_id: [u8; 16],
|
||||
},
|
||||
AuthFail {
|
||||
reason: String,
|
||||
},
|
||||
BatchCheckChunk {
|
||||
hashes: Vec<Hash>,
|
||||
},
|
||||
CheckChunkResp {
|
||||
missing_hashes: Vec<Hash>,
|
||||
},
|
||||
SendChunk {
|
||||
hash: Hash,
|
||||
data: Bytes,
|
||||
},
|
||||
ChunkOk,
|
||||
ChunkFail {
|
||||
reason: String,
|
||||
},
|
||||
BatchCheckMeta {
|
||||
items: Vec<(MetaType, Hash)>,
|
||||
},
|
||||
CheckMetaResp {
|
||||
missing_items: Vec<(MetaType, Hash)>,
|
||||
},
|
||||
SendMeta {
|
||||
meta_type: MetaType,
|
||||
meta_hash: Hash,
|
||||
body: Bytes,
|
||||
},
|
||||
MetaOk,
|
||||
MetaFail {
|
||||
reason: String,
|
||||
},
|
||||
SendSnapshot {
|
||||
snapshot_hash: Hash,
|
||||
body: Bytes,
|
||||
},
|
||||
SnapshotOk {
|
||||
snapshot_id: String,
|
||||
},
|
||||
SnapshotFail {
|
||||
missing_chunks: Vec<Hash>,
|
||||
missing_metas: Vec<(MetaType, Hash)>,
|
||||
},
|
||||
Close,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
/// Serialize message payload to bytes
|
||||
pub fn serialize_payload(&self) -> Result<Bytes> {
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
match self {
|
||||
Message::Hello { client_type, auth_type } => {
|
||||
buf.put_u8(*client_type);
|
||||
buf.put_u8(*auth_type);
|
||||
}
|
||||
Message::HelloOk => {
|
||||
// No payload
|
||||
}
|
||||
Message::AuthUserPass { username, password, machine_id } => {
|
||||
let username_bytes = username.as_bytes();
|
||||
let password_bytes = password.as_bytes();
|
||||
buf.put_u16_le(username_bytes.len() as u16);
|
||||
buf.put_slice(username_bytes);
|
||||
buf.put_u16_le(password_bytes.len() as u16);
|
||||
buf.put_slice(password_bytes);
|
||||
buf.put_i64_le(*machine_id);
|
||||
}
|
||||
Message::AuthCode { code } => {
|
||||
let code_bytes = code.as_bytes();
|
||||
buf.put_u16_le(code_bytes.len() as u16);
|
||||
buf.put_slice(code_bytes);
|
||||
}
|
||||
Message::AuthOk { session_id } => {
|
||||
buf.put_slice(session_id);
|
||||
}
|
||||
Message::AuthFail { reason } => {
|
||||
let reason_bytes = reason.as_bytes();
|
||||
buf.put_u16_le(reason_bytes.len() as u16);
|
||||
buf.put_slice(reason_bytes);
|
||||
}
|
||||
Message::BatchCheckChunk { hashes } => {
|
||||
buf.put_u32_le(hashes.len() as u32);
|
||||
for hash in hashes {
|
||||
buf.put_slice(hash);
|
||||
}
|
||||
}
|
||||
Message::CheckChunkResp { missing_hashes } => {
|
||||
buf.put_u32_le(missing_hashes.len() as u32);
|
||||
for hash in missing_hashes {
|
||||
buf.put_slice(hash);
|
||||
}
|
||||
}
|
||||
Message::SendChunk { hash, data } => {
|
||||
buf.put_slice(hash);
|
||||
buf.put_u32_le(data.len() as u32);
|
||||
buf.put_slice(data);
|
||||
}
|
||||
Message::ChunkOk => {
|
||||
// No payload
|
||||
}
|
||||
Message::ChunkFail { reason } => {
|
||||
let reason_bytes = reason.as_bytes();
|
||||
buf.put_u16_le(reason_bytes.len() as u16);
|
||||
buf.put_slice(reason_bytes);
|
||||
}
|
||||
Message::BatchCheckMeta { items } => {
|
||||
buf.put_u32_le(items.len() as u32);
|
||||
for (meta_type, hash) in items {
|
||||
buf.put_u8(*meta_type as u8);
|
||||
buf.put_slice(hash);
|
||||
}
|
||||
}
|
||||
Message::CheckMetaResp { missing_items } => {
|
||||
buf.put_u32_le(missing_items.len() as u32);
|
||||
for (meta_type, hash) in missing_items {
|
||||
buf.put_u8(*meta_type as u8);
|
||||
buf.put_slice(hash);
|
||||
}
|
||||
}
|
||||
Message::SendMeta { meta_type, meta_hash, body } => {
|
||||
buf.put_u8(*meta_type as u8);
|
||||
buf.put_slice(meta_hash);
|
||||
buf.put_u32_le(body.len() as u32);
|
||||
buf.put_slice(body);
|
||||
}
|
||||
Message::MetaOk => {
|
||||
// No payload
|
||||
}
|
||||
Message::MetaFail { reason } => {
|
||||
let reason_bytes = reason.as_bytes();
|
||||
buf.put_u16_le(reason_bytes.len() as u16);
|
||||
buf.put_slice(reason_bytes);
|
||||
}
|
||||
Message::SendSnapshot { snapshot_hash, body } => {
|
||||
buf.put_slice(snapshot_hash);
|
||||
buf.put_u32_le(body.len() as u32);
|
||||
buf.put_slice(body);
|
||||
}
|
||||
Message::SnapshotOk { snapshot_id } => {
|
||||
let id_bytes = snapshot_id.as_bytes();
|
||||
buf.put_u16_le(id_bytes.len() as u16);
|
||||
buf.put_slice(id_bytes);
|
||||
}
|
||||
Message::SnapshotFail { missing_chunks, missing_metas } => {
|
||||
buf.put_u32_le(missing_chunks.len() as u32);
|
||||
for hash in missing_chunks {
|
||||
buf.put_slice(hash);
|
||||
}
|
||||
buf.put_u32_le(missing_metas.len() as u32);
|
||||
for (meta_type, hash) in missing_metas {
|
||||
buf.put_u8(*meta_type as u8);
|
||||
buf.put_slice(hash);
|
||||
}
|
||||
}
|
||||
Message::Close => {
|
||||
// No payload
|
||||
}
|
||||
}
|
||||
|
||||
Ok(buf.freeze())
|
||||
}
|
||||
|
||||
/// Deserialize message payload from bytes
|
||||
pub fn deserialize_payload(cmd: Command, mut payload: Bytes) -> Result<Self> {
|
||||
match cmd {
|
||||
Command::Hello => {
|
||||
if payload.remaining() < 2 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "Hello payload too short"));
|
||||
}
|
||||
let client_type = payload.get_u8();
|
||||
let auth_type = payload.get_u8();
|
||||
Ok(Message::Hello { client_type, auth_type })
|
||||
}
|
||||
Command::HelloOk => Ok(Message::HelloOk),
|
||||
Command::AuthUserPass => {
|
||||
if payload.remaining() < 12 { // 4 bytes for lengths + at least 8 bytes for machine_id
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "AuthUserPass payload too short"));
|
||||
}
|
||||
let username_len = payload.get_u16_le() as usize;
|
||||
if payload.remaining() < username_len + 10 { // 2 bytes for password len + 8 bytes for machine_id
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "AuthUserPass username too short"));
|
||||
}
|
||||
let username = String::from_utf8(payload.copy_to_bytes(username_len).to_vec())
|
||||
.map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in username"))?;
|
||||
let password_len = payload.get_u16_le() as usize;
|
||||
if payload.remaining() < password_len + 8 { // 8 bytes for machine_id
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "AuthUserPass password too short"));
|
||||
}
|
||||
let password = String::from_utf8(payload.copy_to_bytes(password_len).to_vec())
|
||||
.map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in password"))?;
|
||||
let machine_id = payload.get_i64_le();
|
||||
Ok(Message::AuthUserPass { username, password, machine_id })
|
||||
}
|
||||
Command::AuthCode => {
|
||||
if payload.remaining() < 2 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "AuthCode payload too short"));
|
||||
}
|
||||
let code_len = payload.get_u16_le() as usize;
|
||||
if payload.remaining() < code_len {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "AuthCode code too short"));
|
||||
}
|
||||
let code = String::from_utf8(payload.copy_to_bytes(code_len).to_vec())
|
||||
.map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in code"))?;
|
||||
Ok(Message::AuthCode { code })
|
||||
}
|
||||
Command::AuthOk => {
|
||||
if payload.remaining() < 16 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "AuthOk payload too short"));
|
||||
}
|
||||
let mut session_id = [0u8; 16];
|
||||
payload.copy_to_slice(&mut session_id);
|
||||
Ok(Message::AuthOk { session_id })
|
||||
}
|
||||
Command::AuthFail => {
|
||||
if payload.remaining() < 2 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "AuthFail payload too short"));
|
||||
}
|
||||
let reason_len = payload.get_u16_le() as usize;
|
||||
if payload.remaining() < reason_len {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "AuthFail reason too short"));
|
||||
}
|
||||
let reason = String::from_utf8(payload.copy_to_bytes(reason_len).to_vec())
|
||||
.map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in reason"))?;
|
||||
Ok(Message::AuthFail { reason })
|
||||
}
|
||||
Command::BatchCheckChunk => {
|
||||
if payload.remaining() < 4 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "BatchCheckChunk payload too short"));
|
||||
}
|
||||
let count = payload.get_u32_le() as usize;
|
||||
if payload.remaining() < count * 32 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "BatchCheckChunk hashes too short"));
|
||||
}
|
||||
let mut hashes = Vec::with_capacity(count);
|
||||
for _ in 0..count {
|
||||
let mut hash = [0u8; 32];
|
||||
payload.copy_to_slice(&mut hash);
|
||||
hashes.push(hash);
|
||||
}
|
||||
Ok(Message::BatchCheckChunk { hashes })
|
||||
}
|
||||
Command::CheckChunkResp => {
|
||||
if payload.remaining() < 4 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "CheckChunkResp payload too short"));
|
||||
}
|
||||
let count = payload.get_u32_le() as usize;
|
||||
if payload.remaining() < count * 32 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "CheckChunkResp hashes too short"));
|
||||
}
|
||||
let mut missing_hashes = Vec::with_capacity(count);
|
||||
for _ in 0..count {
|
||||
let mut hash = [0u8; 32];
|
||||
payload.copy_to_slice(&mut hash);
|
||||
missing_hashes.push(hash);
|
||||
}
|
||||
Ok(Message::CheckChunkResp { missing_hashes })
|
||||
}
|
||||
Command::SendChunk => {
|
||||
if payload.remaining() < 36 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "SendChunk payload too short"));
|
||||
}
|
||||
let mut hash = [0u8; 32];
|
||||
payload.copy_to_slice(&mut hash);
|
||||
let size = payload.get_u32_le() as usize;
|
||||
if payload.remaining() < size {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "SendChunk data too short"));
|
||||
}
|
||||
let data = payload.copy_to_bytes(size);
|
||||
Ok(Message::SendChunk { hash, data })
|
||||
}
|
||||
Command::ChunkOk => Ok(Message::ChunkOk),
|
||||
Command::ChunkFail => {
|
||||
if payload.remaining() < 2 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "ChunkFail payload too short"));
|
||||
}
|
||||
let reason_len = payload.get_u16_le() as usize;
|
||||
if payload.remaining() < reason_len {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "ChunkFail reason too short"));
|
||||
}
|
||||
let reason = String::from_utf8(payload.copy_to_bytes(reason_len).to_vec())
|
||||
.map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in reason"))?;
|
||||
Ok(Message::ChunkFail { reason })
|
||||
}
|
||||
Command::BatchCheckMeta => {
|
||||
if payload.remaining() < 4 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "BatchCheckMeta payload too short"));
|
||||
}
|
||||
let count = payload.get_u32_le() as usize;
|
||||
if payload.remaining() < count * 33 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "BatchCheckMeta items too short"));
|
||||
}
|
||||
let mut items = Vec::with_capacity(count);
|
||||
for _ in 0..count {
|
||||
let meta_type = MetaType::try_from(payload.get_u8())?;
|
||||
let mut hash = [0u8; 32];
|
||||
payload.copy_to_slice(&mut hash);
|
||||
items.push((meta_type, hash));
|
||||
}
|
||||
Ok(Message::BatchCheckMeta { items })
|
||||
}
|
||||
Command::CheckMetaResp => {
|
||||
if payload.remaining() < 4 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "CheckMetaResp payload too short"));
|
||||
}
|
||||
let count = payload.get_u32_le() as usize;
|
||||
if payload.remaining() < count * 33 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "CheckMetaResp items too short"));
|
||||
}
|
||||
let mut missing_items = Vec::with_capacity(count);
|
||||
for _ in 0..count {
|
||||
let meta_type = MetaType::try_from(payload.get_u8())?;
|
||||
let mut hash = [0u8; 32];
|
||||
payload.copy_to_slice(&mut hash);
|
||||
missing_items.push((meta_type, hash));
|
||||
}
|
||||
Ok(Message::CheckMetaResp { missing_items })
|
||||
}
|
||||
Command::SendMeta => {
|
||||
if payload.remaining() < 37 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "SendMeta payload too short"));
|
||||
}
|
||||
let meta_type = MetaType::try_from(payload.get_u8())?;
|
||||
let mut meta_hash = [0u8; 32];
|
||||
payload.copy_to_slice(&mut meta_hash);
|
||||
let body_len = payload.get_u32_le() as usize;
|
||||
if payload.remaining() < body_len {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "SendMeta body too short"));
|
||||
}
|
||||
let body = payload.copy_to_bytes(body_len);
|
||||
Ok(Message::SendMeta { meta_type, meta_hash, body })
|
||||
}
|
||||
Command::MetaOk => Ok(Message::MetaOk),
|
||||
Command::MetaFail => {
|
||||
if payload.remaining() < 2 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "MetaFail payload too short"));
|
||||
}
|
||||
let reason_len = payload.get_u16_le() as usize;
|
||||
if payload.remaining() < reason_len {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "MetaFail reason too short"));
|
||||
}
|
||||
let reason = String::from_utf8(payload.copy_to_bytes(reason_len).to_vec())
|
||||
.map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in reason"))?;
|
||||
Ok(Message::MetaFail { reason })
|
||||
}
|
||||
Command::SendSnapshot => {
|
||||
if payload.remaining() < 36 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "SendSnapshot payload too short"));
|
||||
}
|
||||
let mut snapshot_hash = [0u8; 32];
|
||||
payload.copy_to_slice(&mut snapshot_hash);
|
||||
let body_len = payload.get_u32_le() as usize;
|
||||
if payload.remaining() < body_len {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "SendSnapshot body too short"));
|
||||
}
|
||||
let body = payload.copy_to_bytes(body_len);
|
||||
Ok(Message::SendSnapshot { snapshot_hash, body })
|
||||
}
|
||||
Command::SnapshotOk => {
|
||||
if payload.remaining() < 2 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "SnapshotOk payload too short"));
|
||||
}
|
||||
let id_len = payload.get_u16_le() as usize;
|
||||
if payload.remaining() < id_len {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "SnapshotOk id too short"));
|
||||
}
|
||||
let snapshot_id = String::from_utf8(payload.copy_to_bytes(id_len).to_vec())
|
||||
.map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8 in snapshot_id"))?;
|
||||
Ok(Message::SnapshotOk { snapshot_id })
|
||||
}
|
||||
Command::SnapshotFail => {
|
||||
if payload.remaining() < 8 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "SnapshotFail payload too short"));
|
||||
}
|
||||
let chunk_count = payload.get_u32_le() as usize;
|
||||
if payload.remaining() < chunk_count * 32 + 4 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "SnapshotFail chunks too short"));
|
||||
}
|
||||
let mut missing_chunks = Vec::with_capacity(chunk_count);
|
||||
for _ in 0..chunk_count {
|
||||
let mut hash = [0u8; 32];
|
||||
payload.copy_to_slice(&mut hash);
|
||||
missing_chunks.push(hash);
|
||||
}
|
||||
let meta_count = payload.get_u32_le() as usize;
|
||||
if payload.remaining() < meta_count * 33 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "SnapshotFail metas too short"));
|
||||
}
|
||||
let mut missing_metas = Vec::with_capacity(meta_count);
|
||||
for _ in 0..meta_count {
|
||||
let meta_type = MetaType::try_from(payload.get_u8())?;
|
||||
let mut hash = [0u8; 32];
|
||||
payload.copy_to_slice(&mut hash);
|
||||
missing_metas.push((meta_type, hash));
|
||||
}
|
||||
Ok(Message::SnapshotFail { missing_chunks, missing_metas })
|
||||
}
|
||||
Command::Close => Ok(Message::Close),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the command for this message
|
||||
pub fn command(&self) -> Command {
|
||||
match self {
|
||||
Message::Hello { .. } => Command::Hello,
|
||||
Message::HelloOk => Command::HelloOk,
|
||||
Message::AuthUserPass { .. } => Command::AuthUserPass,
|
||||
Message::AuthCode { .. } => Command::AuthCode,
|
||||
Message::AuthOk { .. } => Command::AuthOk,
|
||||
Message::AuthFail { .. } => Command::AuthFail,
|
||||
Message::BatchCheckChunk { .. } => Command::BatchCheckChunk,
|
||||
Message::CheckChunkResp { .. } => Command::CheckChunkResp,
|
||||
Message::SendChunk { .. } => Command::SendChunk,
|
||||
Message::ChunkOk => Command::ChunkOk,
|
||||
Message::ChunkFail { .. } => Command::ChunkFail,
|
||||
Message::BatchCheckMeta { .. } => Command::BatchCheckMeta,
|
||||
Message::CheckMetaResp { .. } => Command::CheckMetaResp,
|
||||
Message::SendMeta { .. } => Command::SendMeta,
|
||||
Message::MetaOk => Command::MetaOk,
|
||||
Message::MetaFail { .. } => Command::MetaFail,
|
||||
Message::SendSnapshot { .. } => Command::SendSnapshot,
|
||||
Message::SnapshotOk { .. } => Command::SnapshotOk,
|
||||
Message::SnapshotFail { .. } => Command::SnapshotFail,
|
||||
Message::Close => Command::Close,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_header_serialization() {
|
||||
let header = MessageHeader::new(Command::Hello, [1; 16], 42);
|
||||
let serialized = header.serialize();
|
||||
let deserialized = MessageHeader::deserialize(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.cmd, Command::Hello);
|
||||
assert_eq!(deserialized.session_id, [1; 16]);
|
||||
assert_eq!(deserialized.payload_len, 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hello_message() {
|
||||
let msg = Message::Hello { client_type: 1, auth_type: 2 };
|
||||
let payload = msg.serialize_payload().unwrap();
|
||||
let deserialized = Message::deserialize_payload(Command::Hello, payload).unwrap();
|
||||
|
||||
match deserialized {
|
||||
Message::Hello { client_type, auth_type } => {
|
||||
assert_eq!(client_type, 1);
|
||||
assert_eq!(auth_type, 2);
|
||||
}
|
||||
_ => panic!("Wrong message type"),
|
||||
}
|
||||
}
|
||||
}
|
463
server/src/sync/server.rs
Normal file
463
server/src/sync/server.rs
Normal file
@@ -0,0 +1,463 @@
|
||||
use anyhow::{Context, Result};
|
||||
use bytes::Bytes;
|
||||
use sqlx::SqlitePool;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::sync::protocol::{Command, Message, MessageHeader, MetaType};
|
||||
use crate::sync::session::{SessionManager, session_cleanup_task};
|
||||
use crate::sync::storage::Storage;
|
||||
use crate::sync::validation::SnapshotValidator;
|
||||
|
||||
/// Configuration for the sync server
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SyncServerConfig {
|
||||
pub bind_address: String,
|
||||
pub port: u16,
|
||||
pub data_dir: String,
|
||||
pub max_connections: usize,
|
||||
pub chunk_size_limit: usize,
|
||||
pub meta_size_limit: usize,
|
||||
pub batch_limit: usize,
|
||||
}
|
||||
|
||||
impl Default for SyncServerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bind_address: "0.0.0.0".to_string(),
|
||||
port: 8380,
|
||||
data_dir: "./data".to_string(),
|
||||
max_connections: 100,
|
||||
chunk_size_limit: 4 * 1024 * 1024, // 4 MiB
|
||||
meta_size_limit: 1024 * 1024, // 1 MiB
|
||||
batch_limit: 1000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Main sync server
|
||||
pub struct SyncServer {
|
||||
config: SyncServerConfig,
|
||||
storage: Storage,
|
||||
session_manager: Arc<SessionManager>,
|
||||
validator: SnapshotValidator,
|
||||
}
|
||||
|
||||
impl SyncServer {
|
||||
pub fn new(config: SyncServerConfig, db_pool: SqlitePool) -> Self {
|
||||
let storage = Storage::new(&config.data_dir);
|
||||
let session_manager = Arc::new(SessionManager::new(db_pool));
|
||||
let validator = SnapshotValidator::new(storage.clone());
|
||||
|
||||
Self {
|
||||
config,
|
||||
storage,
|
||||
session_manager,
|
||||
validator,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the sync server
|
||||
pub async fn start(&self) -> Result<()> {
|
||||
// Initialize storage
|
||||
self.storage.init().await
|
||||
.context("Failed to initialize storage")?;
|
||||
|
||||
let bind_addr = format!("{}:{}", self.config.bind_address, self.config.port);
|
||||
let listener = TcpListener::bind(&bind_addr).await
|
||||
.with_context(|| format!("Failed to bind to {}", bind_addr))?;
|
||||
|
||||
println!("Sync server listening on {}", bind_addr);
|
||||
|
||||
// Start session cleanup task
|
||||
let session_manager_clone = Arc::clone(&self.session_manager);
|
||||
tokio::spawn(async move {
|
||||
session_cleanup_task(session_manager_clone).await;
|
||||
});
|
||||
|
||||
// Accept connections
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((stream, addr)) => {
|
||||
println!("New sync connection from {}", addr);
|
||||
|
||||
let handler = ConnectionHandler::new(
|
||||
stream,
|
||||
self.storage.clone(),
|
||||
Arc::clone(&self.session_manager),
|
||||
self.validator.clone(),
|
||||
self.config.clone(),
|
||||
);
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handler.handle().await {
|
||||
eprintln!("Connection error from {}: {}", addr, e);
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to accept connection: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Connection handler for individual sync clients
|
||||
struct ConnectionHandler {
|
||||
stream: TcpStream,
|
||||
storage: Storage,
|
||||
session_manager: Arc<SessionManager>,
|
||||
validator: SnapshotValidator,
|
||||
config: SyncServerConfig,
|
||||
session_id: Option<[u8; 16]>,
|
||||
machine_id: Option<String>,
|
||||
}
|
||||
|
||||
impl ConnectionHandler {
|
||||
fn new(
|
||||
stream: TcpStream,
|
||||
storage: Storage,
|
||||
session_manager: Arc<SessionManager>,
|
||||
validator: SnapshotValidator,
|
||||
config: SyncServerConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
storage,
|
||||
session_manager,
|
||||
validator,
|
||||
config,
|
||||
session_id: None,
|
||||
machine_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle the connection
|
||||
async fn handle(mut self) -> Result<()> {
|
||||
loop {
|
||||
// Read message header
|
||||
let header = self.read_header().await?;
|
||||
|
||||
// Read payload
|
||||
let payload = if header.payload_len > 0 {
|
||||
self.read_payload(header.payload_len).await?
|
||||
} else {
|
||||
Bytes::new()
|
||||
};
|
||||
|
||||
// Parse message
|
||||
let message = Message::deserialize_payload(header.cmd, payload)
|
||||
.context("Failed to deserialize message")?;
|
||||
|
||||
// Handle message
|
||||
let response = self.handle_message(message).await?;
|
||||
|
||||
// Send response
|
||||
if let Some(response_msg) = response {
|
||||
self.send_message(response_msg).await?;
|
||||
}
|
||||
|
||||
// Close connection if requested
|
||||
if header.cmd == Command::Close {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up session
|
||||
if let Some(session_id) = self.session_id {
|
||||
self.session_manager.remove_session(&session_id).await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read message header
|
||||
async fn read_header(&mut self) -> Result<MessageHeader> {
|
||||
let mut header_buf = [0u8; MessageHeader::SIZE];
|
||||
self.stream.read_exact(&mut header_buf).await
|
||||
.context("Failed to read message header")?;
|
||||
|
||||
MessageHeader::deserialize(&header_buf)
|
||||
.context("Failed to parse message header")
|
||||
}
|
||||
|
||||
/// Read message payload
|
||||
async fn read_payload(&mut self, len: u32) -> Result<Bytes> {
|
||||
if len as usize > self.config.meta_size_limit {
|
||||
return Err(anyhow::anyhow!("Payload too large: {} bytes", len));
|
||||
}
|
||||
|
||||
let mut payload_buf = vec![0u8; len as usize];
|
||||
self.stream.read_exact(&mut payload_buf).await
|
||||
.context("Failed to read message payload")?;
|
||||
|
||||
Ok(Bytes::from(payload_buf))
|
||||
}
|
||||
|
||||
/// Send a message
|
||||
async fn send_message(&mut self, message: Message) -> Result<()> {
|
||||
let session_id = self.session_id.unwrap_or([0u8; 16]);
|
||||
let payload = message.serialize_payload()?;
|
||||
|
||||
let header = MessageHeader::new(message.command(), session_id, payload.len() as u32);
|
||||
let header_bytes = header.serialize();
|
||||
|
||||
self.stream.write_all(&header_bytes).await
|
||||
.context("Failed to write message header")?;
|
||||
|
||||
if !payload.is_empty() {
|
||||
self.stream.write_all(&payload).await
|
||||
.context("Failed to write message payload")?;
|
||||
}
|
||||
|
||||
self.stream.flush().await
|
||||
.context("Failed to flush stream")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle a received message
|
||||
async fn handle_message(&mut self, message: Message) -> Result<Option<Message>> {
|
||||
match message {
|
||||
Message::Hello { client_type: _, auth_type: _ } => {
|
||||
Ok(Some(Message::HelloOk))
|
||||
}
|
||||
|
||||
Message::AuthUserPass { username, password, machine_id } => {
|
||||
match self.session_manager.authenticate_userpass(&username, &password, machine_id).await {
|
||||
Ok(session) => {
|
||||
self.session_id = Some(session.session_id);
|
||||
self.machine_id = Some(session.machine_id);
|
||||
Ok(Some(Message::AuthOk { session_id: session.session_id }))
|
||||
}
|
||||
Err(e) => {
|
||||
Ok(Some(Message::AuthFail { reason: e.to_string() }))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Message::AuthCode { code } => {
|
||||
match self.session_manager.authenticate_code(&code).await {
|
||||
Ok(session) => {
|
||||
self.session_id = Some(session.session_id);
|
||||
self.machine_id = Some(session.machine_id);
|
||||
Ok(Some(Message::AuthOk { session_id: session.session_id }))
|
||||
}
|
||||
Err(e) => {
|
||||
Ok(Some(Message::AuthFail { reason: e.to_string() }))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Message::BatchCheckChunk { hashes } => {
|
||||
self.require_auth()?;
|
||||
|
||||
if hashes.len() > self.config.batch_limit {
|
||||
return Err(anyhow::anyhow!("Batch size exceeds limit: {}", hashes.len()));
|
||||
}
|
||||
|
||||
let missing_hashes = self.validator.validate_chunk_batch(&hashes).await?;
|
||||
Ok(Some(Message::CheckChunkResp { missing_hashes }))
|
||||
}
|
||||
|
||||
Message::SendChunk { hash, data } => {
|
||||
self.require_auth()?;
|
||||
|
||||
if data.len() > self.config.chunk_size_limit {
|
||||
return Ok(Some(Message::ChunkFail {
|
||||
reason: format!("Chunk too large: {} bytes", data.len())
|
||||
}));
|
||||
}
|
||||
|
||||
match self.storage.store_chunk(&hash, &data).await {
|
||||
Ok(()) => Ok(Some(Message::ChunkOk)),
|
||||
Err(e) => Ok(Some(Message::ChunkFail { reason: e.to_string() })),
|
||||
}
|
||||
}
|
||||
|
||||
Message::BatchCheckMeta { items } => {
|
||||
self.require_auth()?;
|
||||
|
||||
if items.len() > self.config.batch_limit {
|
||||
return Err(anyhow::anyhow!("Batch size exceeds limit: {}", items.len()));
|
||||
}
|
||||
|
||||
let missing_items = self.validator.validate_meta_batch(&items).await?;
|
||||
Ok(Some(Message::CheckMetaResp { missing_items }))
|
||||
}
|
||||
|
||||
Message::SendMeta { meta_type, meta_hash, body } => {
|
||||
self.require_auth()?;
|
||||
|
||||
if body.len() > self.config.meta_size_limit {
|
||||
return Ok(Some(Message::MetaFail {
|
||||
reason: format!("Meta object too large: {} bytes", body.len())
|
||||
}));
|
||||
}
|
||||
|
||||
match self.storage.store_meta(meta_type, &meta_hash, &body).await {
|
||||
Ok(()) => Ok(Some(Message::MetaOk)),
|
||||
Err(e) => Ok(Some(Message::MetaFail { reason: e.to_string() })),
|
||||
}
|
||||
}
|
||||
|
||||
Message::SendSnapshot { snapshot_hash, body } => {
|
||||
self.require_auth()?;
|
||||
|
||||
if body.len() > self.config.meta_size_limit {
|
||||
return Ok(Some(Message::SnapshotFail {
|
||||
missing_chunks: vec![],
|
||||
missing_metas: vec![],
|
||||
}));
|
||||
}
|
||||
|
||||
// Validate snapshot
|
||||
match self.validator.validate_snapshot(&snapshot_hash, &body).await {
|
||||
Ok(validation_result) => {
|
||||
if validation_result.is_valid {
|
||||
// Store snapshot meta
|
||||
if let Err(_e) = self.storage.store_meta(MetaType::Snapshot, &snapshot_hash, &body).await {
|
||||
return Ok(Some(Message::SnapshotFail {
|
||||
missing_chunks: vec![],
|
||||
missing_metas: vec![],
|
||||
}));
|
||||
}
|
||||
|
||||
// Create snapshot reference
|
||||
let snapshot_id = Uuid::new_v4().to_string();
|
||||
let machine_id = self.machine_id.as_ref().unwrap();
|
||||
let created_at = chrono::Utc::now().timestamp() as u64;
|
||||
|
||||
if let Err(_e) = self.storage.store_snapshot_ref(
|
||||
machine_id,
|
||||
&snapshot_id,
|
||||
&snapshot_hash,
|
||||
created_at
|
||||
).await {
|
||||
return Ok(Some(Message::SnapshotFail {
|
||||
missing_chunks: vec![],
|
||||
missing_metas: vec![],
|
||||
}));
|
||||
}
|
||||
|
||||
// Store snapshot in database
|
||||
let machine_id_num: i64 = machine_id.parse().unwrap_or(0);
|
||||
let snapshot_hash_hex = hex::encode(snapshot_hash);
|
||||
if let Err(_e) = sqlx::query!(
|
||||
"INSERT INTO snapshots (machine_id, snapshot_hash) VALUES (?, ?)",
|
||||
machine_id_num,
|
||||
snapshot_hash_hex
|
||||
)
|
||||
.execute(self.session_manager.get_db_pool())
|
||||
.await {
|
||||
return Ok(Some(Message::SnapshotFail {
|
||||
missing_chunks: vec![],
|
||||
missing_metas: vec![],
|
||||
}));
|
||||
}
|
||||
|
||||
Ok(Some(Message::SnapshotOk { snapshot_id }))
|
||||
} else {
|
||||
Ok(Some(Message::SnapshotFail {
|
||||
missing_chunks: validation_result.missing_chunks,
|
||||
missing_metas: validation_result.missing_metas,
|
||||
}))
|
||||
}
|
||||
}
|
||||
Err(_e) => {
|
||||
Ok(Some(Message::SnapshotFail {
|
||||
missing_chunks: vec![],
|
||||
missing_metas: vec![],
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Message::Close => {
|
||||
Ok(None) // No response needed
|
||||
}
|
||||
|
||||
// These are response messages that shouldn't be received by the server
|
||||
Message::HelloOk | Message::AuthOk { .. } | Message::AuthFail { .. } |
|
||||
Message::CheckChunkResp { .. } | Message::ChunkOk | Message::ChunkFail { .. } |
|
||||
Message::CheckMetaResp { .. } | Message::MetaOk | Message::MetaFail { .. } |
|
||||
Message::SnapshotOk { .. } | Message::SnapshotFail { .. } => {
|
||||
Err(anyhow::anyhow!("Unexpected response message from client"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Require authentication for protected operations
|
||||
fn require_auth(&self) -> Result<()> {
|
||||
if self.session_id.is_none() {
|
||||
return Err(anyhow::anyhow!("Authentication required"));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
use sqlx::sqlite::SqlitePoolOptions;
|
||||
|
||||
async fn setup_test_server() -> (SyncServer, TempDir) {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
|
||||
let pool = SqlitePoolOptions::new()
|
||||
.connect(":memory:")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Create required tables
|
||||
sqlx::query!(
|
||||
r#"
|
||||
CREATE TABLE users (
|
||||
id INTEGER PRIMARY KEY,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT NOT NULL,
|
||||
active INTEGER DEFAULT 1
|
||||
)
|
||||
"#
|
||||
)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
CREATE TABLE provisioning_codes (
|
||||
id INTEGER PRIMARY KEY,
|
||||
code TEXT UNIQUE NOT NULL,
|
||||
created_by INTEGER NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
used INTEGER DEFAULT 0,
|
||||
used_at TEXT,
|
||||
FOREIGN KEY (created_by) REFERENCES users (id)
|
||||
)
|
||||
"#
|
||||
)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let config = SyncServerConfig {
|
||||
data_dir: temp_dir.path().to_string_lossy().to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
(SyncServer::new(config, pool), temp_dir)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_server_creation() {
|
||||
let (server, _temp_dir) = setup_test_server().await;
|
||||
|
||||
// Initialize storage to verify everything works
|
||||
server.storage.init().await.unwrap();
|
||||
}
|
||||
}
|
344
server/src/sync/session.rs
Normal file
344
server/src/sync/session.rs
Normal file
@@ -0,0 +1,344 @@
|
||||
use anyhow::{Context, Result};
|
||||
use rand::RngCore;
|
||||
use sqlx::SqlitePool;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Session information
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Session {
|
||||
pub session_id: [u8; 16],
|
||||
pub machine_id: String,
|
||||
pub user_id: i64,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
/// Session manager for sync connections
|
||||
#[derive(Debug)]
|
||||
pub struct SessionManager {
|
||||
sessions: Arc<RwLock<HashMap<[u8; 16], Session>>>,
|
||||
db_pool: SqlitePool,
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
pub fn new(db_pool: SqlitePool) -> Self {
|
||||
Self {
|
||||
sessions: Arc::new(RwLock::new(HashMap::new())),
|
||||
db_pool,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get database pool reference
|
||||
pub fn get_db_pool(&self) -> &SqlitePool {
|
||||
&self.db_pool
|
||||
}
|
||||
|
||||
/// Generate a new session ID
|
||||
fn generate_session_id() -> [u8; 16] {
|
||||
let mut session_id = [0u8; 16];
|
||||
rand::thread_rng().fill_bytes(&mut session_id);
|
||||
session_id
|
||||
}
|
||||
|
||||
/// Authenticate with username/password and validate machine ownership
|
||||
pub async fn authenticate_userpass(&self, username: &str, password: &str, machine_id: i64) -> Result<Session> {
|
||||
// Query user from database
|
||||
let user = sqlx::query!(
|
||||
"SELECT id, username, password_hash FROM users WHERE username = ?",
|
||||
username
|
||||
)
|
||||
.fetch_optional(&self.db_pool)
|
||||
.await
|
||||
.context("Failed to query user")?;
|
||||
|
||||
let user = user.ok_or_else(|| anyhow::anyhow!("Invalid credentials"))?;
|
||||
|
||||
// Verify password
|
||||
if !bcrypt::verify(password, &user.password_hash)
|
||||
.context("Failed to verify password")? {
|
||||
return Err(anyhow::anyhow!("Invalid credentials"));
|
||||
}
|
||||
|
||||
let user_id = user.id.unwrap_or(0) as i64;
|
||||
|
||||
// Validate machine ownership
|
||||
let machine = sqlx::query!(
|
||||
"SELECT id, user_id FROM machines WHERE id = ?",
|
||||
machine_id
|
||||
)
|
||||
.fetch_optional(&self.db_pool)
|
||||
.await
|
||||
.context("Failed to query machine")?;
|
||||
|
||||
let machine = machine.ok_or_else(|| anyhow::anyhow!("Machine not found"))?;
|
||||
|
||||
let machine_user_id = machine.user_id;
|
||||
if machine_user_id != user_id {
|
||||
return Err(anyhow::anyhow!("Machine does not belong to user"));
|
||||
}
|
||||
|
||||
// Create session
|
||||
let session_id = Self::generate_session_id();
|
||||
let machine_id_str = machine_id.to_string();
|
||||
let session = Session {
|
||||
session_id,
|
||||
machine_id: machine_id_str,
|
||||
user_id,
|
||||
created_at: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
// Store session
|
||||
let mut sessions = self.sessions.write().await;
|
||||
sessions.insert(session_id, session.clone());
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
/// Authenticate with provisioning code
|
||||
pub async fn authenticate_code(&self, code: &str) -> Result<Session> {
|
||||
// Query provisioning code from database
|
||||
let provisioning_code = sqlx::query!(
|
||||
r#"
|
||||
SELECT pc.id, pc.code, pc.expires_at, pc.used, m.user_id, u.username
|
||||
FROM provisioning_codes pc
|
||||
JOIN machines m ON pc.machine_id = m.id
|
||||
JOIN users u ON m.user_id = u.id
|
||||
WHERE pc.code = ? AND pc.used = 0
|
||||
"#,
|
||||
code
|
||||
)
|
||||
.fetch_optional(&self.db_pool)
|
||||
.await
|
||||
.context("Failed to query provisioning code")?;
|
||||
|
||||
let provisioning_code = provisioning_code
|
||||
.ok_or_else(|| anyhow::anyhow!("Invalid or used provisioning code"))?;
|
||||
|
||||
// Check if code is expired
|
||||
let expires_at: chrono::DateTime<chrono::Utc> = chrono::DateTime::from_naive_utc_and_offset(
|
||||
provisioning_code.expires_at,
|
||||
chrono::Utc
|
||||
);
|
||||
|
||||
if chrono::Utc::now() > expires_at {
|
||||
return Err(anyhow::anyhow!("Provisioning code expired"));
|
||||
}
|
||||
|
||||
// Mark code as used
|
||||
sqlx::query!(
|
||||
"UPDATE provisioning_codes SET used = 1 WHERE id = ?",
|
||||
provisioning_code.id
|
||||
)
|
||||
.execute(&self.db_pool)
|
||||
.await
|
||||
.context("Failed to mark provisioning code as used")?;
|
||||
|
||||
// Create session
|
||||
let session_id = Self::generate_session_id();
|
||||
let machine_id = format!("machine-{}", Uuid::new_v4());
|
||||
let session = Session {
|
||||
session_id,
|
||||
machine_id,
|
||||
user_id: provisioning_code.user_id as i64,
|
||||
created_at: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
// Store session
|
||||
let mut sessions = self.sessions.write().await;
|
||||
sessions.insert(session_id, session.clone());
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
/// Get session by session ID
|
||||
pub async fn get_session(&self, session_id: &[u8; 16]) -> Option<Session> {
|
||||
let sessions = self.sessions.read().await;
|
||||
sessions.get(session_id).cloned()
|
||||
}
|
||||
|
||||
/// Validate session and return associated machine ID
|
||||
pub async fn validate_session(&self, session_id: &[u8; 16]) -> Result<String> {
|
||||
let session = self.get_session(session_id).await
|
||||
.ok_or_else(|| anyhow::anyhow!("Invalid session"))?;
|
||||
|
||||
// Check if session is too old (24 hours)
|
||||
let session_age = chrono::Utc::now() - session.created_at;
|
||||
if session_age > chrono::Duration::hours(24) {
|
||||
// Remove expired session
|
||||
let mut sessions = self.sessions.write().await;
|
||||
sessions.remove(session_id);
|
||||
return Err(anyhow::anyhow!("Session expired"));
|
||||
}
|
||||
|
||||
Ok(session.machine_id)
|
||||
}
|
||||
|
||||
/// Remove session
|
||||
pub async fn remove_session(&self, session_id: &[u8; 16]) {
|
||||
let mut sessions = self.sessions.write().await;
|
||||
sessions.remove(session_id);
|
||||
}
|
||||
|
||||
/// Clean up expired sessions
|
||||
pub async fn cleanup_expired_sessions(&self) {
|
||||
let mut sessions = self.sessions.write().await;
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
sessions.retain(|_, session| {
|
||||
let age = now - session.created_at;
|
||||
age <= chrono::Duration::hours(24)
|
||||
});
|
||||
}
|
||||
|
||||
/// Get active session count
|
||||
pub async fn active_session_count(&self) -> usize {
|
||||
let sessions = self.sessions.read().await;
|
||||
sessions.len()
|
||||
}
|
||||
|
||||
/// List active sessions
|
||||
pub async fn list_active_sessions(&self) -> Vec<Session> {
|
||||
let sessions = self.sessions.read().await;
|
||||
sessions.values().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Periodic cleanup task for expired sessions
|
||||
pub async fn session_cleanup_task(session_manager: Arc<SessionManager>) {
|
||||
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(3600)); // Every hour
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
session_manager.cleanup_expired_sessions().await;
|
||||
println!("Cleaned up expired sync sessions. Active sessions: {}",
|
||||
session_manager.active_session_count().await);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use sqlx::sqlite::SqlitePoolOptions;
|
||||
|
||||
async fn setup_test_db() -> SqlitePool {
|
||||
let pool = SqlitePoolOptions::new()
|
||||
.connect(":memory:")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Create tables
|
||||
sqlx::query!(
|
||||
r#"
|
||||
CREATE TABLE users (
|
||||
id INTEGER PRIMARY KEY,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT NOT NULL,
|
||||
active INTEGER DEFAULT 1
|
||||
)
|
||||
"#
|
||||
)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
CREATE TABLE provisioning_codes (
|
||||
id INTEGER PRIMARY KEY,
|
||||
code TEXT UNIQUE NOT NULL,
|
||||
created_by INTEGER NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
used INTEGER DEFAULT 0,
|
||||
used_at TEXT,
|
||||
FOREIGN KEY (created_by) REFERENCES users (id)
|
||||
)
|
||||
"#
|
||||
)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Insert test user
|
||||
let password_hash = bcrypt::hash("password123", bcrypt::DEFAULT_COST).unwrap();
|
||||
sqlx::query!(
|
||||
"INSERT INTO users (username, password_hash) VALUES (?, ?)",
|
||||
"testuser",
|
||||
password_hash
|
||||
)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
pool
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_authenticate_userpass() {
|
||||
let pool = setup_test_db().await;
|
||||
let session_manager = SessionManager::new(pool);
|
||||
|
||||
let session = session_manager
|
||||
.authenticate_userpass("testuser", "password123")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(session.user_id, 1);
|
||||
assert!(!session.machine_id.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_authenticate_userpass_invalid() {
|
||||
let pool = setup_test_db().await;
|
||||
let session_manager = SessionManager::new(pool);
|
||||
|
||||
let result = session_manager
|
||||
.authenticate_userpass("testuser", "wrongpassword")
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_validation() {
|
||||
let pool = setup_test_db().await;
|
||||
let session_manager = SessionManager::new(pool);
|
||||
|
||||
let session = session_manager
|
||||
.authenticate_userpass("testuser", "password123")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let machine_id = session_manager
|
||||
.validate_session(&session.session_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(machine_id, session.machine_id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_cleanup() {
|
||||
let pool = setup_test_db().await;
|
||||
let session_manager = SessionManager::new(pool);
|
||||
|
||||
let session = session_manager
|
||||
.authenticate_userpass("testuser", "password123")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(session_manager.active_session_count().await, 1);
|
||||
|
||||
// Manually expire the session
|
||||
{
|
||||
let mut sessions = session_manager.sessions.write().await;
|
||||
if let Some(mut session) = sessions.get_mut(&session.session_id) {
|
||||
session.created_at = chrono::Utc::now() - chrono::Duration::hours(25);
|
||||
}
|
||||
}
|
||||
|
||||
session_manager.cleanup_expired_sessions().await;
|
||||
assert_eq!(session_manager.active_session_count().await, 0);
|
||||
}
|
||||
}
|
399
server/src/sync/storage.rs
Normal file
399
server/src/sync/storage.rs
Normal file
@@ -0,0 +1,399 @@
|
||||
use anyhow::{Context, Result};
|
||||
use bytes::Bytes;
|
||||
use std::collections::HashSet;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tokio::fs;
|
||||
use crate::sync::protocol::{Hash, MetaType};
|
||||
use crate::sync::meta::MetaObj;
|
||||
|
||||
/// Storage backend for chunks and meta objects
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Storage {
|
||||
data_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
pub fn new<P: AsRef<Path>>(data_dir: P) -> Self {
|
||||
Self {
|
||||
data_dir: data_dir.as_ref().to_path_buf(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize storage directories
|
||||
pub async fn init(&self) -> Result<()> {
|
||||
let chunks_dir = self.data_dir.join("sync").join("chunks");
|
||||
let meta_dir = self.data_dir.join("sync").join("meta");
|
||||
let machines_dir = self.data_dir.join("sync").join("machines");
|
||||
|
||||
fs::create_dir_all(&chunks_dir).await
|
||||
.context("Failed to create chunks directory")?;
|
||||
|
||||
fs::create_dir_all(&meta_dir).await
|
||||
.context("Failed to create meta directory")?;
|
||||
|
||||
fs::create_dir_all(&machines_dir).await
|
||||
.context("Failed to create machines directory")?;
|
||||
|
||||
// Create subdirectories for each meta type
|
||||
for meta_type in &["files", "dirs", "partitions", "disks", "snapshots"] {
|
||||
fs::create_dir_all(meta_dir.join(meta_type)).await
|
||||
.with_context(|| format!("Failed to create meta/{} directory", meta_type))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get chunk storage path for a hash
|
||||
fn chunk_path(&self, hash: &Hash) -> PathBuf {
|
||||
let hex = hex::encode(hash);
|
||||
let ab = &hex[0..2];
|
||||
let cd = &hex[2..4];
|
||||
let filename = format!("{}.chk", hex);
|
||||
|
||||
self.data_dir
|
||||
.join("sync")
|
||||
.join("chunks")
|
||||
.join(ab)
|
||||
.join(cd)
|
||||
.join(filename)
|
||||
}
|
||||
|
||||
/// Get meta storage path for a hash and type
|
||||
fn meta_path(&self, meta_type: MetaType, hash: &Hash) -> PathBuf {
|
||||
let hex = hex::encode(hash);
|
||||
let ab = &hex[0..2];
|
||||
let cd = &hex[2..4];
|
||||
let filename = format!("{}.meta", hex);
|
||||
|
||||
let type_dir = match meta_type {
|
||||
MetaType::File => "files",
|
||||
MetaType::Dir => "dirs",
|
||||
MetaType::Partition => "partitions",
|
||||
MetaType::Disk => "disks",
|
||||
MetaType::Snapshot => "snapshots",
|
||||
};
|
||||
|
||||
self.data_dir
|
||||
.join("sync")
|
||||
.join("meta")
|
||||
.join(type_dir)
|
||||
.join(ab)
|
||||
.join(cd)
|
||||
.join(filename)
|
||||
}
|
||||
|
||||
/// Check if a chunk exists
|
||||
pub async fn chunk_exists(&self, hash: &Hash) -> bool {
|
||||
let path = self.chunk_path(hash);
|
||||
path.exists()
|
||||
}
|
||||
|
||||
/// Check if multiple chunks exist
|
||||
pub async fn chunks_exist(&self, hashes: &[Hash]) -> Result<HashSet<Hash>> {
|
||||
let mut existing = HashSet::new();
|
||||
|
||||
for hash in hashes {
|
||||
if self.chunk_exists(hash).await {
|
||||
existing.insert(*hash);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(existing)
|
||||
}
|
||||
|
||||
/// Store a chunk
|
||||
pub async fn store_chunk(&self, hash: &Hash, data: &[u8]) -> Result<()> {
|
||||
// Verify hash
|
||||
let computed_hash = blake3::hash(data);
|
||||
if computed_hash.as_bytes() != hash {
|
||||
return Err(anyhow::anyhow!("Chunk hash mismatch"));
|
||||
}
|
||||
|
||||
let path = self.chunk_path(hash);
|
||||
|
||||
// Create parent directories
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).await
|
||||
.context("Failed to create chunk directory")?;
|
||||
}
|
||||
|
||||
// Write to temporary file first, then rename (atomic write)
|
||||
let temp_path = path.with_extension("tmp");
|
||||
fs::write(&temp_path, data).await
|
||||
.context("Failed to write chunk to temporary file")?;
|
||||
|
||||
fs::rename(&temp_path, &path).await
|
||||
.context("Failed to rename chunk file")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load a chunk
|
||||
pub async fn load_chunk(&self, hash: &Hash) -> Result<Option<Bytes>> {
|
||||
let path = self.chunk_path(hash);
|
||||
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let data = fs::read(&path).await
|
||||
.context("Failed to read chunk file")?;
|
||||
|
||||
// Verify hash
|
||||
let computed_hash = blake3::hash(&data);
|
||||
if computed_hash.as_bytes() != hash {
|
||||
return Err(anyhow::anyhow!("Stored chunk hash mismatch"));
|
||||
}
|
||||
|
||||
Ok(Some(Bytes::from(data)))
|
||||
}
|
||||
|
||||
/// Check if a meta object exists
|
||||
pub async fn meta_exists(&self, meta_type: MetaType, hash: &Hash) -> bool {
|
||||
let path = self.meta_path(meta_type, hash);
|
||||
path.exists()
|
||||
}
|
||||
|
||||
/// Check if multiple meta objects exist
|
||||
pub async fn metas_exist(&self, items: &[(MetaType, Hash)]) -> Result<HashSet<(MetaType, Hash)>> {
|
||||
let mut existing = HashSet::new();
|
||||
|
||||
for &(meta_type, hash) in items {
|
||||
if self.meta_exists(meta_type, &hash).await {
|
||||
existing.insert((meta_type, hash));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(existing)
|
||||
}
|
||||
|
||||
/// Store a meta object
|
||||
pub async fn store_meta(&self, meta_type: MetaType, hash: &Hash, body: &[u8]) -> Result<()> {
|
||||
// Verify hash
|
||||
let computed_hash = blake3::hash(body);
|
||||
if computed_hash.as_bytes() != hash {
|
||||
return Err(anyhow::anyhow!("Meta object hash mismatch"));
|
||||
}
|
||||
|
||||
let path = self.meta_path(meta_type, hash);
|
||||
|
||||
// Create parent directories
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).await
|
||||
.context("Failed to create meta directory")?;
|
||||
}
|
||||
|
||||
// Write to temporary file first, then rename (atomic write)
|
||||
let temp_path = path.with_extension("tmp");
|
||||
fs::write(&temp_path, body).await
|
||||
.context("Failed to write meta to temporary file")?;
|
||||
|
||||
fs::rename(&temp_path, &path).await
|
||||
.context("Failed to rename meta file")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load a meta object
|
||||
pub async fn load_meta(&self, meta_type: MetaType, hash: &Hash) -> Result<Option<MetaObj>> {
|
||||
let path = self.meta_path(meta_type, hash);
|
||||
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let data = fs::read(&path).await
|
||||
.context("Failed to read meta file")?;
|
||||
|
||||
// Verify hash
|
||||
let computed_hash = blake3::hash(&data);
|
||||
if computed_hash.as_bytes() != hash {
|
||||
return Err(anyhow::anyhow!("Stored meta object hash mismatch"));
|
||||
}
|
||||
|
||||
let meta_obj = MetaObj::deserialize(meta_type, Bytes::from(data))
|
||||
.context("Failed to deserialize meta object")?;
|
||||
|
||||
Ok(Some(meta_obj))
|
||||
}
|
||||
|
||||
/// Get snapshot storage path for a machine
|
||||
fn snapshot_ref_path(&self, machine_id: &str, snapshot_id: &str) -> PathBuf {
|
||||
self.data_dir
|
||||
.join("sync")
|
||||
.join("machines")
|
||||
.join(machine_id)
|
||||
.join("snapshots")
|
||||
.join(format!("{}.ref", snapshot_id))
|
||||
}
|
||||
|
||||
/// Store a snapshot reference
|
||||
pub async fn store_snapshot_ref(
|
||||
&self,
|
||||
machine_id: &str,
|
||||
snapshot_id: &str,
|
||||
snapshot_hash: &Hash,
|
||||
created_at: u64
|
||||
) -> Result<()> {
|
||||
let path = self.snapshot_ref_path(machine_id, snapshot_id);
|
||||
|
||||
// Create parent directories
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).await
|
||||
.context("Failed to create snapshot reference directory")?;
|
||||
}
|
||||
|
||||
// Create snapshot reference content
|
||||
let content = format!("{}:{}", hex::encode(snapshot_hash), created_at);
|
||||
|
||||
// Write to temporary file first, then rename (atomic write)
|
||||
let temp_path = path.with_extension("tmp");
|
||||
fs::write(&temp_path, content).await
|
||||
.context("Failed to write snapshot reference to temporary file")?;
|
||||
|
||||
fs::rename(&temp_path, &path).await
|
||||
.context("Failed to rename snapshot reference file")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load a snapshot reference
|
||||
pub async fn load_snapshot_ref(&self, machine_id: &str, snapshot_id: &str) -> Result<Option<(Hash, u64)>> {
|
||||
let path = self.snapshot_ref_path(machine_id, snapshot_id);
|
||||
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let content = fs::read_to_string(&path).await
|
||||
.context("Failed to read snapshot reference file")?;
|
||||
|
||||
let parts: Vec<&str> = content.trim().split(':').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err(anyhow::anyhow!("Invalid snapshot reference format"));
|
||||
}
|
||||
|
||||
let snapshot_hash: Hash = hex::decode(parts[0])
|
||||
.context("Failed to decode snapshot hash")?
|
||||
.try_into()
|
||||
.map_err(|_| anyhow::anyhow!("Invalid snapshot hash length"))?;
|
||||
|
||||
let created_at: u64 = parts[1].parse()
|
||||
.context("Failed to parse snapshot timestamp")?;
|
||||
|
||||
Ok(Some((snapshot_hash, created_at)))
|
||||
}
|
||||
|
||||
/// List snapshots for a machine
|
||||
pub async fn list_snapshots(&self, machine_id: &str) -> Result<Vec<String>> {
|
||||
let snapshots_dir = self.data_dir
|
||||
.join("sync")
|
||||
.join("machines")
|
||||
.join(machine_id)
|
||||
.join("snapshots");
|
||||
|
||||
if !snapshots_dir.exists() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut entries = fs::read_dir(&snapshots_dir).await
|
||||
.context("Failed to read snapshots directory")?;
|
||||
|
||||
let mut snapshots = Vec::new();
|
||||
while let Some(entry) = entries.next_entry().await
|
||||
.context("Failed to read snapshot entry")? {
|
||||
|
||||
if let Some(file_name) = entry.file_name().to_str() {
|
||||
if file_name.ends_with(".ref") {
|
||||
let snapshot_id = file_name.trim_end_matches(".ref");
|
||||
snapshots.push(snapshot_id.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
snapshots.sort();
|
||||
Ok(snapshots)
|
||||
}
|
||||
|
||||
/// Delete old snapshots, keeping only the latest N
|
||||
pub async fn cleanup_snapshots(&self, machine_id: &str, keep_count: usize) -> Result<()> {
|
||||
let mut snapshots = self.list_snapshots(machine_id).await?;
|
||||
|
||||
if snapshots.len() <= keep_count {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
snapshots.sort();
|
||||
snapshots.reverse(); // Most recent first
|
||||
|
||||
// Delete older snapshots
|
||||
for snapshot_id in snapshots.iter().skip(keep_count) {
|
||||
let path = self.snapshot_ref_path(machine_id, snapshot_id);
|
||||
if path.exists() {
|
||||
fs::remove_file(&path).await
|
||||
.with_context(|| format!("Failed to delete snapshot {}", snapshot_id))?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Add hex crate to dependencies
|
||||
use hex;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_init() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let storage = Storage::new(temp_dir.path());
|
||||
storage.init().await.unwrap();
|
||||
|
||||
assert!(temp_dir.path().join("sync/chunks").exists());
|
||||
assert!(temp_dir.path().join("sync/meta/files").exists());
|
||||
assert!(temp_dir.path().join("sync/machines").exists());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_chunk_storage() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let storage = Storage::new(temp_dir.path());
|
||||
storage.init().await.unwrap();
|
||||
|
||||
let data = b"test chunk data";
|
||||
let hash = blake3::hash(data).into();
|
||||
|
||||
// Store chunk
|
||||
storage.store_chunk(&hash, data).await.unwrap();
|
||||
assert!(storage.chunk_exists(&hash).await);
|
||||
|
||||
// Load chunk
|
||||
let loaded = storage.load_chunk(&hash).await.unwrap().unwrap();
|
||||
assert_eq!(loaded.as_ref(), data);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_snapshot_ref_storage() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let storage = Storage::new(temp_dir.path());
|
||||
storage.init().await.unwrap();
|
||||
|
||||
let machine_id = "test-machine";
|
||||
let snapshot_id = "snapshot-001";
|
||||
let snapshot_hash = [1u8; 32];
|
||||
let created_at = 1234567890;
|
||||
|
||||
storage.store_snapshot_ref(machine_id, snapshot_id, &snapshot_hash, created_at)
|
||||
.await.unwrap();
|
||||
|
||||
let loaded = storage.load_snapshot_ref(machine_id, snapshot_id)
|
||||
.await.unwrap().unwrap();
|
||||
|
||||
assert_eq!(loaded.0, snapshot_hash);
|
||||
assert_eq!(loaded.1, created_at);
|
||||
}
|
||||
}
|
233
server/src/sync/validation.rs
Normal file
233
server/src/sync/validation.rs
Normal file
@@ -0,0 +1,233 @@
|
||||
use anyhow::{Context, Result};
|
||||
use std::collections::{HashSet, VecDeque};
|
||||
use crate::sync::protocol::{Hash, MetaType};
|
||||
use crate::sync::storage::Storage;
|
||||
use crate::sync::meta::{MetaObj, SnapshotObj, EntryType};
|
||||
|
||||
/// Validation result for snapshot commits
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ValidationResult {
|
||||
pub is_valid: bool,
|
||||
pub missing_chunks: Vec<Hash>,
|
||||
pub missing_metas: Vec<(MetaType, Hash)>,
|
||||
}
|
||||
|
||||
impl ValidationResult {
|
||||
pub fn valid() -> Self {
|
||||
Self {
|
||||
is_valid: true,
|
||||
missing_chunks: Vec::new(),
|
||||
missing_metas: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn invalid(missing_chunks: Vec<Hash>, missing_metas: Vec<(MetaType, Hash)>) -> Self {
|
||||
Self {
|
||||
is_valid: false,
|
||||
missing_chunks,
|
||||
missing_metas,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_missing(&self) -> bool {
|
||||
!self.missing_chunks.is_empty() || !self.missing_metas.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Validator for snapshot object graphs
|
||||
#[derive(Clone)]
|
||||
pub struct SnapshotValidator {
|
||||
storage: Storage,
|
||||
}
|
||||
|
||||
impl SnapshotValidator {
|
||||
pub fn new(storage: Storage) -> Self {
|
||||
Self { storage }
|
||||
}
|
||||
|
||||
/// Validate a complete snapshot object graph using BFS only
|
||||
pub async fn validate_snapshot(&self, snapshot_hash: &Hash, snapshot_body: &[u8]) -> Result<ValidationResult> {
|
||||
// Use the BFS implementation
|
||||
self.validate_snapshot_bfs(snapshot_hash, snapshot_body).await
|
||||
}
|
||||
|
||||
/// Validate a batch of meta objects (for incremental validation)
|
||||
pub async fn validate_meta_batch(&self, metas: &[(MetaType, Hash)]) -> Result<Vec<(MetaType, Hash)>> {
|
||||
let mut missing = Vec::new();
|
||||
|
||||
for &(meta_type, hash) in metas {
|
||||
if !self.storage.meta_exists(meta_type, &hash).await {
|
||||
missing.push((meta_type, hash));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(missing)
|
||||
}
|
||||
|
||||
/// Validate a batch of chunks (for incremental validation)
|
||||
pub async fn validate_chunk_batch(&self, chunks: &[Hash]) -> Result<Vec<Hash>> {
|
||||
let mut missing = Vec::new();
|
||||
|
||||
for &hash in chunks {
|
||||
if !self.storage.chunk_exists(&hash).await {
|
||||
missing.push(hash);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(missing)
|
||||
}
|
||||
|
||||
/// Perform a breadth-first validation (useful for large snapshots)
|
||||
pub async fn validate_snapshot_bfs(&self, snapshot_hash: &Hash, snapshot_body: &[u8]) -> Result<ValidationResult> {
|
||||
// Verify snapshot hash
|
||||
let computed_hash = blake3::hash(snapshot_body);
|
||||
if computed_hash.as_bytes() != snapshot_hash {
|
||||
return Err(anyhow::anyhow!("Snapshot hash mismatch"));
|
||||
}
|
||||
|
||||
// Parse snapshot object
|
||||
let snapshot_obj = SnapshotObj::deserialize(bytes::Bytes::from(snapshot_body.to_vec()))
|
||||
.context("Failed to deserialize snapshot object")?;
|
||||
|
||||
let mut missing_chunks = Vec::new();
|
||||
let mut missing_metas = Vec::new();
|
||||
let mut visited_metas = HashSet::new();
|
||||
let mut queue = VecDeque::new();
|
||||
|
||||
// Initialize queue with disk hashes
|
||||
for disk_hash in &snapshot_obj.disk_hashes {
|
||||
queue.push_back((MetaType::Disk, *disk_hash));
|
||||
}
|
||||
|
||||
// BFS traversal
|
||||
while let Some((meta_type, hash)) = queue.pop_front() {
|
||||
let meta_key = (meta_type, hash);
|
||||
|
||||
if visited_metas.contains(&meta_key) {
|
||||
continue;
|
||||
}
|
||||
visited_metas.insert(meta_key);
|
||||
|
||||
// Check if meta exists
|
||||
if !self.storage.meta_exists(meta_type, &hash).await {
|
||||
missing_metas.push((meta_type, hash));
|
||||
continue; // Skip loading if missing
|
||||
}
|
||||
|
||||
// Load and process meta object
|
||||
if let Some(meta_obj) = self.storage.load_meta(meta_type, &hash).await
|
||||
.context("Failed to load meta object")? {
|
||||
|
||||
match meta_obj {
|
||||
MetaObj::Disk(disk) => {
|
||||
for partition_hash in &disk.partition_hashes {
|
||||
queue.push_back((MetaType::Partition, *partition_hash));
|
||||
}
|
||||
}
|
||||
MetaObj::Partition(partition) => {
|
||||
queue.push_back((MetaType::Dir, partition.root_dir_hash));
|
||||
}
|
||||
MetaObj::Dir(dir) => {
|
||||
for entry in &dir.entries {
|
||||
match entry.entry_type {
|
||||
EntryType::File | EntryType::Symlink => {
|
||||
queue.push_back((MetaType::File, entry.target_meta_hash));
|
||||
}
|
||||
EntryType::Dir => {
|
||||
queue.push_back((MetaType::Dir, entry.target_meta_hash));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
MetaObj::File(file) => {
|
||||
// Check chunk dependencies
|
||||
for chunk_hash in &file.chunk_hashes {
|
||||
if !self.storage.chunk_exists(chunk_hash).await {
|
||||
missing_chunks.push(*chunk_hash);
|
||||
}
|
||||
}
|
||||
}
|
||||
MetaObj::Snapshot(_) => {
|
||||
// Snapshots shouldn't be nested
|
||||
return Err(anyhow::anyhow!("Unexpected nested snapshot object"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if missing_chunks.is_empty() && missing_metas.is_empty() {
|
||||
Ok(ValidationResult::valid())
|
||||
} else {
|
||||
Ok(ValidationResult::invalid(missing_chunks, missing_metas))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
use crate::sync::meta::*;
|
||||
|
||||
async fn setup_test_storage() -> Storage {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let storage = Storage::new(temp_dir.path());
|
||||
storage.init().await.unwrap();
|
||||
storage
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_empty_snapshot() {
|
||||
let storage = setup_test_storage().await;
|
||||
let validator = SnapshotValidator::new(storage);
|
||||
|
||||
let snapshot = SnapshotObj::new(1234567890, vec![]);
|
||||
let snapshot_body = snapshot.serialize().unwrap();
|
||||
let snapshot_hash = snapshot.compute_hash().unwrap();
|
||||
|
||||
let result = validator.validate_snapshot(&snapshot_hash, &snapshot_body)
|
||||
.await.unwrap();
|
||||
|
||||
assert!(result.is_valid);
|
||||
assert!(result.missing_chunks.is_empty());
|
||||
assert!(result.missing_metas.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_missing_disk() {
|
||||
let storage = setup_test_storage().await;
|
||||
let validator = SnapshotValidator::new(storage);
|
||||
|
||||
let missing_disk_hash = [1u8; 32];
|
||||
let snapshot = SnapshotObj::new(1234567890, vec![missing_disk_hash]);
|
||||
let snapshot_body = snapshot.serialize().unwrap();
|
||||
let snapshot_hash = snapshot.compute_hash().unwrap();
|
||||
|
||||
let result = validator.validate_snapshot(&snapshot_hash, &snapshot_body)
|
||||
.await.unwrap();
|
||||
|
||||
assert!(!result.is_valid);
|
||||
assert!(result.missing_chunks.is_empty());
|
||||
assert_eq!(result.missing_metas.len(), 1);
|
||||
assert_eq!(result.missing_metas[0], (MetaType::Disk, missing_disk_hash));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_chunk_batch() {
|
||||
let storage = setup_test_storage().await;
|
||||
let validator = SnapshotValidator::new(storage);
|
||||
|
||||
let chunk_data = b"test chunk";
|
||||
let chunk_hash = blake3::hash(chunk_data).into();
|
||||
let missing_hash = [1u8; 32];
|
||||
|
||||
// Store one chunk
|
||||
storage.store_chunk(&chunk_hash, chunk_data).await.unwrap();
|
||||
|
||||
let chunks = vec![chunk_hash, missing_hash];
|
||||
let missing = validator.validate_chunk_batch(&chunks).await.unwrap();
|
||||
|
||||
assert_eq!(missing.len(), 1);
|
||||
assert_eq!(missing[0], missing_hash);
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user