grafos_pipeline/
checkpoint.rs1extern crate alloc;
4use alloc::format;
5use alloc::vec::Vec;
6
7use grafos_std::block::{BlockLease, BLOCK_SIZE};
8use serde::{de::DeserializeOwned, Serialize};
9
10use crate::error::EdgeError;
11
12const MAGIC: [u8; 4] = *b"PCHK";
14const VERSION: u32 = 1;
15const HEADER_BLOCK: u64 = 0;
16const DATA_START_BLOCK: u64 = 1;
17
18pub struct CheckpointedStageState<T> {
30 state: T,
31 block_lease: BlockLease,
32}
33
34impl<T: Serialize + DeserializeOwned> CheckpointedStageState<T> {
35 pub fn new(state: T, block_lease: BlockLease) -> Self {
40 Self { state, block_lease }
41 }
42
43 pub fn checkpoint(&mut self) -> Result<(), EdgeError> {
53 let data = postcard::to_allocvec(&self.state)
54 .map_err(|e| EdgeError::CheckpointFailed(format!("serialize: {e}")))?;
55
56 let data_len = data.len() as u64;
57
58 let mut header = [0u8; BLOCK_SIZE];
60 header[0..4].copy_from_slice(&MAGIC);
61 header[4..8].copy_from_slice(&VERSION.to_le_bytes());
62 header[8..16].copy_from_slice(&data_len.to_le_bytes());
63 self.block_lease
64 .block()
65 .write_block(HEADER_BLOCK, &header)
66 .map_err(|e| EdgeError::CheckpointFailed(format!("write header: {e}")))?;
67
68 let num_blocks = data.len().div_ceil(BLOCK_SIZE);
70 let available = self.block_lease.block().num_blocks() as usize;
71 if num_blocks + 1 > available {
72 return Err(EdgeError::CheckpointFailed(
73 "state too large for block lease".into(),
74 ));
75 }
76
77 for i in 0..num_blocks {
78 let start = i * BLOCK_SIZE;
79 let end = core::cmp::min(start + BLOCK_SIZE, data.len());
80 let mut block = [0u8; BLOCK_SIZE];
81 block[..end - start].copy_from_slice(&data[start..end]);
82 self.block_lease
83 .block()
84 .write_block(DATA_START_BLOCK + i as u64, &block)
85 .map_err(|e| EdgeError::CheckpointFailed(format!("write data block {i}: {e}")))?;
86 }
87
88 Ok(())
89 }
90
91 pub fn restore(block_lease: BlockLease) -> Result<Self, EdgeError> {
101 let header = block_lease
102 .block()
103 .read_block(HEADER_BLOCK)
104 .map_err(|e| EdgeError::CheckpointFailed(format!("read header: {e}")))?;
105
106 if header[0..4] != MAGIC {
107 return Err(EdgeError::CheckpointFailed("bad magic".into()));
108 }
109 let version = u32::from_le_bytes([header[4], header[5], header[6], header[7]]);
110 if version != VERSION {
111 return Err(EdgeError::CheckpointFailed(format!(
112 "unsupported version: {version}"
113 )));
114 }
115 let data_len = u64::from_le_bytes([
116 header[8], header[9], header[10], header[11], header[12], header[13], header[14],
117 header[15],
118 ]) as usize;
119
120 let num_blocks = data_len.div_ceil(BLOCK_SIZE);
121 let mut data = Vec::with_capacity(data_len);
122 for i in 0..num_blocks {
123 let block = block_lease
124 .block()
125 .read_block(DATA_START_BLOCK + i as u64)
126 .map_err(|e| EdgeError::CheckpointFailed(format!("read data block {i}: {e}")))?;
127 let remaining = data_len - data.len();
128 let take = core::cmp::min(BLOCK_SIZE, remaining);
129 data.extend_from_slice(&block[..take]);
130 }
131
132 let state: T = postcard::from_bytes(&data)
133 .map_err(|e| EdgeError::CheckpointFailed(format!("deserialize: {e}")))?;
134
135 Ok(Self { state, block_lease })
136 }
137
138 pub fn state(&self) -> &T {
140 &self.state
141 }
142
143 pub fn state_mut(&mut self) -> &mut T {
145 &mut self.state
146 }
147
148 pub fn into_block_lease(self) -> BlockLease {
154 self.block_lease
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use crate::EdgeCheckpoint;
162 use grafos_std::block::BlockBuilder;
163 use grafos_std::host;
164
165 fn setup_block(num_blocks: u64) -> BlockLease {
166 host::reset_mock();
167 host::mock_set_fbbu_num_blocks(num_blocks);
168 BlockBuilder::new().acquire().expect("acquire")
169 }
170
171 #[test]
172 fn checkpoint_and_restore_roundtrip() {
173 let block_lease = setup_block(64);
174
175 let state = EdgeCheckpoint {
176 input_offset: 42,
177 output_commit: 99,
178 };
179
180 let mut ckpt = CheckpointedStageState::new(state, block_lease);
181 ckpt.checkpoint().expect("checkpoint");
182
183 let lease = ckpt.into_block_lease();
185 let restored = CheckpointedStageState::<EdgeCheckpoint>::restore(lease).expect("restore");
186 assert_eq!(restored.state().input_offset, 42);
187 assert_eq!(restored.state().output_commit, 99);
188 }
189
190 #[test]
191 fn restore_fails_on_corrupt_data() {
192 let block_lease = setup_block(64);
193 let mut garbage = [0u8; BLOCK_SIZE];
195 garbage[0..4].copy_from_slice(b"BAAD");
196 block_lease.block().write_block(0, &garbage).expect("write");
197
198 let result = CheckpointedStageState::<EdgeCheckpoint>::restore(block_lease);
199 match result {
200 Err(EdgeError::CheckpointFailed(msg)) => assert_eq!(msg, "bad magic"),
201 Err(other) => panic!("unexpected error: {other}"),
202 Ok(_) => panic!("expected error but got Ok"),
203 }
204 }
205
206 #[test]
207 fn state_mut_allows_modification() {
208 let block_lease = setup_block(64);
209 let state = EdgeCheckpoint {
210 input_offset: 0,
211 output_commit: 0,
212 };
213
214 let mut ckpt = CheckpointedStageState::new(state, block_lease);
215 ckpt.state_mut().input_offset = 100;
216 ckpt.state_mut().output_commit = 200;
217
218 assert_eq!(ckpt.state().input_offset, 100);
219 assert_eq!(ckpt.state().output_commit, 200);
220 }
221
222 #[test]
223 fn checkpoint_preserves_complex_state() {
224 let block_lease = setup_block(64);
225
226 #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
227 struct StageState {
228 name: alloc::string::String,
229 items_processed: u64,
230 checkpoint: EdgeCheckpoint,
231 }
232
233 let state = StageState {
234 name: "transform-stage".into(),
235 items_processed: 1024,
236 checkpoint: EdgeCheckpoint {
237 input_offset: 500,
238 output_commit: 480,
239 },
240 };
241
242 let mut ckpt = CheckpointedStageState::new(state.clone(), block_lease);
243 ckpt.checkpoint().expect("checkpoint");
244
245 let lease = ckpt.into_block_lease();
247 let restored = CheckpointedStageState::<StageState>::restore(lease).expect("restore");
248 assert_eq!(restored.state(), &state);
249 }
250}