Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions xrspatial/proximity.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,16 +420,22 @@ def _process(

target_values = np.asarray(target_values)

# x-y coordinates of each pixel.
# flatten the coords of input raster and reshape to 2d
xs = np.tile(raster[x].data, raster.shape[0]).reshape(raster.shape)
ys = np.repeat(raster[y].data, raster.shape[1]).reshape(raster.shape)

if max_distance is None:
max_distance = np.inf

# Get 1D coordinate arrays (these are small, just the axis coordinates)
x_coords = raster[x].data
y_coords = raster[y].data

# Ensure 1D coords are numpy arrays for max_possible_distance calculation
if da is not None and isinstance(x_coords, da.Array):
x_coords = x_coords.compute()
if da is not None and isinstance(y_coords, da.Array):
y_coords = y_coords.compute()

# Compute max_possible_distance using coordinate endpoints directly
max_possible_distance = _distance(
xs[0][0], xs[-1][-1], ys[0][0], ys[-1][-1], distance_metric
x_coords[0], x_coords[-1], y_coords[0], y_coords[-1], distance_metric
)

@ngjit
Expand Down Expand Up @@ -620,13 +626,21 @@ def _process_dask(raster, xs, ys):
return out

if isinstance(raster.data, np.ndarray):
# numpy case
# numpy case - create full coordinate arrays as numpy
xs = np.tile(x_coords, raster.shape[0]).reshape(raster.shape)
ys = np.repeat(y_coords, raster.shape[1]).reshape(raster.shape)
result = _process_numpy(raster.data, xs, ys)

elif da is not None and isinstance(raster.data, da.Array):
# dask + numpy case
xs = da.from_array(xs, chunks=(raster.chunks))
ys = da.from_array(ys, chunks=(raster.chunks))
# dask case - create coordinate arrays as dask arrays directly
# This avoids materializing the full arrays in memory
# Convert 1D coords to dask arrays first
x_coords_da = da.from_array(x_coords, chunks=x_coords.shape[0])
y_coords_da = da.from_array(y_coords, chunks=y_coords.shape[0])
xs = da.tile(x_coords_da, (raster.shape[0], 1))
ys = da.repeat(y_coords_da, raster.shape[1]).reshape(raster.shape)
xs = xs.rechunk(raster.chunks)
ys = ys.rechunk(raster.chunks)
result = _process_dask(raster, xs, ys)

return result
Expand Down
69 changes: 69 additions & 0 deletions xrspatial/tests/test_proximity.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,72 @@ def test_proximity_distance_against_qgis(raster, qgis_proximity_distance_target_

general_output_checks(input_raster, xrspatial_result)
np.testing.assert_allclose(xrspatial_result.data, qgis_result.data, rtol=1e-05, equal_nan=True)


@pytest.mark.skipif(da is None, reason="dask is not installed")
def test_proximity_dask_coord_arrays_are_lazy():
"""
Test that coordinate arrays (xs, ys) are created as dask arrays
when input is a dask array, avoiding memory issues with large rasters.

This is a regression test for the issue where xs and ys were created
as numpy arrays before checking if the input was a dask array,
causing memory issues for large datasets.
"""
from unittest.mock import patch

height, width = 100, 120
data = np.zeros((height, width), dtype=np.float64)
# Add some target pixels
data[10, 10] = 1.0
data[50, 60] = 2.0
data[90, 100] = 3.0

_lon = np.linspace(-180, 180, width)
_lat = np.linspace(90, -90, height)
raster = xr.DataArray(data, dims=['lat', 'lon'])
raster['lon'] = _lon
raster['lat'] = _lat
# Create dask-backed array with chunks
raster.data = da.from_array(data, chunks=(25, 30))

# Track calls to np.tile and np.repeat with the full raster shape
original_tile = np.tile
original_repeat = np.repeat
large_numpy_array_created = []

def tracking_tile(A, reps):
result = original_tile(A, reps)
# Check if result would be the size of the full coordinate grid
if result.size >= height * width:
large_numpy_array_created.append(('tile', result.shape))
return result

def tracking_repeat(a, repeats, axis=None):
result = original_repeat(a, repeats, axis=axis)
# Check if result would be the size of the full coordinate grid
if result.size >= height * width:
large_numpy_array_created.append(('repeat', result.shape))
return result

with patch.object(np, 'tile', tracking_tile):
with patch.object(np, 'repeat', tracking_repeat):
result = proximity(raster, x='lon', y='lat')

# Verify no large numpy coordinate arrays were created
assert len(large_numpy_array_created) == 0, (
f"Large numpy arrays were created for coordinates: {large_numpy_array_created}. "
"For dask inputs, coordinate arrays should be created using dask operations."
)

# Verify result is a dask array
assert isinstance(result.data, da.Array), "Result should be a dask array"

# Verify correctness by computing and checking a few values
computed = result.compute()
# Check that target pixels have distance 0
assert computed.data[10, 10] == 0.0
assert computed.data[50, 60] == 0.0
assert computed.data[90, 100] == 0.0
# Check that non-target pixels have positive distance
assert computed.data[0, 0] > 0.0