Skip to content

Commit 9d3377a

Browse files
fix remaining tests
1 parent a2c9679 commit 9d3377a

File tree

4 files changed

+25
-21
lines changed

4 files changed

+25
-21
lines changed

bigframes/core/local_data.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import pyarrow as pa
3131
import pyarrow.parquet # type: ignore
3232

33+
from bigframes.core import pyarrow_utils
3334
import bigframes.core.schema as schemata
3435
import bigframes.dtypes
3536

@@ -113,7 +114,9 @@ def to_arrow(
113114
schema = self.data.schema
114115
if duration_type == "int":
115116
schema = _schema_durations_to_ints(schema)
116-
batches = map(functools.partial(_cast_pa_batch, schema=schema), batches)
117+
batches = map(
118+
functools.partial(pyarrow_utils.cast_batch, schema=schema), batches
119+
)
117120

118121
if offsets_col is not None:
119122
return schema.append(pa.field(offsets_col, pa.int64())), _append_offsets(
@@ -468,14 +471,6 @@ def _schema_durations_to_ints(schema: pa.Schema) -> pa.Schema:
468471
)
469472

470473

471-
# TODO: Use RecordBatch.cast once min pyarrow>=16.0
472-
def _cast_pa_batch(batch: pa.RecordBatch, schema: pa.Schema) -> pa.RecordBatch:
473-
return pa.record_batch(
474-
[arr.cast(type) for arr, type in zip(batch.columns, schema.types)],
475-
schema=schema,
476-
)
477-
478-
479474
def _pairwise(iterable):
480475
do_yield = False
481476
a = None

bigframes/core/pyarrow_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,11 @@ def chunk_by_row_count(
7777
def cast_batch(batch: pa.RecordBatch, schema: pa.Schema) -> pa.RecordBatch:
7878
if batch.schema == schema:
7979
return batch
80-
# Newer pyarrow versions can directly cast batches, but older supported versions do not.
81-
return pa.Table.from_batches([batch]).cast(schema).to_batches()[0]
80+
# TODO: Use RecordBatch.cast once min pyarrow>=16.0
81+
return pa.record_batch(
82+
[arr.cast(type) for arr, type in zip(batch.columns, schema.types)],
83+
schema=schema,
84+
)
8285

8386

8487
def truncate_pyarrow_iterable(

tests/system/small/test_dataframe_io.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,9 @@ def test_to_sql_query_unnamed_index_included(
10061006
assert idx_labels[0] is None
10071007
assert idx_ids[0].startswith("bigframes")
10081008

1009-
pd_df = scalars_pandas_df_default_index.reset_index(drop=True)
1009+
pd_df = scalars_pandas_df_default_index.reset_index(drop=True).drop(
1010+
columns="duration_col"
1011+
)
10101012
roundtrip = session.read_gbq(sql, index_col=idx_ids)
10111013
roundtrip.index.names = [None]
10121014
utils.assert_pandas_df_equal(roundtrip.to_pandas(), pd_df, check_index_type=False)
@@ -1026,7 +1028,9 @@ def test_to_sql_query_named_index_included(
10261028
assert idx_labels[0] == "rowindex_2"
10271029
assert idx_ids[0] == "rowindex_2"
10281030

1029-
pd_df = scalars_pandas_df_default_index.set_index("rowindex_2", drop=True)
1031+
pd_df = scalars_pandas_df_default_index.set_index("rowindex_2", drop=True).drop(
1032+
columns="duration_col"
1033+
)
10301034
roundtrip = session.read_gbq(sql, index_col=idx_ids)
10311035
utils.assert_pandas_df_equal(roundtrip.to_pandas(), pd_df)
10321036

@@ -1041,7 +1045,9 @@ def test_to_sql_query_unnamed_index_excluded(
10411045
assert len(idx_labels) == 0
10421046
assert len(idx_ids) == 0
10431047

1044-
pd_df = scalars_pandas_df_default_index.reset_index(drop=True)
1048+
pd_df = scalars_pandas_df_default_index.reset_index(drop=True).drop(
1049+
columns="duration_col"
1050+
)
10451051
roundtrip = session.read_gbq(sql)
10461052
utils.assert_pandas_df_equal(
10471053
roundtrip.to_pandas(), pd_df, check_index_type=False, ignore_order=True
@@ -1060,9 +1066,11 @@ def test_to_sql_query_named_index_excluded(
10601066
assert len(idx_labels) == 0
10611067
assert len(idx_ids) == 0
10621068

1063-
pd_df = scalars_pandas_df_default_index.set_index(
1064-
"rowindex_2", drop=True
1065-
).reset_index(drop=True)
1069+
pd_df = (
1070+
scalars_pandas_df_default_index.set_index("rowindex_2", drop=True)
1071+
.reset_index(drop=True)
1072+
.drop(columns="duration_col")
1073+
)
10661074
roundtrip = session.read_gbq(sql)
10671075
utils.assert_pandas_df_equal(
10681076
roundtrip.to_pandas(), pd_df, check_index_type=False, ignore_order=True

tests/system/small/test_session.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1977,16 +1977,14 @@ def test_read_json_gcs_default_engine(session, scalars_dfs, gcs_folder):
19771977

19781978
# The auto detects of BigQuery load job have restrictions to detect the bytes,
19791979
# numeric and geometry types, so they're skipped here.
1980-
df = df.drop(columns=["bytes_col", "numeric_col", "geography_col"])
1980+
df = df.drop(columns=["bytes_col", "numeric_col", "geography_col", "duration_col"])
19811981
scalars_df = scalars_df.drop(
19821982
columns=["bytes_col", "numeric_col", "geography_col", "duration_col"]
19831983
)
19841984

19851985
# pandas read_json does not respect the dtype overrides for these columns
19861986
df = df.drop(columns=["date_col", "datetime_col", "time_col"])
1987-
scalars_df = scalars_df.drop(
1988-
columns=["date_col", "datetime_col", "time_col", "duration_col"]
1989-
)
1987+
scalars_df = scalars_df.drop(columns=["date_col", "datetime_col", "time_col"])
19901988

19911989
assert df.shape[0] == scalars_df.shape[0]
19921990
pd.testing.assert_series_equal(df.dtypes, scalars_df.dtypes)

0 commit comments

Comments
 (0)