|
17 | 17 |
|
18 | 18 | #include "arrow/dataset/file_base.h" |
19 | 19 |
|
| 20 | +#include "arrow/acero/accumulation_queue.h" |
20 | 21 | #include "arrow/acero/exec_plan.h" |
21 | 22 |
|
22 | 23 | #include <algorithm> |
@@ -555,24 +556,30 @@ Result<acero::ExecNode*> MakeWriteNode(acero::ExecPlan* plan, |
555 | 556 | auto node, |
556 | 557 | // to preserve order explicitly sequence the exec batches |
557 | 558 | // this requires exec batch index to be set upstream (e.g. by SourceNode) |
558 | | - acero::MakeExecNode("consuming_sink", plan, std::move(inputs), |
559 | | - acero::ConsumingSinkNodeOptions{ |
560 | | - std::move(consumer), |
561 | | - {}, |
562 | | - /*sequence_output=*/write_options.preserve_order})); |
| 559 | + acero::MakeExecNode( |
| 560 | + "consuming_sink", plan, std::move(inputs), |
| 561 | + acero::ConsumingSinkNodeOptions{ |
| 562 | + std::move(consumer), |
| 563 | + {}, |
| 564 | + /*sequence_output=*/write_node_options.write_options.preserve_order})); |
563 | 565 |
|
564 | 566 | return node; |
565 | 567 | } |
566 | 568 |
|
567 | 569 | namespace { |
568 | 570 |
|
569 | | -class TeeNode : public acero::MapNode { |
| 571 | +class TeeNode : public acero::MapNode, |
| 572 | + public arrow::acero::util::SerialSequencingQueue::Processor { |
570 | 573 | public: |
571 | 574 | TeeNode(acero::ExecPlan* plan, std::vector<acero::ExecNode*> inputs, |
572 | 575 | std::shared_ptr<Schema> output_schema, |
573 | 576 | FileSystemDatasetWriteOptions write_options) |
574 | 577 | : MapNode(plan, std::move(inputs), std::move(output_schema)), |
575 | | - write_options_(std::move(write_options)) {} |
| 578 | + write_options_(std::move(write_options)) { |
| 579 | + if (write_options.preserve_order) { |
| 580 | + sequencer_ = acero::util::SerialSequencingQueue::Make(this); |
| 581 | + } |
| 582 | + } |
576 | 583 |
|
577 | 584 | Status StartProducing() override { |
578 | 585 | ARROW_ASSIGN_OR_RAISE( |
@@ -602,6 +609,18 @@ class TeeNode : public acero::MapNode { |
602 | 609 |
|
603 | 610 | const char* kind_name() const override { return "TeeNode"; } |
604 | 611 |
|
| 612 | + Status InputReceived(ExecNode* input, ExecBatch batch) override { |
| 613 | + DCHECK_EQ(input, inputs_[0]); |
| 614 | + if (sequencer_) { |
| 615 | + return sequencer_->InsertBatch(std::move(batch)); |
| 616 | + } |
| 617 | + return Process(std::move(batch)); |
| 618 | + } |
| 619 | + |
| 620 | + Status Process(ExecBatch batch) override { |
| 621 | + return acero::MapNode::InputReceived(inputs_[0], batch); |
| 622 | + } |
| 623 | + |
605 | 624 | void Finish() override { dataset_writer_->Finish(); } |
606 | 625 |
|
607 | 626 | Result<compute::ExecBatch> ProcessBatch(compute::ExecBatch batch) override { |
@@ -635,6 +654,7 @@ class TeeNode : public acero::MapNode { |
635 | 654 | std::unique_ptr<internal::DatasetWriter> dataset_writer_; |
636 | 655 | FileSystemDatasetWriteOptions write_options_; |
637 | 656 | std::atomic<int32_t> backpressure_counter_ = 0; |
| 657 | + std::unique_ptr<acero::util::SerialSequencingQueue> sequencer_; |
638 | 658 | }; |
639 | 659 |
|
640 | 660 | } // namespace |
|
0 commit comments