Skip to content

Commit d32d8d2

Browse files
committed
add unit tests
1 parent 7e7b0cd commit d32d8d2

File tree

1 file changed

+189
-0
lines changed

1 file changed

+189
-0
lines changed
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest.mock as mock
16+
17+
import pytest
18+
import sqlglot.expressions as sge
19+
20+
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
21+
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
22+
import bigframes.operations as ops
23+
24+
25+
def test_register_unary_op():
26+
compiler = scalar_compiler.ScalarOpCompiler()
27+
28+
class MockUnaryOp(ops.UnaryOp):
29+
name = "mock_unary_op"
30+
31+
mock_op = MockUnaryOp()
32+
mock_impl = mock.Mock()
33+
34+
@compiler.register_unary_op(mock_op)
35+
def _(expr: TypedExpr) -> sge.Expression:
36+
mock_impl(expr)
37+
return sge.Identifier(this="output")
38+
39+
arg = TypedExpr(sge.Identifier(this="input"), "string")
40+
result = compiler.compile_row_op(mock_op, [arg])
41+
assert result == sge.Identifier(this="output")
42+
mock_impl.assert_called_once_with(arg)
43+
44+
45+
def test_register_unary_op_pass_op():
46+
compiler = scalar_compiler.ScalarOpCompiler()
47+
48+
class MockUnaryOp(ops.UnaryOp):
49+
name = "mock_unary_op_pass_op"
50+
51+
mock_op = MockUnaryOp()
52+
mock_impl = mock.Mock()
53+
54+
@compiler.register_unary_op(mock_op, pass_op=True)
55+
def _(expr: TypedExpr, op: ops.UnaryOp) -> sge.Expression:
56+
mock_impl(expr, op)
57+
return sge.Identifier(this="output")
58+
59+
arg = TypedExpr(sge.Identifier(this="input"), "string")
60+
result = compiler.compile_row_op(mock_op, [arg])
61+
assert result == sge.Identifier(this="output")
62+
mock_impl.assert_called_once_with(arg, mock_op)
63+
64+
65+
def test_register_binary_op():
66+
compiler = scalar_compiler.ScalarOpCompiler()
67+
68+
class MockBinaryOp(ops.BinaryOp):
69+
name = "mock_binary_op"
70+
71+
mock_op = MockBinaryOp()
72+
mock_impl = mock.Mock()
73+
74+
@compiler.register_binary_op(mock_op)
75+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
76+
mock_impl(left, right)
77+
return sge.Identifier(this="output")
78+
79+
arg1 = TypedExpr(sge.Identifier(this="input1"), "string")
80+
arg2 = TypedExpr(sge.Identifier(this="input2"), "string")
81+
result = compiler.compile_row_op(mock_op, [arg1, arg2])
82+
assert result == sge.Identifier(this="output")
83+
mock_impl.assert_called_once_with(arg1, arg2)
84+
85+
86+
def test_register_binary_op_pass_on():
87+
compiler = scalar_compiler.ScalarOpCompiler()
88+
89+
class MockBinaryOp(ops.BinaryOp):
90+
name = "mock_binary_op_pass_op"
91+
92+
mock_op = MockBinaryOp()
93+
mock_impl = mock.Mock()
94+
95+
@compiler.register_binary_op(mock_op, pass_op=True)
96+
def _(left: TypedExpr, right: TypedExpr, op: ops.BinaryOp) -> sge.Expression:
97+
mock_impl(left, right, op)
98+
return sge.Identifier(this="output")
99+
100+
arg1 = TypedExpr(sge.Identifier(this="input1"), "string")
101+
arg2 = TypedExpr(sge.Identifier(this="input2"), "string")
102+
result = compiler.compile_row_op(mock_op, [arg1, arg2])
103+
assert result == sge.Identifier(this="output")
104+
mock_impl.assert_called_once_with(arg1, arg2, mock_op)
105+
106+
107+
def test_register_ternary_op():
108+
compiler = scalar_compiler.ScalarOpCompiler()
109+
110+
class MockTernaryOp(ops.TernaryOp):
111+
name = "mock_ternary_op"
112+
113+
mock_op = MockTernaryOp()
114+
mock_impl = mock.Mock()
115+
116+
@compiler.register_ternary_op(mock_op)
117+
def _(arg1: TypedExpr, arg2: TypedExpr, arg3: TypedExpr) -> sge.Expression:
118+
mock_impl(arg1, arg2, arg3)
119+
return sge.Identifier(this="output")
120+
121+
arg1 = TypedExpr(sge.Identifier(this="input1"), "string")
122+
arg2 = TypedExpr(sge.Identifier(this="input2"), "string")
123+
arg3 = TypedExpr(sge.Identifier(this="input3"), "string")
124+
result = compiler.compile_row_op(mock_op, [arg1, arg2, arg3])
125+
assert result == sge.Identifier(this="output")
126+
mock_impl.assert_called_once_with(arg1, arg2, arg3)
127+
128+
129+
def test_register_nary_op():
130+
compiler = scalar_compiler.ScalarOpCompiler()
131+
132+
class MockNaryOp(ops.NaryOp):
133+
name = "mock_nary_op"
134+
135+
mock_op = MockNaryOp()
136+
mock_impl = mock.Mock()
137+
138+
@compiler.register_nary_op(mock_op)
139+
def _(*args: TypedExpr) -> sge.Expression:
140+
mock_impl(*args)
141+
return sge.Identifier(this="output")
142+
143+
arg1 = TypedExpr(sge.Identifier(this="input1"), "string")
144+
arg2 = TypedExpr(sge.Identifier(this="input2"), "string")
145+
result = compiler.compile_row_op(mock_op, [arg1, arg2])
146+
assert result == sge.Identifier(this="output")
147+
mock_impl.assert_called_once_with(arg1, arg2)
148+
149+
150+
def test_register_nary_op_pass_on():
151+
compiler = scalar_compiler.ScalarOpCompiler()
152+
153+
class MockNaryOp(ops.NaryOp):
154+
name = "mock_nary_op_pass_op"
155+
156+
mock_op = MockNaryOp()
157+
mock_impl = mock.Mock()
158+
159+
@compiler.register_nary_op(mock_op, pass_op=True)
160+
def _(*args: TypedExpr, op: ops.NaryOp) -> sge.Expression:
161+
mock_impl(*args, op=op)
162+
return sge.Identifier(this="output")
163+
164+
arg1 = TypedExpr(sge.Identifier(this="input1"), "string")
165+
arg2 = TypedExpr(sge.Identifier(this="input2"), "string")
166+
arg3 = TypedExpr(sge.Identifier(this="input3"), "string")
167+
arg4 = TypedExpr(sge.Identifier(this="input4"), "string")
168+
result = compiler.compile_row_op(mock_op, [arg1, arg2, arg3, arg4])
169+
assert result == sge.Identifier(this="output")
170+
mock_impl.assert_called_once_with(arg1, arg2, arg3, arg4, op=mock_op)
171+
172+
173+
def test_register_duplicate_op_raises():
174+
compiler = scalar_compiler.ScalarOpCompiler()
175+
176+
class MockUnaryOp(ops.UnaryOp):
177+
name = "mock_unary_op_duplicate"
178+
179+
mock_op = MockUnaryOp()
180+
181+
@compiler.register_unary_op(mock_op)
182+
def _(expr: TypedExpr) -> sge.Expression:
183+
return sge.Identifier(this="output")
184+
185+
with pytest.raises(ValueError):
186+
187+
@compiler.register_unary_op(mock_op)
188+
def _(expr: TypedExpr) -> sge.Expression:
189+
return sge.Identifier(this="output2")

0 commit comments

Comments
 (0)