Skip to content

Commit d91642d

Browse files
authored
add support for hashable dims (#56)
1 parent f51ac73 commit d91642d

File tree

6 files changed

+80
-62
lines changed

6 files changed

+80
-62
lines changed

docs/source/background/einops.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ of the elements respectively: `->`, space as delimiter and parenthesis:
1313
side in the einops notation is only used to label the dimensions.
1414
In fact, 5/7 examples in https://einops.rocks/api/rearrange/ fall in this category.
1515
This is not necessary when working with xarray objects.
16-
* In xarray dimension names can be any {term}`hashable <xarray:name>`. `xarray-einstats` only
17-
supports strings as dimension names, but the space can't be used as delimiter.
16+
* In xarray dimension names can be any {term}`hashable <xarray:name>`.
1817
* In xarray dimensions are labeled and the order doesn't matter.
1918
This might seem the same as the first reason but it is not. When splitting
2019
or stacking dimensions you need (and want) the names of both parent and children dimensions.
@@ -25,8 +24,8 @@ of the elements respectively: `->`, space as delimiter and parenthesis:
2524

2625
However, there are also many cases in which dimension names in xarray will be strings
2726
without any spaces nor parenthesis in them. So similarly to the option of
28-
doing `da.stack(dim=("dim1", "dim2"))` which can't be used for all valid
29-
dimension names but is generally easier to write and less error prone,
27+
doing `da.stack(dim=["dim1", "dim2"])` which can't be used for all valid
28+
dimension names but is generally easier to write and less error prone than,
3029
`xarray_einstats.einops` also provides two possible syntaxes.
3130

3231
The guiding principle of the einops module is to take the input expressions
@@ -37,7 +36,7 @@ labeled, we can take advantage of that during the translation process
3736
and thus support "partial" expressions that cover only the dimensions
3837
that will be modified.
3938

40-
Another important consideration is to take into account that _in xarray_,
39+
Another important consideration is to take into account that _in xarray_
4140
dimension order should not matter, hence the constraint of using dicts
4241
on the left side. Imposing this constraint also
4342
makes our job of filling in the "partial" expressions much easier.

src/xarray_einstats/einops.py

Lines changed: 46 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
1515
"""
1616
import warnings
17+
from collections.abc import Hashable
1718

1819
import einops
1920
import xarray as xr
@@ -61,7 +62,7 @@ def process_pattern_list(redims, handler, allow_dict=True, allow_list=True):
6162
allow_dict, allow_list : bool, optional
6263
Whether or not to allow lists or dicts as elements of ``redims``.
6364
When processing ``in_dims`` for example we need the names of
64-
the variables to be decomposed so dicts are required and lists/tuples
65+
the variables to be decomposed so dicts are required and lists
6566
are not accepted.
6667
6768
Returns
@@ -85,14 +86,14 @@ def process_pattern_list(redims, handler, allow_dict=True, allow_list=True):
8586
8687
from xarray_einstats.einops import process_pattern_list, DimHandler
8788
handler = DimHandler()
88-
process_pattern_list(["a", {"b": ("c", "d")}, ("e", "f", "g")], handler)
89+
process_pattern_list(["a", {"b": ["c", "d"]}, ["e", "f", "g"]], handler)
8990
9091
"""
9192
out = []
9293
out_names = []
9394
txt = []
9495
for subitem in redims:
95-
if isinstance(subitem, str):
96+
if isinstance(subitem, Hashable):
9697
out.append(subitem)
9798
out_names.append(subitem)
9899
txt.append(handler.get_name(subitem))
@@ -103,8 +104,10 @@ def process_pattern_list(redims, handler, allow_dict=True, allow_list=True):
103104
f"found {len(subitem)}: {subitem.keys()}"
104105
)
105106
key, values = list(subitem.items())[0]
106-
if isinstance(values, str):
107-
raise ValueError("Found values of type str in a pattern dict, use xarray.rename")
107+
if isinstance(values, Hashable):
108+
raise ValueError(
109+
"Found values of hashable type in a pattern dict, use xarray.rename"
110+
)
108111
out.extend(values)
109112
out_names.append(key)
110113
txt.append(f"( {handler.get_names(values)} )")
@@ -182,7 +185,7 @@ def translate_pattern(pattern):
182185
return dims
183186

