Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions xrspatial/tests/test_zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,56 @@ def test_nodata_values_crosstab_3d(
assert_input_data_unmodified(data_values_3d, copied_data_values_3d)


@pytest.mark.skipif(not dask_array_available(), reason="Requires Dask")
def test_crosstab_dask_from_dataset():
"""
Test crosstab with dask arrays originating from xarray Datasets.

This is a regression test for issue #777 where dask arrays created via
Dataset.to_array().sel() had misaligned chunks that caused IndexError.
"""
# Simulate what happens with rioxarray band_as_variable=True
data_band1 = np.array([[0, 0, 1, 1, 2, 2, 3, 3],
[0, 0, 1, 1, 2, 2, 3, 3],
[0, 0, 1, 1, 2, 2, 3, 3]], dtype=float)
data_band2 = np.array([[1, 1, 2, 2, 3, 3, 0, 0],
[1, 1, 2, 2, 3, 3, 0, 0],
[1, 1, 2, 2, 3, 3, 0, 0]], dtype=float)

# Use different chunk sizes to simulate real-world scenario
dask_band1 = da.from_array(data_band1, chunks=(2, 3))
dask_band2 = da.from_array(data_band2, chunks=(2, 3))

ds = xr.Dataset({
'band_1': (['y', 'x'], dask_band1),
'band_2': (['y', 'x'], dask_band2),
})

# This is the pattern from issue #777: to_array().sel(variable='band_1', drop=True)
values = ds.to_array().sel(variable='band_1', drop=True)

# Create zones with different chunks
zones_data = np.array([[0, 0, 1, 1, 2, 2, 3, 3],
[0, 0, 1, 1, 2, 2, 3, 3],
[0, 0, 1, 1, 2, 2, 3, 3]], dtype=float)
zones_dask = da.from_array(zones_data, chunks=(3, 4))
zones = xr.DataArray(zones_dask, dims=['y', 'x'])

# This should not raise an error
result = crosstab(zones, values)
assert isinstance(result, dd.DataFrame)

result_df = result.compute()
expected = {
'zone': [0.0, 1.0, 2.0, 3.0],
0.0: [6, 0, 0, 0],
1.0: [0, 6, 0, 0],
2.0: [0, 0, 6, 0],
3.0: [0, 0, 0, 6],
}
check_results('dask+numpy', result, expected)


def test_apply():

def func(x):
Expand Down
6 changes: 6 additions & 0 deletions xrspatial/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,12 @@ def crosstab(
if values.ndim not in [2, 3]:
raise ValueError("`values` must use either 2D or 3D coordinates.")

# For 2D values, validate and align chunks between zones and values
# This is critical for dask arrays that may come from different sources
# (e.g., xarray Datasets via to_array().sel())
if values.ndim == 2:
validate_arrays(zones, values)

agg_2d = ["percentage", "count"]
agg_3d_numpy = _DEFAULT_STATS.keys()
agg_3d_dask = ["count"]
Expand Down