Skip to content

Commit 703fd7e

Browse files
refactor materialization type normalize steps
1 parent 5afbb2d commit 703fd7e

File tree

6 files changed

+16
-34
lines changed

6 files changed

+16
-34
lines changed

bigframes/core/bq_data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from google.protobuf import timestamp_pb2
3131
import pyarrow as pa
3232

33+
from bigframes.core import pyarrow_utils
3334
import bigframes.core.schema
3435

3536
if typing.TYPE_CHECKING:
@@ -197,7 +198,7 @@ def get_arrow_batches(
197198
batches = _iter_streams(session.streams, storage_read_client)
198199

199200
def process_batch(pa_batch):
200-
return pa.RecordBatch.from_arrays(pa_batch.columns, names=data.schema.names)
201+
return pyarrow_utils.cast_batch(pa_batch, data.schema.to_pyarrow())
201202

202203
batches = map(process_batch, batches)
203204

bigframes/session/bq_caching_executor.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import google.cloud.bigquery.job as bq_job
2525
import google.cloud.bigquery.table as bq_table
2626
import google.cloud.bigquery_storage_v1
27-
import pyarrow as pa
2827

2928
import bigframes
3029
from bigframes import exceptions as bfe
@@ -321,7 +320,7 @@ def _export_gbq(
321320

322321
# TODO(swast): plumb through the api_name of the user-facing api that
323322
# caused this query.
324-
row_iter, query_job = self._run_execute_query(
323+
_, query_job = self._run_execute_query(
325324
sql=sql,
326325
job_config=job_config,
327326
)
@@ -688,9 +687,7 @@ def _execute_plan_gbq(
688687
)
689688
else:
690689
return executor.LocalExecuteResult(
691-
data=pa.Table.from_batches(
692-
iterator.to_arrow_iterable(), plan.schema.to_pyarrow()
693-
),
690+
data=iterator.to_arrow(),
694691
bf_schema=plan.schema,
695692
)
696693

bigframes/session/direct_gbq_execution.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from google.cloud import bigquery
1919
import google.cloud.bigquery.job as bq_job
2020
import google.cloud.bigquery.table as bq_table
21-
import pyarrow as pa
2221

2322
from bigframes.core import compile, nodes
2423
from bigframes.core.compile import sqlglot
@@ -67,7 +66,7 @@ def execute(
6766

6867
# just immediately downlaod everything for simplicity
6968
return executor.LocalExecuteResult(
70-
data=pa.Table.from_batches(iterator.to_arrow_iterable()),
69+
data=iterator.to_arrow(),
7170
bf_schema=plan.schema,
7271
)
7372

bigframes/session/executor.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
import bigframes
2929
import bigframes.core
30-
from bigframes.core import bq_data, pyarrow_utils
30+
from bigframes.core import bq_data, local_data, pyarrow_utils
3131
import bigframes.core.schema
3232
import bigframes.session._io.pandas as io_pandas
3333
import bigframes.session.execution_spec as ex_spec
@@ -70,7 +70,6 @@ def arrow_batches(self) -> Iterator[pyarrow.RecordBatch]:
7070
result_rows = 0
7171

7272
for batch in self._batches:
73-
batch = pyarrow_utils.cast_batch(batch, self._schema.to_pyarrow())
7473
result_rows += batch.num_rows
7574

7675
maximum_result_rows = bigframes.options.compute.maximum_result_rows
@@ -162,8 +161,10 @@ def batches(self) -> ResultsIterator:
162161

163162
class LocalExecuteResult(ExecuteResult):
164163
def __init__(self, data: pa.Table, bf_schema: bigframes.core.schema.ArraySchema):
165-
self._data = data
166-
self._schema = bf_schema
164+
self._data = local_data.ManagedArrowTable(
165+
data.cast(bf_schema.to_pyarrow()), bf_schema
166+
)
167+
self._data.validate()
167168

168169
@property
169170
def query_job(self) -> Optional[bigquery.QueryJob]:
@@ -175,14 +176,14 @@ def total_bytes_processed(self) -> Optional[int]:
175176

176177
@property
177178
def schema(self) -> bigframes.core.schema.ArraySchema:
178-
return self._schema
179+
return self._data.schema
179180

180181
def batches(self) -> ResultsIterator:
181182
return ResultsIterator(
182-
iter(self._data.to_batches()),
183+
iter(self._data.to_arrow()[1]),
183184
self.schema,
184-
self._data.num_rows,
185-
self._data.nbytes,
185+
self._data.metadata.row_count,
186+
self._data.metadata.total_bytes,
186187
)
187188

188189

bigframes/session/polars_executor.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,11 @@
1616
import itertools
1717
from typing import Optional, TYPE_CHECKING
1818

19-
import pyarrow as pa
20-
2119
from bigframes.core import (
2220
agg_expressions,
2321
array_value,
2422
bigframe_node,
2523
expression,
26-
local_data,
2724
nodes,
2825
)
2926
import bigframes.operations
@@ -154,22 +151,9 @@ def execute(
154151
lazy_frame = lazy_frame.limit(peek)
155152
pa_table = lazy_frame.collect().to_arrow()
156153
return executor.LocalExecuteResult(
157-
data=pa.Table.from_batches(
158-
map(self._adapt_batch, pa_table.to_batches()), plan.schema.to_pyarrow()
159-
),
154+
data=pa_table,
160155
bf_schema=plan.schema,
161156
)
162157

163158
def _can_execute(self, plan: bigframe_node.BigFrameNode):
164159
return all(_is_node_polars_executable(node) for node in plan.unique_nodes())
165-
166-
def _adapt_array(self, array: pa.Array) -> pa.Array:
167-
target_type = local_data.logical_type_replacements(array.type)
168-
if target_type != array.type:
169-
# Safe is false to handle weird polars decimal scaling
170-
return array.cast(target_type, safe=False)
171-
return array
172-
173-
def _adapt_batch(self, batch: pa.RecordBatch) -> pa.RecordBatch:
174-
new_arrays = [self._adapt_array(arr) for arr in batch.columns]
175-
return pa.RecordBatch.from_arrays(new_arrays, names=batch.column_names)

tests/unit/session/test_local_scan_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_local_scan_executor_with_slice(start, stop, expected_rows, object_under
7373
)
7474

7575
result = object_under_test.execute(plan, ordered=True)
76-
result_table = pyarrow.Table.from_batches(result.arrow_batches)
76+
result_table = pyarrow.Table.from_batches(result.batches().arrow_batches)
7777
assert result_table.num_rows == expected_rows
7878

7979

0 commit comments

Comments
 (0)