184187

185-
def _rearrange(da, out_dims, in_dims=None, **kwargs):
188+
def _rearrange(da, out_dims, in_dims=None, dim_lengths=None):
186189
"""Wrap `einops.rearrange <https://einops.rocks/api/rearrange/>`_.
187190
188191
This is the function that actually interfaces with ``einops``.
@@ -198,11 +201,14 @@ def _rearrange(da, out_dims, in_dims=None, **kwargs):
198201
See docstring of :func:`~xarray_einstats.einops.rearrange`
199202
in_dims : list of str or dict, optional
200203
See docstring of :func:`~xarray_einstats.einops.rearrange`
201-
kwargs : dict, optional
204+
dim_lengths : dict, optional
202205
kwargs with key equal to dimension names in ``out_dims``
203206
(that is, strings or dict keys) are passed to einops.rearrange
204207
the rest of keys are passed to :func:`xarray.apply_ufunc`
205208
"""
209+
if dim_lengths is None:
210+
dim_lengths = {}
211+
206212
da_dims = da.dims
207213

208214
handler = DimHandler()
@@ -231,9 +237,9 @@ def _rearrange(da, out_dims, in_dims=None, **kwargs):
231237
{non_core_pattern} {handler.get_names(missing_out_dims)} {out_pattern}"
232238

233239
axes_lengths = {
234-
handler.rename_kwarg(k): v for k, v in kwargs.items() if k in out_names + out_dims
240+
handler.rename_kwarg(k): v for k, v in dim_lengths.items() if k in out_names + out_dims
235241
}
236-
kwargs = {k: v for k, v in kwargs.items() if k not in out_names + out_dims}
242+
kwargs = {k: v for k, v in dim_lengths.items() if k not in out_names + out_dims}
237243
return xr.apply_ufunc(
238244
einops.rearrange,
239245
da,
@@ -245,7 +251,7 @@ def _rearrange(da, out_dims, in_dims=None, **kwargs):
245251
)
246252

247253

248-
def rearrange(da, pattern, pattern_in=None, **kwargs):
254+
def rearrange(da, pattern, pattern_in=None, dim_lengths=None, **dim_lengths_kwargs):
249255
"""Expose `einops.rearrange <https://einops.rocks/api/rearrange/>`_ with an xarray-like API.
250256
251257
It has two possible syntaxes which are independent and somewhat complementary.
@@ -268,12 +274,12 @@ def rearrange(da, pattern, pattern_in=None, **kwargs):
268274
a default name.
269275
270276
If `pattern` is not a string, then it must be a list where each of its elements
271-
is one of: ``str``, ``list`` (to stack those dimensions and give them an
272-
arbitrary name) or ``dict of {str: list}`` (to stack the dimensions indicated
277+
is one of: :term:`python:hashable`, ``list`` (to stack those dimensions and
278+
give them an arbitrary name) or ``dict`` (to stack the dimensions indicated
273279
as values of the dictionary and name the resulting dimensions with the key).
274280
275-
`pattern` is then interpreted as the output side of the einops pattern. See
276-
TODO for more details.
281+
`pattern` is then interpreted as the output side of the einops pattern.
282+
See :ref:`about_einops` for more details.
277283
pattern_in : list of [str or dict], optional
278284
The input pattern for the dimensions. It can only be provided if `pattern`
279285
is a ``list``. Also, note this is only necessary if you want to split some dimensions.
@@ -282,28 +288,22 @@ def rearrange(da, pattern, pattern_in=None, **kwargs):
282288
with the only difference that ``list`` elements are not allowed, the same way
283289
that ``(dim1 dim2)=dim`` is required on the left hand side when using string
284290
patterns.
285-
kwargs : dict, optional
286-
Passed to :func:`xarray_einstats.einops.rearrange`
291+
dim_lengths, **dim_lengths_kwargs : dict, optional
292+
If the keys are dimensions present in `pattern` they will be passed to
293+
`einops.rearrange <https://einops.rocks/api/rearrange/>`_, otherwise,
294+
they are passed to :func:`xarray.apply_ufunc`.
287295
288296
Returns
289297
-------
290298
xarray.DataArray
291299
292-
Notes
293-
-----
294-
Unlike for general xarray objects, where dimension
295-
names can be :term:`hashable <xarray:name>` here
296-
dimension names are not recommended but required to be
297-
strings for both cases. Future releases however might
298-
support this when using lists as `pattern`, comment
299-
on :issue:`50` if you are interested in the feature
300-
or could help implement it.
301-
302-
303300
See Also
304301
--------
305302
xarray_einstats.einops.reduce
306303
"""
304+
if dim_lengths is None:
305+
dim_lengths = {}
306+
dim_lengths = {**dim_lengths, **dim_lengths_kwargs}
307307
if isinstance(pattern, str):
308308
if "->" in pattern:
309309
in_pattern, out_pattern = pattern.split("->")
@@ -312,11 +312,11 @@ def rearrange(da, pattern, pattern_in=None, **kwargs):
312312
out_pattern = pattern
313313
in_dims = None
314314
out_dims = translate_pattern(out_pattern)
315-
return _rearrange(da, out_dims=out_dims, in_dims=in_dims, **kwargs)
316-
return _rearrange(da, out_dims=pattern, in_dims=pattern_in, **kwargs)
315+
return _rearrange(da, out_dims=out_dims, in_dims=in_dims, dim_lengths=dim_lengths)
316+
return _rearrange(da, out_dims=pattern, in_dims=pattern_in, dim_lengths=dim_lengths)
317317

