Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions src/commands/cmd_tdigest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,57 @@ class CommandTDigestMerge : public Commander {
TDigestMergeOptions options_;
};

class CommandTDigestTrimmedMean : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
if (args.size() != 4) {
return {Status::RedisParseErr, errWrongNumOfArguments};
}

key_name_ = args[1];

auto low_cut_quantile = ParseFloat(args[2]);
if (!low_cut_quantile) {
return {Status::RedisParseErr, errValueIsNotFloat};
}
low_cut_quantile_ = *low_cut_quantile;

auto high_cut_quantile = ParseFloat(args[3]);
if (!high_cut_quantile) {
return {Status::RedisParseErr, errValueIsNotFloat};
}
high_cut_quantile_ = *high_cut_quantile;

return Status::OK();
}

Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override {
TDigest tdigest(srv->storage, conn->GetNamespace());
TDigestTrimmedMeanResult result;

auto s = tdigest.TrimmedMean(ctx, key_name_, low_cut_quantile_, high_cut_quantile_, &result);
if (!s.ok()) {
if (s.IsNotFound()) {
return {Status::RedisExecErr, errKeyNotFound};
}
return {Status::RedisExecErr, s.ToString()};
}

if (!result.mean.has_value()) {
*output = redis::BulkString(kNan);
} else {
*output = redis::BulkString(util::Float2String(*result.mean));
}

return Status::OK();
}

private:
std::string key_name_;
double low_cut_quantile_;
double high_cut_quantile_;
};

std::vector<CommandKeyRange> GetMergeKeyRange(const std::vector<std::string> &args) {
auto numkeys = ParseInt<int>(args[2], 10).ValueOr(0);
return {{1, 1, 1}, {3, 2 + numkeys, 1}};
Expand All @@ -507,6 +558,7 @@ REDIS_REGISTER_COMMANDS(TDigest, MakeCmdAttr<CommandTDigestCreate>("tdigest.crea
MakeCmdAttr<CommandTDigestByRevRank>("tdigest.byrevrank", -3, "read-only", 1, 1, 1),
MakeCmdAttr<CommandTDigestByRank>("tdigest.byrank", -3, "read-only", 1, 1, 1),
MakeCmdAttr<CommandTDigestQuantile>("tdigest.quantile", -3, "read-only", 1, 1, 1),
MakeCmdAttr<CommandTDigestTrimmedMean>("tdigest.trimmed_mean", 4, "read-only", 1, 1, 1),
MakeCmdAttr<CommandTDigestReset>("tdigest.reset", 2, "write", 1, 1, 1),
MakeCmdAttr<CommandTDigestMerge>("tdigest.merge", -4, "write", GetMergeKeyRange));
} // namespace redis
35 changes: 35 additions & 0 deletions src/types/redis_tdigest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,41 @@ rocksdb::Status TDigest::applyNewCentroids(ObserverOrUniquePtr<rocksdb::WriteBat
return rocksdb::Status::OK();
}

rocksdb::Status TDigest::TrimmedMean(engine::Context& ctx, const Slice& digest_name, double low_cut_quantile,
double high_cut_quantile, TDigestTrimmedMeanResult* result) {
auto ns_key = AppendNamespacePrefix(digest_name);
TDigestMetadata metadata;

{
LockGuard guard(storage_->GetLockManager(), ns_key);
if (auto status = getMetaDataByNsKey(ctx, ns_key, &metadata); !status.ok()) {
return status;
}

if (metadata.total_observations == 0) {
return rocksdb::Status::OK();
}

if (auto status = mergeNodes(ctx, ns_key, &metadata); !status.ok()) {
return status;
}
}

// Dump centroids and create DummyCentroids wrapper for TDigest algorithm
std::vector<Centroid> centroids;
if (auto status = dumpCentroids(ctx, ns_key, metadata, &centroids); !status.ok()) {
return status;
}
auto dump_centroids = DummyCentroids(metadata, centroids);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @chakkk309 ,

It seems that this line has compile error in CI. Please make a check.

auto trimmed_mean_result = TDigestTrimmedMean(dump_centroids, low_cut_quantile, high_cut_quantile);
if (!trimmed_mean_result) {
return rocksdb::Status::InvalidArgument(trimmed_mean_result.Msg());
}

result->mean = *trimmed_mean_result;
return rocksdb::Status::OK();
}

std::string TDigest::internalSegmentGuardPrefixKey(const TDigestMetadata& metadata, const std::string& ns_key,
SegmentType seg) const {
std::string prefix_key;
Expand Down
6 changes: 6 additions & 0 deletions src/types/redis_tdigest.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ struct TDigestQuantitleResult {
std::optional<std::vector<double>> quantiles;
};

struct TDigestTrimmedMeanResult {
std::optional<double> mean;
};

class TDigest : public SubKeyScanner {
public:
using Slice = rocksdb::Slice;
Expand Down Expand Up @@ -85,6 +89,8 @@ class TDigest : public SubKeyScanner {
std::vector<double>* result);
rocksdb::Status ByRank(engine::Context& ctx, const Slice& digest_name, const std::vector<int>& inputs,
std::vector<double>* result);
rocksdb::Status TrimmedMean(engine::Context& ctx, const Slice& digest_name, double low_cut_quantile,
double high_cut_quantile, TDigestTrimmedMeanResult* result);
rocksdb::Status GetMetaData(engine::Context& context, const Slice& digest_name, TDigestMetadata* metadata);

private:
Expand Down
78 changes: 78 additions & 0 deletions src/types/tdigest.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,81 @@ inline Status TDigestRank(TD&& td, const std::vector<double>& inputs, std::vecto
}
return Status::OK();
}

template <typename TD>
inline Status TDigestRank(TD&& td, const std::vector<double>& inputs, bool reverse, std::vector<int>& result) {
if (reverse) {
return TDigestRankImpl<TD, true>(std::forward<TD>(td), inputs, result);
} else {
return TDigestRankImpl<TD, false>(std::forward<TD>(td), inputs, result);
}
}

template <typename TD>
inline StatusOr<double> TDigestTrimmedMean(TD&& td, double low_cut_quantile, double high_cut_quantile) {
if (td.Size() == 0) {
return Status{Status::InvalidArgument, "empty tdigest"};
}

// Validate quantile parameters
if (low_cut_quantile < 0.0 || low_cut_quantile > 1.0) {
return Status{Status::InvalidArgument, "low cut quantile must be between 0 and 1"};
}
if (high_cut_quantile < 0.0 || high_cut_quantile > 1.0) {
return Status{Status::InvalidArgument, "high cut quantile must be between 0 and 1"};
}
if (low_cut_quantile >= high_cut_quantile) {
return Status{Status::InvalidArgument, "low cut quantile must be less than high cut quantile"};
}

// Get boundary values for trimming
double low_boundary;
double high_boundary;

// For 0 and 1 quantiles, use exact min/max values
if (low_cut_quantile == 0.0) {
low_boundary = td.Min();
} else {
auto low_result = TDigestQuantile(std::forward<TD>(td), low_cut_quantile);
if (!low_result) {
return low_result;
}
low_boundary = *low_result;
}

if (high_cut_quantile == 1.0) {
high_boundary = td.Max();
} else {
auto high_result = TDigestQuantile(std::forward<TD>(td), high_cut_quantile);
if (!high_result) {
return high_result;
}
high_boundary = *high_result;
}

// Calculate trimmed mean by iterating through centroids
auto iter = td.Begin();
double total_weight_in_range = 0;
double weighted_sum = 0;

while (iter->Valid()) {
auto centroid = GET_OR_RET(iter->GetCentroid());

// Check if centroid falls within the trimmed range
// For full range (0 to 1), include all centroids
if ((low_cut_quantile == 0.0 && high_cut_quantile == 1.0) ||
(centroid.mean >= low_boundary && centroid.mean <= high_boundary)) {
total_weight_in_range += centroid.weight;
weighted_sum += centroid.mean * centroid.weight;
}

iter->Next();
}

// Check if we have any data in the trimmed range
if (total_weight_in_range == 0) {
return std::numeric_limits<double>::quiet_NaN();
}

return weighted_sum / total_weight_in_range;
}
63 changes: 63 additions & 0 deletions tests/gocase/unit/type/tdigest/tdigest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,69 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) {
require.EqualValues(t, expected[i], rank, "REVRANK mismatch at index %d", i)
}
})

