Skip to content

Commit 5afbb2d

Browse files
refactor result size statistics
1 parent 6d16001 commit 5afbb2d

File tree

9 files changed

+280
-175
lines changed

9 files changed

+280
-175
lines changed

bigframes/core/blocks.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
Optional,
3838
Sequence,
3939
Tuple,
40-
TYPE_CHECKING,
4140
Union,
4241
)
4342
import warnings
@@ -70,9 +69,6 @@
7069
from bigframes.session import dry_runs, execution_spec
7170
from bigframes.session import executor as executors
7271

73-
if TYPE_CHECKING:
74-
from bigframes.session.executor import ExecuteResult
75-
7672
# Type constraint for wherever column labels are used
7773
Label = typing.Hashable
7874

@@ -98,7 +94,6 @@
9894
LevelsType = typing.Union[LevelType, typing.Sequence[LevelType]]
9995

10096

101-
@dataclasses.dataclass
10297
class PandasBatches(Iterator[pd.DataFrame]):
10398
"""Interface for mutable objects with state represented by a block value object."""
10499

@@ -271,10 +266,14 @@ def shape(self) -> typing.Tuple[int, int]:
271266
except Exception:
272267
pass
273268

274-
row_count = self.session._executor.execute(
275-
self.expr.row_count(),
276-
execution_spec.ExecutionSpec(promise_under_10gb=True, ordered=False),
277-
).to_py_scalar()
269+
row_count = (
270+
self.session._executor.execute(
271+
self.expr.row_count(),
272+
execution_spec.ExecutionSpec(promise_under_10gb=True, ordered=False),
273+
)
274+
.batches()
275+
.to_py_scalar()
276+
)
278277
return (row_count, len(self.value_columns))
279278

280279
@property
@@ -584,7 +583,7 @@ def to_arrow(
584583
ordered=ordered,
585584
),
586585
)
587-
pa_table = execute_result.to_arrow_table()
586+
pa_table = execute_result.batches().to_arrow_table()
588587

589588
pa_index_labels = []
590589
for index_level, index_label in enumerate(self._index_labels):
@@ -636,17 +635,13 @@ def to_pandas(
636635
max_download_size, sampling_method, random_state
637636
)
638637

