Skip to content

Commit 22b3f6c

Browse files
author
Rafał Hibner
committed
Fix tdigest by storing min abd max values
1 parent 01ee0c0 commit 22b3f6c

File tree

4 files changed

+203
-52
lines changed

4 files changed

+203
-52
lines changed

cpp/src/arrow/compute/kernels/aggregate_tdigest.cc

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ struct TDigestBaseImpl : public ScalarAggregator {
3838
auto output_size = tdigest.delta();
3939
out_type = struct_({field("mean", fixed_size_list(float64(), output_size), false),
4040
field("weight", fixed_size_list(float64(), output_size), false),
41-
field("count", uint64(), false)});
41+
field("count", uint64(), false), field("min", float64(), true),
42+
field("max", float64(), true)});
4243
}
4344

4445
Status MergeFrom(KernelContext*, KernelState&& src) override {
@@ -151,8 +152,15 @@ struct TDigestCentroidFinalizer : public TDigestBaseImpl {
151152
auto mean = std::make_shared<FixedSizeListScalar>(MakeArray(mean_data));
152153
auto weight = std::make_shared<FixedSizeListScalar>(MakeArray(weight_data));
153154
auto count = std::make_shared<UInt64Scalar>(this->count);
155+
std::shared_ptr<Scalar> min, max;
156+
if (this->count) {
157+
min = std::make_shared<DoubleScalar>(this->tdigest.Min());
158+
max = std::make_shared<DoubleScalar>(this->tdigest.Max());
159+
} else {
160+
min = max = MakeNullScalar(float64());
161+
}
154162
*out = std::make_shared<StructScalar>(
155-
std::vector<std::shared_ptr<Scalar>>{mean, weight, count}, out_type);
163+
std::vector<std::shared_ptr<Scalar>>{mean, weight, count, min, max}, out_type);
156164
}
157165

