Skip to content

Commit 5f17c46

Browse files
author
Rafał Hibner
committed
Centroids as vector of structs
1 parent 046147f commit 5f17c46

File tree

2 files changed

+242
-181
lines changed

2 files changed

+242
-181
lines changed

cpp/src/arrow/compute/kernels/aggregate_tdigest.cc

Lines changed: 86 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,7 @@ using arrow::internal::VisitSetBitRunsVoid;
3636

3737
struct TDigestBaseImpl : public ScalarAggregator {
3838
explicit TDigestBaseImpl(std::shared_ptr<TDigest::Scaler> scaler, uint32_t buffer_size)
39-
: tdigest{std::move(scaler), buffer_size}, count{0}, all_valid{true} {
40-
out_type = struct_({
41-
field("mean", list(field("item", float64(), false)), false),
42-
field("weight", list(field("item", float64(), false)), false),
43-
field("min", float64(), true),
44-
field("max", float64(), true),
45-
field("count", uint64(), false),
46-
});
47-
}
39+
: tdigest{std::move(scaler), buffer_size}, count{0}, all_valid{true} {}
4840

4941
Status MergeFrom(KernelContext*, KernelState&& src) override {
5042
const auto& other = checked_cast<const TDigestBaseImpl&>(src);
@@ -71,7 +63,20 @@ struct TDigestBaseImpl : public ScalarAggregator {
7163
TDigest tdigest;
7264
uint64_t count;
7365
bool all_valid;
74-
std::shared_ptr<DataType> out_type;
66+
static const std::shared_ptr<DataType>& out_type() {
67+
static auto out_type = struct_({
68+
field("centroids",
69+
list(field("item",
70+
struct_({field("mean", float64(), false),
71+
field("weight", float64(), false)}),
72+
false)),
73+
false),
74+
field("min", float64(), true),
75+
field("max", float64(), true),
76+
field("count", uint64(), false),
77+
});
78+
return out_type;
79+
}
7580
};
7681

7782
struct TDigestQuantileFinalizer : public TDigestBaseImpl {
@@ -126,7 +131,7 @@ struct TDigestCentroidFinalizer : public TDigestBaseImpl {
126131

127132
Status Finalize(KernelContext* ctx, Datum* out) override {
128133
if (!this->all_valid) {
129-
*out = MakeNullScalar(out_type);
134+
*out = MakeNullScalar(out_type());
130135
} else {
131136
// Float64Array
132137
const int64_t out_length = this->tdigest.GetCentroidCount();
@@ -145,10 +150,16 @@ struct TDigestCentroidFinalizer : public TDigestBaseImpl {
145150
std::tie(mean_buffer[i], weight_buffer[i]) = this->tdigest.GetCentroid(i);
146151
}
147152

148-
auto mean = std::make_shared<ListScalar>(MakeArray(mean_data),
149-
list(field("item", float64(), false)));
150-
auto weight = std::make_shared<ListScalar>(MakeArray(weight_data),
151-
list(field("item", float64(), false)));
153+
ARROW_ASSIGN_OR_RAISE(
154+
auto centroids,
155+
StructArray::Make(
156+
{MakeArray(mean_data), MakeArray(weight_data)},
157+
{field("mean", float64(), false), field("weight", float64(), false)}));
158+
auto centroids_scalar = std::make_shared<ListScalar>(
159+
centroids, list(field("item",
160+
struct_({field("mean", float64(), false),
161+
field("weight", float64(), false)}),
162+
false)));
152163
auto count = std::make_shared<UInt64Scalar>(this->count);
153164
std::shared_ptr<Scalar> min, max;
154165
if (this->count) {
@@ -158,7 +169,8 @@ struct TDigestCentroidFinalizer : public TDigestBaseImpl {
158169
min = max = MakeNullScalar(float64());
159170
}
160171
*out = std::make_shared<StructScalar>(
161-
std::vector<std::shared_ptr<Scalar>>{mean, weight, min, max, count}, out_type);
172+
std::vector<std::shared_ptr<Scalar>>{centroids_scalar, min, max, count},
173+
out_type());
162174
}
163175

164176
return Status::OK();
@@ -233,13 +245,15 @@ struct TDigestCentroidConsumerImpl : public TDigestFinalizer_T {
233245

234246
Status Consume(const Scalar* scalar) {
235247
const auto* input_struct_scalar = checked_cast<const StructScalar*>(scalar);
236-
auto mean_array =
248+
auto centroids_array =
237249
checked_cast<const ListScalar*>(input_struct_scalar->value[0].get())->value;
238-
auto weight_array =
239-
checked_cast<const ListScalar*>(input_struct_scalar->value[1].get())->value;
240-
auto min = checked_cast<const DoubleScalar*>(input_struct_scalar->value[2].get());
241-
auto max = checked_cast<const DoubleScalar*>(input_struct_scalar->value[3].get());
242-
auto count = checked_cast<const UInt64Scalar*>(input_struct_scalar->value[4].get());
250+
auto centroids_struct_array = checked_cast<const StructArray*>(centroids_array.get());
251+
auto mean_array = centroids_struct_array->field(0);
252+
auto weight_array = centroids_struct_array->field(1);
253+
checked_cast<const ListScalar*>(input_struct_scalar->value[1].get())->value;
254+
auto min = checked_cast<const DoubleScalar*>(input_struct_scalar->value[1].get());
255+
auto max = checked_cast<const DoubleScalar*>(input_struct_scalar->value[2].get());
256+
auto count = checked_cast<const UInt64Scalar*>(input_struct_scalar->value[3].get());
243257
auto mean_double_array = checked_cast<const DoubleArray*>(mean_array.get());
244258
auto weight_double_array = checked_cast<const DoubleArray*>(weight_array.get());
245259
DCHECK_EQ(mean_double_array->length(), weight_double_array->length());
@@ -282,11 +296,6 @@ struct TDigestCentroidConsumerImpl : public TDigestFinalizer_T {
282296
template <typename ArrowType>
283297
struct TDigestImpl
284298
: public TDigestInputConsumerImpl<ArrowType, TDigestQuantileFinalizer> {
285-
// using TDigestBaseImpl::all_valid;
286-
// using TDigestBaseImpl::count;
287-
// using TDigestBaseImpl::out_type;
288-
// using TDigestBaseImpl::tdigest;
289-
290299
explicit TDigestImpl(const TDigestOptions& options, const DataType& in_type,
291300
std::shared_ptr<TDigest::Scaler> scaler)
292301
: TDigestInputConsumerImpl<ArrowType, TDigestQuantileFinalizer>(
@@ -441,47 +450,52 @@ struct TDigestInitState {
441450
}
442451
};
443452

444-
struct TDigestCentroidTypeMatcher : public TypeMatcher {
445-
~TDigestCentroidTypeMatcher() override = default;
446-
447-
bool Matches(const DataType& type) const override {
448-
if (Type::STRUCT == type.id()) {
449-
const auto& input_struct_type = checked_cast<const StructType&>(type);
450-
if (5 == input_struct_type.num_fields()) {
451-
if (Type::LIST == input_struct_type.field(0)->type()->id() &&
452-
input_struct_type.field(0)->type()->Equals(
453-
input_struct_type.field(1)->type()) &&
454-
Type::DOUBLE == input_struct_type.field(2)->type()->id() &&
455-
Type::DOUBLE == input_struct_type.field(3)->type()->id() &&
456-
Type::UINT64 == input_struct_type.field(4)->type()->id()) {
457-
return true;
458-
}
459-
}
460-
}
461-
return false;
462-
}
463-
464-
static std::string ToStringStatic() {
465-
return "struct{mean:list<item: double not null>[N] not null, "
466-
"weight:fixed_size_list<item: "
467-
"double not null>[N] not null, min:float64, max:float64, count:int64 not "
468-
"null}";
469-
}
470-
std::string ToString() const override { return ToStringStatic(); }
471-
472-
bool Equals(const TypeMatcher& other) const override {
473-
if (this == &other) {
474-
return true;
475-
}
476-
auto casted = dynamic_cast<const TDigestCentroidTypeMatcher*>(&other);
477-
return casted != nullptr;
478-
}
479-
480-
static std::shared_ptr<TDigestCentroidTypeMatcher> getMatcher() {
481-
static auto matcher = std::make_shared<TDigestCentroidTypeMatcher>();
482-
return matcher;
483-
}
484-
};
453+
// struct TDigestCentroidTypeMatcher : public TypeMatcher {
454+
// ~TDigestCentroidTypeMatcher() override = default;
455+
456+
// bool Matches(const DataType& type) const override {
457+
// if (Type::STRUCT == type.id()) {
458+
// const auto& input_struct_type = checked_cast<const StructType&>(type);
459+
// if (4 == input_struct_type.num_fields()) {
460+
// if (Type::LIST == input_struct_type.field(0)->type()->id() &&
461+
// Type::DOUBLE == input_struct_type.field(1)->type()->id() &&
462+
// Type::DOUBLE == input_struct_type.field(2)->type()->id() &&
463+
// Type::UINT64 == input_struct_type.field(3)->type()->id()) {
464+
// const auto& centroid_struct_type = checked_cast<const
465+
// StructType&>(input_struct_type.field(0)->type());
466+
// if (2 == centroid_struct_type.num_fields()) {
467+
// if (Type::DOUBLE == centroid_struct_type.field(0)->type()->id() &&
468+
// Type::DOUBLE == input_struct_type.field(1)->type()->id()){
469+
// return true;
470+
// }
471+
// }
472+
// }
473+
// }
474+
// }
475+
// return false;
476+
// }
477+
478+
// static std::string ToStringStatic() {
479+
// return "struct{mean:list<item: double not null>[N] not null, "
480+
// "weight:fixed_size_list<item: "
481+
// "double not null>[N] not null, min:float64, max:float64, count:int64 not "
482+
// "null}";
483+
// }
484+
// std::string ToString() const override { return ToStringStatic(); }
485+
486+
// bool Equals(const TypeMatcher& other) const override {
487+
// if (this == &other) {
488+
// return true;
489+
// }
490+
// auto casted = dynamic_cast<const TDigestCentroidTypeMatcher*>(&other);
491+
// return casted != nullptr;
492+
// }
493+
494+
// static std::shared_ptr<TDigestCentroidTypeMatcher> getMatcher() {
495+
// static auto matcher = std::make_shared<TDigestCentroidTypeMatcher>();
496+
// return matcher;
497+
// }
498+
// };
485499

486500
Result<std::unique_ptr<KernelState>> TDigestInit(KernelContext* ctx,
487501
const KernelInitArgs& args) {
@@ -525,7 +539,7 @@ void AddTDigestKernels(KernelInit init,
525539
Result<TypeHolder> TDigestMapReduceType(KernelContext* ctx,
526540
const std::vector<TypeHolder>& types) {
527541
auto base = checked_cast<TDigestBaseImpl*>(ctx->state());
528-
return base->out_type;
542+
return base->out_type();
529543
}
530544

531545
void AddTDigestMapKernels(KernelInit init,
@@ -538,14 +552,13 @@ void AddTDigestMapKernels(KernelInit init,
538552
}
539553

540554
void AddTDigestReduceKernels(KernelInit init, ScalarAggregateFunction* func) {
541-
auto sig = KernelSignature::Make({InputType(TDigestCentroidTypeMatcher::getMatcher())},
555+
auto sig = KernelSignature::Make({InputType(TDigestBaseImpl::out_type())},
542556
TDigestMapReduceType);
543557
AddAggKernel(std::move(sig), init, func);
544558
}
545559

546560
void AddTDigestQuantileKernels(KernelInit init, ScalarAggregateFunction* func) {
547-
auto sig = KernelSignature::Make({InputType(TDigestCentroidTypeMatcher::getMatcher())},
548-
float64());
561+
auto sig = KernelSignature::Make({InputType(TDigestBaseImpl::out_type())}, float64());
549562
AddAggKernel(std::move(sig), init, func);
550563
}
551564

@@ -628,7 +641,7 @@ std::shared_ptr<ScalarFunction> AddTDigestQuantileScalarKernels() {
628641
std::make_shared<ScalarFunction>("tdigest_quantile_element_wise", Arity::Unary(),
629642
tdigest_quantile_doc, &default_tdigest_options);
630643
auto output = OutputType{TDigestQuantileScalarImpl::ResolveOutput};
631-
ScalarKernel kernel({InputType(TDigestCentroidTypeMatcher::getMatcher())}, output,
644+
ScalarKernel kernel({InputType(TDigestBaseImpl::out_type())}, output,
632645
TDigestQuantileScalarImpl::Exec, TDigestQuantileScalarImpl::Init);
633646
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
634647
kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;

0 commit comments

Comments
 (0)