Skip to content

Commit f5210f9

Browse files
committed
fix: managed arrow table iterates None list and struct incorrectly
1 parent a0e1e50 commit f5210f9

File tree

4 files changed

+53
-35
lines changed

4 files changed

+53
-35
lines changed

bigframes/core/local_data.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,16 @@ def _(
253253
value_generator = iter_array(
254254
array.flatten(), bigframes.dtypes.get_array_inner_type(dtype)
255255
)
256-
for (start, end) in _pairwise(array.offsets):
257-
arr_size = end.as_py() - start.as_py()
258-
yield list(itertools.islice(value_generator, arr_size))
256+
offset_generator = iter_array(array.offsets, bigframes.dtypes.INT_DTYPE)
257+
is_null_generator = iter_array(array.is_null(), bigframes.dtypes.BOOL_DTYPE)
258+
previous_offset = next(offset_generator)
259+
for is_null, offset in zip(is_null_generator, offset_generator):
260+
arr_size = offset - previous_offset
261+
previous_offset = offset
262+
if is_null:
263+
yield None
264+
else:
265+
yield list(itertools.islice(value_generator, arr_size))
259266

260267
@iter_array.register
261268
def _(
@@ -267,8 +274,14 @@ def _(
267274
sub_generators[field_name] = iter_array(array.field(field_name), dtype)
268275

269276
keys = list(sub_generators.keys())
270-
for row_values in zip(*sub_generators.values()):
271-
yield {key: value for key, value in zip(keys, row_values)}
277+
row_values_iter = zip(*sub_generators.values())
278+
is_null_iter = array.is_null()
279+
280+
for is_row_null, row_values in zip(is_null_iter, row_values_iter):
281+
if not is_row_null:
282+
yield {key: value for key, value in zip(keys, row_values)}
283+
else:
284+
yield None
272285

273286
for batch in table.to_batches():
274287
sub_generators: dict[str, Generator[Any, None, None]] = {}
@@ -354,7 +367,7 @@ def _adapt_arrow_array(array: pa.Array) -> tuple[pa.Array, bigframes.dtypes.Dtyp
354367
new_value = pa.ListArray.from_arrays(
355368
array.offsets, values, mask=array.is_null()
356369
)
357-
return new_value.fill_null([]), bigframes.dtypes.list_type(values_type)
370+
return new_value, bigframes.dtypes.list_type(values_type)
358371
if array.type == bigframes.dtypes.JSON_ARROW_TYPE:
359372
return _canonicalize_json(array), bigframes.dtypes.JSON_DTYPE
360373
target_type = logical_type_replacements(array.type)

tests/system/small/engines/test_read_local.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,9 @@ def test_engines_read_local_w_zero_row_source(
8888
assert_equivalence_execution(local_node, REFERENCE_ENGINE, engine)
8989

9090

91-
# TODO: Fix sqlglot impl
92-
@pytest.mark.parametrize("engine", ["polars", "bq", "pyarrow"], indirect=True)
91+
@pytest.mark.parametrize(
92+
"engine", ["polars", "bq", "pyarrow", "bq-sqlglot"], indirect=True
93+
)
9394
def test_engines_read_local_w_nested_source(
9495
fake_session: bigframes.Session,
9596
nested_data_source: local_data.ManagedArrowTable,

tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_nested_structs_df/out.sql

Lines changed: 0 additions & 19 deletions
This file was deleted.

tests/unit/test_local_data.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,21 @@
2020

2121
pd_data = pd.DataFrame(
2222
{
23-
"ints": [10, 20, 30, 40],
24-
"nested_ints": [[1, 2], [3, 4, 5], [], [20, 30]],
25-
"structs": [{"a": 100}, {}, {"b": 200}, {"b": 300}],
23+
"ints": [10, 20, 30, 40, 50],
24+
"nested_ints": [[1, 2], None, [3, 4, 5], [], [20, 30]],
25+
"structs": [{"a": 100}, None, {}, {"b": 200}, {"b": 300}],
2626
}
2727
)
2828

2929
pd_data_normalized = pd.DataFrame(
3030
{
31-
"ints": pd.Series([10, 20, 30, 40], dtype=dtypes.INT_DTYPE),
31+
"ints": pd.Series([10, 20, 30, 40, 50], dtype=dtypes.INT_DTYPE),
3232
"nested_ints": pd.Series(
33-
[[1, 2], [3, 4, 5], [], [20, 30]], dtype=pd.ArrowDtype(pa.list_(pa.int64()))
33+
[[1, 2], None, [3, 4, 5], [], [20, 30]],
34+
dtype=pd.ArrowDtype(pa.list_(pa.int64())),
3435
),
3536
"structs": pd.Series(
36-
[{"a": 100}, {}, {"b": 200}, {"b": 300}],
37+
[{"a": 100}, None, {}, {"b": 200}, {"b": 300}],
3738
dtype=pd.ArrowDtype(pa.struct({"a": pa.int64(), "b": pa.int64()})),
3839
),
3940
}
@@ -122,11 +123,11 @@ def test_local_data_well_formed_round_trip_chunked():
122123

123124
def test_local_data_well_formed_round_trip_sliced():
124125
pa_table = pa.Table.from_pandas(pd_data, preserve_index=False)
125-
as_rechunked_pyarrow = pa.Table.from_batches(pa_table.slice(2, 4).to_batches())
126+
as_rechunked_pyarrow = pa.Table.from_batches(pa_table.slice(0, 4).to_batches())
126127
local_entry = local_data.ManagedArrowTable.from_pyarrow(as_rechunked_pyarrow)
127128
result = pd.DataFrame(local_entry.itertuples(), columns=pd_data.columns)
128129
pandas.testing.assert_frame_equal(
129-
pd_data_normalized[2:4].reset_index(drop=True),
130+
pd_data_normalized[0:4].reset_index(drop=True),
130131
result.reset_index(drop=True),
131132
check_dtype=False,
132133
)
@@ -143,3 +144,25 @@ def test_local_data_not_equal_other():
143144
local_entry2 = local_data.ManagedArrowTable.from_pandas(pd_data[::2])
144145
assert local_entry != local_entry2
145146
assert hash(local_entry) != hash(local_entry2)
147+
148+
149+
def test_local_data_itertuples_struct_none():
150+
pd_data = pd.DataFrame(
151+
{
152+
"structs": [{"a": 100}, None, {"b": 200}, {"b": 300}],
153+
}
154+
)
155+
local_entry = local_data.ManagedArrowTable.from_pandas(pd_data)
156+
result = list(local_entry.itertuples())
157+
assert result[1][0] is None
158+
159+
160+
def test_local_data_itertuples_list_none():
161+
pd_data = pd.DataFrame(
162+
{
163+
"lists": [[1, 2], None, [3, 4]],
164+
}
165+
)
166+
local_entry = local_data.ManagedArrowTable.from_pandas(pd_data)
167+
result = list(local_entry.itertuples())
168+
assert result[1][0] is None

0 commit comments

Comments
 (0)