1extern crate alloc;
19use alloc::boxed::Box;
20use alloc::vec::Vec;
21
22use grafos_std::error::FabricError;
23use serde::{de::DeserializeOwned, Serialize};
24
25use grafos_collections::queue::FabricQueue;
26
27use crate::placement::NodeConstraint;
28use crate::stage::{Sink, Source};
29
30type MapFn = Box<dyn FnMut(&[u8]) -> Result<Option<Vec<u8>>, FabricError>>;
32type FilterFn = Box<dyn FnMut(&[u8]) -> Result<bool, FabricError>>;
34
35pub struct Pipeline;
37
38impl Pipeline {
39 pub fn from_source<T>(source: impl Source<T> + 'static) -> PipelineBuilder<T>
41 where
42 T: Serialize + DeserializeOwned + 'static,
43 {
44 let mut source = source;
47 let erased_source: Box<dyn FnMut() -> Result<Option<Vec<u8>>, FabricError>> =
48 Box::new(move || match source.next()? {
49 Some(item) => {
50 let bytes =
51 postcard::to_allocvec(&item).map_err(|_| FabricError::IoError(-1))?;
52 Ok(Some(bytes))
53 }
54 None => Ok(None),
55 });
56
57 PipelineBuilder {
58 source: erased_source,
59 stages: Vec::new(),
60 buffer_size: 1024,
61 next_constraint: NodeConstraint::Any,
62 _marker: core::marker::PhantomData,
63 }
64 }
65}
66
67enum StageKind {
69 Map(MapFn),
70 Filter(FilterFn),
71}
72
73pub struct StageEntry {
75 kind: StageKind,
76 #[allow(dead_code)] constraint: NodeConstraint,
78 #[allow(dead_code)] buffer_size: usize,
80}
81
82pub struct PipelineBuilder<T: 'static> {
88 source: Box<dyn FnMut() -> Result<Option<Vec<u8>>, FabricError>>,
90 stages: Vec<StageEntry>,
91 buffer_size: usize,
92 next_constraint: NodeConstraint,
93 _marker: core::marker::PhantomData<T>,
94}
95
96impl<T: Serialize + DeserializeOwned + 'static> PipelineBuilder<T> {
97 pub fn map<U, F>(mut self, mut f: F) -> PipelineBuilder<U>
102 where
103 U: Serialize + DeserializeOwned + 'static,
104 F: FnMut(T) -> U + 'static,
105 {
106 let constraint = core::mem::replace(&mut self.next_constraint, NodeConstraint::Any);
107 let buf_size = self.buffer_size;
108
109 let erased: MapFn = Box::new(move |bytes: &[u8]| {
110 let item: T = postcard::from_bytes(bytes).map_err(|_| FabricError::IoError(-1))?;
111 let out = f(item);
112 let out_bytes = postcard::to_allocvec(&out).map_err(|_| FabricError::IoError(-1))?;
113 Ok(Some(out_bytes))
114 });
115
116 let mut stages = self.stages;
117 stages.push(StageEntry {
118 kind: StageKind::Map(erased),
119 constraint,
120 buffer_size: buf_size,
121 });
122
123 PipelineBuilder {
124 source: self.source,
125 stages,
126 buffer_size: 1024,
127 next_constraint: NodeConstraint::Any,
128 _marker: core::marker::PhantomData,
129 }
130 }
131
132 pub fn filter<F>(mut self, mut pred: F) -> PipelineBuilder<T>
135 where
136 F: FnMut(&T) -> bool + 'static,
137 {
138 let constraint = core::mem::replace(&mut self.next_constraint, NodeConstraint::Any);
139 let buf_size = self.buffer_size;
140
141 let erased: FilterFn = Box::new(move |bytes: &[u8]| {
142 let item: T = postcard::from_bytes(bytes).map_err(|_| FabricError::IoError(-1))?;
143 Ok(pred(&item))
144 });
145
146 let mut stages = self.stages;
147 stages.push(StageEntry {
148 kind: StageKind::Filter(erased),
149 constraint,
150 buffer_size: buf_size,
151 });
152
153 PipelineBuilder {
154 source: self.source,
155 stages,
156 buffer_size: 1024,
157 next_constraint: NodeConstraint::Any,
158 _marker: core::marker::PhantomData,
159 }
160 }
161
162 pub fn on_node(mut self, constraint: NodeConstraint) -> Self {
164 self.next_constraint = constraint;
165 self
166 }
167
168 pub fn buffer_size(mut self, n: usize) -> Self {
170 self.buffer_size = n;
171 self
172 }
173
174 pub fn sink<K: Sink<T> + 'static>(self, sink: K) -> SinkPipeline<T, K> {
176 SinkPipeline {
177 source: self.source,
178 stages: self.stages,
179 sink,
180 _marker: core::marker::PhantomData,
181 }
182 }
183
184 pub fn fold<Acc, F>(self, init: Acc, f: F) -> FoldPipeline<T, Acc, F>
186 where
187 F: FnMut(Acc, T) -> Acc,
188 {
189 FoldPipeline {
190 source: self.source,
191 stages: self.stages,
192 init: Some(init),
193 f,
194 _marker: core::marker::PhantomData,
195 }
196 }
197
198 pub fn collect(self) -> CollectPipeline<T> {
200 CollectPipeline {
201 source: self.source,
202 stages: self.stages,
203 _marker: core::marker::PhantomData,
204 }
205 }
206
207 pub fn build(self) -> Result<PipelineHandle<T>, FabricError> {
209 Ok(PipelineHandle {
210 source: self.source,
211 stages: self.stages,
212 _marker: core::marker::PhantomData,
213 })
214 }
215}
216
217fn run_stages(
220 source: &mut dyn FnMut() -> Result<Option<Vec<u8>>, FabricError>,
221 stages: &mut [StageEntry],
222) -> Result<Vec<Vec<u8>>, FabricError> {
223 let mut items: Vec<Vec<u8>> = Vec::new();
224 while let Some(bytes) = source()? {
225 items.push(bytes);
226 }
227
228 for stage in stages.iter_mut() {
229 let mut output = Vec::new();
230 match &mut stage.kind {
231 StageKind::Map(f) => {
232 for item_bytes in &items {
233 if let Some(out_bytes) = f(item_bytes)? {
234 output.push(out_bytes);
235 }
236 }
237 }
238 StageKind::Filter(pred) => {
239 for item_bytes in &items {
240 if pred(item_bytes)? {
241 output.push(item_bytes.clone());
242 }
243 }
244 }
245 }
246 items = output;
247 }
248
249 Ok(items)
250}
251
252fn deserialize_items<T: DeserializeOwned>(items: Vec<Vec<u8>>) -> Result<Vec<T>, FabricError> {
253 let mut result = Vec::with_capacity(items.len());
254 for item_bytes in items {
255 let item: T = postcard::from_bytes(&item_bytes).map_err(|_| FabricError::IoError(-1))?;
256 result.push(item);
257 }
258 Ok(result)
259}
260
261pub struct PipelineHandle<T: 'static> {
263 source: Box<dyn FnMut() -> Result<Option<Vec<u8>>, FabricError>>,
264 stages: Vec<StageEntry>,
265 _marker: core::marker::PhantomData<T>,
266}
267
268impl<T: Serialize + DeserializeOwned + 'static> PipelineHandle<T> {
269 pub fn run(&mut self) -> Result<(), FabricError> {
271 let _ = run_stages(self.source.as_mut(), &mut self.stages)?;
272 Ok(())
273 }
274}
275
276pub struct SinkPipeline<T: 'static, K: Sink<T>> {
278 source: Box<dyn FnMut() -> Result<Option<Vec<u8>>, FabricError>>,
279 stages: Vec<StageEntry>,
280 sink: K,
281 _marker: core::marker::PhantomData<T>,
282}
283
284impl<T, K> SinkPipeline<T, K>
285where
286 T: Serialize + DeserializeOwned + 'static,
287 K: Sink<T>,
288{
289 pub fn run(&mut self) -> Result<(), FabricError> {
291 let items = run_stages(self.source.as_mut(), &mut self.stages)?;
292 for item_bytes in items {
293 let item: T =
294 postcard::from_bytes(&item_bytes).map_err(|_| FabricError::IoError(-1))?;
295 self.sink.accept(item)?;
296 }
297 Ok(())
298 }
299}
300
301pub struct FoldPipeline<T: 'static, Acc, F: FnMut(Acc, T) -> Acc> {
303 source: Box<dyn FnMut() -> Result<Option<Vec<u8>>, FabricError>>,
304 stages: Vec<StageEntry>,
305 init: Option<Acc>,
306 f: F,
307 _marker: core::marker::PhantomData<T>,
308}
309
310impl<T, Acc, F> FoldPipeline<T, Acc, F>
311where
312 T: Serialize + DeserializeOwned + 'static,
313 F: FnMut(Acc, T) -> Acc,
314{
315 pub fn run(&mut self) -> Result<Acc, FabricError> {
317 let items = run_stages(self.source.as_mut(), &mut self.stages)?;
318 let mut acc = self.init.take().ok_or(FabricError::IoError(-10))?;
319 for item_bytes in items {
320 let item: T =
321 postcard::from_bytes(&item_bytes).map_err(|_| FabricError::IoError(-1))?;
322 acc = (self.f)(acc, item);
323 }
324 Ok(acc)
325 }
326}
327
328pub struct CollectPipeline<T: 'static> {
330 source: Box<dyn FnMut() -> Result<Option<Vec<u8>>, FabricError>>,
331 stages: Vec<StageEntry>,
332 _marker: core::marker::PhantomData<T>,
333}
334
335impl<T: Serialize + DeserializeOwned + 'static> CollectPipeline<T> {
336 pub fn run(&mut self) -> Result<Vec<T>, FabricError> {
338 let items = run_stages(self.source.as_mut(), &mut self.stages)?;
339 deserialize_items(items)
340 }
341}
342
343pub struct BufferedPipeline<T: 'static> {
349 source: Box<dyn FnMut() -> Result<Option<Vec<u8>>, FabricError>>,
350 stages: Vec<StageEntry>,
351 queue_capacity: usize,
352 queue_stride: usize,
353 _marker: core::marker::PhantomData<T>,
354}
355
356impl<T: Serialize + DeserializeOwned + 'static> BufferedPipeline<T> {
357 pub fn new(
359 source: impl Source<T> + 'static,
360 stages: Vec<StageEntry>,
361 capacity: usize,
362 stride: usize,
363 ) -> Self {
364 let mut source = source;
365 let erased: Box<dyn FnMut() -> Result<Option<Vec<u8>>, FabricError>> =
366 Box::new(move || match source.next()? {
367 Some(item) => {
368 let bytes =
369 postcard::to_allocvec(&item).map_err(|_| FabricError::IoError(-1))?;
370 Ok(Some(bytes))
371 }
372 None => Ok(None),
373 });
374
375 BufferedPipeline {
376 source: erased,
377 stages,
378 queue_capacity: capacity,
379 queue_stride: stride,
380 _marker: core::marker::PhantomData,
381 }
382 }
383
384 pub fn run(&mut self) -> Result<Vec<T>, FabricError> {
389 let mut queue: FabricQueue<Vec<u8>> =
391 FabricQueue::with_capacity(self.queue_capacity, self.queue_stride)?;
392
393 while let Some(bytes) = (self.source)()? {
394 if !queue.push(&bytes)? {
395 return Err(FabricError::CapacityExceeded);
396 }
397 }
398
399 for stage in self.stages.iter_mut() {
401 let mut next_queue: FabricQueue<Vec<u8>> =
402 FabricQueue::with_capacity(self.queue_capacity, self.queue_stride)?;
403
404 while let Some(item_bytes) = queue.pop()? {
405 match &mut stage.kind {
406 StageKind::Map(f) => {
407 if let Some(out_bytes) = f(&item_bytes)? {
408 if !next_queue.push(&out_bytes)? {
409 return Err(FabricError::CapacityExceeded);
410 }
411 }
412 }
413 StageKind::Filter(pred) => {
414 if pred(&item_bytes)? && !next_queue.push(&item_bytes)? {
415 return Err(FabricError::CapacityExceeded);
416 }
417 }
418 }
419 }
420
421 queue = next_queue;
422 }
423
424 let mut result = Vec::new();
426 while let Some(item_bytes) = queue.pop()? {
427 let item: T =
428 postcard::from_bytes(&item_bytes).map_err(|_| FabricError::IoError(-1))?;
429 result.push(item);
430 }
431 Ok(result)
432 }
433}
434
435pub fn make_map_stage(f: MapFn) -> StageEntry {
437 StageEntry {
438 kind: StageKind::Map(f),
439 constraint: NodeConstraint::Any,
440 buffer_size: 1024,
441 }
442}
443
444pub fn make_filter_stage(f: FilterFn) -> StageEntry {
446 StageEntry {
447 kind: StageKind::Filter(f),
448 constraint: NodeConstraint::Any,
449 buffer_size: 1024,
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456 use crate::sink::CountSink;
457 use crate::source::VecSource;
458 use grafos_std::host;
459
460 fn setup_mock() {
461 host::reset_mock();
462 host::mock_set_fbmu_arena_size(65536);
463 }
464
465 #[test]
466 fn simple_map_pipeline() {
467 setup_mock();
468 let source = VecSource::new(vec![1u32, 2, 3, 4, 5]);
469 let mut pipeline = Pipeline::from_source(source).map(|x: u32| x * 2).collect();
470 let result = pipeline.run().expect("run");
471 assert_eq!(result, vec![2, 4, 6, 8, 10]);
472 }
473
474 #[test]
475 fn filter_pipeline() {
476 setup_mock();
477 let source = VecSource::new(vec![1u32, 5, 10, 15, 20, 25]);
478 let mut pipeline = Pipeline::from_source(source)
479 .filter(|x: &u32| *x > 10)
480 .collect();
481 let result = pipeline.run().expect("run");
482 assert_eq!(result, vec![15, 20, 25]);
483 }
484
485 #[test]
486 fn fold_pipeline() {
487 setup_mock();
488 let source = VecSource::new(vec![1u32, 2, 3, 4, 5]);
489 let mut pipeline = Pipeline::from_source(source).fold(0u64, |acc, x: u32| acc + x as u64);
490 let result = pipeline.run().expect("run");
491 assert_eq!(result, 15);
492 }
493
494 #[test]
495 fn sink_pipeline() {
496 setup_mock();
497 let source = VecSource::new(vec![10u32, 20, 30]);
498 let sink: CountSink<u32> = CountSink::new();
499 let mut pipeline = Pipeline::from_source(source).sink(sink);
500 pipeline.run().expect("run");
501 }
502
503 #[test]
504 fn map_then_filter() {
505 setup_mock();
506 let source = VecSource::new(vec![1u32, 2, 3, 4, 5]);
507 let mut pipeline = Pipeline::from_source(source)
508 .map(|x: u32| x * 2)
509 .filter(|x: &u32| *x > 4)
510 .collect();
511 let result = pipeline.run().expect("run");
512 assert_eq!(result, vec![6, 8, 10]);
513 }
514
515 #[test]
516 fn multi_stage_map_filter_fold() {
517 setup_mock();
518 let source = VecSource::new(vec![1u32, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
519 let mut pipeline = Pipeline::from_source(source)
520 .map(|x: u32| x * 3)
521 .filter(|x: &u32| *x > 10)
522 .fold(0u64, |acc, x: u32| acc + x as u64);
523 let result = pipeline.run().expect("run");
526 assert_eq!(result, 147);
527 }
528
529 #[test]
530 fn empty_source() {
531 setup_mock();
532 let source = VecSource::<u32>::new(vec![]);
533 let mut pipeline = Pipeline::from_source(source).collect();
534 let result = pipeline.run().expect("run");
535 assert!(result.is_empty());
536 }
537
538 #[test]
539 fn empty_source_fold() {
540 setup_mock();
541 let source = VecSource::<u32>::new(vec![]);
542 let mut pipeline = Pipeline::from_source(source).fold(42u64, |acc, _x: u32| acc + 1);
543 let result = pipeline.run().expect("run");
544 assert_eq!(result, 42);
545 }
546
547 #[test]
548 fn filter_removes_everything() {
549 setup_mock();
550 let source = VecSource::new(vec![1u32, 2, 3]);
551 let mut pipeline = Pipeline::from_source(source)
552 .filter(|_x: &u32| false)
553 .collect();
554 let result = pipeline.run().expect("run");
555 assert!(result.is_empty());
556 }
557
558 #[test]
559 fn on_node_is_accepted() {
560 setup_mock();
561 let source = VecSource::new(vec![1u32]);
562 let mut pipeline = Pipeline::from_source(source)
563 .on_node(NodeConstraint::HasMemory(1024))
564 .map(|x: u32| x + 1)
565 .collect();
566 let result = pipeline.run().expect("run");
567 assert_eq!(result, vec![2]);
568 }
569
570 #[test]
571 fn buffer_size_is_accepted() {
572 setup_mock();
573 let source = VecSource::new(vec![1u32, 2, 3]);
574 let mut pipeline = Pipeline::from_source(source)
575 .buffer_size(64)
576 .map(|x: u32| x + 10)
577 .collect();
578 let result = pipeline.run().expect("run");
579 assert_eq!(result, vec![11, 12, 13]);
580 }
581
582 #[test]
583 fn buffered_pipeline_with_fabric_queue() {
584 setup_mock();
585 let source = VecSource::new(vec![1u32, 2, 3, 4, 5]);
586
587 let map_stage = make_map_stage(Box::new(|bytes: &[u8]| {
588 let item: u32 = postcard::from_bytes(bytes).map_err(|_| FabricError::IoError(-1))?;
589 let out = item * 2;
590 let out_bytes = postcard::to_allocvec(&out).map_err(|_| FabricError::IoError(-1))?;
591 Ok(Some(out_bytes))
592 }));
593
594 let mut buffered = BufferedPipeline::new(source, vec![map_stage], 32, 64);
595 let result = buffered.run().expect("run");
596 assert_eq!(result, vec![2u32, 4, 6, 8, 10]);
597 }
598
599 #[test]
600 fn backpressure_small_buffer() {
601 setup_mock();
602 let source = VecSource::new(vec![1u32, 2, 3]);
604 let mut buffered = BufferedPipeline::new(source, vec![], 4, 64);
605 let result = buffered.run().expect("run");
606 assert_eq!(result, vec![1u32, 2, 3]);
607 }
608
609 #[test]
610 fn pipeline_handle_run() {
611 setup_mock();
612 let source = VecSource::new(vec![1u32, 2, 3]);
613 let mut handle = Pipeline::from_source(source)
614 .map(|x: u32| x + 100)
615 .build()
616 .expect("build");
617 handle.run().expect("run");
618 }
619
620 #[test]
621 fn count_sink_standalone() {
622 use crate::stage::Sink as _;
623 let mut sink: CountSink<u32> = CountSink::new();
624 sink.accept(10).unwrap();
625 sink.accept(20).unwrap();
626 sink.accept(30).unwrap();
627 assert_eq!(sink.count(), 3);
628 }
629
630 #[test]
631 fn buffered_pipeline_filter() {
632 setup_mock();
633 let source = VecSource::new(vec![1u32, 2, 3, 4, 5]);
634
635 let filter_stage = make_filter_stage(Box::new(|bytes: &[u8]| {
636 let item: u32 = postcard::from_bytes(bytes).map_err(|_| FabricError::IoError(-1))?;
637 Ok(item > 3)
638 }));
639
640 let mut buffered = BufferedPipeline::new(source, vec![filter_stage], 32, 64);
641 let result = buffered.run().expect("run");
642 assert_eq!(result, vec![4u32, 5]);
643 }
644
645 #[test]
646 fn buffered_pipeline_map_then_filter() {
647 setup_mock();
648 let source = VecSource::new(vec![1u32, 2, 3, 4, 5]);
649
650 let map_stage = make_map_stage(Box::new(|bytes: &[u8]| {
651 let item: u32 = postcard::from_bytes(bytes).map_err(|_| FabricError::IoError(-1))?;
652 let out = item * 10;
653 let out_bytes = postcard::to_allocvec(&out).map_err(|_| FabricError::IoError(-1))?;
654 Ok(Some(out_bytes))
655 }));
656
657 let filter_stage = make_filter_stage(Box::new(|bytes: &[u8]| {
658 let item: u32 = postcard::from_bytes(bytes).map_err(|_| FabricError::IoError(-1))?;
659 Ok(item > 25)
660 }));
661
662 let mut buffered = BufferedPipeline::new(source, vec![map_stage, filter_stage], 32, 64);
663 let result = buffered.run().expect("run");
664 assert_eq!(result, vec![30u32, 40, 50]);
665 }
666}