Skip to content

Commit c2c5a35

Browse files
author
Rafał Hibner
committed
Fix array consume of centroids
1 parent 0de6e41 commit c2c5a35

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

cpp/src/arrow/compute/kernels/aggregate_tdigest.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,16 +231,20 @@ struct TDigestCentroidConsumerImpl : public TDigestFinalizer_T {
231231
auto mean_double_array = checked_cast<const DoubleArray*>(mean_array.get());
232232
auto weight_double_array = checked_cast<const DoubleArray*>(weight_array.get());
233233
DCHECK_EQ(mean_double_array->length(), weight_double_array->length());
234-
if (min->is_valid) {
234+
auto count_uint64 = count->value;
235+
if (count_uint64) {
236+
DCHECK(min->is_valid);
235237
DCHECK(max->is_valid);
238+
this->count += count_uint64;
236239
this->tdigest.SetMinMax(min->value, max->value);
240+
for (int64_t i = 0; i < mean_double_array->length(); i++) {
241+
this->tdigest.NanAdd(mean_double_array->Value(i), weight_double_array->Value(i));
242+
}
237243
} else {
244+
DCHECK(!min->is_valid);
238245
DCHECK(!max->is_valid);
239246
}
240-
for (int64_t i = 0; i < mean_double_array->length(); i++) {
241-
this->tdigest.NanAdd(mean_double_array->Value(i), weight_double_array->Value(i));
242-
}
243-
this->count += count->value;
247+
244248
return Status::OK();
245249
}
246250
Status Consume(KernelContext*, const ExecSpan& batch) override {
@@ -252,7 +256,7 @@ struct TDigestCentroidConsumerImpl : public TDigestFinalizer_T {
252256
if (batch[0].is_array()) {
253257
std::shared_ptr<Array> array = MakeArray(batch[0].array.ToArrayData());
254258
for (int i = 0; i < array->length(); ++i) {
255-
ARROW_ASSIGN_OR_RAISE(auto scalar, array->GetScalar(0));
259+
ARROW_ASSIGN_OR_RAISE(auto scalar, array->GetScalar(i));
256260
ARROW_RETURN_NOT_OK(Consume(scalar.get()));
257261
}
258262
} else {

cpp/src/arrow/compute/kernels/aggregate_test.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4640,6 +4640,32 @@ TEST(TestTDigestReduceKernel, Basic) {
46404640
ResultWith(ScalarFromJSON(type,
46414641
"{\"mean\":[1.5, 3.5, 5.5],\"weight\":[2, 2, "
46424642
"2],\"min\":1.0,\"max\":6.0,\"count\":6}")));
4643+
4644+
EXPECT_THAT(TDigestReduce(
4645+
ArrayFromJSON(
4646+
type,
4647+
"["
4648+
"{\"mean\":[],\"weight\":[],\"min\":null,\"max\":null,\"count\":0},"
4649+
"{\"mean\":[1.5, 3.5, 5.5],\"weight\":[2, 2, "
4650+
"2],\"min\":1.0,\"max\":6.0,\"count\":6}"
4651+
"]"),
4652+
options),
4653+
ResultWith(ScalarFromJSON(type,
4654+
"{\"mean\":[1.5, 3.5, 5.5],\"weight\":[2, 2, "
4655+
"2],\"min\":1.0,\"max\":6.0,\"count\":6}")));
4656+
4657+
EXPECT_THAT(TDigestReduce(
4658+
ArrayFromJSON(
4659+
type,
4660+
"["
4661+
"{\"mean\":[1.5, 3.5, 5.5],\"weight\":[2, 2, "
4662+
"2],\"min\":1.0,\"max\":6.0,\"count\":6},"
4663+
"{\"mean\":[],\"weight\":[],\"min\":null,\"max\":null,\"count\":0}"
4664+
"]"),
4665+
options),
4666+
ResultWith(ScalarFromJSON(type,
4667+
"{\"mean\":[1.5, 3.5, 5.5],\"weight\":[2, "
4668+
"2, 2],\"min\":1.0,\"max\":6.0,\"count\":6}")));
46434669
}
46444670

46454671
TEST(TestTDigestQuantileKernel, Basic) {

0 commit comments

Comments
 (0)