Skip to content

Commit 390d753

Browse files
committed
Update unit test to pass back either pyarrow array or array wrapped as scalar
1 parent 33f0b7f commit 390d753

File tree

1 file changed

+37
-14
lines changed

1 file changed

+37
-14
lines changed

python/tests/test_udaf.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,14 @@ def state(self) -> list[pa.Scalar]:
6161

6262

6363
class CollectTimestamps(Accumulator):
64-
def __init__(self):
64+
def __init__(self, wrap_in_scalar: bool):
6565
self._values: list[datetime] = []
66+
self.wrap_in_scalar = wrap_in_scalar
6667

6768
def state(self) -> list[pa.Scalar]:
68-
return [pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))]
69+
if self.wrap_in_scalar:
70+
return [pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))]
71+
return [pa.array(self._values, type=pa.timestamp("ns"))]
6972

7073
def update(self, values: pa.Array) -> None:
7174
self._values.extend(values.to_pylist())
@@ -76,7 +79,9 @@ def merge(self, states: list[pa.Array]) -> None:
7679
self._values.extend(state)
7780

7881
def evaluate(self) -> pa.Scalar:
79-
return pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))
82+
if self.wrap_in_scalar:
83+
return pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))
84+
return pa.array(self._values, type=pa.timestamp("ns"))
8085

8186

8287
@pytest.fixture
@@ -240,28 +245,46 @@ def test_register_udaf(ctx, df) -> None:
240245
assert df_result.collect()[0][0][0].as_py() == 14.0
241246

242247

243-
def test_udaf_list_timestamp_return(ctx) -> None:
244-
timestamps = [
248+
@pytest.mark.parametrize("wrap_in_scalar", [True, False])
249+
def test_udaf_list_timestamp_return(ctx, wrap_in_scalar) -> None:
250+
timestamps1 = [
245251
datetime(2024, 1, 1, tzinfo=timezone.utc),
246252
datetime(2024, 1, 2, tzinfo=timezone.utc),
247253
]
248-
batch = pa.RecordBatch.from_arrays(
249-
[pa.array(timestamps, type=pa.timestamp("ns"))],
254+
timestamps2 = [
255+
datetime(2024, 1, 3, tzinfo=timezone.utc),
256+
datetime(2024, 1, 4, tzinfo=timezone.utc),
257+
]
258+
batch1 = pa.RecordBatch.from_arrays(
259+
[pa.array(timestamps1, type=pa.timestamp("ns"))],
250260
names=["ts"],
251261
)
252-
df = ctx.create_dataframe([[batch]], name="timestamp_table")
262+
batch2 = pa.RecordBatch.from_arrays(
263+
[pa.array(timestamps2, type=pa.timestamp("ns"))],
264+
names=["ts"],
265+
)
266+
df = ctx.create_dataframe([[batch1], [batch2]], name="timestamp_table")
267+
268+
list_type = pa.list_(
269+
pa.field("item", type=pa.timestamp("ns"), nullable=wrap_in_scalar)
270+
)
253271

254272
collect = udaf(
255-
CollectTimestamps,
273+
lambda: CollectTimestamps(wrap_in_scalar),
256274
pa.timestamp("ns"),
257-
pa.list_(pa.timestamp("ns")),
258-
[pa.list_(pa.timestamp("ns"))],
275+
list_type,
276+
[list_type],
259277
volatility="immutable",
260278
)
261279

262280
result = df.aggregate([], [collect(column("ts"))]).collect()[0]
263281

264-
assert result.column(0) == pa.array(
265-
[timestamps],
266-
type=pa.list_(pa.timestamp("ns")),
282+
# There is no guarantee about the ordering of the batches, so perform a sort
283+
# to get consistent results. Alternatively we could sort on evaluate().
284+
assert (
285+
result.column(0).values.sort()
286+
== pa.array(
287+
[[*timestamps1, *timestamps2]],
288+
type=list_type,
289+
).values
267290
)

0 commit comments

Comments
 (0)