Skip to content

Commit 6d16001

Browse files
refactor data source classes
1 parent ca51638 commit 6d16001

File tree

11 files changed

+176
-151
lines changed

11 files changed

+176
-151
lines changed

bigframes/core/array_value.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import pandas
2424
import pyarrow as pa
2525

26-
from bigframes.core import agg_expressions
26+
from bigframes.core import agg_expressions, bq_data
2727
import bigframes.core.expression as ex
2828
import bigframes.core.guid
2929
import bigframes.core.identifiers as ids
@@ -63,7 +63,7 @@ def from_pyarrow(cls, arrow_table: pa.Table, session: Session):
6363
def from_managed(cls, source: local_data.ManagedArrowTable, session: Session):
6464
scan_list = nodes.ScanList(
6565
tuple(
66-
nodes.ScanItem(ids.ColumnId(item.column), item.dtype, item.column)
66+
nodes.ScanItem(ids.ColumnId(item.column), item.column)
6767
for item in source.schema.items
6868
)
6969
)
@@ -100,7 +100,7 @@ def from_table(
100100
if offsets_col and primary_key:
101101
raise ValueError("must set at most one of 'offests', 'primary_key'")
102102
# define data source only for needed columns, this makes row-hashing cheaper
103-
table_def = nodes.GbqTable.from_table(table, columns=schema.names)
103+
table_def = bq_data.GbqTable.from_table(table, columns=schema.names)
104104

105105
# create ordering from info
106106
ordering = None
@@ -114,12 +114,13 @@ def from_table(
114114
# Scan all columns by default, we define this list as it can be pruned while preserving source_def
115115
scan_list = nodes.ScanList(
116116
tuple(
117-
nodes.ScanItem(ids.ColumnId(item.column), item.dtype, item.column)
117+
nodes.ScanItem(ids.ColumnId(item.column), item.column)
118118
for item in schema.items
119119
)
120120
)
121-
source_def = nodes.BigqueryDataSource(
121+
source_def = bq_data.BigqueryDataSource(
122122
table=table_def,
123+
schema=schema,
123124
at_time=at_time,
124125
sql_predicate=predicate,
125126
ordering=ordering,
@@ -130,7 +131,7 @@ def from_table(
130131
@classmethod
131132
def from_bq_data_source(
132133
cls,
133-
source: nodes.BigqueryDataSource,
134+
source: bq_data.BigqueryDataSource,
134135
scan_list: nodes.ScanList,
135136
session: Session,
136137
):

bigframes/core/bq_data.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import dataclasses
18+
import datetime
19+
import functools
20+
import typing
21+
from typing import Optional, Sequence, Tuple
22+
23+
import google.cloud.bigquery as bq
24+
25+
import bigframes.core.schema
26+
27+
if typing.TYPE_CHECKING:
28+
import bigframes.core.ordering as orderings
29+
30+
31+
@dataclasses.dataclass(frozen=True)
32+
class GbqTable:
33+
project_id: str = dataclasses.field()
34+
dataset_id: str = dataclasses.field()
35+
table_id: str = dataclasses.field()
36+
physical_schema: Tuple[bq.SchemaField, ...] = dataclasses.field()
37+
is_physically_stored: bool = dataclasses.field()
38+
cluster_cols: typing.Optional[Tuple[str, ...]]
39+
40+
@staticmethod
41+
def from_table(table: bq.Table, columns: Sequence[str] = ()) -> GbqTable:
42+
# Subsetting fields with columns can reduce cost of row-hash default ordering
43+
if columns:
44+
schema = tuple(item for item in table.schema if item.name in columns)
45+
else:
46+
schema = tuple(table.schema)
47+
return GbqTable(
48+
project_id=table.project,
49+
dataset_id=table.dataset_id,
50+
table_id=table.table_id,
51+
physical_schema=schema,
52+
is_physically_stored=(table.table_type in ["TABLE", "MATERIALIZED_VIEW"]),
53+
cluster_cols=None
54+
if table.clustering_fields is None
55+
else tuple(table.clustering_fields),
56+
)
57+
58+
def get_table_ref(self) -> bq.TableReference:
59+
return bq.TableReference(
60+
bq.DatasetReference(self.project_id, self.dataset_id), self.table_id
61+
)
62+
63+
@property
64+
@functools.cache
65+
def schema_by_id(self):
66+
return {col.name: col for col in self.physical_schema}
67+
68+
69+
@dataclasses.dataclass(frozen=True)
70+
class BigqueryDataSource:
71+
"""
72+
Google BigQuery Data source.
73+
74+
This should not be modified once defined, as all attributes contribute to the default ordering.
75+
"""
76+
77+
table: GbqTable
78+
schema: bigframes.core.schema.ArraySchema
79+
at_time: typing.Optional[datetime.datetime] = None
80+
# Added for backwards compatibility, not validated
81+
sql_predicate: typing.Optional[str] = None
82+
ordering: typing.Optional[orderings.RowOrdering] = None
83+
# Optimization field
84+
n_rows: Optional[int] = None

bigframes/core/compile/ibis_compiler/ibis_compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import bigframes_vendored.ibis.expr.types as ibis_types
2525

2626
from bigframes import dtypes, operations
27-
from bigframes.core import expression, pyarrow_utils
27+
from bigframes.core import bq_data, expression, pyarrow_utils
2828
import bigframes.core.compile.compiled as compiled
2929
import bigframes.core.compile.concat as concat_impl
3030
import bigframes.core.compile.configs as configs
@@ -186,7 +186,7 @@ def compile_readtable(node: nodes.ReadTableNode, *args):
186186
# TODO(b/395912450): Remove workaround solution once b/374784249 got resolved.
187187
for scan_item in node.scan_list.items:
188188
if (
189-
scan_item.dtype == dtypes.JSON_DTYPE
189+
node.source.schema.get_type(scan_item.source_id) == dtypes.JSON_DTYPE
190190
and ibis_table[scan_item.source_id].type() == ibis_dtypes.string
191191
):
192192
json_column = scalar_op_registry.parse_json(
@@ -204,7 +204,7 @@ def compile_readtable(node: nodes.ReadTableNode, *args):
204204

205205

206206
def _table_to_ibis(
207-
source: nodes.BigqueryDataSource,
207+
source: bq_data.BigqueryDataSource,
208208
scan_cols: typing.Sequence[str],
209209
) -> ibis_types.Table:
210210
full_table_name = (

bigframes/core/nodes.py

Lines changed: 16 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import abc
1818
import dataclasses
19-
import datetime
2019
import functools
2120
import itertools
2221
import typing
@@ -31,9 +30,7 @@
3130
Tuple,
3231
)
3332

34-
import google.cloud.bigquery as bq
35-
36-
from bigframes.core import agg_expressions, identifiers, local_data, sequences
33+
from bigframes.core import agg_expressions, bq_data, identifiers, local_data, sequences
3734
from bigframes.core.bigframe_node import BigFrameNode, COLUMN_SET
3835
import bigframes.core.expression as ex
3936
from bigframes.core.field import Field
@@ -599,14 +596,13 @@ def transform_children(self, t: Callable[[BigFrameNode], BigFrameNode]) -> LeafN
599596

600597
class ScanItem(typing.NamedTuple):
601598
id: identifiers.ColumnId
602-
dtype: bigframes.dtypes.Dtype # Might be multiple logical types for a given physical source type
603599
source_id: str # Flexible enough for both local data and bq data
604600

605601
def with_id(self, id: identifiers.ColumnId) -> ScanItem:
606-
return ScanItem(id, self.dtype, self.source_id)
602+
return ScanItem(id, self.source_id)
607603

608604
def with_source_id(self, source_id: str) -> ScanItem:
609-
return ScanItem(self.id, self.dtype, source_id)
605+
return ScanItem(self.id, source_id)
610606

611607

612608
@dataclasses.dataclass(frozen=True)
@@ -661,7 +657,7 @@ def remap_source_ids(
661657
def append(
662658
self, source_id: str, dtype: bigframes.dtypes.Dtype, id: identifiers.ColumnId
663659
) -> ScanList:
664-
return ScanList((*self.items, ScanItem(id, dtype, source_id)))
660+
return ScanList((*self.items, ScanItem(id, source_id)))
665661

666662

667663
@dataclasses.dataclass(frozen=True, eq=False)
@@ -677,8 +673,10 @@ class ReadLocalNode(LeafNode):
677673
@property
678674
def fields(self) -> Sequence[Field]:
679675
fields = tuple(
680-
Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items
676+
Field(col_id, self.local_data_source.schema.get_type(source_id))
677+
for col_id, source_id in self.scan_list.items
681678
)
679+
682680
if self.offsets_col is not None:
683681
return tuple(
684682
itertools.chain(
@@ -726,7 +724,7 @@ def remap_vars(
726724
) -> ReadLocalNode:
727725
new_scan_list = ScanList(
728726
tuple(
729-
ScanItem(mappings.get(item.id, item.id), item.dtype, item.source_id)
727+
ScanItem(mappings.get(item.id, item.id), item.source_id)
730728
for item in self.scan_list.items
731729
)
732730
)
@@ -745,64 +743,10 @@ def remap_refs(
745743
return self
746744

747745

748-
@dataclasses.dataclass(frozen=True)
749-
class GbqTable:
750-
project_id: str = dataclasses.field()
751-
dataset_id: str = dataclasses.field()
752-
table_id: str = dataclasses.field()
753-
physical_schema: Tuple[bq.SchemaField, ...] = dataclasses.field()
754-
is_physically_stored: bool = dataclasses.field()
755-
cluster_cols: typing.Optional[Tuple[str, ...]]
756-
757-
@staticmethod
758-
def from_table(table: bq.Table, columns: Sequence[str] = ()) -> GbqTable:
759-
# Subsetting fields with columns can reduce cost of row-hash default ordering
760-
if columns:
761-
schema = tuple(item for item in table.schema if item.name in columns)
762-
else:
763-
schema = tuple(table.schema)
764-
return GbqTable(
765-
project_id=table.project,
766-
dataset_id=table.dataset_id,
767-
table_id=table.table_id,
768-
physical_schema=schema,
769-
is_physically_stored=(table.table_type in ["TABLE", "MATERIALIZED_VIEW"]),
770-
cluster_cols=None
771-
if table.clustering_fields is None
772-
else tuple(table.clustering_fields),
773-
)
774-
775-
def get_table_ref(self) -> bq.TableReference:
776-
return bq.TableReference(
777-
bq.DatasetReference(self.project_id, self.dataset_id), self.table_id
778-
)
779-
780-
@property
781-
@functools.cache
782-
def schema_by_id(self):
783-
return {col.name: col for col in self.physical_schema}
784-
785-
786-
@dataclasses.dataclass(frozen=True)
787-
class BigqueryDataSource:
788-
"""
789-
Google BigQuery Data source.
790-
791-
This should not be modified once defined, as all attributes contribute to the default ordering.
792-
"""
793-
794-
table: GbqTable
795-
at_time: typing.Optional[datetime.datetime] = None
796-
# Added for backwards compatibility, not validated
797-
sql_predicate: typing.Optional[str] = None
798-
ordering: typing.Optional[orderings.RowOrdering] = None
799-
n_rows: Optional[int] = None
800-
801-
802746
## Put ordering in here or just add order_by node above?
803747
@dataclasses.dataclass(frozen=True, eq=False)
804748
class ReadTableNode(LeafNode):
805-
source: BigqueryDataSource
749+
source: bq_data.BigqueryDataSource
806750
# Subset of physical schema column
807751
# Mapping of table schema ids to bfet id.
808752
scan_list: ScanList
@@ -826,8 +770,12 @@ def session(self):
826770
@property
827771
def fields(self) -> Sequence[Field]:
828772
return tuple(
829-
Field(col_id, dtype, self.source.table.schema_by_id[source_id].is_nullable)
830-
for col_id, dtype, source_id in self.scan_list.items
773+
Field(
774+
col_id,
775+
self.source.schema.get_type(source_id),
776+
self.source.table.schema_by_id[source_id].is_nullable,
777+
)
778+
for col_id, source_id in self.scan_list.items
831779
)
832780

833781
@property
@@ -886,7 +834,7 @@ def remap_vars(
886834
) -> ReadTableNode:
887835
new_scan_list = ScanList(
888836
tuple(
889-
ScanItem(mappings.get(item.id, item.id), item.dtype, item.source_id)
837+
ScanItem(mappings.get(item.id, item.id), item.source_id)
890838
for item in self.scan_list.items
891839
)
892840
)
@@ -907,7 +855,6 @@ def with_order_cols(self):
907855
new_scan_cols = [
908856
ScanItem(
909857
identifiers.ColumnId.unique(),
910-
dtype=bigframes.dtypes.convert_schema_field(field)[1],
911858
source_id=field.name,
912859
)
913860
for field in self.source.table.physical_schema

bigframes/core/rewrite/fold_row_count.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import pyarrow as pa
1717

18-
from bigframes import dtypes
1918
from bigframes.core import local_data, nodes
2019
from bigframes.operations import aggregations
2120

@@ -34,10 +33,7 @@ def fold_row_counts(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
3433
pa.table({"count": pa.array([node.child.row_count], type=pa.int64())})
3534
)
3635
scan_list = nodes.ScanList(
37-
tuple(
38-
nodes.ScanItem(out_id, dtypes.INT_DTYPE, "count")
39-
for _, out_id in node.aggregations
40-
)
36+
tuple(nodes.ScanItem(out_id, "count") for _, out_id in node.aggregations)
4137
)
4238
return nodes.ReadLocalNode(
4339
local_data_source=local_data_source, scan_list=scan_list, session=node.session

0 commit comments

Comments
 (0)