Skip to content

Commit 1180868

Browse files
authored
[Data] Support List Types for Unique Aggregator and encode_lists flag (#58916)
## Description Basically the same idea as #58659 So `Unique` aggregator uses `pyarrow.compute.unique` function internally. This doesn't work with non-hashable types like lists. Similar to what I did for `ApproximateTopK`, we now use pickle to serialize and deserialize elements. Other improvements: - `ignore_nulls` flag didn't work at all. This flag now properly works - Had to force `ignore_nulls=False` for datasets `unique` api for backwards compatibility (we set `ignore_nulls` to `True` by default, so behavior for datasets `unique` api will change now that `ignore_nulls` actually works) ## Related issues This PR replaces #58538 ## Additional information [Design doc on my notion](https://www.notion.so/kyuds/Unique-Aggregator-Improvements-2b67a80e48eb80de9820edf9d4996e0a?source=copy_link) --------- Signed-off-by: Daniel Shin <kyuseung1016@gmail.com> Signed-off-by: kyuds <kyuseung1016@gmail.com>
1 parent 456d190 commit 1180868

File tree

4 files changed

+98
-7
lines changed

4 files changed

+98
-7
lines changed

python/ray/data/aggregate.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818

1919
import numpy as np
20+
import pyarrow as pa
2021
import pyarrow.compute as pc
2122

2223
from ray.data._internal.util import is_null
@@ -935,35 +936,60 @@ class Unique(AggregateFnV2[Set[Any], List[Any]]):
935936
ignore_nulls: Whether to ignore null values when collecting unique items.
936937
Default is True (nulls are excluded).
937938
alias_name: Optional name for the resulting column.
939+
encode_lists: If `True`, encode list elements. If `False`, encode
940+
whole lists (i.e., the entire list is considered as a single object).
941+
`False` by default. Note that this is a top-level flatten (not a recursive
942+
flatten) operation.
938943
"""
939944

940945
def __init__(
941946
self,
942947
on: Optional[str] = None,
943948
ignore_nulls: bool = True,
944949
alias_name: Optional[str] = None,
950+
encode_lists: bool = False,
945951
):
946952
super().__init__(
947953
alias_name if alias_name else f"unique({str(on)})",
948954
on=on,
949955
ignore_nulls=ignore_nulls,
950956
zero_factory=set,
951957
)
958+
self._encode_lists = encode_lists
952959

953960
def combine(self, current_accumulator: Set[Any], new: Set[Any]) -> Set[Any]:
954961
return self._to_set(current_accumulator) | self._to_set(new)
955962

956963
def aggregate_block(self, block: Block) -> List[Any]:
957-
import pyarrow.compute as pac
958-
959964
col = BlockAccessor.for_block(block).to_arrow().column(self._target_col_name)
960-
return pac.unique(col).to_pylist()
965+
if pa.types.is_list(col.type):
966+
if self._encode_lists:
967+
col = pc.list_flatten(col)
968+
else:
969+
# pyarrow doesn't natively support calculating unique over
970+
# list-like objects (ie: lists, tuples). Using pandas seem to be
971+
# much more efficient than doing something like json dump/load or
972+
# pickle dump/load.
973+
series = BlockAccessor.for_block(block).to_pandas()[
974+
self._target_col_name
975+
]
976+
series = series.map(lambda x: None if x is None else tuple(x))
977+
if self._ignore_nulls:
978+
series = series.dropna()
979+
return list(series.unique())
980+
if self._ignore_nulls:
981+
col = pc.drop_null(col)
982+
return pc.unique(col).to_pylist()
961983

962984
@staticmethod
963985
def _to_set(x):
964986
if isinstance(x, set):
965987
return x
966988
elif isinstance(x, list):
989+
if len(x) > 0 and isinstance(x[0], list):
990+
# necessary because pyarrow converts all tuples to
991+
# list internally.
992+
x = map(lambda v: None if v is None else tuple(v), x)
967993
return set(x)
968994
else:
969995
return {x}

python/ray/data/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2963,7 +2963,7 @@ def unique(self, column: str) -> List[Any]:
29632963
29642964
>>> import ray
29652965
>>> ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
2966-
>>> ds.unique("target")
2966+
>>> sorted(ds.unique("target"))
29672967
[0, 1, 2]
29682968
29692969
One common use case is to convert the class labels
@@ -2986,7 +2986,7 @@ def unique(self, column: str) -> List[Any]:
29862986
Returns:
29872987
A list with unique elements in the given column.
29882988
""" # noqa: E501
2989-
ret = self._aggregate_on(Unique, column)
2989+
ret = self._aggregate_on(Unique, column, ignore_nulls=False)
29902990
return self._aggregate_result(ret)
29912991

29922992
@AllToAllAPI

python/ray/data/tests/test_custom_agg.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections import Counter
2+
13
import numpy as np
24
import pytest
35

@@ -6,6 +8,7 @@
68
ApproximateQuantile,
79
ApproximateTopK,
810
MissingValuePercentage,
11+
Unique,
912
ZeroPercentage,
1013
)
1114
from ray.data.tests.conftest import * # noqa
@@ -496,6 +499,68 @@ def test_approximate_topk_encode_lists(self, ray_start_regular_shared_2_cpus):
496499
assert result["approx_topk(id)"][2] == {"id": 3, "count": 1}
497500

