Skip to content

Commit b8a2d19

Browse files
fix duration tests
1 parent f07be8e commit b8a2d19

File tree

4 files changed

+12
-3
lines changed

4 files changed

+12
-3
lines changed

bigframes/core/pyarrow_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@ def chunk_by_row_count(
7474
yield buffer.take_as_batches(len(buffer))
7575

7676

77+
def cast_batch(batch: pa.RecordBatch, schema: pa.Schema) -> pa.RecordBatch:
78+
if batch.schema == schema:
79+
return batch
80+
# Newer pyarrow versions can directly cast batches, but older supported versions do not.
81+
return pa.Table.from_batches([batch]).cast(schema).to_batches()[0]
82+
83+
7784
def truncate_pyarrow_iterable(
7885
batches: Iterable[pa.RecordBatch], max_results: int
7986
) -> Iterator[pa.RecordBatch]:

bigframes/session/executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def arrow_batches(self) -> Iterator[pyarrow.RecordBatch]:
5050
result_rows = 0
5151

5252
for batch in self._arrow_batches:
53+
batch = pyarrow_utils.cast_batch(batch, self.schema.to_pyarrow())
5354
result_rows += batch.num_rows
5455

5556
maximum_result_rows = bigframes.options.compute.maximum_result_rows

tests/system/small/test_dataframe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,6 +1695,7 @@ def test_get_dtypes(scalars_df_default_index):
16951695
"string_col": pd.StringDtype(storage="pyarrow"),
16961696
"time_col": pd.ArrowDtype(pa.time64("us")),
16971697
"timestamp_col": pd.ArrowDtype(pa.timestamp("us", tz="UTC")),
1698+
"duration_col": pd.ArrowDtype(pa.duration("us")),
16981699
}
16991700
pd.testing.assert_series_equal(
17001701
dtypes,

tests/system/small/test_dataframe_io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_sql_executes(scalars_df_default_index, bigquery_client):
5555
"""
5656
# Do some operations to make for more complex SQL.
5757
df = (
58-
scalars_df_default_index.drop(columns=["geography_col"])
58+
scalars_df_default_index.drop(columns=["geography_col", "duration_col"])
5959
.groupby("string_col")
6060
.max()
6161
)
@@ -87,7 +87,7 @@ def test_sql_executes_and_includes_named_index(
8787
"""
8888
# Do some operations to make for more complex SQL.
8989
df = (
90-
scalars_df_default_index.drop(columns=["geography_col"])
90+
scalars_df_default_index.drop(columns=["geography_col", "duration_col"])
9191
.groupby("string_col")
9292
.max()
9393
)
@@ -120,7 +120,7 @@ def test_sql_executes_and_includes_named_multiindex(
120120
"""
121121
# Do some operations to make for more complex SQL.
122122
df = (
123-
scalars_df_default_index.drop(columns=["geography_col"])
123+
scalars_df_default_index.drop(columns=["geography_col", "duration_col"])
124124
.groupby(["string_col", "bool_col"])
125125
.max()
126126
)

0 commit comments

Comments
 (0)