158166
return Status::OK();
@@ -234,10 +242,20 @@ struct TDigestCentroidConsumerImpl : public TDigestFinalizer_T {
234242
checked_cast<const FixedSizeListScalar*>(input_struct_scalar->value[1].get())
235243
->value;
236244
auto count = checked_cast<const UInt64Scalar*>(input_struct_scalar->value[2].get());
245+
auto min = checked_cast<const DoubleScalar*>(input_struct_scalar->value[3].get());
246+
auto max = checked_cast<const DoubleScalar*>(input_struct_scalar->value[4].get());
237247
auto mean_double_array = checked_cast<const DoubleArray*>(mean_array.get());
238248
auto weight_double_array = checked_cast<const DoubleArray*>(weight_array.get());
239249
DCHECK_EQ(mean_double_array->length(), this->tdigest.delta());
240250
DCHECK_EQ(weight_double_array->length(), this->tdigest.delta());
251+
252+
if (min->is_valid) {
253+
DCHECK(max->is_valid);
254+
this->tdigest.SetMinMax(min->value, max->value);
255+
256+
} else {
257+
DCHECK(!max->is_valid);
258+
}
241259
for (int64_t i = 0; i < this->tdigest.delta(); i++) {
242260
if (mean_double_array->IsNull(i)) {
243261
break;
@@ -372,11 +390,13 @@ struct TDigestCentroidTypeMatcher : public TypeMatcher {
372390
static Result<uint32_t> getDelta(const DataType& type) {
373391
if (Type::STRUCT == type.id()) {
374392
const auto& input_struct_type = checked_cast<const StructType&>(type);
375-
if (3 == input_struct_type.num_fields()) {
393+
if (5 == input_struct_type.num_fields()) {
376394
if (Type::FIXED_SIZE_LIST == input_struct_type.field(0)->type()->id() &&
377395
input_struct_type.field(0)->type()->Equals(
378396
input_struct_type.field(1)->type()) &&
379-
Type::UINT64 == input_struct_type.field(2)->type()->id()) {
397+
Type::UINT64 == input_struct_type.field(2)->type()->id() &&
398+
Type::DOUBLE == input_struct_type.field(3)->type()->id() &&
399+
Type::DOUBLE == input_struct_type.field(4)->type()->id()) {
380400
auto fsl = checked_cast<const FixedSizeListType*>(
381401
input_struct_type.field(0)->type().get());
382402
return fsl->list_size();
@@ -391,7 +411,7 @@ struct TDigestCentroidTypeMatcher : public TypeMatcher {
391411

392412
static std::string ToStringStatic() {
393413
return "struct{mean:fixed_size_list<item: double>[N], weight:fixed_size_list<item: "
394-
"double>[N], count:int64}";
414+
"double>[N], count:int64, min:float64, max:float64}";
395415
}
396416
std::string ToString() const override { return ToStringStatic(); }
397417

cpp/src/arrow/compute/kernels/aggregate_test.cc

Lines changed: 168 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <gtest/gtest.h>
2727

2828
#include "arrow/array.h"
29+
#include "arrow/array/builder_nested.h"
2930
#include "arrow/chunked_array.h"
3031
#include "arrow/compute/api_aggregate.h"
3132
#include "arrow/compute/api_scalar.h"
@@ -4103,6 +4104,8 @@ class TestRandomQuantileKernel : public TestPrimitiveQuantileKernel<ArrowType> {
41034104
GenerateChunked(chunk_sizes, num_quantiles, &chunked, &quantiles);
41044105

41054106
VerifyTDigest(chunked, quantiles);
4107+
VerifyTDigestMapQuantile(chunked, quantiles);
4108+
VerifyTDigestMapReduceQuantile(chunked, quantiles);
41064109
}
41074110

41084111
void CheckTDigestsSliced(const std::vector<int>& chunk_sizes, int64_t num_quantiles) {
@@ -4119,6 +4122,8 @@ class TestRandomQuantileKernel : public TestPrimitiveQuantileKernel<ArrowType> {
41194122
};
41204123
for (const auto& os : offset_size) {
41214124
VerifyTDigest(chunked->Slice(os[0], os[1]), quantiles);
4125+
VerifyTDigestMapQuantile(chunked->Slice(os[0], os[1]), quantiles);
4126+
VerifyTDigestMapReduceQuantile(chunked->Slice(os[0], os[1]), quantiles);
41224127
}
41234128
}
41244129

@@ -4157,6 +4162,56 @@ class TestRandomQuantileKernel : public TestPrimitiveQuantileKernel<ArrowType> {
41574162
*chunked = ChunkedArray::Make(array_vector).ValueOrDie();
41584163
}
41594164

4165+
void VerifyTDigestMapQuantile(const std::shared_ptr<ChunkedArray>& chunked,
4166+
std::vector<double>& quantiles) {
4167+
ASSERT_OK_AND_ASSIGN(Datum centroids, TDigestMap(chunked));
4168+
TDigestQuantileOptions options(quantiles);
4169+
ASSERT_OK_AND_ASSIGN(Datum out, TDigestQuantile(centroids, options));
4170+
const auto& out_array = out.make_array();
4171+
ValidateOutput(*out_array);
4172+
ASSERT_EQ(out_array->length(), quantiles.size());
4173+
ASSERT_EQ(out_array->null_count(), 0);
4174+
AssertTypeEqual(out_array->type(), float64());
4175+
4176+
// linear interpolated exact quantile as reference
4177+
std::vector<std::vector<Datum>> exact =
4178+
NaiveQuantile(*chunked, quantiles, {QuantileOptions::LINEAR});
4179+
const double* approx = out_array->data()->GetValues<double>(1);
4180+
for (size_t i = 0; i < quantiles.size(); ++i) {
4181+
const auto& exact_scalar = checked_pointer_cast<DoubleScalar>(exact[i][0].scalar());
4182+
const double tolerance = std::fabs(exact_scalar->value) * 0.05;
4183+
EXPECT_NEAR(approx[i], exact_scalar->value, tolerance) << quantiles[i];
4184+
}
4185+
}
4186+
4187+
void VerifyTDigestMapReduceQuantile(const std::shared_ptr<ChunkedArray>& chunked,
4188+
std::vector<double>& quantiles) {
4189+
ArrayVector map_chunks;
4190+
for (const auto& chunk : chunked->chunks()) {
4191+
ASSERT_OK_AND_ASSIGN(Datum centroids, TDigestMap(chunk));
4192+
ASSERT_OK_AND_ASSIGN(auto map_chunk, MakeArrayFromScalar(*centroids.scalar(), 1));
4193+
map_chunks.push_back(std::move(map_chunk));
4194+
}
4195+
auto map_chunked = std::make_shared<ChunkedArray>(std::move(map_chunks));
4196+
TDigestQuantileOptions options(quantiles);
4197+
ASSERT_OK_AND_ASSIGN(Datum out, TDigestQuantile(map_chunked, options));
4198+
const auto& out_array = out.make_array();
4199+
ValidateOutput(*out_array);
4200+
ASSERT_EQ(out_array->length(), quantiles.size());
4201+
ASSERT_EQ(out_array->null_count(), 0);
4202+
AssertTypeEqual(out_array->type(), float64());
4203+
4204+
// linear interpolated exact quantile as reference
4205+
std::vector<std::vector<Datum>> exact =
4206+
NaiveQuantile(*chunked, quantiles, {QuantileOptions::LINEAR});
4207+
const double* approx = out_array->data()->GetValues<double>(1);
4208+
for (size_t i = 0; i < quantiles.size(); ++i) {
4209+
const auto& exact_scalar = checked_pointer_cast<DoubleScalar>(exact[i][0].scalar());
4210+
const double tolerance = std::fabs(exact_scalar->value) * 0.05;
4211+
EXPECT_NEAR(approx[i], exact_scalar->value, tolerance) << quantiles[i];
4212+
}
4213+
}
4214+
41604215
void VerifyTDigest(const std::shared_ptr<ChunkedArray>& chunked,
41614216
std::vector<double>& quantiles) {
41624217
TDigestOptions options(quantiles);
@@ -4456,9 +4511,11 @@ TEST(TestTDigestKernel, Options) {
44564511

44574512
TEST(TestTDigestMapKernel, Options) {
44584513
auto input_type = float64();
4459-
auto output_type = struct_({field("mean", fixed_size_list(float64(), 5), false),
4460-
field("weight", fixed_size_list(float64(), 5), false),
4461-
field("count", uint64(), false)});
4514+
auto output_type =
4515+
struct_({field("mean", fixed_size_list(float64(), 5), false),
4516+
field("weight", fixed_size_list(float64(), 5), false),
4517+
field("count", uint64(), false), field("min", float64(), true),
4518+
field("max", float64(), true)});
44624519
TDigestMapOptions keep_nulls(/*delta=*/5, /*buffer_size=*/500,
44634520
/*skip_nulls=*/false,
44644521
/*scaler=*/TDigestMapOptions::Scaler::K0);
@@ -4470,74 +4527,139 @@ TEST(TestTDigestMapKernel, Options) {
44704527
TDigestMap(ArrayFromJSON(input_type, "[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]"), keep_nulls),
44714528
ResultWith(ScalarFromJSON(output_type,
44724529
"{\"mean\":[1.5, 3.5, 5.5, null, null],\"weight\":[2, 2, "
4473-
"2, null, null],\"count\":6}")));
4530+
"2, null, null],\"count\":6,\"min\":1.0,\"max\":6.0}")));
44744531
EXPECT_THAT(
44754532
TDigestMap(ArrayFromJSON(input_type, "[1.0, 2.0, 3.0, 4.0, 5.0]"), keep_nulls),
44764533
ResultWith(ScalarFromJSON(output_type,
44774534
"{\"mean\":[1.5, 3.5, 5.0, null, null],\"weight\":[2, 2, "
4478-
"1, null, null],\"count\":5}")));
4479-
EXPECT_THAT(TDigestMap(ArrayFromJSON(input_type, "[1.0, 2.0, 3.0, 4.0]"), keep_nulls),
4480-
ResultWith(ScalarFromJSON(output_type,
4481-
"{\"mean\":[1.0, 2.0, 3.0, 4.0, "
4482-
"null],\"weight\":[1,1,1,1,null],\"count\":4}")));
4483-
4484-
EXPECT_THAT(TDigestMap(ArrayFromJSON(input_type, "[1.0, 2.0, 3.0]"), keep_nulls),
4485-
ResultWith(ScalarFromJSON(output_type,
4486-
"{\"mean\":[1.0,2.0,3.0,null,null],\"weight\":[1,"
4487-
"1,1,null,null],\"count\":3}")));
4535+
"1, null, null],\"count\":5,\"min\":1.0,\"max\":5.0}")));
4536+
EXPECT_THAT(
4537+
TDigestMap(ArrayFromJSON(input_type, "[1.0, 2.0, 3.0, 4.0]"), keep_nulls),
4538+
ResultWith(ScalarFromJSON(
4539+
output_type,
4540+
"{\"mean\":[1.0, 2.0, 3.0, 4.0, "
4541+
"null],\"weight\":[1,1,1,1,null],\"count\":4,\"min\":1.0,\"max\":4.0}")));
4542+
4543+
EXPECT_THAT(
4544+
TDigestMap(ArrayFromJSON(input_type, "[1.0, 2.0, 3.0]"), keep_nulls),
4545+
ResultWith(ScalarFromJSON(output_type,
4546+
"{\"mean\":[1.0,2.0,3.0,null,null],\"weight\":[1,"
4547+
"1,1,null,null],\"count\":3,\"min\":1.0,\"max\":3.0}")));
44884548
EXPECT_THAT(TDigestMap(ArrayFromJSON(input_type, "[1.0, 2.0, 3.0, null]"), keep_nulls),
44894549
ResultWith(ScalarFromJSON(output_type, "null")));
44904550
EXPECT_THAT(TDigestMap(ScalarFromJSON(input_type, "1.0"), keep_nulls),
4491-
ResultWith(ScalarFromJSON(output_type,
4492-
"{\"mean\":[1.0,null,null,null,null],\"weight\":["
4493-
"1,null,null,null,null],\"count\":1}")));
4551+
ResultWith(ScalarFromJSON(
4552+
output_type,
4553+
"{\"mean\":[1.0,null,null,null,null],\"weight\":["
4554+
"1,null,null,null,null],\"count\":1,\"min\":1.0,\"max\":1.0}")));
44944555
EXPECT_THAT(TDigestMap(ScalarFromJSON(input_type, "null"), keep_nulls),
44954556
ResultWith(ScalarFromJSON(output_type, "null")));
44964557

4497-
EXPECT_THAT(TDigestMap(ArrayFromJSON(input_type, "[1.0, 2.0, 3.0, null]"), skip_nulls),
4498-
ResultWith(ScalarFromJSON(output_type,
4499-
"{\"mean\":[1.0,2.0,3.0,null,null],\"weight\":[1,"
4500-
"1,1,null,null],\"count\":3}")));
4558+
EXPECT_THAT(
4559+
TDigestMap(ArrayFromJSON(input_type, "[1.0, 2.0, 3.0, null]"), skip_nulls),
4560+
ResultWith(ScalarFromJSON(output_type,
4561+
"{\"mean\":[1.0,2.0,3.0,null,null],\"weight\":[1,"
4562+
"1,1,null,null],\"count\":3,\"min\":1.0,\"max\":3.0}")));
45014563
EXPECT_THAT(TDigestMap(ArrayFromJSON(input_type, "[1.0, 2.0, null]"), skip_nulls),
4502-
ResultWith(ScalarFromJSON(output_type,
4503-
"{\"mean\":[1.0,2.0,null,null,null],\"weight\":["
4504-
"1,1,null,null,null],\"count\":2}")));
4564+
ResultWith(ScalarFromJSON(
4565+
output_type,
4566+
"{\"mean\":[1.0,2.0,null,null,null],\"weight\":["
4567+
"1,1,null,null,null],\"count\":2,\"min\":1.0,\"max\":2.0}")));
45054568
EXPECT_THAT(TDigestMap(ScalarFromJSON(input_type, "1.0"), skip_nulls),
4506-
ResultWith(ScalarFromJSON(output_type,
4507-
"{\"mean\":[1.0,null,null,null,null],\"weight\":["
4508-
"1,null,null,null,null],\"count\":1}")));
4569+
ResultWith(ScalarFromJSON(
4570+
output_type,
4571+
"{\"mean\":[1.0,null,null,null,null],\"weight\":["
4572+
"1,null,null,null,null],\"count\":1,\"min\":1.0,\"max\":1.0}")));
45094573
EXPECT_THAT(TDigestMap(ScalarFromJSON(input_type, "null"), skip_nulls),
4510-
ResultWith(ScalarFromJSON(output_type,
4511-
"{\"mean\":[null,null,null,null,null],\"weight\":"
4512-
"[null,null,null,null,null],\"count\":0}")));
4574+
ResultWith(ScalarFromJSON(
4575+
output_type,
4576+
"{\"mean\":[null,null,null,null,null],\"weight\":"
4577+
"[null,null,null,null,null],\"count\":0,\"min\":null,\"max\":null}")));
45134578
}
45144579

45154580
TEST(TestTDigestReduceKernel, Basic) {
45164581
auto type = struct_({field("mean", fixed_size_list(float64(), 5), false),
45174582
field("weight", fixed_size_list(float64(), 5), false),
4518-
field("count", uint64(), false)});
4583+
field("count", uint64(), false), field("min", float64(), true),
4584+
field("max", float64(), true)});
45194585
TDigestReduceOptions options(/*scaler=*/TDigestMapOptions::Scaler::K0);
45204586
EXPECT_THAT(
4521-
TDigestReduce(ArrayFromJSON(type,
4522-
"["
4523-
"{\"mean\":[1.5, 3.5, 5.5, null, null],\"weight\":[2, "
4524-
"2, 2, null, null],\"count\":6},"
4525-
"{\"mean\":[1.5, 3.5, 5.5, null, null],\"weight\":[2, "
4526-
"2, 2, null, null],\"count\":6}"
4527-
"]"),
4528-
options),
4587+
TDigestReduce(
4588+
ArrayFromJSON(type,
4589+
"["
4590+
"{\"mean\":[1.5, 3.5, 5.5, null, null],\"weight\":[2, "
4591+
"2, 2, null, null],\"count\":6,\"min\":1.0,\"max\":6.0},"
4592+
"{\"mean\":[1.5, 3.5, 5.5, null, null],\"weight\":[2, "
4593+
"2, 2, null, null],\"count\":6,\"min\":1.0,\"max\":6.0}"
4594+
"]"),
4595+
options),
45294596
ResultWith(ScalarFromJSON(type,
4530-
"{\"mean\":[1.5, 1.5, 3.5, 3.5, 5.5],\"weight\":[2, 2, "
4531-
"2, 2, 2],\"count\":12}")));
4597+
"{\"mean\":[1.5, 3.5, 5.5, null, null],\"weight\":[4, 4, "
4598+
"4, null, null],\"count\":12,\"min\":1.0,\"max\":6.0}")));
45324599

