Skip to content

Commit 02af3ad

Browse files
committed
fix unit tests
1 parent 90c89a1 commit 02af3ad

File tree

2 files changed

+6
-12
lines changed

2 files changed

+6
-12
lines changed

bigframes/core/compile/sqlglot/aggregations/op_registration.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,16 @@ def arg_checker(*args, **kwargs):
4141
)
4242
return item(*args, **kwargs)
4343

44-
key = op if isinstance(op, type) else type(op)
44+
key = str(op)
4545
if key in self._registered_ops:
4646
raise ValueError(f"{key} is already registered")
47-
self._registered_ops[str(key)] = item
47+
self._registered_ops[key] = item
4848
return arg_checker
4949

5050
return decorator
5151

5252
def __getitem__(self, op: str | agg_ops.WindowOp) -> CompilationFunc:
53-
if isinstance(op, agg_ops.WindowOp):
54-
if not hasattr(op, "name"):
55-
raise ValueError(f"The operator must have a 'name' attribute. Got {op}")
56-
else:
57-
key = typing.cast(str, op.name)
58-
if key in self._registered_ops:
59-
return self._registered_ops[key]
60-
return self._registered_ops[str(type(op))]
61-
return self._registered_ops[op]
53+
key = op if isinstance(op, type) else type(op)
54+
if str(key) not in self._registered_ops:
55+
raise ValueError(f"{key} is already not registered")
56+
return self._registered_ops[str(key)]

tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def test_func(op: agg_ops.SizeOp, input: sge.Expression) -> sge.Expression:
2929
return input
3030

3131
assert reg[agg_ops.SizeOp()](op, input) == test_func(op, input)
32-
assert reg[agg_ops.SizeOp.name](op, input) == test_func(op, input)
3332

3433

3534
def test_register_function_first_argument_is_not_agg_op_raise_error():

0 commit comments

Comments
 (0)