318318

319-
def _reduce(da, reduction, out_dims, in_dims=None, **kwargs):
319+
def _reduce(da, reduction, out_dims, in_dims=None, dim_lengths=None):
320320
"""Wrap `einops.reduce <https://einops.rocks/api/reduce/>`_.
321321
322322
This is the function that actually interfaces with ``einops``.
@@ -338,11 +338,14 @@ def _reduce(da, reduction, out_dims, in_dims=None, **kwargs):
338338
in_dims : list of str or dict, optional
339339
The input pattern for the dimensions.
340340
This is only necessary if you want to split some dimensions.
341-
kwargs : dict, optional
341+
dim_lengths : dict, optional
342342
kwargs with key equal to dimension names in ``out_dims``
343343
(that is, strings or dict keys) are passed to einops.rearrange
344344
the rest of keys are passed to :func:`xarray.apply_ufunc`
345345
"""
346+
if dim_lengths is None:
347+
dim_lengths = {}
348+
346349
da_dims = da.dims
347350

348351
handler = DimHandler()
@@ -361,8 +364,8 @@ def _reduce(da, reduction, out_dims, in_dims=None, **kwargs):
361364
pattern = f"{handler.get_names(missing_in_dims)} {in_pattern} -> {out_pattern}"
362365

363366
all_dims = set(out_dims + out_names + in_names + in_dims)
364-
axes_lengths = {handler.rename_kwarg(k): v for k, v in kwargs.items() if k in all_dims}
365-
kwargs = {k: v for k, v in kwargs.items() if k not in all_dims}
367+
axes_lengths = {handler.rename_kwarg(k): v for k, v in dim_lengths.items() if k in all_dims}
368+
kwargs = {k: v for k, v in dim_lengths.items() if k not in all_dims}
366369
return xr.apply_ufunc(
367370
einops.reduce,
368371
da,
@@ -375,7 +378,7 @@ def _reduce(da, reduction, out_dims, in_dims=None, **kwargs):
375378
)
376379

377380

