Skip to content

Commit 7cd7371

Browse files
committed
unit test for local executor
1 parent 63661be commit 7cd7371

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

tests/unit/session/test_local_scan_executor.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,92 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import pyarrow
17+
import pytest
18+
19+
from bigframes import dtypes
20+
from bigframes.core import identifiers, local_data, nodes
21+
from bigframes.session import local_scan_executor
22+
from bigframes.testing import mocks
23+
24+
25+
@pytest.fixture
26+
def object_under_test():
27+
return local_scan_executor.LocalScanExecutor()
28+
29+
30+
def create_read_local_node(arrow_table: pyarrow.Table):
31+
session = mocks.create_bigquery_session()
32+
local_data_source = local_data.ManagedArrowTable.from_pyarrow(arrow_table)
33+
return nodes.ReadLocalNode(
34+
local_data_source=local_data_source,
35+
session=session,
36+
scan_list=nodes.ScanList(
37+
items=tuple(
38+
nodes.ScanItem(
39+
id=identifiers.ColumnId(column_name),
40+
dtype=dtypes.arrow_dtype_to_bigframes_dtype(
41+
arrow_table.field(column_name).type
42+
),
43+
source_id=column_name,
44+
)
45+
for column_name in arrow_table.column_names
46+
),
47+
),
48+
)
49+
50+
51+
@pytest.mark.parametrize(
52+
("start", "stop", "expected_rows"),
53+
(
54+
(None, None, 10),
55+
(0, None, 10),
56+
(4, None, 6),
57+
(None, 10, 10),
58+
(None, 7, 7),
59+
(1, 9, 8),
60+
),
61+
)
62+
def test_local_scan_executor_with_slice(start, stop, expected_rows, object_under_test):
63+
pyarrow_table = pyarrow.Table.from_pydict(
64+
{
65+
"rowindex": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
66+
"letters": ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"],
67+
}
68+
)
69+
assert pyarrow_table.num_rows == 10
70+
71+
local_node = create_read_local_node(pyarrow_table)
72+
plan = nodes.SliceNode(
73+
child=local_node,
74+
start=start,
75+
stop=stop,
76+
)
77+
78+
result = object_under_test.execute(plan, ordered=True)
79+
result_table = pyarrow.Table.from_batches(result.arrow_batches)
80+
assert result_table.num_rows == expected_rows
81+
82+
83+
@pytest.mark.parametrize(
84+
("start", "stop", "step"),
85+
(
86+
(-1, None, 1),
87+
(None, -1, 1),
88+
(None, None, 2),
89+
(None, None, -1),
90+
),
91+
)
92+
def test_local_scan_executor_with_slice_unsupported_inputs(
93+
start, stop, step, object_under_test
94+
):
95+
local_node = create_read_local_node(pyarrow.Table.from_pydict({"col": [1, 2, 3]}))
96+
plan = nodes.SliceNode(
97+
child=local_node,
98+
start=start,
99+
stop=stop,
100+
step=step,
101+
)
102+
assert object_under_test.execute(plan, ordered=True) is None

0 commit comments

Comments
 (0)