t.Run("tdigest.trimmed_mean with different arguments", func(t *testing.T) {
keyPrefix := "tdigest_trimmed_mean_"

// Test invalid arguments
require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN").Err(), errMsgWrongNumberArg)
require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", keyPrefix+"key").Err(), errMsgWrongNumberArg)
require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", keyPrefix+"key", "0.1").Err(), errMsgWrongNumberArg)

// Test non-existent key
require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", keyPrefix+"nonexistent", "0.1", "0.9").Err(), errMsgKeyNotExist)

// Test with empty tdigest
emptyKey := keyPrefix + "empty"
require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", emptyKey, "compression", "100").Err())
rsp := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", emptyKey, "0.1", "0.9")
require.NoError(t, rsp.Err())
result, err := rsp.Result()
require.NoError(t, err)
require.Equal(t, "nan", result)

// Test with sample data
key1 := keyPrefix + "test1"
require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key1, "compression", "100").Err())
require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key1, "1", "2", "3", "4", "5", "6", "7", "8", "9", "10").Err())

// Test trimmed mean with trimming
rsp = rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "0.1", "0.9")
require.NoError(t, rsp.Err())
result, err = rsp.Result()
require.NoError(t, err)
mean, err := strconv.ParseFloat(result.(string), 64)
require.NoError(t, err)
require.InDelta(t, 5.5, mean, 1.0)

// Test with no trimming
rsp = rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "0", "1")
require.NoError(t, rsp.Err())
result, err = rsp.Result()
require.NoError(t, err)
mean, err = strconv.ParseFloat(result.(string), 64)
require.NoError(t, err)
require.InDelta(t, 5.5, mean, 0.1)

// Test with invalid quantile ranges
require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "-0.1", "0.9").Err(), "low cut quantile must be between 0 and 1")
require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "0.1", "1.1").Err(), "high cut quantile must be between 0 and 1")
require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "0.9", "0.1").Err(), "low cut quantile must be less than high cut quantile")

// Test with skewed data
key2 := keyPrefix + "skewed"
require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key2, "compression", "100").Err())
require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key2, "1", "1", "1", "1", "1", "10", "100").Err())

// Test trimming with outliers
rsp = rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key2, "0.2", "0.8")
require.NoError(t, rsp.Err())
result, err = rsp.Result()
require.NoError(t, err)
mean, err = strconv.ParseFloat(result.(string), 64)
require.NoError(t, err)
require.Less(t, mean, 50.0)
})
}

func TestTDigestByRankAndByRevRank(t *testing.T) {
Expand Down
Loading