Skip to content

Commit 50f5c09

Browse files
fix: Fix sample op volatility in absence of total order
1 parent 7e959b9 commit 50f5c09

File tree

20 files changed

+334
-68
lines changed

20 files changed

+334
-68
lines changed

bigframes/core/array_value.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -540,13 +540,22 @@ def explode(self, column_ids: typing.Sequence[str]) -> ArrayValue:
540540
offsets = tuple(ex.deref(id) for id in column_ids)
541541
return ArrayValue(nodes.ExplodeNode(child=self.node, column_ids=offsets))
542542

543-
def _uniform_sampling(self, fraction: float) -> ArrayValue:
543+
def _uniform_sampling(
544+
self, fraction: float, shuffle: bool, seed: Optional[int] = None
545+
) -> ArrayValue:
544546
"""Sampling the table on given fraction.
545547
546548
.. warning::
547549
The row numbers of result is non-deterministic, avoid to use.
548550
"""
549-
return ArrayValue(nodes.RandomSampleNode(self.node, fraction))
551+
return ArrayValue(
552+
nodes.RandomSampleNode(self.node, fraction, shuffle=shuffle, seed=seed)
553+
)
554+
555+
def _shuffle(self, seed: Optional[int] = None):
556+
return ArrayValue(
557+
nodes.RandomSampleNode(self.node, fraction=1.0, shuffle=True, seed=seed)
558+
)
550559

551560
# Deterministically generate namespaced ids for new variables
552561
# These new ids are only unique within the current namespace.

