Skip to content

Commit 5f13789

Browse files
committed
use fixed agg name and undo changes on op registeration
1 parent 66d4be7 commit 5f13789

File tree

3 files changed

+9
-11
lines changed

3 files changed

+9
-11
lines changed

bigframes/core/compile/sqlglot/aggregations/op_registration.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,12 @@ def arg_checker(*args, **kwargs):
4141
)
4242
return item(*args, **kwargs)
4343

44-
key = op if isinstance(op, type) else type(op)
45-
if key in self._registered_ops:
46-
raise ValueError(f"{key} is already registered")
44+
if hasattr(op, "name"):
45+
key = typing.cast(str, op.name)
46+
if key in self._registered_ops:
47+
raise ValueError(f"{key} is already registered")
48+
else:
49+
raise ValueError(f"The operator must have a 'name' attribute. Got {op}")
4750
self._registered_ops[key] = item
4851
return arg_checker
4952

@@ -55,7 +58,5 @@ def __getitem__(self, op: str | agg_ops.WindowOp) -> CompilationFunc:
5558
raise ValueError(f"The operator must have a 'name' attribute. Got {op}")
5659
else:
5760
key = typing.cast(str, op.name)
58-
if key in self._registered_ops:
59-
return self._registered_ops[key]
60-
return self._registered_ops[type(op)]
61+
return self._registered_ops[key]
6162
return self._registered_ops[op]

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _(
116116
) -> sge.Expression:
117117
# TODO: Support interpolation argument
118118
# TODO: Support percentile_disc
119-
result = sge.func("PERCENTILE_CONT", column.expr, sge.convert(op.q))
119+
result: sge.Expression = sge.func("PERCENTILE_CONT", column.expr, sge.convert(op.q))
120120
if window is None:
121121
# PERCENTILE_CONT is a navigation function, not an aggregate function, so it always needs an OVER clause.
122122
result = sge.Window(this=result)

bigframes/operations/aggregations.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,9 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
239239

240240
@dataclasses.dataclass(frozen=True)
241241
class ApproxQuartilesOp(UnaryAggregateOp):
242+
name: typing.ClassVar[str] = "approx_quantile"
242243
quartile: int
243244

244-
@property
245-
def name(self):
246-
return f"{self.quartile * 25}%"
247-
248245
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
249246
if not dtypes.is_orderable(input_types[0]):
250247
raise TypeError(f"Type {input_types[0]} is not orderable")

0 commit comments

Comments
 (0)