|
9 | 9 | import einops |
10 | 10 | import xarray as xr |
11 | 11 |
|
12 | | -__all__ = ["rearrange", "raw_rearrange", "reduce", "raw_reduce"] |
| 12 | +__all__ = ["rearrange", "raw_rearrange", "reduce", "raw_reduce", "DaskBackend"] |
13 | 13 |
|
14 | 14 |
|
15 | 15 | class DimHandler: |
@@ -223,8 +223,15 @@ def rearrange(da, out_dims, in_dims=None, **kwargs): |
223 | 223 | missing_in_dims = [dim for dim in da_dims if dim not in in_names] |
224 | 224 | expected_missing = set(out_dims).union(in_names).difference(in_dims) |
225 | 225 | missing_out_dims = [dim for dim in da_dims if dim not in expected_missing] |
226 | | - pattern = f"{handler.get_names(missing_in_dims)} {in_pattern} ->\ |
227 | | - {handler.get_names(missing_out_dims)} {out_pattern}" |
| 226 | + |
| 227 | + # avoid using dimensions as core dims unnecesarly |
| 228 | + non_core_dims = [dim for dim in missing_in_dims if dim in missing_out_dims] |
| 229 | + missing_in_dims = [dim for dim in missing_in_dims if dim not in non_core_dims] |
| 230 | + missing_out_dims = [dim for dim in missing_out_dims if dim not in non_core_dims] |
| 231 | + |
| 232 | + non_core_pattern = handler.get_names(non_core_dims) |
| 233 | + pattern = f"{non_core_pattern} {handler.get_names(missing_in_dims)} {in_pattern} ->\ |
| 234 | + {non_core_pattern} {handler.get_names(missing_out_dims)} {out_pattern}" |
228 | 235 |
|
229 | 236 | axes_lengths = { |
230 | 237 | handler.rename_kwarg(k): v for k, v in kwargs.items() if k in out_names + out_dims |
@@ -395,3 +402,58 @@ def raw_reduce(da, pattern, reduction, **kwargs): |
395 | 402 | in_dims = None |
396 | 403 | out_dims = translate_pattern(out_pattern) |
397 | 404 | return reduce(da, reduction, out_dims=out_dims, in_dims=in_dims, **kwargs) |
| 405 | + |
| 406 | + |
| 407 | +class DaskBackend(einops._backends.AbstractBackend): # pylint: disable=protected-access |
| 408 | + """Dask backend class for einops. |
| 409 | +
|
| 410 | + It should be imported before using functions of :mod:`xarray_einstats.einops` |
| 411 | + on Dask backed DataArrays. |
| 412 | + It doesn't need to be initialized or used explicitly |
| 413 | +
|
| 414 | + Notes |
| 415 | + ----- |
| 416 | + Class created from the advise on |
| 417 | + `issue einops#120 <https://github.com/arogozhnikov/einops/issues/120>`_ about Dask support. |
| 418 | + And from reading |
| 419 | + `einops/_backends <https://github.com/arogozhnikov/einops/blob/master/einops/_backends.py>`_, |
| 420 | + the source of the AbstractBackend class of which DaskBackend is a subclass. |
| 421 | + """ |
| 422 | + |
| 423 | + # pylint: disable=no-self-use |
| 424 | + framework_name = "dask" |
| 425 | + |
| 426 | + def __init__(self): |
| 427 | + """Initialize DaskBackend. |
| 428 | +
|
| 429 | + Contains the imports to avoid errors when dask is not installed |
| 430 | + """ |
| 431 | + import dask.array as dsar |
| 432 | + |
| 433 | + self.dsar = dsar |
| 434 | + |
| 435 | + def is_appropriate_type(self, tensor): |
| 436 | + """Recognizes tensors it can handle.""" |
| 437 | + return isinstance(tensor, self.dsar.core.Array) |
| 438 | + |
| 439 | + def from_numpy(self, x): # noqa: D102 |
| 440 | + return self.dsar.array(x) |
| 441 | + |
| 442 | + def to_numpy(self, x): # noqa: D102 |
| 443 | + return x.compute() |
| 444 | + |
| 445 | + def arange(self, start, stop): # noqa: D102 |
| 446 | + # supplementary method used only in testing, so should implement CPU version |
| 447 | + return self.dsar.arange(start, stop) |
| 448 | + |
| 449 | + def stack_on_zeroth_dimension(self, tensors: list): # noqa: D102 |
| 450 | + return self.dsar.stack(tensors) |
| 451 | + |
| 452 | + def tile(self, x, repeats): # noqa: D102 |
| 453 | + return self.dsar.tile(x, repeats) |
| 454 | + |
| 455 | + def is_float_type(self, x): # noqa: D102 |
| 456 | + return x.dtype in ("float16", "float32", "float64", "float128") |
| 457 | + |
| 458 | + def add_axis(self, x, new_position): # noqa: D102 |
| 459 | + return self.dsar.expand_dims(x, new_position) |
0 commit comments