378-
def reduce(da, pattern, reduction, pattern_in=None, **kwargs):
381+
def reduce(da, pattern, reduction, pattern_in=None, dim_lengths=None, **dim_lengths_kwargs):
379382
"""Expose `einops.reduce <https://einops.rocks/api/reduce/>`_ with an xarray-like API.
380383
381384
It has two possible syntaxes which are independent and somewhat complementary.
@@ -412,27 +415,22 @@ def reduce(da, pattern, reduction, pattern_in=None, **kwargs):
412415
The syntax and interpretation is the same as the case when `pattern` is a list,
413416
with the only difference that ``list`` elements are not allowed, the same way
414417
that ``(dim1 dim2)=dim`` is required on the left hand side when using string
415-
kwargs : dict, optional
416-
Passed to :func:`xarray_einstats.einops.reduce`
418+
dim_lengths, **dim_lengths_kwargs : dict, optional
419+
If the keys are dimensions present in `pattern` they will be passed to
420+
`einops.reduce <https://einops.rocks/api/reduce/>`_, otherwise,
421+
they are passed to :func:`xarray.apply_ufunc`.
417422
418423
Returns
419424
-------
420425
xarray.DataArray
421426
422-
Notes
423-
-----
424-
Unlike for general xarray objects, where dimension
425-
names can be :term:`hashable <xarray:name>` here
426-
dimension names are not recommended but required to be
427-
strings for both cases. Future releases however might
428-
support this when using lists as `pattern`, comment
429-
on :issue:`50` if you are interested in the feature
430-
or could help implement it.
431-
432427
See Also
433428
--------
434429
xarray_einstats.einops.rearrange
435430
"""
431+
if dim_lengths is None:
432+
dim_lengths = {}
433+
dim_lengths = {**dim_lengths, **dim_lengths_kwargs}
436434
if isinstance(pattern, str):
437435
if "->" in pattern:
438436
in_pattern, out_pattern = pattern.split("->")
@@ -441,8 +439,8 @@ def reduce(da, pattern, reduction, pattern_in=None, **kwargs):
441439
out_pattern = pattern
442440
in_dims = None
443441
out_dims = translate_pattern(out_pattern)
444-
return _reduce(da, reduction, out_dims=out_dims, in_dims=in_dims, **kwargs)
445-
return _reduce(da, reduction, out_dims=pattern, in_dims=pattern_in, **kwargs)
442+
return _reduce(da, reduction, out_dims=out_dims, in_dims=in_dims, dim_lengths=dim_lengths)
443+
return _reduce(da, reduction, out_dims=pattern, in_dims=pattern_in, dim_lengths=dim_lengths)
446444

447445

448446
def raw_reduce(*args, **kwargs):

src/xarray_einstats/numba.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def ecdf(da, dims=None, *, npoints=None, **kwargs):
268268
dims = da.dims
269269
elif isinstance(dims, str):
270270
dims = [dims]
271-
total_points = np.product([da.sizes[d] for d in dims])
271+
total_points = np.prod([da.sizes[d] for d in dims])
272272
if npoints is None:
273273
npoints = min(total_points, 200)
274274
x = xr.DataArray(np.linspace(0, 1, npoints), dims=["quantile"])

tests/test_accessors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_einops_accessor_rearrange(data):
8888

8989
@pytest.mark.skipif(find_spec("einops") is None, reason="einops must be installed")
9090
def test_einops_accessor_reduce(data):
91-
pattern_in = [{"batch (hh.mm)": ("d1", "d2")}]
91+
pattern_in = [{"batch (hh.mm)": ["d1", "d2"]}]
9292
pattern = ["d1", "subject"]
9393
kwargs = {"d2": 2}
9494
input_data = data.rename({"batch": "batch (hh.mm)"})

