grafos_batch/
lib.rs

1//! grafos-batch — Task graph executor for DAG-structured batch jobs.
2//!
3//! This crate provides a mini-Spark/Airflow-style batch job executor built on
4//! fabricBIOS. Tasks are organized into a directed acyclic graph (DAG) with
5//! explicit data dependencies. The executor computes a topological sort,
6//! groups independent tasks into execution waves, and runs them with automatic
7//! retry and failure propagation.
8//!
9//! # Quick start
10//!
11//! ```rust
12//! use grafos_batch::*;
13//!
14//! let mut graph = TaskGraph::new();
15//! let a = graph.add_task(TaskDef {
16//!     name: "produce".into(),
17//!     task_fn: Box::new(|_ctx| {
18//!         let mut data = std::collections::HashMap::new();
19//!         data.insert("out".into(), b"hello".to_vec());
20//!         Ok(TaskOutput { data })
21//!     }),
22//!     resource_req: ResourceReq::default(),
23//!     inputs: vec![],
24//!     outputs: vec![DataRef { name: "out".into(), format: DataFormat::RawBytes }],
25//!     retries: 0,
26//! });
27//! let b = graph.add_task(TaskDef {
28//!     name: "consume".into(),
29//!     task_fn: Box::new(|ctx| {
30//!         assert_eq!(ctx.inputs.get("out").unwrap(), b"hello");
31//!         Ok(TaskOutput { data: std::collections::HashMap::new() })
32//!     }),
33//!     resource_req: ResourceReq::default(),
34//!     inputs: vec![DataRef { name: "out".into(), format: DataFormat::RawBytes }],
35//!     outputs: vec![],
36//!     retries: 0,
37//! });
38//! graph.add_dependency(a, b).unwrap();
39//! let plan = graph.build().unwrap();
40//! assert_eq!(plan.waves().len(), 2);
41//! ```
42
43#![cfg_attr(not(feature = "std"), no_std)]
44
45extern crate alloc;
46
47use alloc::boxed::Box;
48use alloc::collections::BTreeMap;
49use alloc::collections::BTreeSet;
50use alloc::string::String;
51use alloc::vec::Vec;
52use core::fmt;
53
54#[cfg(feature = "std")]
55use std::collections::HashMap;
56
57#[cfg(not(feature = "std"))]
58use alloc::collections::BTreeMap as HashMap;
59
60use grafos_std::error::FabricError;
61
62// ── TaskId ──────────────────────────────────────────────────────────────
63
64/// Unique identifier for a task within a [`TaskGraph`].
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
66pub struct TaskId(pub u64);
67
68impl fmt::Display for TaskId {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        write!(f, "Task({})", self.0)
71    }
72}
73
74// ── DataRef / DataFormat ────────────────────────────────────────────────
75
76/// Format of data exchanged between tasks.
77#[derive(Debug, Clone, PartialEq, Eq)]
78pub enum DataFormat {
79    /// Structured data serialized via postcard.
80    Postcard,
81    /// Opaque byte blob.
82    RawBytes,
83}
84
85/// Reference to a named data artifact produced or consumed by a task.
86#[derive(Debug, Clone)]
87pub struct DataRef {
88    /// Name used to match outputs to inputs across tasks.
89    pub name: String,
90    /// Serialization format.
91    pub format: DataFormat,
92}
93
94// ── ResourceReq ─────────────────────────────────────────────────────────
95
96/// Resource requirements for a task.
97#[derive(Debug, Clone)]
98pub struct ResourceReq {
99    /// Minimum memory in bytes.
100    pub min_memory: u64,
101    /// Minimum CPU cores.
102    pub min_cpu_cores: u32,
103    /// Minimum block storage in bytes.
104    pub min_block: u64,
105    /// Fuel budget for execution.
106    pub fuel: u64,
107}
108
109impl Default for ResourceReq {
110    fn default() -> Self {
111        ResourceReq {
112            min_memory: 0,
113            min_cpu_cores: 1,
114            min_block: 0,
115            fuel: 1_000_000,
116        }
117    }
118}
119
120// ── TaskContext / TaskOutput ─────────────────────────────────────────────
121
122/// Execution context provided to a task function.
123pub struct TaskContext {
124    /// The ID of the currently executing task.
125    pub task_id: TaskId,
126    /// Input data keyed by [`DataRef::name`].
127    pub inputs: HashMap<String, Vec<u8>>,
128}
129
130/// Output produced by a task function.
131pub struct TaskOutput {
132    /// Output data keyed by [`DataRef::name`].
133    pub data: HashMap<String, Vec<u8>>,
134}
135
136impl fmt::Debug for TaskOutput {
137    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138        f.debug_struct("TaskOutput")
139            .field("data_keys", &self.data.keys().collect::<Vec<_>>())
140            .finish()
141    }
142}
143
144// ── TaskDef ─────────────────────────────────────────────────────────────
145
146/// Definition of a task within the graph.
147///
148/// The `task_fn` closure uses `FnMut` so it can be retried on failure.
149/// For tasks with `retries > 0`, the closure is called again with the
150/// same [`TaskContext`] on each retry attempt.
151#[allow(clippy::type_complexity)]
152pub struct TaskDef {
153    /// Human-readable task name.
154    pub name: String,
155    /// The function to execute. Called with a [`TaskContext`] on each
156    /// attempt (initial + retries).
157    pub task_fn: Box<dyn FnMut(&TaskContext) -> Result<TaskOutput, FabricError>>,
158    /// Resource requirements.
159    pub resource_req: ResourceReq,
160    /// Named data inputs this task consumes.
161    pub inputs: Vec<DataRef>,
162    /// Named data outputs this task produces.
163    pub outputs: Vec<DataRef>,
164    /// Number of retry attempts on failure (0 = no retries).
165    pub retries: u32,
166}
167
168// ── TaskGraph ───────────────────────────────────────────────────────────
169
170/// Builder for constructing a task DAG.
171///
172/// Tasks are added with [`add_task`](TaskGraph::add_task) and dependency
173/// edges with [`add_dependency`](TaskGraph::add_dependency). Call
174/// [`build`](TaskGraph::build) to validate the graph and compute an
175/// [`ExecutionPlan`].
176pub struct TaskGraph {
177    tasks: BTreeMap<TaskId, TaskDef>,
178    /// Forward edges: from → set of to (meaning "to" depends on "from").
179    deps: BTreeMap<TaskId, BTreeSet<TaskId>>,
180    /// Reverse edges: to → set of from (meaning "to" is blocked by "from").
181    rev_deps: BTreeMap<TaskId, BTreeSet<TaskId>>,
182    next_id: u64,
183}
184
185impl TaskGraph {
186    /// Create an empty task graph.
187    pub fn new() -> Self {
188        TaskGraph {
189            tasks: BTreeMap::new(),
190            deps: BTreeMap::new(),
191            rev_deps: BTreeMap::new(),
192            next_id: 0,
193        }
194    }
195
196    /// Add a task and return its assigned [`TaskId`].
197    pub fn add_task(&mut self, def: TaskDef) -> TaskId {
198        let id = TaskId(self.next_id);
199        self.next_id += 1;
200        self.tasks.insert(id, def);
201        id
202    }
203
204    /// Declare that `to` depends on `from` (i.e. `from` must complete before
205    /// `to` can start).
206    ///
207    /// # Errors
208    ///
209    /// Returns `FabricError::IoError(-100)` if either task ID does not
210    /// exist in the graph.
211    pub fn add_dependency(&mut self, from: TaskId, to: TaskId) -> Result<(), FabricError> {
212        if !self.tasks.contains_key(&from) {
213            return Err(FabricError::IoError(-100));
214        }
215        if !self.tasks.contains_key(&to) {
216            return Err(FabricError::IoError(-100));
217        }
218        self.deps.entry(from).or_default().insert(to);
219        self.rev_deps.entry(to).or_default().insert(from);
220        Ok(())
221    }
222
223    /// Validate the graph: check for cycles and unresolved data references.
224    ///
225    /// # Errors
226    ///
227    /// Returns `FabricError::IoError(-101)` if a cycle is detected.
228    /// Returns `FabricError::IoError(-102)` if a task has an input with no
229    /// upstream producer.
230    pub fn validate(&self) -> Result<(), FabricError> {
231        self.detect_cycle()?;
232        self.check_data_refs()?;
233        Ok(())
234    }
235
236    /// Validate and compute an [`ExecutionPlan`] with topologically sorted
237    /// waves. Consumes the graph.
238    ///
239    /// # Errors
240    ///
241    /// Returns the same errors as [`validate`](TaskGraph::validate).
242    pub fn build(self) -> Result<ExecutionPlan, FabricError> {
243        self.validate()?;
244        let waves = self.compute_waves();
245        Ok(ExecutionPlan {
246            waves,
247            tasks: self.tasks,
248            rev_deps: self.rev_deps,
249        })
250    }
251
252    fn detect_cycle(&self) -> Result<(), FabricError> {
253        const WHITE: u8 = 0;
254
255        let mut color: BTreeMap<TaskId, u8> = BTreeMap::new();
256        for &id in self.tasks.keys() {
257            color.insert(id, WHITE);
258        }
259
260        for &id in self.tasks.keys() {
261            if color[&id] == WHITE && self.dfs_has_cycle(id, &mut color) {
262                return Err(FabricError::IoError(-101));
263            }
264        }
265        Ok(())
266    }
267
268    fn dfs_has_cycle(&self, node: TaskId, color: &mut BTreeMap<TaskId, u8>) -> bool {
269        color.insert(node, 1); // GRAY
270        if let Some(neighbors) = self.deps.get(&node) {
271            for &next in neighbors {
272                match color.get(&next) {
273                    Some(1) => return true, // back edge = cycle
274                    Some(0) => {
275                        if self.dfs_has_cycle(next, color) {
276                            return true;
277                        }
278                    }
279                    _ => {}
280                }
281            }
282        }
283        color.insert(node, 2); // BLACK
284        false
285    }
286
287    fn check_data_refs(&self) -> Result<(), FabricError> {
288        let mut all_outputs: BTreeSet<&str> = BTreeSet::new();
289        for def in self.tasks.values() {
290            for out_ref in &def.outputs {
291                all_outputs.insert(&out_ref.name);
292            }
293        }
294        for def in self.tasks.values() {
295            for in_ref in &def.inputs {
296                if !all_outputs.contains(in_ref.name.as_str()) {
297                    return Err(FabricError::IoError(-102));
298                }
299            }
300        }
301        Ok(())
302    }
303
304    fn compute_waves(&self) -> Vec<Vec<TaskId>> {
305        // Kahn's algorithm producing level-grouped topological order.
306        let mut in_degree: BTreeMap<TaskId, usize> = BTreeMap::new();
307        for &id in self.tasks.keys() {
308            in_degree.insert(id, 0);
309        }
310        for dependents in self.deps.values() {
311            for &dep in dependents {
312                *in_degree.entry(dep).or_insert(0) += 1;
313            }
314        }
315
316        let mut waves = Vec::new();
317        let mut ready: Vec<TaskId> = in_degree
318            .iter()
319            .filter(|(_, &deg)| deg == 0)
320            .map(|(&id, _)| id)
321            .collect();
322        ready.sort();
323
324        while !ready.is_empty() {
325            waves.push(ready.clone());
326            let mut next_ready = Vec::new();
327            for &id in &ready {
328                if let Some(dependents) = self.deps.get(&id) {
329                    for &dep in dependents {
330                        let deg = in_degree.get_mut(&dep).unwrap();
331                        *deg -= 1;
332                        if *deg == 0 {
333                            next_ready.push(dep);
334                        }
335                    }
336                }
337            }
338            next_ready.sort();
339            ready = next_ready;
340        }
341
342        waves
343    }
344}
345
346impl Default for TaskGraph {
347    fn default() -> Self {
348        Self::new()
349    }
350}
351
352// ── ExecutionPlan ───────────────────────────────────────────────────────
353
354/// A validated execution plan with tasks grouped into waves.
355///
356/// Wave N contains only tasks whose dependencies all appear in waves < N.
357/// Tasks within the same wave are independent and could be executed in
358/// parallel (though the current executor runs them sequentially).
359pub struct ExecutionPlan {
360    waves: Vec<Vec<TaskId>>,
361    tasks: BTreeMap<TaskId, TaskDef>,
362    rev_deps: BTreeMap<TaskId, BTreeSet<TaskId>>,
363}
364
365impl fmt::Debug for ExecutionPlan {
366    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
367        f.debug_struct("ExecutionPlan")
368            .field("waves", &self.waves)
369            .field("task_count", &self.task_count())
370            .finish()
371    }
372}
373
374impl ExecutionPlan {
375    /// Get the execution waves. Each wave is a list of independent task IDs.
376    pub fn waves(&self) -> &[Vec<TaskId>] {
377        &self.waves
378    }
379
380    /// Total number of tasks across all waves.
381    pub fn task_count(&self) -> usize {
382        self.waves.iter().map(|w| w.len()).sum()
383    }
384}
385
386// ── TaskResult / TaskStatus / ExecutionResult ──────────────────────────
387
388/// Status of a completed task.
389#[derive(Debug, Clone, PartialEq, Eq)]
390pub enum TaskStatus {
391    /// Task completed successfully.
392    Success,
393    /// Task failed after exhausting retries.
394    Failed(String),
395    /// Task was skipped because an upstream dependency failed.
396    Skipped,
397}
398
399/// Result of executing a single task.
400#[derive(Debug)]
401pub struct TaskResult {
402    /// Final status.
403    pub status: TaskStatus,
404    /// Number of retries consumed before success or final failure.
405    pub retries_used: u32,
406    /// Output data if the task succeeded.
407    pub output: Option<TaskOutput>,
408}
409
410/// Aggregate result of executing an entire [`ExecutionPlan`].
411pub struct ExecutionResult {
412    /// Per-task results keyed by [`TaskId`].
413    pub task_results: BTreeMap<TaskId, TaskResult>,
414    /// Total tasks in the plan.
415    pub total_tasks: u32,
416    /// Tasks that completed successfully.
417    pub succeeded: u32,
418    /// Tasks that failed after exhausting retries.
419    pub failed: u32,
420    /// Tasks that were skipped due to upstream failures.
421    pub skipped: u32,
422}
423
424// ── Executor ────────────────────────────────────────────────────────────
425
426/// Single-threaded task graph executor.
427///
428/// Executes an [`ExecutionPlan`] wave by wave, propagating outputs from
429/// completed tasks to downstream consumers via [`DataRef`] name matching.
430/// Failed tasks are retried up to [`TaskDef::retries`] times; if retries
431/// are exhausted, all transitive downstream tasks are marked
432/// [`TaskStatus::Skipped`].
433pub struct Executor;
434
435impl Executor {
436    /// Create a new executor.
437    pub fn new() -> Self {
438        Executor
439    }
440
441    /// Execute the plan, consuming it.
442    ///
443    /// Tasks are executed sequentially within each wave. Outputs from
444    /// completed tasks are stored and made available as inputs to downstream
445    /// tasks matched by [`DataRef::name`].
446    pub fn run(mut plan: ExecutionPlan) -> Result<ExecutionResult, FabricError> {
447        let total_tasks = plan.task_count() as u32;
448        let mut results: BTreeMap<TaskId, TaskResult> = BTreeMap::new();
449        let mut data_store: HashMap<String, Vec<u8>> = HashMap::new();
450        let mut failed_ids: BTreeSet<TaskId> = BTreeSet::new();
451
452        let mut succeeded: u32 = 0;
453        let mut failed: u32 = 0;
454        let mut skipped: u32 = 0;
455
456        let waves: Vec<Vec<TaskId>> = plan.waves.clone();
457
458        for wave in &waves {
459            for &task_id in wave {
460                // Check if any upstream dependency failed or was skipped.
461                let should_skip = plan
462                    .rev_deps
463                    .get(&task_id)
464                    .map(|upstream| upstream.iter().any(|u| failed_ids.contains(u)))
465                    .unwrap_or(false);
466
467                if should_skip {
468                    failed_ids.insert(task_id);
469                    results.insert(
470                        task_id,
471                        TaskResult {
472                            status: TaskStatus::Skipped,
473                            retries_used: 0,
474                            output: None,
475                        },
476                    );
477                    skipped += 1;
478                    continue;
479                }
480
481                let mut def = plan.tasks.remove(&task_id).unwrap();
482
483                // Gather inputs from data store.
484                let mut inputs = HashMap::new();
485                for in_ref in &def.inputs {
486                    if let Some(data) = data_store.get(&in_ref.name) {
487                        inputs.insert(in_ref.name.clone(), data.clone());
488                    }
489                }
490
491                let max_attempts = def.retries + 1;
492                let mut last_err = None;
493                let mut retries_used = 0u32;
494
495                for attempt in 0..max_attempts {
496                    let ctx = TaskContext {
497                        task_id,
498                        inputs: inputs.clone(),
499                    };
500
501                    match (def.task_fn)(&ctx) {
502                        Ok(output) => {
503                            for out_ref in &def.outputs {
504                                if let Some(data) = output.data.get(&out_ref.name) {
505                                    data_store.insert(out_ref.name.clone(), data.clone());
506                                }
507                            }
508                            retries_used = attempt;
509                            results.insert(
510                                task_id,
511                                TaskResult {
512                                    status: TaskStatus::Success,
513                                    retries_used,
514                                    output: Some(output),
515                                },
516                            );
517                            succeeded += 1;
518                            last_err = None;
519                            break;
520                        }
521                        Err(e) => {
522                            last_err = Some(e);
523                            retries_used = attempt;
524                        }
525                    }
526                }
527
528                if let Some(e) = last_err {
529                    failed_ids.insert(task_id);
530                    results.insert(
531                        task_id,
532                        TaskResult {
533                            status: TaskStatus::Failed(alloc::format!("{}", e)),
534                            retries_used,
535                            output: None,
536                        },
537                    );
538                    failed += 1;
539                }
540            }
541        }
542
543        Ok(ExecutionResult {
544            task_results: results,
545            total_tasks,
546            succeeded,
547            failed,
548            skipped,
549        })
550    }
551}
552
553impl Default for Executor {
554    fn default() -> Self {
555        Self::new()
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    #[test]
564    fn empty_graph() {
565        let graph = TaskGraph::new();
566        let plan = graph.build().unwrap();
567        assert_eq!(plan.task_count(), 0);
568        assert!(plan.waves().is_empty());
569        let result = Executor::run(plan).unwrap();
570        assert_eq!(result.total_tasks, 0);
571        assert_eq!(result.succeeded, 0);
572        assert_eq!(result.failed, 0);
573        assert_eq!(result.skipped, 0);
574    }
575
576    #[test]
577    fn single_task() {
578        let mut graph = TaskGraph::new();
579        graph.add_task(TaskDef {
580            name: "solo".into(),
581            task_fn: Box::new(|_ctx| {
582                Ok(TaskOutput {
583                    data: HashMap::new(),
584                })
585            }),
586            resource_req: ResourceReq::default(),
587            inputs: vec![],
588            outputs: vec![],
589            retries: 0,
590        });
591
592        let plan = graph.build().unwrap();
593        assert_eq!(plan.task_count(), 1);
594        assert_eq!(plan.waves().len(), 1);
595
596        let result = Executor::run(plan).unwrap();
597        assert_eq!(result.succeeded, 1);
598    }
599
600    #[test]
601    fn linear_dag_a_b_c() {
602        let mut graph = TaskGraph::new();
603        let a = graph.add_task(TaskDef {
604            name: "A".into(),
605            task_fn: Box::new(|_ctx| {
606                let mut data = HashMap::new();
607                data.insert("a_out".into(), b"from_a".to_vec());
608                Ok(TaskOutput { data })
609            }),
610            resource_req: ResourceReq::default(),
611            inputs: vec![],
612            outputs: vec![DataRef {
613                name: "a_out".into(),
614                format: DataFormat::RawBytes,
615            }],
616            retries: 0,
617        });
618        let b = graph.add_task(TaskDef {
619            name: "B".into(),
620            task_fn: Box::new(|ctx| {
621                let a_data = ctx.inputs.get("a_out").unwrap();
622                assert_eq!(a_data, b"from_a");
623                let mut data = HashMap::new();
624                data.insert("b_out".into(), b"from_b".to_vec());
625                Ok(TaskOutput { data })
626            }),
627            resource_req: ResourceReq::default(),
628            inputs: vec![DataRef {
629                name: "a_out".into(),
630                format: DataFormat::RawBytes,
631            }],
632            outputs: vec![DataRef {
633                name: "b_out".into(),
634                format: DataFormat::RawBytes,
635            }],
636            retries: 0,
637        });
638        let c = graph.add_task(TaskDef {
639            name: "C".into(),
640            task_fn: Box::new(|ctx| {
641                let b_data = ctx.inputs.get("b_out").unwrap();
642                assert_eq!(b_data, b"from_b");
643                Ok(TaskOutput {
644                    data: HashMap::new(),
645                })
646            }),
647            resource_req: ResourceReq::default(),
648            inputs: vec![DataRef {
649                name: "b_out".into(),
650                format: DataFormat::RawBytes,
651            }],
652            outputs: vec![],
653            retries: 0,
654        });
655
656        graph.add_dependency(a, b).unwrap();
657        graph.add_dependency(b, c).unwrap();
658
659        let plan = graph.build().unwrap();
660        assert_eq!(plan.waves().len(), 3);
661        assert_eq!(plan.waves()[0], vec![a]);
662        assert_eq!(plan.waves()[1], vec![b]);
663        assert_eq!(plan.waves()[2], vec![c]);
664
665        let result = Executor::run(plan).unwrap();
666        assert_eq!(result.succeeded, 3);
667        assert_eq!(result.failed, 0);
668        assert_eq!(result.skipped, 0);
669    }
670
671    #[test]
672    fn diamond_dag() {
673        let mut graph = TaskGraph::new();
674        let a = graph.add_task(TaskDef {
675            name: "A".into(),
676            task_fn: Box::new(|_ctx| {
677                let mut data = HashMap::new();
678                data.insert("shared".into(), b"from_a".to_vec());
679                Ok(TaskOutput { data })
680            }),
681            resource_req: ResourceReq::default(),
682            inputs: vec![],
683            outputs: vec![DataRef {
684                name: "shared".into(),
685                format: DataFormat::RawBytes,
686            }],
687            retries: 0,
688        });
689        let b = graph.add_task(TaskDef {
690            name: "B".into(),
691            task_fn: Box::new(|ctx| {
692                assert!(ctx.inputs.contains_key("shared"));
693                let mut data = HashMap::new();
694                data.insert("b_out".into(), b"from_b".to_vec());
695                Ok(TaskOutput { data })
696            }),
697            resource_req: ResourceReq::default(),
698            inputs: vec![DataRef {
699                name: "shared".into(),
700                format: DataFormat::RawBytes,
701            }],
702            outputs: vec![DataRef {
703                name: "b_out".into(),
704                format: DataFormat::RawBytes,
705            }],
706            retries: 0,
707        });
708        let c = graph.add_task(TaskDef {
709            name: "C".into(),
710            task_fn: Box::new(|ctx| {
711                assert!(ctx.inputs.contains_key("shared"));
712                let mut data = HashMap::new();
713                data.insert("c_out".into(), b"from_c".to_vec());
714                Ok(TaskOutput { data })
715            }),
716            resource_req: ResourceReq::default(),
717            inputs: vec![DataRef {
718                name: "shared".into(),
719                format: DataFormat::RawBytes,
720            }],
721            outputs: vec![DataRef {
722                name: "c_out".into(),
723                format: DataFormat::RawBytes,
724            }],
725            retries: 0,
726        });
727        let d = graph.add_task(TaskDef {
728            name: "D".into(),
729            task_fn: Box::new(|ctx| {
730                assert!(ctx.inputs.contains_key("b_out"));
731                assert!(ctx.inputs.contains_key("c_out"));
732                Ok(TaskOutput {
733                    data: HashMap::new(),
734                })
735            }),
736            resource_req: ResourceReq::default(),
737            inputs: vec![
738                DataRef {
739                    name: "b_out".into(),
740                    format: DataFormat::RawBytes,
741                },
742                DataRef {
743                    name: "c_out".into(),
744                    format: DataFormat::RawBytes,
745                },
746            ],
747            outputs: vec![],
748            retries: 0,
749        });
750
751        graph.add_dependency(a, b).unwrap();
752        graph.add_dependency(a, c).unwrap();
753        graph.add_dependency(b, d).unwrap();
754        graph.add_dependency(c, d).unwrap();
755
756        let plan = graph.build().unwrap();
757        assert_eq!(plan.waves().len(), 3);
758        assert_eq!(plan.waves()[0], vec![a]);
759        assert!(plan.waves()[1].contains(&b));
760        assert!(plan.waves()[1].contains(&c));
761        assert_eq!(plan.waves()[2], vec![d]);
762
763        let result = Executor::run(plan).unwrap();
764        assert_eq!(result.succeeded, 4);
765    }
766
767    #[test]
768    fn cycle_detection() {
769        let mut graph = TaskGraph::new();
770        let a = graph.add_task(TaskDef {
771            name: "A".into(),
772            task_fn: Box::new(|_| {
773                Ok(TaskOutput {
774                    data: HashMap::new(),
775                })
776            }),
777            resource_req: ResourceReq::default(),
778            inputs: vec![],
779            outputs: vec![],
780            retries: 0,
781        });
782        let b = graph.add_task(TaskDef {
783            name: "B".into(),
784            task_fn: Box::new(|_| {
785                Ok(TaskOutput {
786                    data: HashMap::new(),
787                })
788            }),
789            resource_req: ResourceReq::default(),
790            inputs: vec![],
791            outputs: vec![],
792            retries: 0,
793        });
794        let c = graph.add_task(TaskDef {
795            name: "C".into(),
796            task_fn: Box::new(|_| {
797                Ok(TaskOutput {
798                    data: HashMap::new(),
799                })
800            }),
801            resource_req: ResourceReq::default(),
802            inputs: vec![],
803            outputs: vec![],
804            retries: 0,
805        });
806
807        graph.add_dependency(a, b).unwrap();
808        graph.add_dependency(b, c).unwrap();
809        graph.add_dependency(c, a).unwrap();
810
811        let err = graph.build().unwrap_err();
812        assert_eq!(err, FabricError::IoError(-101));
813    }
814
815    #[test]
816    fn task_failure_skips_downstream() {
817        let mut graph = TaskGraph::new();
818        let a = graph.add_task(TaskDef {
819            name: "A-fails".into(),
820            task_fn: Box::new(|_ctx| Err(FabricError::Disconnected)),
821            resource_req: ResourceReq::default(),
822            inputs: vec![],
823            outputs: vec![DataRef {
824                name: "a_out".into(),
825                format: DataFormat::RawBytes,
826            }],
827            retries: 0,
828        });
829        let b = graph.add_task(TaskDef {
830            name: "B-skipped".into(),
831            task_fn: Box::new(|_ctx| panic!("should not be called")),
832            resource_req: ResourceReq::default(),
833            inputs: vec![DataRef {
834                name: "a_out".into(),
835                format: DataFormat::RawBytes,
836            }],
837            outputs: vec![],
838            retries: 0,
839        });
840
841        graph.add_dependency(a, b).unwrap();
842
843        let plan = graph.build().unwrap();
844        let result = Executor::run(plan).unwrap();
845
846        assert_eq!(result.failed, 1);
847        assert_eq!(result.skipped, 1);
848        assert_eq!(result.succeeded, 0);
849
850        assert!(matches!(
851            result.task_results.get(&a).unwrap().status,
852            TaskStatus::Failed(_)
853        ));
854        assert_eq!(
855            result.task_results.get(&b).unwrap().status,
856            TaskStatus::Skipped
857        );
858    }
859
860    #[test]
861    fn task_retry_succeeds() {
862        use std::cell::Cell;
863
864        let counter = std::rc::Rc::new(Cell::new(0u32));
865        let counter_clone = counter.clone();
866
867        let mut graph = TaskGraph::new();
868        graph.add_task(TaskDef {
869            name: "flaky".into(),
870            task_fn: Box::new(move |_ctx| {
871                let attempt = counter_clone.get();
872                counter_clone.set(attempt + 1);
873                if attempt == 0 {
874                    Err(FabricError::Disconnected)
875                } else {
876                    Ok(TaskOutput {
877                        data: HashMap::new(),
878                    })
879                }
880            }),
881            resource_req: ResourceReq::default(),
882            inputs: vec![],
883            outputs: vec![],
884            retries: 2,
885        });
886
887        let plan = graph.build().unwrap();
888        let result = Executor::run(plan).unwrap();
889        assert_eq!(result.succeeded, 1);
890        assert_eq!(result.failed, 0);
891        assert_eq!(counter.get(), 2); // called twice: fail then succeed
892        let task_result = result.task_results.values().next().unwrap();
893        assert_eq!(task_result.status, TaskStatus::Success);
894        assert_eq!(task_result.retries_used, 1);
895    }
896
897    #[test]
898    fn retries_exhausted_marks_failed_and_skips_downstream() {
899        let mut graph = TaskGraph::new();
900        let a = graph.add_task(TaskDef {
901            name: "always-fails".into(),
902            task_fn: Box::new(|_ctx| Err(FabricError::Disconnected)),
903            resource_req: ResourceReq::default(),
904            inputs: vec![],
905            outputs: vec![],
906            retries: 2,
907        });
908        let b = graph.add_task(TaskDef {
909            name: "downstream".into(),
910            task_fn: Box::new(|_ctx| panic!("should not run")),
911            resource_req: ResourceReq::default(),
912            inputs: vec![],
913            outputs: vec![],
914            retries: 0,
915        });
916
917        graph.add_dependency(a, b).unwrap();
918
919        let plan = graph.build().unwrap();
920        let result = Executor::run(plan).unwrap();
921        assert_eq!(result.failed, 1);
922        assert_eq!(result.skipped, 1);
923
924        let a_result = result.task_results.get(&a).unwrap();
925        assert!(matches!(a_result.status, TaskStatus::Failed(_)));
926        assert_eq!(a_result.retries_used, 2); // 0, 1, 2 = 3 attempts, last index is 2
927    }
928
929    #[test]
930    fn data_passing_raw_bytes() {
931        let mut graph = TaskGraph::new();
932        let a = graph.add_task(TaskDef {
933            name: "producer".into(),
934            task_fn: Box::new(|_ctx| {
935                let mut data = HashMap::new();
936                data.insert("payload".into(), vec![1, 2, 3, 4, 5]);
937                Ok(TaskOutput { data })
938            }),
939            resource_req: ResourceReq::default(),
940            inputs: vec![],
941            outputs: vec![DataRef {
942                name: "payload".into(),
943                format: DataFormat::RawBytes,
944            }],
945            retries: 0,
946        });
947        let b = graph.add_task(TaskDef {
948            name: "consumer".into(),
949            task_fn: Box::new(|ctx| {
950                let payload = ctx.inputs.get("payload").unwrap();
951                assert_eq!(payload, &[1, 2, 3, 4, 5]);
952                Ok(TaskOutput {
953                    data: HashMap::new(),
954                })
955            }),
956            resource_req: ResourceReq::default(),
957            inputs: vec![DataRef {
958                name: "payload".into(),
959                format: DataFormat::RawBytes,
960            }],
961            outputs: vec![],
962            retries: 0,
963        });
964
965        graph.add_dependency(a, b).unwrap();
966        let plan = graph.build().unwrap();
967        let result = Executor::run(plan).unwrap();
968        assert_eq!(result.succeeded, 2);
969    }
970
971    #[test]
972    fn unknown_task_id_in_dependency() {
973        let mut graph = TaskGraph::new();
974        let a = graph.add_task(TaskDef {
975            name: "A".into(),
976            task_fn: Box::new(|_| {
977                Ok(TaskOutput {
978                    data: HashMap::new(),
979                })
980            }),
981            resource_req: ResourceReq::default(),
982            inputs: vec![],
983            outputs: vec![],
984            retries: 0,
985        });
986
987        let err = graph.add_dependency(a, TaskId(999));
988        assert_eq!(err.unwrap_err(), FabricError::IoError(-100));
989    }
990
991    #[test]
992    fn unresolved_input_ref() {
993        let mut graph = TaskGraph::new();
994        graph.add_task(TaskDef {
995            name: "needs-missing".into(),
996            task_fn: Box::new(|_| {
997                Ok(TaskOutput {
998                    data: HashMap::new(),
999                })
1000            }),
1001            resource_req: ResourceReq::default(),
1002            inputs: vec![DataRef {
1003                name: "nonexistent".into(),
1004                format: DataFormat::RawBytes,
1005            }],
1006            outputs: vec![],
1007            retries: 0,
1008        });
1009
1010        let err = graph.build().unwrap_err();
1011        assert_eq!(err, FabricError::IoError(-102));
1012    }
1013
1014    #[test]
1015    fn task_id_display() {
1016        let id = TaskId(42);
1017        assert_eq!(alloc::format!("{}", id), "Task(42)");
1018    }
1019
1020    #[test]
1021    fn resource_req_default() {
1022        let req = ResourceReq::default();
1023        assert_eq!(req.min_cpu_cores, 1);
1024        assert_eq!(req.min_memory, 0);
1025        assert_eq!(req.min_block, 0);
1026        assert_eq!(req.fuel, 1_000_000);
1027    }
1028
1029    #[test]
1030    fn wide_parallel_wave() {
1031        let mut graph = TaskGraph::new();
1032        for i in 0..5 {
1033            graph.add_task(TaskDef {
1034                name: alloc::format!("task-{}", i),
1035                task_fn: Box::new(|_| {
1036                    Ok(TaskOutput {
1037                        data: HashMap::new(),
1038                    })
1039                }),
1040                resource_req: ResourceReq::default(),
1041                inputs: vec![],
1042                outputs: vec![],
1043                retries: 0,
1044            });
1045        }
1046
1047        let plan = graph.build().unwrap();
1048        assert_eq!(plan.waves().len(), 1);
1049        assert_eq!(plan.waves()[0].len(), 5);
1050
1051        let result = Executor::run(plan).unwrap();
1052        assert_eq!(result.succeeded, 5);
1053    }
1054
1055    #[test]
1056    fn transitive_skip_propagation() {
1057        let mut graph = TaskGraph::new();
1058        let a = graph.add_task(TaskDef {
1059            name: "A-fails".into(),
1060            task_fn: Box::new(|_| Err(FabricError::Disconnected)),
1061            resource_req: ResourceReq::default(),
1062            inputs: vec![],
1063            outputs: vec![],
1064            retries: 0,
1065        });
1066        let b = graph.add_task(TaskDef {
1067            name: "B".into(),
1068            task_fn: Box::new(|_| panic!("should not run")),
1069            resource_req: ResourceReq::default(),
1070            inputs: vec![],
1071            outputs: vec![],
1072            retries: 0,
1073        });
1074        let c = graph.add_task(TaskDef {
1075            name: "C".into(),
1076            task_fn: Box::new(|_| panic!("should not run")),
1077            resource_req: ResourceReq::default(),
1078            inputs: vec![],
1079            outputs: vec![],
1080            retries: 0,
1081        });
1082
1083        graph.add_dependency(a, b).unwrap();
1084        graph.add_dependency(b, c).unwrap();
1085
1086        let plan = graph.build().unwrap();
1087        let result = Executor::run(plan).unwrap();
1088        assert_eq!(result.failed, 1);
1089        assert_eq!(result.skipped, 2);
1090        assert_eq!(
1091            result.task_results.get(&a).unwrap().status,
1092            TaskStatus::Failed("disconnected".into())
1093        );
1094        assert_eq!(
1095            result.task_results.get(&b).unwrap().status,
1096            TaskStatus::Skipped
1097        );
1098        assert_eq!(
1099            result.task_results.get(&c).unwrap().status,
1100            TaskStatus::Skipped
1101        );
1102    }
1103}