use std::{ any::{Any, TypeId}, collections::{HashMap, HashSet}, marker::PhantomData, sync::{Arc, Mutex}, }; #[derive(Default, Copy, Clone, PartialOrd, PartialEq, Eq, Hash, Debug)] struct TransactionId(u64); impl From for TransactionId { fn from(id: u64) -> Self { Self(id) } } #[derive(Debug, Copy, Clone)] pub struct Id { id: u64, _type: PhantomData, } impl Id { pub fn new(id: u64) -> Self { Self { id, _type: PhantomData, } } } impl std::cmp::Eq for Id {} impl std::cmp::PartialEq for Id { fn eq(&self, other: &Self) -> bool { self.id == other.id } } impl std::hash::Hash for Id { fn hash(&self, state: &mut H) { self.id.hash(state); } } #[derive(Default, Copy, Clone, PartialEq, Eq, Hash, Debug)] struct TypelessId(u64); impl From> for TypelessId { fn from(id: Id) -> Self { Self(id.id) } } impl From<&Id> for TypelessId { fn from(id: &Id) -> Self { Self(id.id) } } impl From for TypelessId { fn from(id: u64) -> Self { Self(id) } } #[derive(Debug, Default)] pub struct Database { log: Vec, snapshots: HashMap, } impl Database { fn build_snapshot(&mut self, timestamp: TransactionId) -> &Snapshot { let mut snapshot = Snapshot { timestamp, cached: Default::default(), }; for t in self.transactions_to(timestamp) { for write in &t.write_set.0 { match &write.operation { DocumentWriteOperation::New(data) => { snapshot.cached.insert(write.id, data.cloned()); } DocumentWriteOperation::Modified(fields) => { match snapshot.get_typeless_mut(write.id) { Some(data) => { for field in fields { data.write_modification(field); } } None => { panic!( "somehow found a modified write for id {:?} at timestamp {:?} that doesn't exist in snapshot at timestamp {:?}", write.id, t.id, snapshot.timestamp ) } } } } } } self.snapshots.insert(timestamp, snapshot); self.snapshots.get(×tamp).unwrap() } fn transactions_to(&mut self, timestamp: TransactionId) -> impl Iterator { self.log.iter().take_while(move |t| t.id <= timestamp) } fn transactions_from( &mut self, timestamp: TransactionId, ) -> impl Iterator { self.log.iter().skip_while(move |t| t.id < timestamp) } fn apply_transaction(&mut self, candidate: TransactionCandidate) { let timestamp = (self.log.len() as u64).into(); println!("applying transaction to {:?}: {:?}", timestamp, &candidate); self.log.push(Transaction::commit(timestamp, candidate)); self.build_snapshot(timestamp); } } #[derive(Clone)] pub struct DatabaseAccessor { db: Arc>, } impl DatabaseAccessor { pub fn new(db: Database) -> Self { Self { db: Arc::new(Mutex::new(db)), } } } impl DatabaseAccessor { // TODO pub fn get_all(&self) -> Vec { vec![] } pub fn transact(&self, f: F) -> O where F: Fn(&mut PendingTransaction) -> O, { let mut me = self.clone(); for i in 0..5 { println!("Starting transaction attempt {i}"); let current_timestamp = self.db.lock().unwrap().log.len() as u64; let (candidate, result): (TransactionCandidate, O) = { let mut pending = PendingTransaction::new(&mut me, current_timestamp.into()); let result = f(&mut pending); (pending.into(), result) }; println!("transaction read set: {:?}", candidate.read_set); let mut db = self.db.lock().unwrap(); if db .transactions_from(candidate.timestamp) .any(|t| candidate.read_set.overlaps_with(&t.write_set)) { println!( "Read set overlaps with write set: {:#?}", candidate.read_set ); continue; } db.apply_transaction(candidate); return result; } panic!("failed to apply transaction after 5 attempts"); } fn get( &mut self, id: Id, timestamp: TransactionId, ) -> Option { let db = &mut self.db.lock().unwrap(); let snapshot = { if db.snapshots.contains_key(×tamp) { db.snapshots.get(×tamp).unwrap() } else { db.build_snapshot(timestamp) } }; Some(snapshot.get(id).unwrap()) } } #[derive(Debug)] struct Transaction { id: TransactionId, write_set: WriteSet, } impl Transaction { fn commit(id: TransactionId, candidate: TransactionCandidate) -> Self { Self { id, write_set: candidate.write_set, } } } #[derive(Debug, Default)] struct Snapshot { timestamp: TransactionId, cached: HashMap>, } impl Snapshot { fn get(&self, id: Id) -> Option { let readable = self.cached.get(&id.into())?; if (**readable).type_id() != TypeId::of::() { panic!( "Type mismatch: expected {:?}, got {:?}", TypeId::of::(), (**readable).type_id() ); // None } else if let Some(data) = (&**readable as &dyn Any).downcast_ref::() { Some((*data).clone()) } else { panic!( "Type mismatch: expected {:?}, got {:?}", TypeId::of::(), (**readable).type_id() ); } } fn get_typeless_mut( &mut self, id: TypelessId, ) -> Option<&mut Box> { self.cached.get_mut(&id) } } #[derive(Debug)] struct TransactionCandidate { timestamp: TransactionId, read_set: ReadSet, write_set: WriteSet, } pub struct PendingTransaction<'a> { accessor: &'a mut DatabaseAccessor, timestamp: TransactionId, read_set: ReadSet, write_set: WriteSet, } impl<'a> From> for TransactionCandidate { fn from(pending: PendingTransaction) -> Self { Self { timestamp: pending.timestamp, read_set: pending.read_set, write_set: pending.write_set, } } } impl<'a> PendingTransaction<'a> { fn new(accessor: &'a mut DatabaseAccessor, timestamp: TransactionId) -> Self { Self { accessor, timestamp, read_set: ReadSet::default(), write_set: WriteSet::default(), } } pub fn insert(&mut self, id: Id, data: T) { // FIXME: decide on a way to generate unqiue ids self.write_set.push(DocumentWrite::new(id.into(), data)); } pub fn modify( &mut self, id: Id, transform: impl Fn(T) -> T, ) { println!("attempting modification in timestamp {:?}", self.timestamp); let old_data = self.accessor.get(id.clone(), self.timestamp).unwrap(); let new_data = transform(old_data.clone()); let modifications = old_data.modifications_between(&new_data); for modification in &modifications { self.read_set.push(DocumentRead::Field(modification.field)); } self.write_set .push(DocumentWrite::modification(id, modifications)); } pub fn get(&mut self, id: Id) -> Option { self.read_set.push(DocumentRead::Complete); self.accessor.get(id, self.timestamp) } pub fn get_field( &mut self, id: Id, field: &'static str, ) -> Option { self.read_set.push(DocumentRead::Field(field)); self.accessor .get(id, self.timestamp) .and_then(|data| data.field(field)) } } // TODO: make these actual sets? #[derive(Debug, Default)] struct ReadSet(Vec); #[derive(Debug, Default)] struct WriteSet(Vec); impl ReadSet { fn push(&mut self, read: DocumentRead) { self.0.push(read); } } impl WriteSet { fn push(&mut self, write: DocumentWrite) { self.0.push(write); } } impl ReadSet { fn overlaps_with(&self, write_set: &WriteSet) -> bool { if !write_set.0.is_empty() && self .0 .iter() .any(|read| matches!(read, DocumentRead::Complete)) { return true; } let writes = write_set .0 .iter() .filter_map(|write| match &write.operation { DocumentWriteOperation::New(_) => None, DocumentWriteOperation::Modified(fields) => Some(fields.iter()), }) .flatten() .map(|field| field.field) .collect::>(); let reads = self .0 .iter() .filter_map(|read| match read { DocumentRead::Complete => None, DocumentRead::Field(field) => Some(*field), }) .collect::>(); reads.intersection(&writes).count() > 0 } } pub trait Storable: Any + std::fmt::Debug { fn cloned(&self) -> Box; fn write_modification(&mut self, field: &DocumentField); fn field(&self, field: &'static str) -> Option; } pub trait Modifiable { fn modifications_between(&self, other: &Self) -> Vec; } #[derive(Debug)] enum DocumentRead { Complete, Field(&'static str), } #[derive(Debug)] struct DocumentWrite { id: TypelessId, type_id: TypeId, operation: DocumentWriteOperation, } impl DocumentWrite { fn new(id: TypelessId, data: T) -> Self { Self { id, type_id: TypeId::of::(), operation: DocumentWriteOperation::New(Box::new(data)), } } fn modification(id: Id, operation: Vec) -> Self { Self { id: id.into(), type_id: TypeId::of::(), operation: DocumentWriteOperation::Modified(operation), } } } #[derive(Debug)] pub enum DocumentWriteOperation { New(Box), Modified(Vec), } #[derive(Debug, Clone)] pub struct DocumentField { field: &'static str, value: Value, } impl DocumentField { fn of(field: &'static str, value: Value) -> Self { Self { field, value } } } #[derive(Debug, Clone)] pub enum Value { String(String), Int(i64), Float(f64), Bool(bool), Array(Vec), } impl Value { fn as_string(&self) -> Option<&String> { match self { Value::String(s) => Some(s), _ => None, } } fn as_int(&self) -> Option<&i64> { match self { Value::Int(i) => Some(i), _ => None, } } fn as_array(&self) -> Option<&Vec> { match self { Value::Array(a) => Some(a), _ => None, } } } #[cfg(test)] mod tests { use std::collections::HashSet; use super::*; #[derive(Debug, Clone)] struct MyTestData { name: Name, age: i64, contacts: HashSet>, } #[derive(Debug, Clone)] struct Name { first: String, last: String, } // TODO: make this a derive macro impl Storable for MyTestData { fn cloned(&self) -> Box { Box::new(self.clone()) } fn write_modification(&mut self, field: &DocumentField) { match field.field { "name.first" => self.name.first = field.value.as_string().unwrap().to_string(), "name.last" => self.name.last = field.value.as_string().unwrap().to_string(), "age" => self.age = *field.value.as_int().unwrap(), "contacts" => { self.contacts = field .value .as_array() .unwrap() .iter() .map(|id| Id::::new(id.0)) .collect(); } _ => panic!("unknown field"), } } fn field(&self, field: &'static str) -> Option { match field { "name.first" => Some(Value::String(self.name.first.clone())), "name.last" => Some(Value::String(self.name.last.clone())), "age" => Some(Value::Int(self.age)), "contacts" => Some(Value::Array( self.contacts.iter().map(|id| id.into()).collect(), )), _ => None, } } } // TODO: make this a derive macro impl Modifiable for MyTestData { fn modifications_between(&self, other: &Self) -> Vec { let mut modifications = vec![]; if self.name.first != other.name.first { modifications.push(DocumentField::of( "name.first", Value::String(other.name.first.clone()), )); } if self.name.last != other.name.last { modifications.push(DocumentField::of( "name.last", Value::String(other.name.last.clone()), )); } if self.age != other.age { modifications.push(DocumentField::of("age", Value::Int(other.age))); } let contacts_diff = self.contacts.difference(&other.contacts); if contacts_diff.count() > 0 { modifications.push(DocumentField::of( "contacts", Value::Array(other.contacts.iter().map(|id| id.into()).collect()), )); } modifications } } #[test] fn test() { let db = Database::default(); let mut accessor = DatabaseAccessor::new(db); let data = MyTestData { name: Name { first: "John".to_string(), last: "Doe".to_string(), }, age: 42, contacts: HashSet::new(), }; println!("initial state: {:#?}", accessor.db.lock().unwrap()); accessor.transact(move |t| { t.insert(Id::new(1), data.clone()); t.insert( Id::new(2), MyTestData { name: Name { first: "Oldy".to_string(), last: "McOlderton".to_string(), }, age: 69, contacts: vec![Id::new(1)].into_iter().collect(), }, ); }); let mut cloned_accessor = accessor.clone(); let handle = std::thread::spawn(move || { cloned_accessor.transact(|t| { let Some(Value::String(last_name_of_oldy)) = t.get_field(Id::::new(2), "name.last") else { panic!("expected a string") }; t.modify(Id::::new(1), |mut data| { data.name.last = format!("Gearbox {}", last_name_of_oldy); data }); }); }); let id = Id::::new(1); accessor.transact(|t| { t.modify(id.clone(), |mut data| { data.name.last = "Not McOlderton Gearbox".to_string(); data }); }); let oldys_contacts = accessor.transact(|t| { let oldy = t.get(Id::::new(2)).unwrap(); oldy.contacts .into_iter() .filter_map(|id| t.get(id)) .map(|data| format!("{} {}", data.name.first, data.name.last)) .collect::>() }); handle.join().unwrap(); let data0 = accessor.get(Id::::new(1), 0.into()); let data1 = accessor.get(Id::::new(1), 1.into()); let data2 = accessor.get(Id::::new(1), 2.into()); let data3 = accessor.get(Id::::new(1), 3.into()); println!("data at timestamp 0: {:#?}", data0); println!("data at timestamp 1: {:#?}", data1); println!("data at timestamp 2: {:#?}", data2); println!("data at timestamp 3: {:#?}", data3); panic!("oldys contacts: {:#?}", oldys_contacts); } }