From 6a2198fe22f8cb4e7e61aed0287571e4c34b481d Mon Sep 17 00:00:00 2001 From: Jackie Yan Date: Thu, 25 Dec 2025 22:38:43 +0800 Subject: [PATCH] feat(tdigest): implement TDIGEST.TRIMMED_MEAN command --- src/commands/cmd_tdigest.cc | 52 ++++++++++++++ src/types/redis_tdigest.cc | 35 ++++++++++ src/types/redis_tdigest.h | 6 ++ src/types/tdigest.h | 70 +++++++++++++++++++ .../gocase/unit/type/tdigest/tdigest_test.go | 63 +++++++++++++++++ 5 files changed, 226 insertions(+) diff --git a/src/commands/cmd_tdigest.cc b/src/commands/cmd_tdigest.cc index 64dfafcd7e8..12065574f3e 100644 --- a/src/commands/cmd_tdigest.cc +++ b/src/commands/cmd_tdigest.cc @@ -412,6 +412,57 @@ class CommandTDigestMerge : public Commander { TDigestMergeOptions options_; }; +class CommandTDigestTrimmedMean : public Commander { + public: + Status Parse(const std::vector &args) override { + if (args.size() != 4) { + return {Status::RedisParseErr, errWrongNumOfArguments}; + } + + key_name_ = args[1]; + + auto low_cut_quantile = ParseFloat(args[2]); + if (!low_cut_quantile) { + return {Status::RedisParseErr, errValueIsNotFloat}; + } + low_cut_quantile_ = *low_cut_quantile; + + auto high_cut_quantile = ParseFloat(args[3]); + if (!high_cut_quantile) { + return {Status::RedisParseErr, errValueIsNotFloat}; + } + high_cut_quantile_ = *high_cut_quantile; + + return Status::OK(); + } + + Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { + TDigest tdigest(srv->storage, conn->GetNamespace()); + TDigestTrimmedMeanResult result; + + auto s = tdigest.TrimmedMean(ctx, key_name_, low_cut_quantile_, high_cut_quantile_, &result); + if (!s.ok()) { + if (s.IsNotFound()) { + return {Status::RedisExecErr, errKeyNotFound}; + } + return {Status::RedisExecErr, s.ToString()}; + } + + if (!result.mean.has_value()) { + *output = redis::BulkString(kNan); + } else { + *output = redis::BulkString(util::Float2String(*result.mean)); + } + + return Status::OK(); + } + + private: + std::string key_name_; + double low_cut_quantile_; + double high_cut_quantile_; +}; + std::vector GetMergeKeyRange(const std::vector &args) { auto numkeys = ParseInt(args[2], 10).ValueOr(0); return {{1, 1, 1}, {3, 2 + numkeys, 1}}; @@ -425,6 +476,7 @@ REDIS_REGISTER_COMMANDS(TDigest, MakeCmdAttr("tdigest.crea MakeCmdAttr("tdigest.revrank", -3, "read-only", 1, 1, 1), MakeCmdAttr("tdigest.rank", -3, "read-only", 1, 1, 1), MakeCmdAttr("tdigest.quantile", -3, "read-only", 1, 1, 1), + MakeCmdAttr("tdigest.trimmed_mean", 4, "read-only", 1, 1, 1), MakeCmdAttr("tdigest.reset", 2, "write", 1, 1, 1), MakeCmdAttr("tdigest.merge", -4, "write", GetMergeKeyRange)); } // namespace redis diff --git a/src/types/redis_tdigest.cc b/src/types/redis_tdigest.cc index fec7aef1c9b..51b14c0e750 100644 --- a/src/types/redis_tdigest.cc +++ b/src/types/redis_tdigest.cc @@ -725,6 +725,41 @@ rocksdb::Status TDigest::applyNewCentroids(ObserverOrUniquePtrGetLockManager(), ns_key); + if (auto status = getMetaDataByNsKey(ctx, ns_key, &metadata); !status.ok()) { + return status; + } + + if (metadata.total_observations == 0) { + return rocksdb::Status::OK(); + } + + if (auto status = mergeNodes(ctx, ns_key, &metadata); !status.ok()) { + return status; + } + } + + // Dump centroids and create DummyCentroids wrapper for TDigest algorithm + std::vector centroids; + if (auto status = dumpCentroids(ctx, ns_key, metadata, ¢roids); !status.ok()) { + return status; + } + auto dump_centroids = DummyCentroids(metadata, centroids); + auto trimmed_mean_result = TDigestTrimmedMean(dump_centroids, low_cut_quantile, high_cut_quantile); + if (!trimmed_mean_result) { + return rocksdb::Status::InvalidArgument(trimmed_mean_result.Msg()); + } + + result->mean = *trimmed_mean_result; + return rocksdb::Status::OK(); +} + std::string TDigest::internalSegmentGuardPrefixKey(const TDigestMetadata& metadata, const std::string& ns_key, SegmentType seg) const { std::string prefix_key; diff --git a/src/types/redis_tdigest.h b/src/types/redis_tdigest.h index 5daaed80c81..81917d6b8ea 100644 --- a/src/types/redis_tdigest.h +++ b/src/types/redis_tdigest.h @@ -53,6 +53,10 @@ struct TDigestQuantitleResult { std::optional> quantiles; }; +struct TDigestTrimmedMeanResult { + std::optional mean; +}; + class TDigest : public SubKeyScanner { public: using Slice = rocksdb::Slice; @@ -79,6 +83,8 @@ class TDigest : public SubKeyScanner { const TDigestMergeOptions& options); rocksdb::Status Rank(engine::Context& ctx, const Slice& digest_name, const std::vector& inputs, bool reverse, std::vector& result); + rocksdb::Status TrimmedMean(engine::Context& ctx, const Slice& digest_name, double low_cut_quantile, + double high_cut_quantile, TDigestTrimmedMeanResult* result); rocksdb::Status GetMetaData(engine::Context& context, const Slice& digest_name, TDigestMetadata* metadata); private: diff --git a/src/types/tdigest.h b/src/types/tdigest.h index d77b673f7a8..8ec68fc84e3 100644 --- a/src/types/tdigest.h +++ b/src/types/tdigest.h @@ -276,3 +276,73 @@ inline Status TDigestRank(TD&& td, const std::vector& inputs, bool rever return TDigestRankImpl(std::forward(td), inputs, result); } } + + +template +inline StatusOr TDigestTrimmedMean(TD&& td, double low_cut_quantile, double high_cut_quantile) { + if (td.Size() == 0) { + return Status{Status::InvalidArgument, "empty tdigest"}; + } + + // Validate quantile parameters + if (low_cut_quantile < 0.0 || low_cut_quantile > 1.0) { + return Status{Status::InvalidArgument, "low cut quantile must be between 0 and 1"}; + } + if (high_cut_quantile < 0.0 || high_cut_quantile > 1.0) { + return Status{Status::InvalidArgument, "high cut quantile must be between 0 and 1"}; + } + if (low_cut_quantile >= high_cut_quantile) { + return Status{Status::InvalidArgument, "low cut quantile must be less than high cut quantile"}; + } + + // Get boundary values for trimming + double low_boundary; + double high_boundary; + + // For 0 and 1 quantiles, use exact min/max values + if (low_cut_quantile == 0.0) { + low_boundary = td.Min(); + } else { + auto low_result = TDigestQuantile(std::forward(td), low_cut_quantile); + if (!low_result) { + return low_result; + } + low_boundary = *low_result; + } + + if (high_cut_quantile == 1.0) { + high_boundary = td.Max(); + } else { + auto high_result = TDigestQuantile(std::forward(td), high_cut_quantile); + if (!high_result) { + return high_result; + } + high_boundary = *high_result; + } + + // Calculate trimmed mean by iterating through centroids + auto iter = td.Begin(); + double total_weight_in_range = 0; + double weighted_sum = 0; + + while (iter->Valid()) { + auto centroid = GET_OR_RET(iter->GetCentroid()); + + // Check if centroid falls within the trimmed range + // For full range (0 to 1), include all centroids + if ((low_cut_quantile == 0.0 && high_cut_quantile == 1.0) || + (centroid.mean >= low_boundary && centroid.mean <= high_boundary)) { + total_weight_in_range += centroid.weight; + weighted_sum += centroid.mean * centroid.weight; + } + + iter->Next(); + } + + // Check if we have any data in the trimmed range + if (total_weight_in_range == 0) { + return std::numeric_limits::quiet_NaN(); + } + + return weighted_sum / total_weight_in_range; +} diff --git a/tests/gocase/unit/type/tdigest/tdigest_test.go b/tests/gocase/unit/type/tdigest/tdigest_test.go index 335ee0eff75..472d530dbe1 100644 --- a/tests/gocase/unit/type/tdigest/tdigest_test.go +++ b/tests/gocase/unit/type/tdigest/tdigest_test.go @@ -716,4 +716,67 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) { require.EqualValues(t, expected[i], rank, "REVRANK mismatch at index %d", i) } }) + + t.Run("tdigest.trimmed_mean with different arguments", func(t *testing.T) { + keyPrefix := "tdigest_trimmed_mean_" + + // Test invalid arguments + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN").Err(), errMsgWrongNumberArg) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", keyPrefix+"key").Err(), errMsgWrongNumberArg) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", keyPrefix+"key", "0.1").Err(), errMsgWrongNumberArg) + + // Test non-existent key + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", keyPrefix+"nonexistent", "0.1", "0.9").Err(), errMsgKeyNotExist) + + // Test with empty tdigest + emptyKey := keyPrefix + "empty" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", emptyKey, "compression", "100").Err()) + rsp := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", emptyKey, "0.1", "0.9") + require.NoError(t, rsp.Err()) + result, err := rsp.Result() + require.NoError(t, err) + require.Equal(t, "nan", result) + + // Test with sample data + key1 := keyPrefix + "test1" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key1, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key1, "1", "2", "3", "4", "5", "6", "7", "8", "9", "10").Err()) + + // Test trimmed mean with trimming + rsp = rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "0.1", "0.9") + require.NoError(t, rsp.Err()) + result, err = rsp.Result() + require.NoError(t, err) + mean, err := strconv.ParseFloat(result.(string), 64) + require.NoError(t, err) + require.InDelta(t, 5.5, mean, 1.0) + + // Test with no trimming + rsp = rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "0", "1") + require.NoError(t, rsp.Err()) + result, err = rsp.Result() + require.NoError(t, err) + mean, err = strconv.ParseFloat(result.(string), 64) + require.NoError(t, err) + require.InDelta(t, 5.5, mean, 0.1) + + // Test with invalid quantile ranges + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "-0.1", "0.9").Err(), "low cut quantile must be between 0 and 1") + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "0.1", "1.1").Err(), "high cut quantile must be between 0 and 1") + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "0.9", "0.1").Err(), "low cut quantile must be less than high cut quantile") + + // Test with skewed data + key2 := keyPrefix + "skewed" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key2, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key2, "1", "1", "1", "1", "1", "10", "100").Err()) + + // Test trimming with outliers + rsp = rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key2, "0.2", "0.8") + require.NoError(t, rsp.Err()) + result, err = rsp.Result() + require.NoError(t, err) + mean, err = strconv.ParseFloat(result.(string), 64) + require.NoError(t, err) + require.Less(t, mean, 50.0) + }) }