@@ -36,15 +36,7 @@ using arrow::internal::VisitSetBitRunsVoid;
3636
3737struct TDigestBaseImpl : public ScalarAggregator {
3838 explicit TDigestBaseImpl (std::shared_ptr<TDigest::Scaler> scaler, uint32_t buffer_size)
39- : tdigest{std::move (scaler), buffer_size}, count{0 }, all_valid{true } {
40- out_type = struct_ ({
41- field (" mean" , list (field (" item" , float64 (), false )), false ),
42- field (" weight" , list (field (" item" , float64 (), false )), false ),
43- field (" min" , float64 (), true ),
44- field (" max" , float64 (), true ),
45- field (" count" , uint64 (), false ),
46- });
47- }
39+ : tdigest{std::move (scaler), buffer_size}, count{0 }, all_valid{true } {}
4840
4941 Status MergeFrom (KernelContext*, KernelState&& src) override {
5042 const auto & other = checked_cast<const TDigestBaseImpl&>(src);
@@ -71,7 +63,20 @@ struct TDigestBaseImpl : public ScalarAggregator {
7163 TDigest tdigest;
7264 uint64_t count;
7365 bool all_valid;
74- std::shared_ptr<DataType> out_type;
66+ static const std::shared_ptr<DataType>& out_type () {
67+ static auto out_type = struct_ ({
68+ field (" centroids" ,
69+ list (field (" item" ,
70+ struct_ ({field (" mean" , float64 (), false ),
71+ field (" weight" , float64 (), false )}),
72+ false )),
73+ false ),
74+ field (" min" , float64 (), true ),
75+ field (" max" , float64 (), true ),
76+ field (" count" , uint64 (), false ),
77+ });
78+ return out_type;
79+ }
7580};
7681
7782struct TDigestQuantileFinalizer : public TDigestBaseImpl {
@@ -126,7 +131,7 @@ struct TDigestCentroidFinalizer : public TDigestBaseImpl {
126131
127132 Status Finalize (KernelContext* ctx, Datum* out) override {
128133 if (!this ->all_valid ) {
129- *out = MakeNullScalar (out_type);
134+ *out = MakeNullScalar (out_type () );
130135 } else {
131136 // Float64Array
132137 const int64_t out_length = this ->tdigest .GetCentroidCount ();
@@ -145,10 +150,16 @@ struct TDigestCentroidFinalizer : public TDigestBaseImpl {
145150 std::tie (mean_buffer[i], weight_buffer[i]) = this ->tdigest .GetCentroid (i);
146151 }
147152
148- auto mean = std::make_shared<ListScalar>(MakeArray (mean_data),
149- list (field (" item" , float64 (), false )));
150- auto weight = std::make_shared<ListScalar>(MakeArray (weight_data),
151- list (field (" item" , float64 (), false )));
153+ ARROW_ASSIGN_OR_RAISE (
154+ auto centroids,
155+ StructArray::Make (
156+ {MakeArray (mean_data), MakeArray (weight_data)},
157+ {field (" mean" , float64 (), false ), field (" weight" , float64 (), false )}));
158+ auto centroids_scalar = std::make_shared<ListScalar>(
159+ centroids, list (field (" item" ,
160+ struct_ ({field (" mean" , float64 (), false ),
161+ field (" weight" , float64 (), false )}),
162+ false )));
152163 auto count = std::make_shared<UInt64Scalar>(this ->count );
153164 std::shared_ptr<Scalar> min, max;
154165 if (this ->count ) {
@@ -158,7 +169,8 @@ struct TDigestCentroidFinalizer : public TDigestBaseImpl {
158169 min = max = MakeNullScalar (float64 ());
159170 }
160171 *out = std::make_shared<StructScalar>(
161- std::vector<std::shared_ptr<Scalar>>{mean, weight, min, max, count}, out_type);
172+ std::vector<std::shared_ptr<Scalar>>{centroids_scalar, min, max, count},
173+ out_type ());
162174 }
163175
164176 return Status::OK ();
@@ -233,13 +245,15 @@ struct TDigestCentroidConsumerImpl : public TDigestFinalizer_T {
233245
234246 Status Consume (const Scalar* scalar) {
235247 const auto * input_struct_scalar = checked_cast<const StructScalar*>(scalar);
236- auto mean_array =
248+ auto centroids_array =
237249 checked_cast<const ListScalar*>(input_struct_scalar->value [0 ].get ())->value ;
238- auto weight_array =
239- checked_cast<const ListScalar*>(input_struct_scalar->value [1 ].get ())->value ;
240- auto min = checked_cast<const DoubleScalar*>(input_struct_scalar->value [2 ].get ());
241- auto max = checked_cast<const DoubleScalar*>(input_struct_scalar->value [3 ].get ());
242- auto count = checked_cast<const UInt64Scalar*>(input_struct_scalar->value [4 ].get ());
250+ auto centroids_struct_array = checked_cast<const StructArray*>(centroids_array.get ());
251+ auto mean_array = centroids_struct_array->field (0 );
252+ auto weight_array = centroids_struct_array->field (1 );
253+ checked_cast<const ListScalar*>(input_struct_scalar->value [1 ].get ())->value ;
254+ auto min = checked_cast<const DoubleScalar*>(input_struct_scalar->value [1 ].get ());
255+ auto max = checked_cast<const DoubleScalar*>(input_struct_scalar->value [2 ].get ());
256+ auto count = checked_cast<const UInt64Scalar*>(input_struct_scalar->value [3 ].get ());
243257 auto mean_double_array = checked_cast<const DoubleArray*>(mean_array.get ());
244258 auto weight_double_array = checked_cast<const DoubleArray*>(weight_array.get ());
245259 DCHECK_EQ (mean_double_array->length (), weight_double_array->length ());
@@ -282,11 +296,6 @@ struct TDigestCentroidConsumerImpl : public TDigestFinalizer_T {
282296template <typename ArrowType>
283297struct TDigestImpl
284298 : public TDigestInputConsumerImpl<ArrowType, TDigestQuantileFinalizer> {
285- // using TDigestBaseImpl::all_valid;
286- // using TDigestBaseImpl::count;
287- // using TDigestBaseImpl::out_type;
288- // using TDigestBaseImpl::tdigest;
289-
290299 explicit TDigestImpl (const TDigestOptions& options, const DataType& in_type,
291300 std::shared_ptr<TDigest::Scaler> scaler)
292301 : TDigestInputConsumerImpl<ArrowType, TDigestQuantileFinalizer>(
@@ -441,47 +450,52 @@ struct TDigestInitState {
441450 }
442451};
443452
444- struct TDigestCentroidTypeMatcher : public TypeMatcher {
445- ~TDigestCentroidTypeMatcher () override = default ;
446-
447- bool Matches (const DataType& type) const override {
448- if (Type::STRUCT == type.id ()) {
449- const auto & input_struct_type = checked_cast<const StructType&>(type);
450- if (5 == input_struct_type.num_fields ()) {
451- if (Type::LIST == input_struct_type.field (0 )->type ()->id () &&
452- input_struct_type.field (0 )->type ()->Equals (
453- input_struct_type.field (1 )->type ()) &&
454- Type::DOUBLE == input_struct_type.field (2 )->type ()->id () &&
455- Type::DOUBLE == input_struct_type.field (3 )->type ()->id () &&
456- Type::UINT64 == input_struct_type.field (4 )->type ()->id ()) {
457- return true ;
458- }
459- }
460- }
461- return false ;
462- }
463-
464- static std::string ToStringStatic () {
465- return " struct{mean:list<item: double not null>[N] not null, "
466- " weight:fixed_size_list<item: "
467- " double not null>[N] not null, min:float64, max:float64, count:int64 not "
468- " null}" ;
469- }
470- std::string ToString () const override { return ToStringStatic (); }
471-
472- bool Equals (const TypeMatcher& other) const override {
473- if (this == &other) {
474- return true ;
475- }
476- auto casted = dynamic_cast <const TDigestCentroidTypeMatcher*>(&other);
477- return casted != nullptr ;
478- }
479-
480- static std::shared_ptr<TDigestCentroidTypeMatcher> getMatcher () {
481- static auto matcher = std::make_shared<TDigestCentroidTypeMatcher>();
482- return matcher;
483- }
484- };
453+ // struct TDigestCentroidTypeMatcher : public TypeMatcher {
454+ // ~TDigestCentroidTypeMatcher() override = default;
455+
456+ // bool Matches(const DataType& type) const override {
457+ // if (Type::STRUCT == type.id()) {
458+ // const auto& input_struct_type = checked_cast<const StructType&>(type);
459+ // if (4 == input_struct_type.num_fields()) {
460+ // if (Type::LIST == input_struct_type.field(0)->type()->id() &&
461+ // Type::DOUBLE == input_struct_type.field(1)->type()->id() &&
462+ // Type::DOUBLE == input_struct_type.field(2)->type()->id() &&
463+ // Type::UINT64 == input_struct_type.field(3)->type()->id()) {
464+ // const auto& centroid_struct_type = checked_cast<const
465+ // StructType&>(input_struct_type.field(0)->type());
466+ // if (2 == centroid_struct_type.num_fields()) {
467+ // if (Type::DOUBLE == centroid_struct_type.field(0)->type()->id() &&
468+ // Type::DOUBLE == input_struct_type.field(1)->type()->id()){
469+ // return true;
470+ // }
471+ // }
472+ // }
473+ // }
474+ // }
475+ // return false;
476+ // }
477+
478+ // static std::string ToStringStatic() {
479+ // return "struct{mean:list<item: double not null>[N] not null, "
480+ // "weight:fixed_size_list<item: "
481+ // "double not null>[N] not null, min:float64, max:float64, count:int64 not "
482+ // "null}";
483+ // }
484+ // std::string ToString() const override { return ToStringStatic(); }
485+
486+ // bool Equals(const TypeMatcher& other) const override {
487+ // if (this == &other) {
488+ // return true;
489+ // }
490+ // auto casted = dynamic_cast<const TDigestCentroidTypeMatcher*>(&other);
491+ // return casted != nullptr;
492+ // }
493+
494+ // static std::shared_ptr<TDigestCentroidTypeMatcher> getMatcher() {
495+ // static auto matcher = std::make_shared<TDigestCentroidTypeMatcher>();
496+ // return matcher;
497+ // }
498+ // };
485499
486500Result<std::unique_ptr<KernelState>> TDigestInit (KernelContext* ctx,
487501 const KernelInitArgs& args) {
@@ -525,7 +539,7 @@ void AddTDigestKernels(KernelInit init,
525539Result<TypeHolder> TDigestMapReduceType (KernelContext* ctx,
526540 const std::vector<TypeHolder>& types) {
527541 auto base = checked_cast<TDigestBaseImpl*>(ctx->state ());
528- return base->out_type ;
542+ return base->out_type () ;
529543}
530544
531545void AddTDigestMapKernels (KernelInit init,
@@ -538,14 +552,13 @@ void AddTDigestMapKernels(KernelInit init,
538552}
539553
540554void AddTDigestReduceKernels (KernelInit init, ScalarAggregateFunction* func) {
541- auto sig = KernelSignature::Make ({InputType (TDigestCentroidTypeMatcher::getMatcher ())},
555+ auto sig = KernelSignature::Make ({InputType (TDigestBaseImpl::out_type ())},
542556 TDigestMapReduceType);
543557 AddAggKernel (std::move (sig), init, func);
544558}
545559
546560void AddTDigestQuantileKernels (KernelInit init, ScalarAggregateFunction* func) {
547- auto sig = KernelSignature::Make ({InputType (TDigestCentroidTypeMatcher::getMatcher ())},
548- float64 ());
561+ auto sig = KernelSignature::Make ({InputType (TDigestBaseImpl::out_type ())}, float64 ());
549562 AddAggKernel (std::move (sig), init, func);
550563}
551564
@@ -628,7 +641,7 @@ std::shared_ptr<ScalarFunction> AddTDigestQuantileScalarKernels() {
628641 std::make_shared<ScalarFunction>(" tdigest_quantile_element_wise" , Arity::Unary (),
629642 tdigest_quantile_doc, &default_tdigest_options);
630643 auto output = OutputType{TDigestQuantileScalarImpl::ResolveOutput};
631- ScalarKernel kernel ({InputType (TDigestCentroidTypeMatcher::getMatcher ())}, output,
644+ ScalarKernel kernel ({InputType (TDigestBaseImpl::out_type ())}, output,
632645 TDigestQuantileScalarImpl::Exec, TDigestQuantileScalarImpl::Init);
633646 kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
634647 kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
0 commit comments