|
16 | 16 |
|
17 | 17 | import abc |
18 | 18 | import dataclasses |
| 19 | +import datetime |
19 | 20 | import functools |
20 | 21 | import itertools |
21 | | -from typing import Iterator, Literal, Optional, Union |
| 22 | +from typing import Any, Iterator, Literal, Optional, Sequence, Union |
22 | 23 |
|
23 | | -from google.cloud import bigquery |
| 24 | +from google.cloud import bigquery, bigquery_storage_v1 |
24 | 25 | import pandas as pd |
25 | 26 | import pyarrow |
| 27 | +import pyarrow as pa |
26 | 28 |
|
27 | 29 | import bigframes |
28 | 30 | import bigframes.core |
|
38 | 40 | ) |
39 | 41 |
|
40 | 42 |
|
41 | | -@dataclasses.dataclass(frozen=True) |
42 | | -class ExecuteResult: |
43 | | - _arrow_batches: Iterator[pyarrow.RecordBatch] |
44 | | - schema: bigframes.core.schema.ArraySchema |
45 | | - query_job: Optional[bigquery.QueryJob] = None |
46 | | - total_bytes: Optional[int] = None |
47 | | - total_rows: Optional[int] = None |
48 | | - total_bytes_processed: Optional[int] = None |
| 43 | +class ExecuteResult(abc.ABC): |
| 44 | + @property |
| 45 | + @abc.abstractmethod |
| 46 | + def query_job(self) -> Optional[bigquery.QueryJob]: |
| 47 | + ... |
| 48 | + |
| 49 | + @property |
| 50 | + @abc.abstractmethod |
| 51 | + def total_bytes(self) -> Optional[int]: |
| 52 | + ... |
| 53 | + |
| 54 | + @property |
| 55 | + @abc.abstractmethod |
| 56 | + def total_rows(self) -> Optional[int]: |
| 57 | + ... |
| 58 | + |
| 59 | + @property |
| 60 | + @abc.abstractmethod |
| 61 | + def total_bytes_processed(self) -> Optional[int]: |
| 62 | + ... |
| 63 | + |
| 64 | + @property |
| 65 | + @abc.abstractmethod |
| 66 | + def schema(self) -> bigframes.core.schema.ArraySchema: |
| 67 | + ... |
| 68 | + |
| 69 | + @abc.abstractmethod |
| 70 | + def _get_arrow_batches(self) -> Iterator[pyarrow.RecordBatch]: |
| 71 | + ... |
49 | 72 |
|
50 | 73 | @property |
51 | 74 | def arrow_batches(self) -> Iterator[pyarrow.RecordBatch]: |
52 | 75 | result_rows = 0 |
53 | 76 |
|
54 | | - for batch in self._arrow_batches: |
| 77 | + for batch in self._get_arrow_batches(): |
55 | 78 | batch = pyarrow_utils.cast_batch(batch, self.schema.to_pyarrow()) |
56 | 79 | result_rows += batch.num_rows |
57 | 80 |
|
@@ -121,6 +144,171 @@ def to_py_scalar(self): |
121 | 144 | return column[0] |
122 | 145 |
|
123 | 146 |
|
| 147 | +class LocalExecuteResult(ExecuteResult): |
| 148 | + def __init__(self, data: pa.Table, bf_schema: bigframes.core.schema.ArraySchema): |
| 149 | + self._data = data |
| 150 | + self._schema = bf_schema |
| 151 | + |
| 152 | + @property |
| 153 | + def query_job(self) -> Optional[bigquery.QueryJob]: |
| 154 | + return None |
| 155 | + |
| 156 | + @property |
| 157 | + def total_bytes(self) -> Optional[int]: |
| 158 | + return None |
| 159 | + |
| 160 | + @property |
| 161 | + def total_rows(self) -> Optional[int]: |
| 162 | + return self._data.num_rows |
| 163 | + |
| 164 | + @property |
| 165 | + def total_bytes_processed(self) -> Optional[int]: |
| 166 | + return None |
| 167 | + |
| 168 | + @property |
| 169 | + def schema(self) -> bigframes.core.schema.ArraySchema: |
| 170 | + return self._schema |
| 171 | + |
| 172 | + def _get_arrow_batches(self) -> Iterator[pyarrow.RecordBatch]: |
| 173 | + return iter(self._data.to_batches()) |
| 174 | + |
| 175 | + |
| 176 | +class EmptyExecuteResult(ExecuteResult): |
| 177 | + def __init__( |
| 178 | + self, |
| 179 | + bf_schema: bigframes.core.schema.ArraySchema, |
| 180 | + query_job: Optional[bigquery.QueryJob] = None, |
| 181 | + ): |
| 182 | + self._schema = bf_schema |
| 183 | + self._query_job = query_job |
| 184 | + |
| 185 | + @property |
| 186 | + def query_job(self) -> Optional[bigquery.QueryJob]: |
| 187 | + return self._query_job |
| 188 | + |
| 189 | + @property |
| 190 | + def total_bytes(self) -> Optional[int]: |
| 191 | + return None |
| 192 | + |
| 193 | + @property |
| 194 | + def total_rows(self) -> Optional[int]: |
| 195 | + return 0 |
| 196 | + |
| 197 | + @property |
| 198 | + def total_bytes_processed(self) -> Optional[int]: |
| 199 | + if self.query_job: |
| 200 | + return self.query_job.total_bytes_processed |
| 201 | + return None |
| 202 | + |
| 203 | + @property |
| 204 | + def schema(self) -> bigframes.core.schema.ArraySchema: |
| 205 | + return self._schema |
| 206 | + |
| 207 | + def _get_arrow_batches(self) -> Iterator[pyarrow.RecordBatch]: |
| 208 | + return iter([]) |
| 209 | + |
| 210 | + |
| 211 | +class BQTableExecuteResult(ExecuteResult): |
| 212 | + def __init__( |
| 213 | + self, |
| 214 | + data: bigquery.TableReference, |
| 215 | + bf_schema: bigframes.core.schema.ArraySchema, |
| 216 | + bq_client: bigquery.Client, |
| 217 | + storage_client: bigquery_storage_v1.BigQueryReadClient, |
| 218 | + *, |
| 219 | + query_job: Optional[bigquery.QueryJob] = None, |
| 220 | + snapshot_time: Optional[datetime.datetime] = None, |
| 221 | + limit: Optional[int] = None, |
| 222 | + selected_fields: Optional[Sequence[str]] = None, |
| 223 | + sql_predicate: Optional[str] = None, |
| 224 | + ): |
| 225 | + self._data = data |
| 226 | + self._schema = bf_schema |
| 227 | + self._query_job = query_job |
| 228 | + self._bqclient = bq_client |
| 229 | + self._storage_client = storage_client |
| 230 | + self._snapshot_time = snapshot_time |
| 231 | + self._limit = limit |
| 232 | + self._selected_fields = selected_fields |
| 233 | + self._predicate = sql_predicate |
| 234 | + |
| 235 | + @property |
| 236 | + def query_job(self) -> Optional[bigquery.QueryJob]: |
| 237 | + return self._query_job |
| 238 | + |
| 239 | + @property |
| 240 | + def total_bytes(self) -> Optional[int]: |
| 241 | + return None |
| 242 | + |
| 243 | + @property |
| 244 | + def total_rows(self) -> Optional[int]: |
| 245 | + return self._get_table_metadata(self._data).num_rows |
| 246 | + |
| 247 | + @functools.cache |
| 248 | + def _get_table_metadata(self) -> bigquery.Table: |
| 249 | + return self._bqclient.get_table(self._data) |
| 250 | + |
| 251 | + @property |
| 252 | + def total_bytes_processed(self) -> Optional[int]: |
| 253 | + if self.query_job: |
| 254 | + return self.query_job.total_bytes_processed |
| 255 | + return None |
| 256 | + |
| 257 | + @property |
| 258 | + def schema(self) -> bigframes.core.schema.ArraySchema: |
| 259 | + return self._schema |
| 260 | + |
| 261 | + def _get_arrow_batches(self) -> Iterator[pyarrow.RecordBatch]: |
| 262 | + import google.cloud.bigquery_storage_v1.types as bq_storage_types |
| 263 | + from google.protobuf import timestamp_pb2 |
| 264 | + |
| 265 | + table_mod_options = {} |
| 266 | + read_options_dict: dict[str, Any] = {} |
| 267 | + if self._selected_fields: |
| 268 | + read_options_dict["selected_fields"] = list(self._selected_fields) |
| 269 | + if self._predicate: |
| 270 | + read_options_dict["row_restriction"] = self._predicate |
| 271 | + read_options = bq_storage_types.ReadSession.TableReadOptions( |
| 272 | + **read_options_dict |
| 273 | + ) |
| 274 | + |
| 275 | + if self._snapshot_time: |
| 276 | + snapshot_time = timestamp_pb2.Timestamp() |
| 277 | + snapshot_time.FromDatetime(self._snapshot_time) |
| 278 | + table_mod_options["snapshot_time"] = snapshot_time = snapshot_time |
| 279 | + table_mods = bq_storage_types.ReadSession.TableModifiers(**table_mod_options) |
| 280 | + |
| 281 | + requested_session = bq_storage_types.stream.ReadSession( |
| 282 | + table=self._data.to_bqstorage(), |
| 283 | + data_format=bq_storage_types.DataFormat.ARROW, |
| 284 | + read_options=read_options, |
| 285 | + table_modifiers=table_mods, |
| 286 | + ) |
| 287 | + # Single stream to maintain ordering |
| 288 | + request = bq_storage_types.CreateReadSessionRequest( |
| 289 | + parent=f"projects/{self._data.project}", |
| 290 | + read_session=requested_session, |
| 291 | + max_stream_count=1, |
| 292 | + ) |
| 293 | + session = self._storage_client.create_read_session(request=request) |
| 294 | + |
| 295 | + if not session.streams: |
| 296 | + batches: Iterator[pa.RecordBatch] = iter([]) |
| 297 | + else: |
| 298 | + reader = self._storage_client.read_rows(session.streams[0].name) |
| 299 | + rowstream = reader.rows() |
| 300 | + |
| 301 | + def process_page(page): |
| 302 | + pa_batch = page.to_arrow() |
| 303 | + return pa.RecordBatch.from_arrays( |
| 304 | + pa_batch.columns, names=self.schema.names |
| 305 | + ) |
| 306 | + |
| 307 | + batches = map(process_page, rowstream.pages) |
| 308 | + |
| 309 | + return batches |
| 310 | + |
| 311 | + |
124 | 312 | @dataclasses.dataclass(frozen=True) |
125 | 313 | class HierarchicalKey: |
126 | 314 | columns: tuple[str, ...] |
|
0 commit comments