tests/test_einops.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ class TestRearrange:
6060
"args",
6161
(
6262
(
63-
{"pattern": [{"dex": ("drug dose (mg)", "experiment")}]},
63+
{"pattern": [{"dex": ["drug dose (mg)", "experiment"]}]},
6464
((4, 6, 8 * 15), ["batch", "subject", "dex"]),
6565
),
6666
(
6767
{
68-
"pattern_in": [{"drug dose (mg)": ("d1", "d2")}],
68+
"pattern_in": [{"drug dose (mg)": ["d1", "d2"]}],
6969
"pattern": ["d1", "d2", "batch"],
7070
"d1": 2,
7171
"d2": 4,
@@ -80,6 +80,16 @@ def test_rearrange(self, data, args):
8080
assert out_da.shape == shape
8181
assert list(out_da.dims) == dims
8282

83+
def test_rearrange_tuple_dim(self, data):
84+
out_da = rearrange(
85+
data.rename(drug=("drug dose", "mg")),
86+
pattern_in=[{("drug dose", "mg"): [("d", 1), ("d", 2)]}],
87+
pattern=[("d", 1), ("d", 2), "batch"],
88+
dim_lengths={("d", 1): 2, ("d", 2): 4},
89+
)
90+
assert out_da.shape == (6, 15, 2, 4, 4)
91+
assert list(out_da.dims) == ["subject", "experiment", ("d", 1), ("d", 2), "batch"]
92+
8393

8494
class TestRawReduce:
8595
@pytest.mark.parametrize(
@@ -110,16 +120,16 @@ class TestReduce:
110120
),
111121
(
112122
{
113-
"pattern_in": [{"batch (hh.mm)": ("d1", "d2")}],
123+
"pattern_in": [{"batch (hh.mm)": ["d1", "d2"]}],
114124
"pattern": ["d1", "subject"],
115125
"d2": 2,
116126
},
117127
((2, 6), ["d1", "subject"]),
118128
),
119129
(
120130
{
121-
"pattern_in": [{"drug": ("d1", "d2")}, {"batch (hh.mm)": ("b1", "b2")}],
122-
"pattern": ["subject", ("b1", "d1")],
131+
"pattern_in": [{"drug": ["d1", "d2"]}, {"batch (hh.mm)": ["b1", "b2"]}],
132+
"pattern": ["subject", ["b1", "d1"]],
123133
"d2": 4,
124134
"b2": 2,
125135
},
@@ -132,3 +142,14 @@ def test_reduce(self, data, args):
132142
out_da = reduce(data.rename({"batch": "batch (hh.mm)"}), reduction="mean", **kwargs)
133143
assert out_da.shape == shape
134144
assert list(out_da.dims) == dims
145+
146+
def test_reduce_tuple_dim(self, data):
147+
out_da = reduce(
148+
data.rename(drug=("drug dose", "mg")),
149+
reduction="mean",
150+
pattern_in=[{("drug dose", "mg"): [("d", 1), ("d", 2)]}],
151+
pattern=["subject", ("d", 2), "batch"],
152+
dim_lengths={("d", 1): 2, ("d", 2): 4},
153+
)
154+
assert out_da.shape == (6, 4, 4)
155+
assert list(out_da.dims) == ["subject", ("d", 2), "batch"]

tests/test_linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def test_svd(self, matrices):
246246
u_da, s_da, vh_da = svd(matrices, dims=("dim", "dim2"), out_append="_bis")
247247
s_full = xr.zeros_like(matrices)
248248
idx = xr.DataArray(np.arange(len(matrices["dim"])), dims="pointwise_sel")
249-
s_full.loc[{"dim": idx, "dim2": idx}] = s_da
249+
s_full.loc[{"dim": idx, "dim2": idx}] = s_da.rename(dim="pointwise_sel")
250250
compare = matmul(
251251
matmul(u_da, s_full, dims=[["dim", "dim_bis"], ["dim", "dim2"]]),
252252
vh_da,
@@ -259,7 +259,7 @@ def test_svd_non_square(self, matrices):
259259
s_full = xr.zeros_like(matrices)
260260
# experiment is shorter than dim
261261
idx = xr.DataArray(np.arange(len(matrices["experiment"])), dims="pointwise_sel")
262-
s_full.loc[{"experiment": idx, "dim": idx}] = s_da.transpose("batch", "experiment", "dim2")
262+
s_full.loc[{"experiment": idx, "dim": idx}] = s_da.rename(experiment="pointwise_sel")
263263
compare = matmul(
264264
matmul(u_da, s_full, dims=[["experiment", "experiment_bis"], ["experiment", "dim"]]),
265265
vh_da,

0 commit comments

Comments
 (0)