Skip to content

Commit 1f85658

Browse files
fix name mappings for remote sources
1 parent bf8c827 commit 1f85658

File tree

5 files changed

+28
-11
lines changed

5 files changed

+28
-11
lines changed

bigframes/core/bq_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,9 @@ def get_arrow_batches(
207207
batches = _iter_streams(session.streams, storage_read_client)
208208

209209
def process_batch(pa_batch):
210-
return pyarrow_utils.cast_batch(pa_batch, data.schema.to_pyarrow())
210+
return pyarrow_utils.cast_batch(
211+
pa_batch.select(columns), data.schema.select(columns).to_pyarrow()
212+
)
211213

212214
batches = map(process_batch, batches)
213215

bigframes/core/pyarrow_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@ def cast_batch(batch: pa.RecordBatch, schema: pa.Schema) -> pa.RecordBatch:
8484
)
8585

8686

87+
def rename_batch(batch: pa.RecordBatch, names: list[str]) -> pa.RecordBatch:
88+
if batch.schema.names == names:
89+
return batch
90+
# TODO: Use RecordBatch.rename_columns once min pyarrow>=16.0
91+
return pa.RecordBatch.from_arrays(batch.columns, names)
92+
93+
8794
def truncate_pyarrow_iterable(
8895
batches: Iterable[pa.RecordBatch], max_results: int
8996
) -> Iterator[pa.RecordBatch]:

bigframes/session/bq_caching_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ def _execute_plan_gbq(
687687
project_id=self.bqclient.project,
688688
storage_client=self.bqstoragereadclient,
689689
query_job=query_job,
690-
selected_fields=tuple(col for col in og_schema.names),
690+
selected_fields=tuple((col, col) for col in og_schema.names),
691691
)
692692
else:
693693
return executor.LocalExecuteResult(

bigframes/session/executor.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -220,14 +220,16 @@ def __init__(
220220
*,
221221
query_job: Optional[bigquery.QueryJob] = None,
222222
limit: Optional[int] = None,
223-
selected_fields: Optional[Sequence[str]] = None,
223+
selected_fields: Optional[Sequence[tuple[str, str]]] = None,
224224
):
225225
self._data = data
226226
self._project_id = project_id
227227
self._query_job = query_job
228228
self._storage_client = storage_client
229229
self._limit = limit
230-
self._selected_fields = selected_fields
230+
self._selected_fields = selected_fields or [
231+
(name, name) for name in data.schema.names
232+
]
231233

232234
@property
233235
def query_job(self) -> Optional[bigquery.QueryJob]:
@@ -240,20 +242,24 @@ def total_bytes_processed(self) -> Optional[int]:
240242
return None
241243

242244
@property
245+
@functools.cache
243246
def schema(self) -> bigframes.core.schema.ArraySchema:
244-
schema = self._data.schema
245-
if self._selected_fields:
246-
return schema.select(self._selected_fields)
247-
return schema
247+
source_ids = [selection[0] for selection in self._selected_fields]
248+
return self._data.schema.select(source_ids).rename(dict(self._selected_fields))
248249

249250
def batches(self) -> ResultsIterator:
250251
read_batches = bq_data.get_arrow_batches(
251252
self._data,
252-
self._selected_fields or self._data.schema.names,
253+
[x[0] for x in self._selected_fields],
253254
self._storage_client,
254255
self._project_id,
255256
)
256-
arrow_batches = read_batches.iter
257+
arrow_batches: Iterator[pa.RecordBatch] = map(
258+
functools.partial(
259+
pyarrow_utils.rename_batch, names=list(self.schema.names)
260+
),
261+
read_batches.iter,
262+
)
257263
approx_bytes: Optional[int] = read_batches.approx_bytes
258264
approx_rows: Optional[int] = self._data.n_rows or read_batches.approx_rows
259265

bigframes/session/read_api_execution.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ def execute(
5656
project_id=self.project,
5757
storage_client=self.bqstoragereadclient,
5858
limit=peek,
59-
selected_fields=[item.source_id for item in node.scan_list.items],
59+
selected_fields=[
60+
(item.source_id, item.id.sql) for item in node.scan_list.items
61+
],
6062
)
6163

6264
def _try_adapt_plan(

0 commit comments

Comments
 (0)