45334600
EXPECT_THAT(
4534-
TDigestReduce(ScalarFromJSON(type,
4535-
"{\"mean\":[1.5, 3.5, 5.5, null, null],\"weight\":[2, "
4536-
"2, 2, null, null],\"count\":6}"),
4537-
options),
4601+
TDigestReduce(
4602+
ScalarFromJSON(type,
4603+
"{\"mean\":[1.5, 3.5, 5.5, null, null],\"weight\":[2, "
4604+
"2, 2, null, null],\"count\":6,\"min\":1.0,\"max\":6.0}"),
4605+
options),
45384606
ResultWith(ScalarFromJSON(type,
45394607
"{\"mean\":[1.5, 3.5, 5.5, null, null],\"weight\":[2, 2, "
4540-
"2, null, null],\"count\":6}")));
4608+
"2, null, null],\"count\":6,\"min\":1.0,\"max\":6.0}")));
4609+
}
4610+
4611+
TEST(TestTDigestQuantileKernel, Basic) {
4612+
auto input_type =
4613+
struct_({field("mean", fixed_size_list(float64(), 5), false),
4614+
field("weight", fixed_size_list(float64(), 5), false),
4615+
field("count", uint64(), false), field("min", float64(), true),
4616+
field("max", float64(), true)});
4617+
4618+
auto output_type = float64();
4619+
4620+
auto input_array =
4621+
ArrayFromJSON(input_type,
4622+
"["
4623+
"{\"mean\":[1.5, 3.5, 5.5, null, null],\"weight\":[2, "
4624+
"2, 2, null, null],\"count\":6,\"min\":1.0,\"max\":6.0},"
4625+
"{\"mean\":[1.5, 3.5, 5.5, null, null],\"weight\":[2, "
4626+
"2, 2, null, null],\"count\":6,\"min\":1.0,\"max\":6.0}"
4627+
"]");
4628+
4629+
TDigestQuantileOptions multiple(/*q=*/{0.1, 0.5, 0.9}, /*min_count=*/12);
4630+
TDigestQuantileOptions min_count(/*q=*/0.5, /*min_count=*/13);
4631+
4632+
EXPECT_THAT(TDigestQuantile(input_array, multiple),
4633+
ResultWith(ArrayFromJSON(output_type, "[1.5666666666666667, 3.5, 5.5]")));
4634+
EXPECT_THAT(TDigestQuantile(input_array, min_count),
4635+
ResultWith(ArrayFromJSON(output_type, "[null]")));
4636+
}
4637+
4638+
TEST(TestTDigestMapReduceQuantileKernel, Basic) {
4639+
auto input_type =
4640+
struct_({field("mean", fixed_size_list(float64(), 5), false),
4641+
field("weight", fixed_size_list(float64(), 5), false),
4642+
field("count", uint64(), false), field("min", float64(), true),
4643+
field("max", float64(), true)});
4644+
4645+
auto output_type = float64();
4646+
4647+
auto input_array =
4648+
ArrayFromJSON(input_type,
4649+
"["
4650+
"{\"mean\":[1.5, 3.5, 5.5, null, null],\"weight\":[2, "
4651+
"2, 2, null, null],\"count\":6,\"min\":1.0,\"max\":6.0},"
4652+
"{\"mean\":[1.5, 3.5, 5.5, null, null],\"weight\":[2, "
4653+
"2, 2, null, null],\"count\":6,\"min\":1.0,\"max\":6.0}"
4654+
"]");
4655+
4656+
TDigestQuantileOptions multiple(/*q=*/{0.1, 0.5, 0.9}, /*min_count=*/12);
4657+
TDigestQuantileOptions min_count(/*q=*/0.5, /*min_count=*/13);
4658+
4659+
EXPECT_THAT(TDigestQuantile(input_array, multiple),
4660+
ResultWith(ArrayFromJSON(output_type, "[1.5666666666666667, 3.5, 5.5]")));
4661+
EXPECT_THAT(TDigestQuantile(input_array, min_count),
4662+
ResultWith(ArrayFromJSON(output_type, "[null]")));
45414663
}
45424664

