1use alloc::collections::BTreeMap;
4use alloc::string::String;
5use alloc::vec::Vec;
6use core::fmt;
7
8use grafos_std::error::FabricError;
9
10use crate::output_store::JobOutputStore;
11use crate::retry::{RetryPolicy, RetryableError};
12use crate::work_chunk::{ChunkId, WorkChunk};
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum ChunkStatus {
17 Success,
19 Failed(String),
21 Skipped,
23}
24
25#[derive(Debug, Clone)]
27pub struct ChunkResult {
28 pub status: ChunkStatus,
30 pub retries_used: u32,
32}
33
34pub struct JobResult {
36 pub chunk_results: BTreeMap<ChunkId, ChunkResult>,
38 pub aggregate: Vec<u8>,
40 pub succeeded: u32,
42 pub failed: u32,
44 pub skipped: u32,
46}
47
48impl fmt::Debug for JobResult {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 f.debug_struct("JobResult")
51 .field("succeeded", &self.succeeded)
52 .field("failed", &self.failed)
53 .field("skipped", &self.skipped)
54 .field("aggregate_len", &self.aggregate.len())
55 .finish()
56 }
57}
58
59pub struct JobCoordinator {
70 policy: RetryPolicy,
71}
72
73impl JobCoordinator {
74 pub fn new(policy: RetryPolicy) -> Self {
76 JobCoordinator { policy }
77 }
78
79 pub fn run<F, R>(
96 &mut self,
97 chunks: &[Box<dyn WorkChunk>],
98 store: &mut dyn JobOutputStore,
99 mut execute_fn: F,
100 reduce_fn: R,
101 ) -> Result<JobResult, FabricError>
102 where
103 F: FnMut(&[u8]) -> Result<Vec<u8>, FabricError>,
104 R: FnOnce(&[(ChunkId, Vec<u8>)]) -> Vec<u8>,
105 {
106 let mut chunk_results: BTreeMap<ChunkId, ChunkResult> = BTreeMap::new();
107 let mut succeeded: u32 = 0;
108 let mut failed: u32 = 0;
109 let mut skipped: u32 = 0;
110
111 for chunk in chunks {
112 let cid = chunk.chunk_id();
113
114 if store.contains(cid) {
116 chunk_results.insert(
117 cid,
118 ChunkResult {
119 status: ChunkStatus::Skipped,
120 retries_used: 0,
121 },
122 );
123 skipped += 1;
124 continue;
125 }
126
127 let chunk_bytes = chunk.to_bytes();
128 let max_attempts = self.policy.max_retries + 1;
129 let mut last_err: Option<FabricError> = None;
130 let mut retries_used: u32 = 0;
131 let mut _backoff = self.policy.backoff();
132
133 for attempt in 0..max_attempts {
134 match execute_fn(&chunk_bytes) {
135 Ok(output) => {
136 store.put(cid, output)?;
137 retries_used = attempt;
138 last_err = None;
139 break;
140 }
141 Err(e) => {
142 retries_used = attempt;
143 if self.policy.classify(&e) == RetryableError::Permanent {
145 last_err = Some(e);
146 break;
147 }
148 last_err = Some(e);
149 let _delay = _backoff.next_delay();
154 }
155 }
156 }
157
158 if let Some(e) = last_err {
159 chunk_results.insert(
160 cid,
161 ChunkResult {
162 status: ChunkStatus::Failed(alloc::format!("{}", e)),
163 retries_used,
164 },
165 );
166 failed += 1;
167 } else {
168 chunk_results.insert(
169 cid,
170 ChunkResult {
171 status: ChunkStatus::Success,
172 retries_used,
173 },
174 );
175 succeeded += 1;
176 }
177 }
178
179 let mut outputs: Vec<(ChunkId, Vec<u8>)> = Vec::new();
181 for chunk in chunks {
182 let cid = chunk.chunk_id();
183 if let Some(data) = store.get(cid)? {
184 outputs.push((cid, data));
185 }
186 }
187
188 let aggregate = reduce_fn(&outputs);
189
190 Ok(JobResult {
191 chunk_results,
192 aggregate,
193 succeeded,
194 failed,
195 skipped,
196 })
197 }
198
199 pub fn teardown(&self, store: &mut dyn JobOutputStore) {
201 store.clear();
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::{MemoryOutputStore, RetryPolicy};
209
210 #[derive(Clone, serde::Serialize, serde::Deserialize)]
211 struct TestChunk {
212 id: u64,
213 value: u64,
214 }
215
216 impl WorkChunk for TestChunk {
217 fn chunk_id(&self) -> ChunkId {
218 ChunkId(self.id)
219 }
220 fn to_bytes(&self) -> Vec<u8> {
221 postcard::to_allocvec(self).unwrap()
222 }
223 fn from_bytes(bytes: &[u8]) -> Result<Self, FabricError> {
224 postcard::from_bytes(bytes).map_err(|_| FabricError::IoError(-200))
225 }
226 }
227
228 fn make_chunks(values: &[u64]) -> Vec<Box<dyn WorkChunk>> {
229 values
230 .iter()
231 .enumerate()
232 .map(|(i, &v)| -> Box<dyn WorkChunk> {
233 Box::new(TestChunk {
234 id: i as u64,
235 value: v,
236 })
237 })
238 .collect()
239 }
240
241 fn square_execute(bytes: &[u8]) -> Result<Vec<u8>, FabricError> {
242 let chunk: TestChunk =
243 postcard::from_bytes(bytes).map_err(|_| FabricError::IoError(-200))?;
244 let result = chunk.value * chunk.value;
245 Ok(postcard::to_allocvec(&result).unwrap())
246 }
247
248 fn sum_reduce(outputs: &[(ChunkId, Vec<u8>)]) -> Vec<u8> {
249 let sum: u64 = outputs
250 .iter()
251 .map(|(_, v)| postcard::from_bytes::<u64>(v).unwrap_or(0))
252 .sum();
253 postcard::to_allocvec(&sum).unwrap()
254 }
255
256 #[test]
257 fn basic_map_reduce() {
258 let chunks = make_chunks(&[2, 3, 4]);
259 let mut store = MemoryOutputStore::new();
260 let mut coord = JobCoordinator::new(RetryPolicy::default());
261
262 let result = coord
263 .run(&chunks, &mut store, square_execute, sum_reduce)
264 .unwrap();
265
266 assert_eq!(result.succeeded, 3);
267 assert_eq!(result.failed, 0);
268 assert_eq!(result.skipped, 0);
269
270 let total: u64 = postcard::from_bytes(&result.aggregate).unwrap();
271 assert_eq!(total, 4 + 9 + 16); }
273
274 #[test]
275 fn idempotent_skip_on_retry() {
276 let chunks = make_chunks(&[5]);
277 let mut store = MemoryOutputStore::new();
278
279 let expected: u64 = 25;
281 store
282 .put(ChunkId(0), postcard::to_allocvec(&expected).unwrap())
283 .unwrap();
284
285 let mut coord = JobCoordinator::new(RetryPolicy::default());
286 let result = coord
287 .run(&chunks, &mut store, square_execute, sum_reduce)
288 .unwrap();
289
290 assert_eq!(result.succeeded, 0);
291 assert_eq!(result.skipped, 1);
292 assert_eq!(result.failed, 0);
293
294 let total: u64 = postcard::from_bytes(&result.aggregate).unwrap();
295 assert_eq!(total, 25);
296 }
297
298 #[test]
299 fn retry_on_transient_failure_converges() {
300 use std::cell::Cell;
301 use std::rc::Rc;
302
303 let fail_count = Rc::new(Cell::new(0u32));
304 let fail_clone = fail_count.clone();
305
306 let chunks = make_chunks(&[7]);
307 let mut store = MemoryOutputStore::new();
308 let mut coord = JobCoordinator::new(RetryPolicy {
309 max_retries: 3,
310 initial_backoff_secs: 1,
311 max_backoff_secs: 4,
312 });
313
314 let result = coord
315 .run(
316 &chunks,
317 &mut store,
318 |bytes| {
319 let count = fail_clone.get();
320 fail_clone.set(count + 1);
321 if count < 2 {
322 Err(FabricError::Disconnected)
324 } else {
325 square_execute(bytes)
326 }
327 },
328 sum_reduce,
329 )
330 .unwrap();
331
332 assert_eq!(result.succeeded, 1);
333 assert_eq!(result.failed, 0);
334 assert_eq!(fail_count.get(), 3); let cr = result.chunk_results.get(&ChunkId(0)).unwrap();
337 assert_eq!(cr.status, ChunkStatus::Success);
338 assert_eq!(cr.retries_used, 2);
339
340 let total: u64 = postcard::from_bytes(&result.aggregate).unwrap();
341 assert_eq!(total, 49);
342 }
343
344 #[test]
345 fn permanent_failure_not_retried() {
346 use std::cell::Cell;
347 use std::rc::Rc;
348
349 let call_count = Rc::new(Cell::new(0u32));
350 let call_clone = call_count.clone();
351
352 let chunks = make_chunks(&[1]);
353 let mut store = MemoryOutputStore::new();
354 let mut coord = JobCoordinator::new(RetryPolicy {
355 max_retries: 5,
356 ..RetryPolicy::default()
357 });
358
359 let result = coord
360 .run(
361 &chunks,
362 &mut store,
363 |_bytes| {
364 call_clone.set(call_clone.get() + 1);
365 Err(FabricError::Fenced) },
367 sum_reduce,
368 )
369 .unwrap();
370
371 assert_eq!(result.failed, 1);
372 assert_eq!(result.succeeded, 0);
373 assert_eq!(call_count.get(), 1);
375
376 let cr = result.chunk_results.get(&ChunkId(0)).unwrap();
377 assert!(matches!(cr.status, ChunkStatus::Failed(_)));
378 }
379
380 #[test]
381 fn retries_exhausted() {
382 let chunks = make_chunks(&[1]);
383 let mut store = MemoryOutputStore::new();
384 let mut coord = JobCoordinator::new(RetryPolicy {
385 max_retries: 2,
386 initial_backoff_secs: 1,
387 max_backoff_secs: 4,
388 });
389
390 let result = coord
391 .run(
392 &chunks,
393 &mut store,
394 |_bytes| Err(FabricError::LeaseExpired), sum_reduce,
396 )
397 .unwrap();
398
399 assert_eq!(result.failed, 1);
400 assert_eq!(result.succeeded, 0);
401
402 let cr = result.chunk_results.get(&ChunkId(0)).unwrap();
403 assert!(matches!(cr.status, ChunkStatus::Failed(_)));
404 assert_eq!(cr.retries_used, 2); }
406
407 #[test]
408 fn determinism_same_inputs_same_outputs() {
409 use std::cell::Cell;
410 use std::rc::Rc;
411
412 for _ in 0..2 {
414 let fail_once = Rc::new(Cell::new(false));
415 let fail_clone = fail_once.clone();
416
417 let chunks = make_chunks(&[10, 20]);
418 let mut store = MemoryOutputStore::new();
419 let mut coord = JobCoordinator::new(RetryPolicy {
420 max_retries: 2,
421 initial_backoff_secs: 1,
422 max_backoff_secs: 4,
423 });
424
425 let result = coord
426 .run(
427 &chunks,
428 &mut store,
429 |bytes| {
430 if !fail_clone.get() {
432 fail_clone.set(true);
433 Err(FabricError::Disconnected)
434 } else {
435 square_execute(bytes)
436 }
437 },
438 sum_reduce,
439 )
440 .unwrap();
441
442 let total: u64 = postcard::from_bytes(&result.aggregate).unwrap();
445 assert_eq!(total, 100 + 400);
446 assert_eq!(result.succeeded, 2);
447 }
448 }
449
450 #[test]
451 fn teardown_clears_store() {
452 let chunks = make_chunks(&[3]);
453 let mut store = MemoryOutputStore::new();
454 let mut coord = JobCoordinator::new(RetryPolicy::default());
455
456 let _ = coord
457 .run(&chunks, &mut store, square_execute, sum_reduce)
458 .unwrap();
459
460 assert!(store.contains(ChunkId(0)));
461 coord.teardown(&mut store);
462 assert!(!store.contains(ChunkId(0)));
463 }
464
465 #[test]
466 fn empty_job() {
467 let chunks: Vec<Box<dyn WorkChunk>> = vec![];
468 let mut store = MemoryOutputStore::new();
469 let mut coord = JobCoordinator::new(RetryPolicy::default());
470
471 let result = coord
472 .run(&chunks, &mut store, square_execute, |_outputs| Vec::new())
473 .unwrap();
474
475 assert_eq!(result.succeeded, 0);
476 assert_eq!(result.failed, 0);
477 assert_eq!(result.skipped, 0);
478 assert!(result.aggregate.is_empty());
479 }
480
481 #[test]
482 fn mixed_success_and_failure() {
483 use std::cell::Cell;
484 use std::rc::Rc;
485
486 let call_idx = Rc::new(Cell::new(0u32));
487 let call_clone = call_idx.clone();
488
489 let chunks = make_chunks(&[2, 3, 4]);
490 let mut store = MemoryOutputStore::new();
491 let mut coord = JobCoordinator::new(RetryPolicy {
492 max_retries: 0,
493 ..RetryPolicy::default()
494 });
495
496 let result = coord
497 .run(
498 &chunks,
499 &mut store,
500 |bytes| {
501 let idx = call_clone.get();
502 call_clone.set(idx + 1);
503 if idx == 1 {
504 Err(FabricError::Fenced)
506 } else {
507 square_execute(bytes)
508 }
509 },
510 sum_reduce,
511 )
512 .unwrap();
513
514 assert_eq!(result.succeeded, 2);
515 assert_eq!(result.failed, 1);
516
517 let total: u64 = postcard::from_bytes(&result.aggregate).unwrap();
519 assert_eq!(total, 4 + 16); }
521}