Skip to content

Commit ca51638

Browse files
refactor: ExecuteResult is reusable, sampleable
1 parent 7600001 commit ca51638

File tree

7 files changed

+255
-137
lines changed

7 files changed

+255
-137
lines changed

bigframes/session/bq_caching_executor.py

Lines changed: 19 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
import os
1919
import threading
2020
from typing import Literal, Mapping, Optional, Sequence, Tuple
21-
import warnings
2221
import weakref
2322

2423
import google.api_core.exceptions
2524
from google.cloud import bigquery
2625
import google.cloud.bigquery.job as bq_job
2726
import google.cloud.bigquery.table as bq_table
2827
import google.cloud.bigquery_storage_v1
28+
import pyarrow as pa
2929

3030
import bigframes
3131
from bigframes import exceptions as bfe
@@ -157,6 +157,7 @@ def __init__(
157157
self._semi_executors: Sequence[semi_executor.SemiExecutor] = (
158158
read_api_execution.ReadApiSemiExecutor(
159159
bqstoragereadclient=bqstoragereadclient,
160+
bqclient=self.bqclient,
160161
project=self.bqclient.project,
161162
),
162163
local_scan_executor.LocalScanExecutor(),
@@ -347,14 +348,9 @@ def _export_gbq(
347348
table.schema = array_value.schema.to_bigquery()
348349
self.bqclient.update_table(table, ["schema"])
349350

350-
return executor.ExecuteResult(
351-
row_iter.to_arrow_iterable(
352-
bqstorage_client=self.bqstoragereadclient,
353-
max_stream_count=_MAX_READ_STREAMS,
354-
),
355-
array_value.schema,
356-
query_job,
357-
total_bytes_processed=row_iter.total_bytes_processed,
351+
return executor.EmptyExecuteResult(
352+
bf_schema=array_value.schema,
353+
query_job=query_job,
358354
)
359355

360356
def dry_run(
@@ -672,41 +668,28 @@ def _execute_plan_gbq(
672668
query_with_job=(destination_table is not None),
673669
)
674670

675-
table_info: Optional[bigquery.Table] = None
676-
if query_job and query_job.destination:
677-
table_info = self.bqclient.get_table(query_job.destination)
678-
size_bytes = table_info.num_bytes
679-
else:
680-
size_bytes = None
681-
682671
# we could actually cache even when caching is not explicitly requested, but being conservative for now
683672
if cache_spec is not None:
684-
assert table_info is not None
673+
assert query_job and query_job.destination
685674
assert compiled.row_order is not None
675+
table_info = self.bqclient.get_table(query_job.destination)
686676
self.cache.cache_results_table(
687677
og_plan, table_info, compiled.row_order, num_rows=table_info.num_rows
688678
)
689679

690-
if size_bytes is not None and size_bytes >= MAX_SMALL_RESULT_BYTES:
691-
msg = bfe.format_message(
692-
"The query result size has exceeded 10 GB. In BigFrames 2.0 and "
693-
"later, you might need to manually set `allow_large_results=True` in "
694-
"the IO method or adjust the BigFrames option: "
695-
"`bigframes.options.compute.allow_large_results=True`."
680+
if query_job and query_job.destination:
681+
return executor.BQTableExecuteResult(
682+
data=query_job.destination,
683+
bf_schema=og_schema,
684+
bq_client=self.bqclient,
685+
storage_client=self.bqstoragereadclient,
686+
query_job=query_job,
687+
)
688+
else:
689+
return executor.LocalExecuteResult(
690+
data=pa.Table.from_batches(iterator.to_arrow_iterable()),
691+
bf_schema=plan.schema,
696692
)
697-
warnings.warn(msg, FutureWarning)
698-
699-
return executor.ExecuteResult(
700-
_arrow_batches=iterator.to_arrow_iterable(
701-
bqstorage_client=self.bqstoragereadclient,
702-
max_stream_count=_MAX_READ_STREAMS,
703-
),
704-
schema=og_schema,
705-
query_job=query_job,
706-
total_bytes=size_bytes,
707-
total_rows=iterator.total_rows,
708-
total_bytes_processed=iterator.total_bytes_processed,
709-
)
710693

711694

712695
def _if_schema_match(

bigframes/session/direct_gbq_execution.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from google.cloud import bigquery
1919
import google.cloud.bigquery.job as bq_job
2020
import google.cloud.bigquery.table as bq_table
21+
import pyarrow as pa
2122

2223
from bigframes.core import compile, nodes
2324
from bigframes.core.compile import sqlglot
@@ -64,13 +65,16 @@ def execute(
6465
sql=compiled.sql,
6566
)
6667

67-
return executor.ExecuteResult(
68-
_arrow_batches=iterator.to_arrow_iterable(),
69-
schema=plan.schema,
70-
query_job=query_job,
71-
total_rows=iterator.total_rows,
72-
total_bytes_processed=iterator.total_bytes_processed,
73-
)
68+
if query_job is not None and query_job.destination is not None:
69+
return executor.BQTableExecuteResult(
70+
data=query_job.destination,
71+
bf_schema=plan.schema,
72+
)
73+
else:
74+
return executor.LocalExecuteResult(
75+
data=pa.Table.from_batches(iterator.to_arrow_iterable()),
76+
bf_schema=plan.schema,
77+
)
7478

7579
def _run_execute_query(
7680
self,

bigframes/session/executor.py

Lines changed: 199 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616

1717
import abc
1818
import dataclasses
19+
import datetime
1920
import functools
2021
import itertools
21-
from typing import Iterator, Literal, Optional, Union
22+
from typing import Any, Iterator, Literal, Optional, Sequence, Union
2223

23-
from google.cloud import bigquery
24+
from google.cloud import bigquery, bigquery_storage_v1
2425
import pandas as pd
2526
import pyarrow
27+
import pyarrow as pa
2628

2729
import bigframes
2830
import bigframes.core
@@ -38,20 +40,41 @@
3840
)
3941

4042

41-
@dataclasses.dataclass(frozen=True)
42-
class ExecuteResult:
43-
_arrow_batches: Iterator[pyarrow.RecordBatch]
44-
schema: bigframes.core.schema.ArraySchema
45-
query_job: Optional[bigquery.QueryJob] = None
46-
total_bytes: Optional[int] = None
47-
total_rows: Optional[int] = None
48-
total_bytes_processed: Optional[int] = None
43+
class ExecuteResult(abc.ABC):
44+
@property
45+
@abc.abstractmethod
46+
def query_job(self) -> Optional[bigquery.QueryJob]:
47+
...
48+
49+
@property
50+
@abc.abstractmethod
51+
def total_bytes(self) -> Optional[int]:
52+
...
53+
54+
@property
55+
@abc.abstractmethod
56+
def total_rows(self) -> Optional[int]:
57+
...
58+
59+
@property
60+
@abc.abstractmethod
61+
def total_bytes_processed(self) -> Optional[int]:
62+
...
63+
64+
@property
65+
@abc.abstractmethod
66+
def schema(self) -> bigframes.core.schema.ArraySchema:
67+
...
68+
69+
@abc.abstractmethod
70+
def _get_arrow_batches(self) -> Iterator[pyarrow.RecordBatch]:
71+
...
4972

5073
@property
5174
def arrow_batches(self) -> Iterator[pyarrow.RecordBatch]:
5275
result_rows = 0
5376

54-
for batch in self._arrow_batches:
77+
for batch in self._get_arrow_batches():
5578
batch = pyarrow_utils.cast_batch(batch, self.schema.to_pyarrow())
5679
result_rows += batch.num_rows
5780

@@ -121,6 +144,171 @@ def to_py_scalar(self):
121144
return column[0]
122145

123146

147+
class LocalExecuteResult(ExecuteResult):
148+
def __init__(self, data: pa.Table, bf_schema: bigframes.core.schema.ArraySchema):
149+
self._data = data
150+
self._schema = bf_schema
151+
152+
@property
153+
def query_job(self) -> Optional[bigquery.QueryJob]:
154+
return None
155+
156+
@property
157+
def total_bytes(self) -> Optional[int]:
158+
return None
159+
160+
@property
161+
def total_rows(self) -> Optional[int]:
162+
return self._data.num_rows
163+
164+
@property
165+
def total_bytes_processed(self) -> Optional[int]:
166+
return None
167+
168+
@property
169+
def schema(self) -> bigframes.core.schema.ArraySchema:
170+
return self._schema
171+
172+
def _get_arrow_batches(self) -> Iterator[pyarrow.RecordBatch]:
173+
return iter(self._data.to_batches())
174+
175+
176+
class EmptyExecuteResult(ExecuteResult):
177+
def __init__(
178+
self,
179+
bf_schema: bigframes.core.schema.ArraySchema,
180+
query_job: Optional[bigquery.QueryJob] = None,
181+
):
182+
self._schema = bf_schema
183+
self._query_job = query_job
184+
185+
@property
186+
def query_job(self) -> Optional[bigquery.QueryJob]:
187+
return self._query_job
188+
189+
@property
190+
def total_bytes(self) -> Optional[int]:
191+
return None
192+
193+
@property
194+
def total_rows(self) -> Optional[int]:
195+
return 0
196+
197+
@property
198+
def total_bytes_processed(self) -> Optional[int]:
199+
if self.query_job:
200+
return self.query_job.total_bytes_processed
201+
return None
202+
203+
@property
204+
def schema(self) -> bigframes.core.schema.ArraySchema:
205+
return self._schema
206+
207+
def _get_arrow_batches(self) -> Iterator[pyarrow.RecordBatch]:
208+
return iter([])
209+
210+
211+
class BQTableExecuteResult(ExecuteResult):
212+
def __init__(
213+
self,
214+
data: bigquery.TableReference,
215+
bf_schema: bigframes.core.schema.ArraySchema,
216+
bq_client: bigquery.Client,
217+
storage_client: bigquery_storage_v1.BigQueryReadClient,
218+
*,
219+
query_job: Optional[bigquery.QueryJob] = None,
220+
snapshot_time: Optional[datetime.datetime] = None,
221+
limit: Optional[int] = None,
222+
selected_fields: Optional[Sequence[str]] = None,
223+
sql_predicate: Optional[str] = None,
224+
):
225+
self._data = data
226+
self._schema = bf_schema
227+
self._query_job = query_job
228+
self._bqclient = bq_client
229+
self._storage_client = storage_client
230+
self._snapshot_time = snapshot_time
231+
self._limit = limit
232+
self._selected_fields = selected_fields
233+
self._predicate = sql_predicate
234+
235+
@property
236+
def query_job(self) -> Optional[bigquery.QueryJob]:
237+
return self._query_job
238+
239+
@property
240+
def total_bytes(self) -> Optional[int]:
241+
return None
242+
243+
@property
244+
def total_rows(self) -> Optional[int]:
245+
return self._get_table_metadata(self._data).num_rows
246+
247+
@functools.cache
248+
def _get_table_metadata(self) -> bigquery.Table:
249+
return self._bqclient.get_table(self._data)
250+
251+
@property
252+
def total_bytes_processed(self) -> Optional[int]:
253+
if self.query_job:
254+
return self.query_job.total_bytes_processed
255+
return None
256+
257+
@property
258+
def schema(self) -> bigframes.core.schema.ArraySchema:
259+
return self._schema
260+
261+
def _get_arrow_batches(self) -> Iterator[pyarrow.RecordBatch]:
262+
import google.cloud.bigquery_storage_v1.types as bq_storage_types
263+
from google.protobuf import timestamp_pb2
264+
265+
table_mod_options = {}
266+
read_options_dict: dict[str, Any] = {}
267+
if self._selected_fields:
268+
read_options_dict["selected_fields"] = list(self._selected_fields)
269+
if self._predicate:
270+
read_options_dict["row_restriction"] = self._predicate
271+
read_options = bq_storage_types.ReadSession.TableReadOptions(
272+
**read_options_dict
273+
)
274+
275+
if self._snapshot_time:
276+
snapshot_time = timestamp_pb2.Timestamp()
277+
snapshot_time.FromDatetime(self._snapshot_time)
278+
table_mod_options["snapshot_time"] = snapshot_time = snapshot_time
279+
table_mods = bq_storage_types.ReadSession.TableModifiers(**table_mod_options)
280+
281+
requested_session = bq_storage_types.stream.ReadSession(
282+
table=self._data.to_bqstorage(),
283+
data_format=bq_storage_types.DataFormat.ARROW,
284+
read_options=read_options,
285+
table_modifiers=table_mods,
286+
)
287+
# Single stream to maintain ordering
288+
request = bq_storage_types.CreateReadSessionRequest(
289+
parent=f"projects/{self._data.project}",
290+
read_session=requested_session,
291+
max_stream_count=1,
292+
)
293+
session = self._storage_client.create_read_session(request=request)
294+
295+
if not session.streams:
296+
batches: Iterator[pa.RecordBatch] = iter([])
297+
else:
298+
reader = self._storage_client.read_rows(session.streams[0].name)
299+
rowstream = reader.rows()
300+
301+
def process_page(page):
302+
pa_batch = page.to_arrow()
303+
return pa.RecordBatch.from_arrays(
304+
pa_batch.columns, names=self.schema.names
305+
)
306+
307+
batches = map(process_page, rowstream.pages)
308+
309+
return batches
310+
311+
124312
@dataclasses.dataclass(frozen=True)
125313
class HierarchicalKey:
126314
columns: tuple[str, ...]

bigframes/session/local_scan_executor.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,7 @@ def execute(
5757
if (peek is not None) and (total_rows is not None):
5858
total_rows = min(peek, total_rows)
5959

60-
return executor.ExecuteResult(
61-
_arrow_batches=arrow_table.to_batches(),
62-
schema=plan.schema,
63-
query_job=None,
64-
total_bytes=None,
65-
total_rows=total_rows,
60+
return executor.LocalExecuteResult(
61+
data=arrow_table,
62+
bf_schema=plan.schema,
6663
)

bigframes/session/polars_executor.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,9 @@ def execute(
153153
if peek is not None:
154154
lazy_frame = lazy_frame.limit(peek)
155155
pa_table = lazy_frame.collect().to_arrow()
156-
return executor.ExecuteResult(
157-
_arrow_batches=iter(map(self._adapt_batch, pa_table.to_batches())),
158-
schema=plan.schema,
159-
total_bytes=pa_table.nbytes,
160-
total_rows=pa_table.num_rows,
156+
return executor.LocalExecuteResult(
157+
data=pa.Table.from_batches(map(self._adapt_batch, pa_table.to_batches())),
158+
bf_schema=plan.schema,
161159
)
162160

163161
def _can_execute(self, plan: bigframe_node.BigFrameNode):

0 commit comments

Comments
 (0)