45434665
TEST(TestTDigestKernel, ApproximateMedian) {

cpp/src/arrow/util/tdigest.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,9 @@ class TDigest::TDigestImpl {
234234

235235
// merge input data with current tdigest
236236
void MergeInput(std::vector<std::pair<double, double>>& input) {
237-
total_weight_ += input.size();
237+
for (const auto& i : input) {
238+
total_weight_ += i.second;
239+
}
238240

239241
std::sort(input.begin(), input.end(),
240242
[](const std::pair<double, double>& lhs,
@@ -341,6 +343,10 @@ class TDigest::TDigestImpl {
341343
}
342344
return total_weight_ == 0 ? NAN : sum / total_weight_;
343345
}
346+
void SetMinMax(double min, double max) {
347+
min_ = std::min(min_, min);
348+
max_ = std::max(max_, max);
349+
}
344350

345351
double total_weight() const { return total_weight_; }
346352

@@ -413,6 +419,8 @@ std::optional<std::pair<double, double>> TDigest::GetCentroid(size_t i) const {
413419
return impl_->GetCentroid(i);
414420
}
415421

422+
void TDigest::SetMinMax(double min, double max) { impl_->SetMinMax(min, max); }
423+
416424
double TDigest::Mean() const {
417425
MergeInput();
418426
return impl_->Mean();

cpp/src/arrow/util/tdigest_internal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class ARROW_EXPORT TDigest {
107107
double Quantile(double q) const;
108108
std::optional<std::pair<double, double>> GetCentroid(size_t i) const;
109109

110+
void SetMinMax(double min, double max);
110111
double Min() const { return Quantile(0); }
111112
double Max() const { return Quantile(1); }
112113
double Mean() const;

0 commit comments

Comments
 (0)