639-
ex_result = self._materialize_local(
638+
return self._materialize_local(
640639
materialize_options=MaterializationOptions(
641640
downsampling=sampling,
642641
allow_large_results=allow_large_results,
643642
ordered=ordered,
644643
)
645644
)
646-
df = ex_result.to_pandas()
647-
df = self._copy_index_to_pandas(df)
648-
df.set_axis(self.column_labels, axis=1, copy=False)
649-
return df, ex_result.query_job
650645

651646
def _get_sampling_option(
652647
self,
@@ -683,7 +678,7 @@ def try_peek(
683678
self.expr,
684679
execution_spec.ExecutionSpec(promise_under_10gb=under_10gb, peek=n),
685680
)
686-
df = result.to_pandas()
681+
df = result.batches().to_pandas()
687682
return self._copy_index_to_pandas(df)
688683
else:
689684
return None
@@ -704,13 +699,14 @@ def to_pandas_batches(
704699
if (allow_large_results is not None)
705700
else not bigframes.options._allow_large_results
706701
)
707-
execute_result = self.session._executor.execute(
702+
execution_result = self.session._executor.execute(
708703
self.expr,
709704
execution_spec.ExecutionSpec(
710705
promise_under_10gb=under_10gb,
711706
ordered=True,
712707
),
713708
)
709+
result_batches = execution_result.batches()
714710

715711
# To reduce the number of edge cases to consider when working with the
716712
# results of this, always return at least one DataFrame. See:
@@ -724,19 +720,21 @@ def to_pandas_batches(
724720
dfs = map(
725721
lambda a: a[0],
726722
itertools.zip_longest(
727-
execute_result.to_pandas_batches(page_size, max_results),
723+
result_batches.to_pandas_batches(page_size, max_results),
728724
[0],
729725
fillvalue=empty_val,
730726
),
731727
)
732728
dfs = iter(map(self._copy_index_to_pandas, dfs))
733729

734-
total_rows = execute_result.total_rows
730+
total_rows = result_batches.approx_total_rows
735731
if (total_rows is not None) and (max_results is not None):
736732
total_rows = min(total_rows, max_results)
737733

738734
return PandasBatches(
739-
dfs, total_rows, total_bytes_processed=execute_result.total_bytes_processed
735+
dfs,
736+
total_rows,
737+
total_bytes_processed=execution_result.total_bytes_processed,
740738
)
741739

742740
def _copy_index_to_pandas(self, df: pd.DataFrame) -> pd.DataFrame:
@@ -754,7 +752,7 @@ def _copy_index_to_pandas(self, df: pd.DataFrame) -> pd.DataFrame:
754752

755753
def _materialize_local(
756754
self, materialize_options: MaterializationOptions = MaterializationOptions()
757-
) -> ExecuteResult:
755+
) -> tuple[pd.DataFrame, Optional[bigquery.QueryJob]]:
758756
"""Run query and download results as a pandas DataFrame. Return the total number of results as well."""
759757
# TODO(swast): Allow for dry run and timeout.
760758
under_10gb = (
@@ -769,9 +767,11 @@ def _materialize_local(
769767
ordered=materialize_options.ordered,
770768
),
771769
)
770+
result_batches = execute_result.batches()
771+
772772
sample_config = materialize_options.downsampling
773-
if execute_result.total_bytes is not None:
774-
table_mb = execute_result.total_bytes / _BYTES_TO_MEGABYTES
773+
if result_batches.approx_total_bytes is not None:
774+
table_mb = result_batches.approx_total_bytes / _BYTES_TO_MEGABYTES
775775
max_download_size = sample_config.max_download_size
776776
fraction = (
777777
max_download_size / table_mb
@@ -792,7 +792,7 @@ def _materialize_local(
792792

793793
# TODO: Maybe materialize before downsampling
794794
# Some downsampling methods
795-
if fraction < 1 and (execute_result.total_rows is not None):
795+
if fraction < 1 and (result_batches.approx_total_rows is not None):
796796
if not sample_config.enable_downsampling:
797797
raise RuntimeError(
798798
f"The data size ({table_mb:.2f} MB) exceeds the maximum download limit of "
@@ -811,7 +811,7 @@ def _materialize_local(
811811
"the downloading limit."
812812
)
813813
warnings.warn(msg, category=UserWarning)
814-
total_rows = execute_result.total_rows
814+
total_rows = result_batches.approx_total_rows
815815
# Remove downsampling config from subsequent invocations, as otherwise could result in many
816816
# iterations if downsampling undershoots
817817
return self._downsample(
@@ -823,7 +823,10 @@ def _materialize_local(
823823
MaterializationOptions(ordered=materialize_options.ordered)
824824
)
825825
else:
826-
return execute_result
826+
df = result_batches.to_pandas()
827+
df = self._copy_index_to_pandas(df)
828+
df.set_axis(self.column_labels, axis=1, copy=False)
829+
return df, execute_result.query_job
827830

828831
def _downsample(
829832
self, total_rows: int, sampling_method: str, fraction: float, random_state
@@ -1662,15 +1665,19 @@ def retrieve_repr_request_results(
16621665
ordered=True,
16631666
),
16641667
)
1665-
row_count = self.session._executor.execute(
1666-
self.expr.row_count(),
1667-
execution_spec.ExecutionSpec(
1668-
promise_under_10gb=True,
1669-
ordered=False,
1670-
),
1671-
).to_py_scalar()
1668+
row_count = (
1669+
self.session._executor.execute(
1670+
self.expr.row_count(),
1671+
execution_spec.ExecutionSpec(
1672+
promise_under_10gb=True,
1673+
ordered=False,
1674+
),
1675+
)
1676+
.batches()
1677+
.to_py_scalar()
1678+
)
16721679

1673-
head_df = head_result.to_pandas()
1680+
head_df = head_result.batches().to_pandas()
16741681
return self._copy_index_to_pandas(head_df), row_count, head_result.query_job
16751682

16761683
def promote_offsets(self, label: Label = None) -> typing.Tuple[Block, str]:

bigframes/core/bq_data.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,21 @@
1414

1515
from __future__ import annotations
1616

17+
import concurrent.futures
1718
import dataclasses
1819
import datetime
1920
import functools
21+
import os
22+
import queue
23+
import threading
2024
import typing
21-
from typing import Optional, Sequence, Tuple
25+
from typing import Any, Iterator, Optional, Sequence, Tuple
2226

27+
from google.cloud import bigquery_storage_v1
2328
import google.cloud.bigquery as bq
29+
import google.cloud.bigquery_storage_v1.types as bq_storage_types
30+
from google.protobuf import timestamp_pb2
31+
import pyarrow as pa
2432

2533
import bigframes.core.schema
2634

@@ -82,3 +90,117 @@ class BigqueryDataSource:
8290
ordering: typing.Optional[orderings.RowOrdering] = None
8391
# Optimization field
8492
n_rows: Optional[int] = None
93+
94+
95+
_WORKER_TIME_INCREMENT = 0.05
96+
97+
98+
def _iter_stream(
99+
stream_name: str,
100+
storage_read_client: bigquery_storage_v1.BigQueryReadClient,
101+
result_queue: queue.Queue,
102+
stop_event: threading.Event,
103+
):
104+
reader = storage_read_client.read_rows(stream_name)
105+
for page in reader.rows():
106+
try:
107+
result_queue.put(page.to_arrow(), timeout=_WORKER_TIME_INCREMENT)
108+
except queue.Full:
109+
continue
110+
if stop_event.is_set():
111+
return
112+
113+
114+
def _iter_streams(
115+
streams, storage_read_client: bigquery_storage_v1.BigQueryReadClient
116+
) -> Iterator[pa.RecordBatch]:
117+
stop_event = threading.Event()
118+
result_queue: queue.Queue = queue.Queue(
119+
len(streams)
120+
) # each response is large, so small queue is appropriate
121+
122+
in_progress: list[concurrent.futures.Future] = []
123+
with concurrent.futures.ThreadPoolExecutor(max_workers=len(streams)) as pool:
124+
for stream in streams:
125+
in_progress.append(
126+
pool.submit(
127+
_iter_stream, stream, storage_read_client, result_queue, stop_event
128+
)
129+
)
130+
131+
while in_progress:
132+
try:
133+
yield result_queue.get(timeout=0.1)
134+
except queue.Empty:
135+
new_in_progress = []
136+
for future in in_progress:
137+
if future.done():
138+
try:
139+
future.result()
140+
finally:
141+
stop_event.set()
142+
raise
143+
else:
144+
new_in_progress.append(future)
145+
in_progress = new_in_progress
146+
147+
148+
@dataclasses.dataclass
149+
class ReadResult:
150+
iter: Iterator[pa.RecordBatch]
151+
approx_rows: int
152+
approx_bytes: int
153+
154+
155+
def get_arrow_batches(
156+
data: BigqueryDataSource,
157+
columns: Sequence[str],
158+
storage_read_client: bigquery_storage_v1.BigQueryReadClient,
159+
) -> ReadResult:
160+
table_mod_options = {}
161+
read_options_dict: dict[str, Any] = {"selected_fields": list(columns)}
162+
if data.sql_predicate:
163+
read_options_dict["row_restriction"] = data.sql_predicate
164+
read_options = bq_storage_types.ReadSession.TableReadOptions(**read_options_dict)
165+
166+
if data.at_time:
167+
snapshot_time = timestamp_pb2.Timestamp()
168+
snapshot_time.FromDatetime(data.at_time)
169+
table_mod_options["snapshot_time"] = snapshot_time
170+
table_mods = bq_storage_types.ReadSession.TableModifiers(**table_mod_options)
171+
172+
requested_session = bq_storage_types.stream.ReadSession(
173+
table=data.table.get_table_ref().to_bqstorage(),
174+
data_format=bq_storage_types.DataFormat.ARROW,
175+
read_options=read_options,
176+
table_modifiers=table_mods,
177+
)
178+
# Single stream to maintain ordering
179+
request = bq_storage_types.CreateReadSessionRequest(
180+
parent=f"projects/{data.table.project_id}",
181+
read_session=requested_session,
182+
max_stream_count=1,
183+
)
184+
185+
if data.ordering is not None:
186+
max_streams = 1
187+
else:
188+
max_streams = os.cpu_count() or 8
189+
190+
session = storage_read_client.create_read_session(
191+
request=request, max_stream_count=max_streams
192+
)
193+
194+
if not session.streams:
195+
batches: Iterator[pa.RecordBatch] = iter([])
196+
else:
197+
batches = _iter_streams(session.streams, storage_read_client)
198+
199+
def process_batch(pa_batch):
200+
return pa.RecordBatch.from_arrays(pa_batch.columns, names=data.schema.names)
201+
202+
batches = map(process_batch, batches)
203+
204+
return ReadResult(
205+
batches, session.estimated_row_count, session.estimated_total_bytes_scanned
206+
)

bigframes/core/indexes/base.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,13 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:
290290
count_agg = ex_types.UnaryAggregation(agg_ops.count_op, ex.deref(offsets_id))
291291
count_result = filtered_block._expr.aggregate([(count_agg, "count")])
292292

293-
count_scalar = self._block.session._executor.execute(
294-
count_result, ex_spec.ExecutionSpec(promise_under_10gb=True)
295-
).to_py_scalar()
293+
count_scalar = (
294+
self._block.session._executor.execute(
295+
count_result, ex_spec.ExecutionSpec(promise_under_10gb=True)
296+
)
297+
.batches()
298+
.to_py_scalar()
299+
)
296300

297301
if count_scalar == 0:
298302
raise KeyError(f"'{key}' is not in index")
@@ -301,9 +305,13 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:
301305
if count_scalar == 1:
302306
min_agg = ex_types.UnaryAggregation(agg_ops.min_op, ex.deref(offsets_id))
303307
position_result = filtered_block._expr.aggregate([(min_agg, "position")])
304-
position_scalar = self._block.session._executor.execute(
305-
position_result, ex_spec.ExecutionSpec(promise_under_10gb=True)
306-
).to_py_scalar()
308+
position_scalar = (
309+
self._block.session._executor.execute(
310+
position_result, ex_spec.ExecutionSpec(promise_under_10gb=True)
311+
)
312+
.batches()
313+
.to_py_scalar()
314+
)
307315
return int(position_scalar)
308316

309317
# Handle multiple matches based on index monotonicity
@@ -333,10 +341,14 @@ def _get_monotonic_slice(self, filtered_block, offsets_id: str) -> slice:
333341
combined_result = filtered_block._expr.aggregate(min_max_aggs)
334342

335343
# Execute query and extract positions
336-
result_df = self._block.session._executor.execute(
337-
combined_result,
338-
execution_spec=ex_spec.ExecutionSpec(promise_under_10gb=True),
339-
).to_pandas()
344+
result_df = (
345+
self._block.session._executor.execute(
346+
combined_result,
347+
execution_spec=ex_spec.ExecutionSpec(promise_under_10gb=True),
348+
)
349+
.batches()
350+
.to_pandas()
351+
)
340352
min_pos = int(result_df["min_pos"].iloc[0])
341353
max_pos = int(result_df["max_pos"].iloc[0])
342354

0 commit comments

Comments
 (0)