498501

502+
class TestUnique:
503+
"""Test cases for Unique aggregation."""
504+
505+
def test_unique_basic(self, ray_start_regular_shared_2_cpus):
506+
"""Test basic Unique aggregation."""
507+
data = [{"id": "a"}, {"id": "b"}, {"id": "b"}, {"id": None}]
508+
ds = ray.data.from_items(data)
509+
result = ds.aggregate(Unique(on="id", ignore_nulls=False))
510+
511+
answer = ["a", "b", None]
512+
513+
assert Counter(result["unique(id)"]) == Counter(answer)
514+
515+
def test_unique_ignores_nulls(self, ray_start_regular_shared_2_cpus):
516+
"""Test Unique properly ignores nulls."""
517+
data = [{"id": "a"}, {"id": None}, {"id": "b"}, {"id": "b"}, {"id": None}]
518+
ds = ray.data.from_items(data)
519+
result = ds.aggregate(Unique(on="id"))
520+
521+
assert sorted(result["unique(id)"]) == ["a", "b"]
522+
523+
def test_unique_custom_alias(self, ray_start_regular_shared_2_cpus):
524+
"""Test Unique with custom alias."""
525+
data = [{"id": "a"}, {"id": "b"}, {"id": "b"}]
526+
ds = ray.data.from_items(data)
527+
result = ds.aggregate(Unique(on="id", alias_name="custom"))
528+
529+
assert sorted(result["custom"]) == ["a", "b"]
530+
531+
def test_unique_list_datatype(self, ray_start_regular_shared_2_cpus):
532+
"""Test Unique works with non-hashable types like list."""
533+
data = [
534+
{"id": ["a", "b", "c"]},
535+
{"id": ["a", "b", "c"]},
536+
{"id": ["a", "b", "c"]},
537+
]
538+
ds = ray.data.from_items(data)
539+
result = ds.aggregate(Unique(on="id"))
540+
541+
assert result["unique(id)"][0] == ["a", "b", "c"]
542+
543+
def test_unique_encode_lists(self, ray_start_regular_shared_2_cpus):
544+
"""Test Unique works when encode_lists is True."""
545+
data = [{"id": ["a", "b", "c"]}, {"id": ["a", "a", "a", "b", None]}]
546+
ds = ray.data.from_items(data)
547+
result = ds.aggregate(Unique(on="id", encode_lists=True, ignore_nulls=False))
548+
549+
answer = ["a", "b", "c", None]
550+
551+
assert Counter(result["unique(id)"]) == Counter(answer)
552+
553+
def test_unique_encode_lists_ignores_nulls(self, ray_start_regular_shared_2_cpus):
554+
"""Test Unique will drop null values when encode_lists is True."""
555+
data = [{"id": ["a", "b", "c"]}, {"id": ["a", "a", "a", "b", None]}]
556+
ds = ray.data.from_items(data)
557+
result = ds.aggregate(Unique(on="id", encode_lists=True))
558+
559+
answer = ["a", "b", "c"]
560+
561+
assert Counter(result["unique(id)"]) == Counter(answer)
562+
563+
499564
if __name__ == "__main__":
500565
import sys
501566

python/ray/data/tests/test_groupby_e2e.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ def test_groupby_multi_agg_with_nans(
644644
Mean("B", alias_name="mean_b", ignore_nulls=ignore_nulls),
645645
Std("B", alias_name="std_b", ignore_nulls=ignore_nulls),
646646
Quantile("B", alias_name="quantile_b", ignore_nulls=ignore_nulls),
647-
Unique("B", alias_name="unique_b"),
647+
Unique("B", alias_name="unique_b", ignore_nulls=False),
648648
)
649649
)
650650

@@ -751,7 +751,7 @@ def test_groupby_aggregations_are_associative(
751751
Mean("B", alias_name="mean_b", ignore_nulls=ignore_nulls),
752752
Std("B", alias_name="std_b", ignore_nulls=ignore_nulls),
753753
Quantile("B", alias_name="quantile_b", ignore_nulls=ignore_nulls),
754-
Unique("B", alias_name="unique_b"),
754+
Unique("B", alias_name="unique_b", ignore_nulls=False),
755755
]
756756

757757
# Step 0: Prepare expected output (using Pandas)

0 commit comments

Comments
 (0)