peillute/
snapshot.rs

1//! Snapshot management for distributed state consistency
2//!
3//! This module implements a distributed snapshot algorithm for ensuring
4//! consistency across nodes in the distributed system. It handles snapshot
5//! creation, consistency checking, and persistence.
6
7#[cfg(feature = "server")]
8/// Summary of a transaction for snapshot purposes
9#[derive(Clone, serde::Serialize, serde::Deserialize, Debug, Eq, PartialEq, Hash)]
10pub struct TxSummary {
11    /// Lamport timestamp of the transaction
12    pub lamport_time: i64,
13    /// ID of the node that created the transaction
14    pub source_node: String,
15    /// Source user of the transaction
16    pub from_user: String,
17    /// Destination user of the transaction
18    pub to_user: String,
19    /// Transaction amount
20    pub amount_in_cent: i64,
21}
22
23#[cfg(feature = "server")]
24/// Snapshot mode
25#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)]
26pub enum SnapshotMode {
27    /// When all snapshots are received, we can create a global snapshot and save the file
28    FileMode,
29    /// When all snapshot are received, we can create a global snapshot and send it to the network
30    NetworkMode,
31    /// When all snapshot are received, we can create a global snapshot and apply it to the local state
32    SyncMode,
33}
34
35#[cfg(feature = "server")]
36impl From<&crate::db::Transaction> for TxSummary {
37    fn from(tx: &crate::db::Transaction) -> Self {
38        Self {
39            lamport_time: tx.lamport_time,
40            source_node: tx.source_node.clone(),
41            from_user: tx.from_user.clone(),
42            to_user: tx.to_user.clone(),
43            amount_in_cent: (tx.amount * 100.0) as i64,
44        }
45    }
46}
47
48#[cfg(feature = "server")]
49/// Local snapshot of a node's state
50#[derive(Clone)]
51pub struct LocalSnapshot {
52    /// ID of the node taking the snapshot
53    pub site_id: String,
54    /// Vector clock state at the time of the snapshot
55    pub vector_clock: std::collections::HashMap<String, i64>,
56    /// Set of transactions known to this node
57    pub tx_log: std::collections::HashSet<TxSummary>,
58}
59
60#[cfg(feature = "server")]
61/// Global snapshot combining all local snapshots
62#[derive(Clone, Debug, serde::Serialize)]
63pub struct GlobalSnapshot {
64    /// Union of all transactions across nodes
65    pub all_transactions: std::collections::HashSet<TxSummary>,
66    /// Map of missing transactions per node
67    pub missing: std::collections::HashMap<String, std::collections::HashSet<TxSummary>>,
68}
69
70#[cfg(feature = "server")]
71impl GlobalSnapshot {
72    /// Checks if a set of local snapshots is consistent
73    ///
74    /// A set of snapshots is consistent if for any pair of nodes i and j,
75    /// the vector clock value of j in i's snapshot is not greater than
76    /// the vector clock value of j in j's own snapshot.
77    pub fn is_consistent(snaps: &[LocalSnapshot]) -> bool {
78        for si in snaps {
79            for sj in snaps {
80                if let (Some(&cij), Some(&cjj)) = (
81                    si.vector_clock.get(&sj.site_id),
82                    sj.vector_clock.get(&sj.site_id),
83                ) {
84                    if cij > cjj {
85                        return false;
86                    }
87                }
88            }
89        }
90        true
91    }
92}
93
94#[cfg(feature = "server")]
95/// Manages the snapshot collection process
96pub struct SnapshotManager {
97    /// Number of snapshots expected to be collected
98    pub expected: usize,
99    /// Vector of received local snapshots
100    pub received: Vec<LocalSnapshot>,
101    /// Path to the last snapshot saved
102    pub path: Option<std::path::PathBuf>,
103    /// Snapshot mode
104    pub mode: SnapshotMode,
105}
106
107#[cfg(feature = "server")]
108impl SnapshotManager {
109    /// Creates a new snapshot manager expecting the given number of snapshots
110    pub fn new(expected: usize) -> Self {
111        Self {
112            expected,
113            received: Vec::new(),
114            path: None,
115            mode: SnapshotMode::FileMode,
116        }
117    }
118
119    /// Adds a snapshot response to the collection
120    ///
121    /// Returns a global snapshot if all expected snapshots have been received
122    /// and they are consistent. If the snapshots are inconsistent, it will
123    /// attempt to find a consistent subset by backtracking to the minimum
124    /// vector clock values.
125    /// all_received is defined by the state of our wave diffusion protocol
126    pub fn push(&mut self, resp: crate::message::SnapshotResponse) -> Option<GlobalSnapshot> {
127        log::debug!("Adding snapshot {} in the manager.", resp.site_id);
128        self.received.push(LocalSnapshot {
129            site_id: resp.site_id.clone(),
130            vector_clock: resp.clock.get_vector_clock_map().clone(),
131            tx_log: resp.tx_log.into_iter().collect(),
132        });
133
134        if self.received.len() < self.expected {
135            log::debug!("{}/{} sites received.", self.received.len(), self.expected);
136            return None;
137        }
138
139        log::debug!("All local snapshots received, processing snapshot.");
140
141        // In Sync and Network modes we simply aggregate all received
142        // transactions without enforcing snapshot consistency. This prevents
143        // dropping valid transactions when a node joins and requests a global
144        // snapshot or when an intermediate node aggregates snapshots from its
145        // children
146        if matches!(
147            self.mode,
148            SnapshotMode::SyncMode | SnapshotMode::NetworkMode
149        ) {
150            return Some(self.build_snapshot(&self.received));
151        }
152
153        if GlobalSnapshot::is_consistent(&self.received) {
154            return Some(self.build_snapshot(&self.received));
155        }
156
157        // Back-track to the last consistent snapshot by computing the minimum vector clock
158        // V_j = min_i Ci[j], where Ci[j] is the clock value for site j in snapshot i.
159        // This ensures that we only consider transactions that are consistent across all snapshots.
160
161        // Compute the minimum vector clock (vmin) across all received snapshots.
162        let mut vmin: std::collections::HashMap<String, i64> = std::collections::HashMap::new();
163        for snap in &self.received {
164            for (site, &val) in &snap.vector_clock {
165                // Update vmin for each site to the minimum value observed across all snapshots.
166                vmin.entry(site.clone())
167                    .and_modify(|m| *m = (*m).min(val))
168                    .or_insert(val);
169            }
170        }
171
172        // Create a new list of snapshots with trimmed vector clocks and transaction logs.
173        let mut trimmed: Vec<LocalSnapshot> = Vec::new();
174        for mut s in self.received.clone() {
175            // Limit the vector clock for the current site to the minimum value in vmin.
176            let lim = *vmin.get(&s.site_id).unwrap_or(&0);
177            s.vector_clock.insert(s.site_id.clone(), lim);
178
179            // Filter the transaction log to only include transactions that are consistent
180            // with the minimum vector clock for their source node.
181            let tx_keep: std::collections::HashSet<_> = s
182                .tx_log
183                .into_iter()
184                .filter(|t| t.lamport_time <= *vmin.get(&t.source_node).unwrap_or(&0))
185                .collect();
186            s.tx_log = tx_keep;
187
188            trimmed.push(s);
189        }
190
191        Some(self.build_snapshot(&trimmed))
192    }
193
194    /// Builds a global snapshot from a set of local snapshots
195    ///
196    /// Computes the union of all transactions and identifies missing
197    /// transactions for each node.
198    fn build_snapshot(&self, snaps: &[LocalSnapshot]) -> GlobalSnapshot {
199        let mut union: std::collections::HashSet<TxSummary> = std::collections::HashSet::new();
200        for s in snaps {
201            log::info!(
202                "Adding transactions from site {}, transaction : {:?}",
203                s.site_id,
204                s.tx_log
205            );
206            union.extend(s.tx_log.iter().cloned());
207        }
208
209        let mut miss: std::collections::HashMap<String, std::collections::HashSet<TxSummary>> =
210            std::collections::HashMap::new();
211        for s in snaps {
212            let diff: std::collections::HashSet<_> = union.difference(&s.tx_log).cloned().collect();
213            if !diff.is_empty() {
214                miss.insert(s.site_id.clone(), diff);
215            }
216        }
217        GlobalSnapshot {
218            all_transactions: union,
219            missing: miss,
220        }
221    }
222}
223
224#[cfg(feature = "server")]
225/// Initiates a new snapshot process
226///
227/// Collects the local transaction log and sends snapshot requests to all peers.
228pub async fn start_snapshot(mode: SnapshotMode) -> Result<(), Box<dyn std::error::Error>> {
229    let local_txs = crate::db::get_local_transaction_log()?;
230    let summaries: Vec<TxSummary> = local_txs.iter().map(|t| t.into()).collect();
231
232    let (site_id, clock, expected) = {
233        let st = crate::state::LOCAL_APP_STATE.lock().await;
234        // We expect a snapshot from all connected peers
235        // + 1 for self
236        let expected_peers = match mode {
237            SnapshotMode::NetworkMode => {
238                // In NetworkMode, we expect a snapshot from all connected peers except our parent
239                st.get_connected_nei_addr().len()
240            }
241            _ => {
242                // Default case: expect snapshots from all connected peers including ourselves
243                st.get_connected_nei_addr().len() + 1
244            }
245        };
246        (st.get_site_id(), st.get_clock(), expected_peers)
247    };
248
249    {
250        let mut mgr = LOCAL_SNAPSHOT_MANAGER.lock().await;
251        mgr.expected = expected;
252        mgr.received.clear();
253        mgr.mode = mode.clone();
254        if let Some(gs) = mgr.push(crate::message::SnapshotResponse {
255            site_id: site_id.clone(),
256            clock: clock.clone(),
257            tx_log: summaries.clone(),
258        }) {
259            if mode.clone() == SnapshotMode::FileMode {
260                log::info!(
261                    "Global snapshot ready to be saved at start, hold per site : {:#?}",
262                    gs.missing
263                );
264                mgr.path = crate::snapshot::persist(&gs, site_id.clone())
265                    .await
266                    .unwrap()
267                    .parse()
268                    .ok();
269            } else if mode.clone() == SnapshotMode::SyncMode {
270                log::info!("No other site, synchronization done");
271            } else {
272                log::error!(
273                    "Start snapshot is not supposed to be called when there is no neighbours with network mode"
274                );
275            }
276        }
277    }
278
279    Ok(())
280}
281
282#[cfg(feature = "server")]
283/// Persists a global snapshot to disk
284///
285/// Saves the snapshot as a JSON file with a timestamp in the filename.
286pub async fn persist(snapshot: &GlobalSnapshot, site_id: String) -> std::io::Result<String> {
287    use std::io::Write;
288
289    let ts = chrono::Local::now().format("%Y%m%d_%H%M%S");
290    let filename = format!("snapshot_{}_{}.json", site_id, ts);
291
292    let mut file = std::fs::File::create(&filename)?;
293    let json = serde_json::to_string_pretty(snapshot).unwrap();
294    file.write_all(json.as_bytes())?;
295    println!("📸 Snapshot completed successfully at {}", filename);
296
297    Ok(filename)
298}
299
300#[cfg(feature = "server")]
301lazy_static::lazy_static! {
302    pub static ref LOCAL_SNAPSHOT_MANAGER: tokio::sync::Mutex<SnapshotManager> =
303        tokio::sync::Mutex::new(SnapshotManager::new(0));
304}
305
306#[cfg(test)]
307#[cfg(feature = "server")]
308mod tests {
309    use super::*;
310
311    fn mk_clock(pairs: &[(&str, i64)]) -> crate::clock::Clock {
312        let mut m = std::collections::HashMap::new();
313        for (id, v) in pairs {
314            m.insert((*id).to_string(), *v);
315        }
316        let c = crate::clock::Clock::new_with_values(0, m);
317        c
318    }
319
320    fn resp(site: &str, vc: &[(&str, i64)], txs: &[TxSummary]) -> crate::message::SnapshotResponse {
321        crate::message::SnapshotResponse {
322            site_id: site.to_string(),
323            clock: mk_clock(vc),
324            tx_log: txs.to_vec(),
325        }
326    }
327
328    #[test]
329    fn consistency_ok() {
330        let s1 = LocalSnapshot {
331            site_id: "A".into(),
332            vector_clock: std::collections::HashMap::from_iter([("A".into(), 1), ("B".into(), 0)]),
333            tx_log: std::collections::HashSet::new(),
334        };
335        let s2 = LocalSnapshot {
336            site_id: "B".into(),
337            vector_clock: std::collections::HashMap::from_iter([("A".into(), 1), ("B".into(), 1)]),
338            tx_log: std::collections::HashSet::new(),
339        };
340        assert!(GlobalSnapshot::is_consistent(&[s1, s2]));
341    }
342
343    #[test]
344    fn consistency_violation() {
345        let s1 = LocalSnapshot {
346            site_id: "A".into(),
347            vector_clock: std::collections::HashMap::from_iter([("A".into(), 2), ("B".into(), 2)]),
348            tx_log: std::collections::HashSet::new(),
349        };
350        let s2 = LocalSnapshot {
351            site_id: "B".into(),
352            vector_clock: std::collections::HashMap::from_iter([("A".into(), 1), ("B".into(), 1)]),
353            tx_log: std::collections::HashSet::new(),
354        };
355        assert!(!GlobalSnapshot::is_consistent(&[s1, s2]));
356    }
357
358    #[test]
359    fn push_waits_for_expected() {
360        let mut mgr = SnapshotManager::new(2);
361        let tx = TxSummary {
362            lamport_time: 1,
363            source_node: "A".into(),
364            from_user: "user1".into(),
365            to_user: "user2".into(),
366            amount_in_cent: 100,
367        };
368        let r1 = resp("A", &[("A", 1)], &[tx.clone()]);
369        assert!(mgr.push(r1).is_none());
370        assert_eq!(mgr.received.len(), 1);
371    }
372
373    #[test]
374    fn push_detects_incoherence() {
375        let mut mgr = SnapshotManager::new(2);
376        let bad_r1 = resp("A", &[("A", 2), ("B", 2)], &[]);
377        let bad_r2 = resp("B", &[("A", 1), ("B", 1)], &[]);
378        assert!(mgr.push(bad_r1).is_none());
379        let snap = mgr.push(bad_r2).expect("back-tracked snapshot");
380        assert!(GlobalSnapshot::is_consistent(&[LocalSnapshot {
381            site_id: "dummy".into(),
382            vector_clock: std::collections::HashMap::new(),
383            tx_log: snap.all_transactions.clone()
384        }]));
385        assert!(snap.missing.is_empty() || !snap.missing.contains_key("A"));
386    }
387
388    #[test]
389    fn push_computes_missing_and_dedup() {
390        let mut mgr = SnapshotManager::new(2);
391        let t1 = TxSummary {
392            lamport_time: 10,
393            source_node: "A".into(),
394            from_user: "user1".into(),
395            to_user: "user2".into(),
396            amount_in_cent: 100,
397        };
398        let t2 = TxSummary {
399            lamport_time: 11,
400            source_node: "B".into(),
401            from_user: "user3".into(),
402            to_user: "user4".into(),
403            amount_in_cent: 200,
404        };
405
406        let r1 = resp("A", &[("A", 1)], &[t1.clone()]);
407        let r2 = resp("B", &[("B", 1)], &[t1.clone(), t2.clone()]);
408
409        let _ = mgr.push(r1);
410        let gs = mgr.push(r2).expect("snapshot ready");
411
412        assert_eq!(gs.all_transactions.len(), 2);
413        assert_eq!(
414            gs.missing["A"],
415            std::collections::HashSet::from_iter([t2.clone()])
416        );
417        assert!(!gs.missing.contains_key("B"));
418    }
419
420    #[test]
421    fn consistency_handles_missing_columns() {
422        let a = LocalSnapshot {
423            site_id: "A".into(),
424            vector_clock: std::collections::HashMap::from_iter([("A".into(), 3)]),
425            tx_log: std::collections::HashSet::new(),
426        };
427        let b = LocalSnapshot {
428            site_id: "B".into(),
429            vector_clock: std::collections::HashMap::from_iter([("B".into(), 1)]),
430            tx_log: std::collections::HashSet::new(),
431        };
432        assert!(GlobalSnapshot::is_consistent(&[a, b]));
433    }
434
435    #[test]
436    fn backtrack_trims_future_transactions() {
437        let mut mgr = SnapshotManager::new(2);
438
439        let t1 = TxSummary {
440            lamport_time: 1,
441            source_node: "A".into(),
442            from_user: "user1".into(),
443            to_user: "user2".into(),
444            amount_in_cent: 100,
445        };
446        let t3 = TxSummary {
447            lamport_time: 3,
448            source_node: "A".into(),
449            from_user: "user1".into(),
450            to_user: "user2".into(),
451            amount_in_cent: 300,
452        };
453        let t5 = TxSummary {
454            lamport_time: 5,
455            source_node: "A".into(),
456            from_user: "user1".into(),
457            to_user: "user2".into(),
458            amount_in_cent: 500,
459        };
460
461        let r_a = resp(
462            "A",
463            &[("A", 5), ("B", 2)], // ← incohérence ici
464            &[t1.clone(), t3.clone(), t5.clone()],
465        );
466
467        let r_b = resp("B", &[("A", 2), ("B", 1)], &[]);
468
469        let _ = mgr.push(r_a);
470        let snap = mgr.push(r_b).expect("snapshot after back-track");
471
472        assert!(snap.all_transactions.contains(&t1));
473        assert!(!snap.all_transactions.contains(&t5));
474    }
475
476    #[test]
477    fn union_is_deduplicated() {
478        let mut mgr = SnapshotManager::new(2);
479        let tx = TxSummary {
480            lamport_time: 7,
481            source_node: "A".into(),
482            from_user: "user1".into(),
483            to_user: "user2".into(),
484            amount_in_cent: 700,
485        };
486
487        let r1 = resp("A", &[("A", 1)], &[tx.clone()]);
488        let r2 = resp("B", &[("B", 1)], &[tx.clone()]);
489
490        let _ = mgr.push(r1);
491        let gs = mgr.push(r2).expect("snapshot ready");
492        assert_eq!(gs.all_transactions.len(), 1);
493    }
494}