@@ -61,11 +61,14 @@ def state(self) -> list[pa.Scalar]:
6161
6262
6363class CollectTimestamps (Accumulator ):
64- def __init__ (self ):
64+ def __init__ (self , wrap_in_scalar : bool ):
6565 self ._values : list [datetime ] = []
66+ self .wrap_in_scalar = wrap_in_scalar
6667
6768 def state (self ) -> list [pa .Scalar ]:
68- return [pa .scalar (self ._values , type = pa .list_ (pa .timestamp ("ns" )))]
69+ if self .wrap_in_scalar :
70+ return [pa .scalar (self ._values , type = pa .list_ (pa .timestamp ("ns" )))]
71+ return [pa .array (self ._values , type = pa .timestamp ("ns" ))]
6972
7073 def update (self , values : pa .Array ) -> None :
7174 self ._values .extend (values .to_pylist ())
@@ -76,7 +79,9 @@ def merge(self, states: list[pa.Array]) -> None:
7679 self ._values .extend (state )
7780
7881 def evaluate (self ) -> pa .Scalar :
79- return pa .scalar (self ._values , type = pa .list_ (pa .timestamp ("ns" )))
82+ if self .wrap_in_scalar :
83+ return pa .scalar (self ._values , type = pa .list_ (pa .timestamp ("ns" )))
84+ return pa .array (self ._values , type = pa .timestamp ("ns" ))
8085
8186
8287@pytest .fixture
@@ -240,28 +245,46 @@ def test_register_udaf(ctx, df) -> None:
240245 assert df_result .collect ()[0 ][0 ][0 ].as_py () == 14.0
241246
242247
243- def test_udaf_list_timestamp_return (ctx ) -> None :
244- timestamps = [
248+ @pytest .mark .parametrize ("wrap_in_scalar" , [True , False ])
249+ def test_udaf_list_timestamp_return (ctx , wrap_in_scalar ) -> None :
250+ timestamps1 = [
245251 datetime (2024 , 1 , 1 , tzinfo = timezone .utc ),
246252 datetime (2024 , 1 , 2 , tzinfo = timezone .utc ),
247253 ]
248- batch = pa .RecordBatch .from_arrays (
249- [pa .array (timestamps , type = pa .timestamp ("ns" ))],
254+ timestamps2 = [
255+ datetime (2024 , 1 , 3 , tzinfo = timezone .utc ),
256+ datetime (2024 , 1 , 4 , tzinfo = timezone .utc ),
257+ ]
258+ batch1 = pa .RecordBatch .from_arrays (
259+ [pa .array (timestamps1 , type = pa .timestamp ("ns" ))],
250260 names = ["ts" ],
251261 )
252- df = ctx .create_dataframe ([[batch ]], name = "timestamp_table" )
262+ batch2 = pa .RecordBatch .from_arrays (
263+ [pa .array (timestamps2 , type = pa .timestamp ("ns" ))],
264+ names = ["ts" ],
265+ )
266+ df = ctx .create_dataframe ([[batch1 ], [batch2 ]], name = "timestamp_table" )
267+
268+ list_type = pa .list_ (
269+ pa .field ("item" , type = pa .timestamp ("ns" ), nullable = wrap_in_scalar )
270+ )
253271
254272 collect = udaf (
255- CollectTimestamps ,
273+ lambda : CollectTimestamps ( wrap_in_scalar ) ,
256274 pa .timestamp ("ns" ),
257- pa . list_ ( pa . timestamp ( "ns" )) ,
258- [pa . list_ ( pa . timestamp ( "ns" )) ],
275+ list_type ,
276+ [list_type ],
259277 volatility = "immutable" ,
260278 )
261279
262280 result = df .aggregate ([], [collect (column ("ts" ))]).collect ()[0 ]
263281
264- assert result .column (0 ) == pa .array (
265- [timestamps ],
266- type = pa .list_ (pa .timestamp ("ns" )),
282+ # There is no guarantee about the ordering of the batches, so perform a sort
283+ # to get consistent results. Alternatively we could sort on evaluate().
284+ assert (
285+ result .column (0 ).values .sort ()
286+ == pa .array (
287+ [[* timestamps1 , * timestamps2 ]],
288+ type = list_type ,
289+ ).values
267290 )
0 commit comments