bigframes/core/blocks.py

Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -833,35 +833,46 @@ def _materialize_local(
833833
return df, execute_result.query_job
834834

835835
def _downsample(
836-
self, total_rows: int, sampling_method: str, fraction: float, random_state
836+
self,
837+
total_rows: int,
838+
sampling_method: str,
839+
fraction: float,
840+
random_state: Optional[int],
837841
) -> Block:
838842
# either selecting fraction or number of rows
839843
if sampling_method == _HEAD:
840844
filtered_block = self.slice(stop=int(total_rows * fraction))
841845
return filtered_block
842846
elif (sampling_method == _UNIFORM) and (random_state is None):
843-
filtered_expr = self.expr._uniform_sampling(fraction)
844-
block = Block(
845-
filtered_expr,
846-
index_columns=self.index_columns,
847-
column_labels=self.column_labels,
848-
index_labels=self.index.names,
849-
)
850-
return block
847+
return self.sample(fraction=fraction, shuffle=False, seed=random_state)
851848
elif sampling_method == _UNIFORM:
852-
block = self.split(
853-
fracs=(fraction,),
854-
random_state=random_state,
855-
sort=False,
856-
)[0]
857-
return block
849+
return self.sample(fraction=fraction, shuffle=False)
858850
else:
859851
# This part should never be called, just in case.
860852
raise NotImplementedError(
861853
f"The downsampling method {sampling_method} is not implemented, "
862854
f"please choose from {','.join(_SAMPLING_METHODS)}."
863855
)
864856

857+
def sample(
858+
self, fraction: float, shuffle: bool, seed: Optional[int] = None
859+
) -> Block:
860+
assert fraction <= 1.0 and fraction >= 0
861+
return Block(
862+
self.expr._uniform_sampling(fraction=fraction, shuffle=shuffle, seed=seed),
863+
index_columns=self.index_columns,
864+
column_labels=self.column_labels,
865+
index_labels=self.index.names,
866+
)
867+
868+
def shuffle(self, seed: Optional[int] = None) -> Block:
869+
return Block(
870+
self.expr._uniform_sampling(fraction=1.0, shuffle=True, seed=seed),
871+
index_columns=self.index_columns,
872+
column_labels=self.column_labels,
873+
index_labels=self.index.names,
874+
)
875+
865876
def split(
866877
self,
867878
ns: Iterable[int] = (),
@@ -894,22 +905,11 @@ def split(
894905
random_state = random.randint(-(2**63), 2**63 - 1)
895906

896907
# Create a new column with random_state value.
897-
block, random_state_col = block.create_constant(str(random_state))
908+
og_ordering_col = None
909+
if sort is False:
910+
block, og_ordering_col = block.promote_offsets()
898911

899-
# Create an ordering col and convert to string
900-
block, ordering_col = block.promote_offsets()
901-
block, string_ordering_col = block.apply_unary_op(
902-
ordering_col, ops.AsTypeOp(to_type=bigframes.dtypes.STRING_DTYPE)
903-
)
904-
905-
# Apply hash method to sum col and order by it.
906-
block, string_sum_col = block.apply_binary_op(
907-
string_ordering_col, random_state_col, ops.strconcat_op
908-
)
909-
block, hash_string_sum_col = block.apply_unary_op(string_sum_col, ops.hash_op)
910-
block = block.order_by(
911-
[ordering.OrderingExpression(ex.deref(hash_string_sum_col))]
912-
)
912+
block = block.shuffle(seed=random_state)
913913

914914
intervals = []
915915
cur = 0
@@ -934,21 +934,15 @@ def split(
934934
for sliced_block in sliced_blocks
935935
]
936936
elif sort is False:
937+
assert og_ordering_col is not None
937938
sliced_blocks = [
938939
sliced_block.order_by(
939-
[ordering.OrderingExpression(ex.deref(ordering_col))]
940-
)
940+
[ordering.OrderingExpression(ex.deref(og_ordering_col))]
941+
).drop_columns([og_ordering_col])
941942
for sliced_block in sliced_blocks
942943
]
943944

944-
drop_cols = [
945-
random_state_col,
946-
ordering_col,
947-
string_ordering_col,
948-
string_sum_col,
949-
hash_string_sum_col,
950-
]
951-
return [sliced_block.drop_columns(drop_cols) for sliced_block in sliced_blocks]
945+
return [sliced_block for sliced_block in sliced_blocks]
952946

953947
def _compute_dry_run(
954948
self,

bigframes/core/compile/ibis_compiler/default_ordering.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,11 @@ def _convert_to_nonnull_string(column: ibis_types.Value) -> ibis_types.StringVal
5757
)
5858

5959

60-
def gen_row_key(
60+
def gen_row_hash(
6161
columns: Sequence[ibis_types.Value],
6262
) -> bigframes_vendored.ibis.Value:
6363
ordering_hash_part = guid.generate_guid("bigframes_ordering_")
6464
ordering_hash_part2 = guid.generate_guid("bigframes_ordering_")
65-
ordering_rand_part = guid.generate_guid("bigframes_ordering_")
6665

6766
# All inputs into hash must be non-null or resulting hash will be null
6867
str_values = list(map(_convert_to_nonnull_string, columns))
@@ -81,11 +80,4 @@ def gen_row_key(
8180
.name(ordering_hash_part2)
8281
.cast(ibis_dtypes.String(nullable=True))
8382
)
84-
# Used to disambiguate between identical rows (which will have identical hash)
85-
random_value = (
86-
bigframes_vendored.ibis.random()
87-
.name(ordering_rand_part)
88-
.cast(ibis_dtypes.String(nullable=True))
89-
)
90-
91-
return full_row_hash.concat(full_row_hash_p2, random_value)
83+
return full_row_hash.concat(full_row_hash_p2)

bigframes/core/compile/ibis_compiler/ibis_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
7777

7878
def _replace_unsupported_ops(node: nodes.BigFrameNode):
7979
# TODO: Run all replacement rules as single bottom-up pass
80+
node = nodes.bottom_up(node, rewrites.rewrite_random_sample)
8081
node = nodes.bottom_up(node, rewrites.rewrite_slice)
8182
node = nodes.bottom_up(node, rewrites.rewrite_timedelta_expressions)
8283
node = nodes.bottom_up(node, rewrites.rewrite_range_rolling)

bigframes/core/compile/ibis_compiler/scalar_op_compiler.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,35 @@ def compile_row_op(
100100
impl = self._registry[op.name]
101101
return impl(inputs, op)
102102

103+
def register_nullary_op(
104+
self,
105+
op_ref: typing.Union[ops.NullaryOp, type[ops.NullaryOp]],
106+
pass_op: bool = False,
107+
):
108+
"""
109+
Decorator to register a unary op implementation.
110+
111+
Args:
112+
op_ref (UnaryOp or UnaryOp type):
113+
Class or instance of operator that is implemented by the decorated function.
114+
pass_op (bool):
115+
Set to true if implementation takes the operator object as the last argument.
116+
This is needed for parameterized ops where parameters are part of op object.
117+
"""
118+
key = typing.cast(str, op_ref.name)
119+
120+
def decorator(impl: typing.Callable[..., ibis_types.Value]):
121+
def normalized_impl(args: typing.Sequence[ibis_types.Value], op: ops.RowOp):
122+
if pass_op:
123+
return impl(op)
124+
else:
125+
return impl()
126+
127+
self._register(key, normalized_impl)
128+
return impl
129+
130+
return decorator
131+
103132
def register_unary_op(
104133
self,
105134
op_ref: typing.Union[ops.UnaryOp, type[ops.UnaryOp]],

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1987,9 +1987,19 @@ def _construct_prompt(
19871987
return ibis.struct(prompt)
19881988

19891989

1990-
@scalar_op_compiler.register_nary_op(ops.RowKey, pass_op=True)
1991-
def rowkey_op_impl(*values: ibis_types.Value, op: ops.RowKey) -> ibis_types.Value:
1992-
return bigframes.core.compile.ibis_compiler.default_ordering.gen_row_key(values)
1990+
@scalar_op_compiler.register_nary_op(ops.RowHash, pass_op=True)
1991+
def rowkey_op_impl(*values: ibis_types.Value, op: ops.RowHash) -> ibis_types.Value:
1992+
return bigframes.core.compile.ibis_compiler.default_ordering.gen_row_hash(values)
1993+
1994+
1995+
@scalar_op_compiler.register_nullary_op(ops.rand_op, pass_op=False)
1996+
def rand_op_impl() -> ibis_types.Value:
1997+
return ibis.random()
1998+
1999+
2000+
@scalar_op_compiler.register_nullary_op(ops.gen_uuid_op, pass_op=False)
2001+
def gen_uuid_op_impl() -> ibis_types.Value:
2002+
return ibis.uuid()
19932003

19942004

19952005
# Helpers

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotI
386386

387387

388388
def _replace_unsupported_ops(node: nodes.BigFrameNode):
389+
node = nodes.bottom_up(node, rewrite.rewrite_random_sample)
389390
node = nodes.bottom_up(node, rewrite.rewrite_slice)
390391
node = nodes.bottom_up(node, rewrite.rewrite_range_rolling)
391392
return node

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2424
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2525

26+
register_nullary_op = scalar_compiler.scalar_op_compiler.register_nullary_op
2627
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
2728
register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op
2829
register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op
@@ -173,7 +174,7 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
173174
return sge.Coalesce(this=left.expr, expressions=[right.expr])
174175

175176

176-
@register_nary_op(ops.RowKey)
177+
@register_nary_op(ops.RowHash)
177178
def _(*values: TypedExpr) -> sge.Expression:
178179
# All inputs into hash must be non-null or resulting hash will be null
179180
str_values = [_convert_to_nonnull_string_sqlglot(value) for value in values]
@@ -197,6 +198,16 @@ def _(*values: TypedExpr) -> sge.Expression:
197198
)
198199

199200

201+
@register_nullary_op(ops.rand_op)
202+
def _() -> sge.Expression:
203+
return sge.func("RAND")
204+
205+
206+
@register_nullary_op(ops.gen_uuid_op)
207+
def _() -> sge.Expression:
208+
return sge.func("GENERATE_UUID")
209+
210+
200211
# Helper functions
201212
def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
202213
from_type = expr.dtype

bigframes/core/compile/sqlglot/scalar_compiler.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,35 @@ def compile_row_op(
9393
impl = self._registry[op.name]
9494
return impl(inputs, op)
9595

96+
def register_nullary_op(
97+
self,
98+
op_ref: typing.Union[ops.NullaryOp, type[ops.NullaryOp]],
99+
pass_op: bool = False,
100+
):
101+
"""
102+
Decorator to register a unary op implementation.
103+
104+
Args:
105+
op_ref (UnaryOp or UnaryOp type):
106+
Class or instance of operator that is implemented by the decorated function.
107+
pass_op (bool):
108+
Set to true if implementation takes the operator object as the last argument.
109+
This is needed for parameterized ops where parameters are part of op object.
110+
"""
111+
key = typing.cast(str, op_ref.name)
112+
113+
def decorator(impl: typing.Callable[..., sge.Expression]):
114+
def normalized_impl(args: typing.Sequence[TypedExpr], op: ops.RowOp):
115+
if pass_op:
116+
return impl(op)
117+
else:
118+
return impl()
119+
120+
self._register(key, normalized_impl)
121+
return impl
122+
123+
return decorator
124+
96125
def register_unary_op(
97126
self,
98127
op_ref: typing.Union[ops.UnaryOp, type[ops.UnaryOp]],

bigframes/core/nodes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1531,10 +1531,12 @@ def remap_refs(
15311531
@dataclasses.dataclass(frozen=True, eq=False)
15321532
class RandomSampleNode(UnaryNode):
15331533
fraction: float
1534+
shuffle: bool
1535+
seed: Optional[int] = None
15341536

15351537
@property
15361538
def deterministic(self) -> bool:
1537-
return False
1539+
return self.seed is not None
15381540

15391541
@property
15401542
def row_preserving(self) -> bool:

0 commit comments

Comments
 (0)