pub mod write_set; use std::{ any::{Any, TypeId}, collections::{HashMap, HashSet}, marker::PhantomData, sync::{Arc, Mutex}, }; use crate::outbound::db_custom::write_set::{Diffable, Storable, Write, WriteOperation, WriteSet}; #[derive(Default, Copy, Clone, PartialOrd, PartialEq, Eq, Hash, Debug)] struct TransactionId(u64); impl From for TransactionId { fn from(id: u64) -> Self { Self(id) } } #[derive(Debug, Copy, Clone)] pub struct Id { id: u64, _type: PhantomData, } impl Id { pub fn new(id: u64) -> Self { Self { id, _type: PhantomData, } } } impl bincode::Encode for Id { fn encode( &self, encoder: &mut E, ) -> Result<(), bincode::error::EncodeError> { bincode::enc::Encode::encode(&self.id, encoder)?; Ok(()) } } impl bincode::Decode for Id { fn decode>( decoder: &mut D, ) -> Result { let id = bincode::de::Decode::decode(decoder)?; Ok(Self { id, _type: PhantomData, }) } } impl<'de, T, C> bincode::BorrowDecode<'de, C> for Id { fn borrow_decode>( decoder: &mut D, ) -> Result { let id = bincode::de::Decode::decode(decoder)?; Ok(Self { id, _type: PhantomData, }) } } impl std::cmp::Eq for Id {} impl std::cmp::PartialEq for Id { fn eq(&self, other: &Self) -> bool { self.id == other.id } } impl std::hash::Hash for Id { fn hash(&self, state: &mut H) { self.id.hash(state); } } #[derive(Default, Copy, Clone, PartialEq, Eq, Hash, Debug, bincode::Encode, bincode::Decode)] struct TypelessId(u64); impl From> for TypelessId { fn from(id: Id) -> Self { Self(id.id) } } impl From<&Id> for TypelessId { fn from(id: &Id) -> Self { Self(id.id) } } impl From for TypelessId { fn from(id: u64) -> Self { Self(id) } } #[derive(Debug, Default)] pub struct Database { log: Vec, snapshots: HashMap, registered_types: HashMap Box>, } impl Database { fn register_type( &mut self, name: impl ToString, decode_fn: fn(&[u8]) -> Box, ) { self.registered_types.insert(name.to_string(), decode_fn); } fn decode(&self, data: &[u8]) -> Result, anyhow::Error> { let (name, len): (String, usize) = bincode::decode_from_slice(data, bincode::config::standard()) .expect("decode MUST succeed"); let create_storable = self.registered_types.get(&name).expect("aalkdjhfl"); Ok(create_storable(&data[len..])) } fn build_snapshot(&mut self, timestamp: TransactionId) -> &Snapshot { let mut snapshot = Snapshot { timestamp, cached: Default::default(), }; for t in self.transactions_to(timestamp) { for write in &t.write_set.writes { match &write.operation { WriteOperation::Full(data) => { snapshot .cached .insert(write.id, self.decode(data).expect("decode MUST succeed")); } WriteOperation::Partial(partial) => match snapshot.get_typeless_mut(write.id) { Some(data) => { for op in partial { data.apply_partial_write(op); } } None => { panic!( "somehow found a partial write for id {:?} at timestamp {:?} that doesn't exist in snapshot at timestamp {:?}", write.id, t.id, snapshot.timestamp ) } }, } } } self.snapshots.insert(timestamp, snapshot); self.snapshots.get(×tamp).unwrap() } fn transactions_to(&self, timestamp: TransactionId) -> impl Iterator { self.log.iter().take_while(move |t| t.id <= timestamp) } fn transactions_from(&self, timestamp: TransactionId) -> impl Iterator { self.log.iter().skip_while(move |t| t.id < timestamp) } fn apply_transaction(&mut self, candidate: TransactionCandidate) { let timestamp = (self.log.len() as u64).into(); println!("applying transaction to {:?}: {:?}", timestamp, &candidate); self.log.push(Transaction::commit(timestamp, candidate)); self.build_snapshot(timestamp); } } #[derive(Clone)] pub struct DatabaseAccessor { db: Arc>, } impl DatabaseAccessor { pub fn new(db: Database) -> Self { Self { db: Arc::new(Mutex::new(db)), } } } impl DatabaseAccessor { // TODO pub fn get_all(&self) -> Vec { vec![] } pub fn transact(&self, f: F) -> O where F: Fn(&mut PendingTransaction) -> O, { let mut me = self.clone(); for i in 0..5 { println!("Starting transaction attempt {i}"); let current_timestamp = self.db.lock().unwrap().log.len() as u64; let (candidate, result): (TransactionCandidate, O) = { let mut pending = PendingTransaction::new(&mut me, current_timestamp.into()); let result = f(&mut pending); (pending.into(), result) }; println!("transaction read set: {:?}", candidate.read_set); let mut db = self.db.lock().unwrap(); if db .transactions_from(candidate.timestamp) .any(|t| candidate.read_set.overlaps_with(&t.write_set)) { println!( "Read set overlaps with write set: {:#?}", candidate.read_set ); continue; } db.apply_transaction(candidate); return result; } panic!("failed to apply transaction after 5 attempts"); } fn get( &mut self, id: Id, timestamp: TransactionId, ) -> Option { let db = &mut self.db.lock().unwrap(); let snapshot = { if db.snapshots.contains_key(×tamp) { db.snapshots.get(×tamp).unwrap() } else { db.build_snapshot(timestamp) } }; Some(snapshot.get(id).unwrap()) } } #[derive(Debug)] struct Transaction { id: TransactionId, write_set: WriteSet, } impl Transaction { fn commit(id: TransactionId, candidate: TransactionCandidate) -> Self { Self { id, write_set: candidate.write_set, } } } #[derive(Debug, Default)] struct Snapshot { timestamp: TransactionId, cached: HashMap>, } impl Snapshot { fn get(&self, id: Id) -> Option { let readable = self.cached.get(&id.into())?; if (**readable).type_id() != TypeId::of::() { panic!( "Type mismatch: expected {:?}, got {:?}", TypeId::of::(), (**readable).type_id() ); // None } else if let Some(data) = (&**readable as &dyn Any).downcast_ref::() { Some((*data).clone()) } else { panic!( "Type mismatch: expected {:?}, got {:?}", TypeId::of::(), (**readable).type_id() ); } } fn get_typeless_mut( &mut self, id: TypelessId, ) -> Option<&mut Box> { self.cached.get_mut(&id) } } #[derive(Debug)] struct TransactionCandidate { timestamp: TransactionId, read_set: ReadSet, write_set: WriteSet, } pub struct PendingTransaction<'a> { accessor: &'a mut DatabaseAccessor, timestamp: TransactionId, read_set: ReadSet, write_set: WriteSet, } impl<'a> From> for TransactionCandidate { fn from(pending: PendingTransaction) -> Self { Self { timestamp: pending.timestamp, read_set: pending.read_set, write_set: pending.write_set, } } } impl<'a> PendingTransaction<'a> { fn new(accessor: &'a mut DatabaseAccessor, timestamp: TransactionId) -> Self { Self { accessor, timestamp, read_set: ReadSet::default(), write_set: WriteSet::default(), } } pub fn insert(&mut self, id: Id, data: T) { // FIXME: decide on a way to generate unqiue ids self.write_set.writes.push(Write { id: id.into(), operation: data.as_full_write_op(), }); } pub fn modify( &mut self, id: Id, transform: impl Fn(T) -> T, ) { println!("attempting modification in timestamp {:?}", self.timestamp); let old_data = self.accessor.get(id.clone(), self.timestamp).unwrap(); let new_data = transform(old_data.clone()); let operation = old_data.as_partial_write_op(&new_data); match &operation { WriteOperation::Full(_) => self.read_set.push(DocumentRead::Complete), WriteOperation::Partial(partial_writes) => self.read_set.0.extend( partial_writes .iter() .map(|w| DocumentRead::Field(w.field_ident.clone())), ), } self.write_set.writes.push(Write { id: (&id).into(), operation, }); } pub fn get(&mut self, id: Id) -> Option { self.read_set.push(DocumentRead::Complete); self.accessor.get(id, self.timestamp) } pub fn get_field( &mut self, id: Id, field: &'static str, ) -> Option { self.read_set.push(DocumentRead::Field(field.to_string())); self.accessor .get(id, self.timestamp) .and_then(|data| data.field(field)) } } // TODO: make these actual sets? #[derive(Debug, Default)] struct ReadSet(Vec); impl ReadSet { fn push(&mut self, read: DocumentRead) { self.0.push(read); } } impl ReadSet { fn overlaps_with(&self, write_set: &WriteSet) -> bool { if !write_set.writes.is_empty() && self .0 .iter() .any(|read| matches!(read, DocumentRead::Complete)) { return true; } let writes = write_set .writes .iter() .filter_map(|write| match &write.operation { WriteOperation::Full(_) => None, WriteOperation::Partial(partial) => Some(partial.iter()), }) .flatten() .map(|field| field.field_ident.as_str()) .collect::>(); let reads = self .0 .iter() .filter_map(|read| match read { DocumentRead::Complete => None, DocumentRead::Field(field) => Some(field.as_ref()), }) .collect::>(); reads.intersection(&writes).count() > 0 } } #[derive(Debug)] enum DocumentRead { Complete, Field(String), } #[derive(Debug, Clone, bincode::Encode, bincode::Decode)] pub enum Value { String(String), Int(i64), Float(f64), Bool(bool), Array(Vec), } impl Value { fn as_string(&self) -> Option<&String> { match self { Value::String(s) => Some(s), _ => None, } } fn as_int(&self) -> Option<&i64> { match self { Value::Int(i) => Some(i), _ => None, } } fn as_array(&self) -> Option<&Vec> { match self { Value::Array(a) => Some(a), _ => None, } } fn as_bytes(&self) -> Vec { bincode::encode_to_vec(self, bincode::config::standard()).expect("encode MUST succeed") } fn from_bytes(data: &[u8]) -> Self { let (value, _): (Self, _) = bincode::decode_from_slice(data, bincode::config::standard()) .expect("decode MUST succeed"); value } } #[cfg(test)] mod tests { use std::collections::HashSet; use crate::outbound::db_custom::write_set::FieldDiff; use super::*; #[derive(Debug, Clone, bincode::Encode, bincode::Decode)] struct MyTestData { name: Name, age: i64, contacts: HashSet>, } #[derive(Debug, Clone, bincode::Encode, bincode::Decode)] struct Name { first: String, last: String, } #[derive(bincode::Encode)] struct A { name: String, data: T, } impl A { fn with(name: impl ToString, data: T) -> Self { Self { name: name.to_string(), data, } } } impl Storable for MyTestData { fn as_full_write_op(&self) -> WriteOperation where Self: Sized, { let data = bincode::encode_to_vec(A::with("MyTestData", self), bincode::config::standard()) .expect("encode MUST succeed"); WriteOperation::Full(data) } fn as_partial_write_op(&self, other: &Self) -> WriteOperation where Self: Sized, { WriteOperation::Partial(self.diff(other).into_iter().map(Into::into).collect()) } fn apply_partial_write(&mut self, op: &write_set::PartialWrite) { let value = Value::from_bytes(&op.data); self.set_field(&op.field_ident, value); } fn set_field(&mut self, field_ident: &str, value: Value) where Self: Sized, { match field_ident { "name.first" => { let Value::String(name) = value else { panic!("expected 'name.first' to be a 'String'"); }; self.name.first = name; } "name.last" => { let Value::String(name) = value else { panic!("expected 'name.last' to be a 'String'"); }; self.name.last = name; } "age" => { let Value::Int(age) = value else { panic!("expected 'age' to be a 'Int'"); }; self.age = age; } "contacts" => { let Value::Array(contacts) = value else { panic!("expected 'contacts' to be a 'Vec'"); }; self.contacts = contacts.into_iter().map(|id| Id::<_>::new(id.0)).collect(); } _ => panic!("invalid field '{field_ident}'"), }; } fn field(&self, field_ident: &str) -> Option where Self: Sized, { match field_ident { "name.first" => Some(Value::String(self.name.first.clone())), "name.last" => Some(Value::String(self.name.last.clone())), "age" => Some(Value::Int(self.age)), "contacts" => Some(Value::Array( self.contacts.iter().map(|id| id.into()).collect(), )), _ => None, } } } impl Diffable for MyTestData { fn diff(&self, other: &Self) -> Vec where Self: Sized, { let mut modifications = vec![]; if self.name.first != other.name.first { modifications.push(FieldDiff::of( "name.first", Value::String(other.name.first.clone()), )); } if self.name.last != other.name.last { modifications.push(FieldDiff::of( "name.last", Value::String(other.name.last.clone()), )); } if self.age != other.age { modifications.push(FieldDiff::of("age", Value::Int(other.age))); } let contacts_diff = self.contacts.difference(&other.contacts); if contacts_diff.count() > 0 { modifications.push(FieldDiff::of( "contacts", Value::Array(other.contacts.iter().map(|id| id.into()).collect()), )); } modifications } } #[test] fn test() { let mut db = Database::default(); db.register_type("MyTestData", |data| { let (data, _): (MyTestData, usize) = bincode::decode_from_slice(data, bincode::config::standard()) .expect("decode MUST succeed"); Box::new(data) }); let mut accessor = DatabaseAccessor::new(db); let data = MyTestData { name: Name { first: "John".to_string(), last: "Doe".to_string(), }, age: 42, contacts: HashSet::new(), }; println!("initial state: {:#?}", accessor.db.lock().unwrap()); accessor.transact(move |t| { t.insert(Id::new(1), data.clone()); t.insert( Id::new(2), MyTestData { name: Name { first: "Oldy".to_string(), last: "McOlderton".to_string(), }, age: 69, contacts: vec![Id::new(1)].into_iter().collect(), }, ); }); let cloned_accessor = accessor.clone(); let handle = std::thread::spawn(move || { cloned_accessor.transact(|t| { let Some(Value::String(last_name_of_oldy)) = t.get_field(Id::::new(2), "name.last") else { panic!("expected a string") }; t.modify(Id::::new(1), |mut data| { data.name.last = format!("Gearbox {}", last_name_of_oldy); data }); }); }); let id = Id::::new(1); accessor.transact(|t| { t.modify(id.clone(), |mut data| { data.name.last = "Not McOlderton Gearbox".to_string(); data }); }); let oldys_contacts = accessor.transact(|t| { let oldy = t.get(Id::::new(2)).unwrap(); oldy.contacts .into_iter() .filter_map(|id| t.get(id)) .map(|data| format!("{} {}", data.name.first, data.name.last)) .collect::>() }); handle.join().unwrap(); let data0 = accessor.get(Id::::new(1), 0.into()); let data1 = accessor.get(Id::::new(1), 1.into()); let data2 = accessor.get(Id::::new(1), 2.into()); let data3 = accessor.get(Id::::new(1), 3.into()); println!("data at timestamp 0: {:#?}", data0); println!("data at timestamp 1: {:#?}", data1); println!("data at timestamp 2: {:#?}", data2); println!("data at timestamp 3: {:#?}", data3); panic!("oldys contacts: {:#?}", oldys_contacts); } }