Skip to content

Commit fdf3dd7

Browse files
committed
test: add and fix unit test for block split rounding
1 parent 9148797 commit fdf3dd7

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest import mock
16+
17+
import pandas as pd
18+
19+
import bigframes
20+
import bigframes.core.blocks as blocks
21+
22+
23+
def test_block_split_rounding():
24+
# Setup a mock block with a specific shape
25+
mock_session = mock.create_autospec(spec=bigframes.Session)
26+
# Block.from_local needs a real-ish session for some things, but we can mock shape[0]
27+
28+
# Let's use a real Block with local data for simplicity if possible
29+
df = pd.DataFrame({"a": range(29757)})
30+
block = blocks.Block.from_local(df, mock_session)
31+
32+
# We need to mock the internal behavior of split or check the result sizes
33+
# Since split returns new Blocks, we can check their shapes if they are computed.
34+
# But split calls block.slice which calls block.expr.slice...
35+
36+
# Instead of full execution, let's just test the rounding logic by mocking block.shape
37+
with mock.patch.object(
38+
blocks.Block, "shape", new_callable=mock.PropertyMock
39+
) as mock_shape:
40+
mock_shape.return_value = (29757, 1)
41+
42+
# We need to mock other things that split calls to avoid full execution
43+
with mock.patch.object(blocks.Block, "create_constant") as mock_create_constant:
44+
mock_create_constant.return_value = (block, "random_col")
45+
with mock.patch.object(
46+
blocks.Block, "promote_offsets"
47+
) as mock_promote_offsets:
48+
mock_promote_offsets.return_value = (block, "offset_col")
49+
with mock.patch.object(
50+
blocks.Block, "apply_unary_op"
51+
) as mock_apply_unary_op:
52+
mock_apply_unary_op.return_value = (block, "unary_col")
53+
with mock.patch.object(
54+
blocks.Block, "apply_binary_op"
55+
) as mock_apply_binary_op:
56+
mock_apply_binary_op.return_value = (block, "binary_col")
57+
with mock.patch.object(
58+
blocks.Block, "order_by"
59+
) as mock_order_by:
60+
mock_order_by.return_value = block
61+
with mock.patch.object(blocks.Block, "slice") as mock_slice:
62+
mock_slice.return_value = block
63+
64+
# Call split
65+
block.split(fracs=(0.8, 0.2))
66+
67+
# Check calls to slice
68+
# Expected sample_sizes with round():
69+
# round(0.8 * 29757) = 23806
70+
# round(0.2 * 29757) = 5951
71+
72+
calls = mock_slice.call_args_list
73+
assert len(calls) == 2
74+
assert calls[0].kwargs["start"] == 0
75+
assert calls[0].kwargs["stop"] == 23806
76+
assert calls[1].kwargs["start"] == 23806
77+
assert calls[1].kwargs["stop"] == 23806 + 5951

0 commit comments

Comments
 (0)