1414
1515"""
1616import warnings
17+ from collections .abc import Hashable
1718
1819import einops
1920import xarray as xr
@@ -61,7 +62,7 @@ def process_pattern_list(redims, handler, allow_dict=True, allow_list=True):
6162 allow_dict, allow_list : bool, optional
6263 Whether or not to allow lists or dicts as elements of ``redims``.
6364 When processing ``in_dims`` for example we need the names of
64- the variables to be decomposed so dicts are required and lists/tuples
65+ the variables to be decomposed so dicts are required and lists
6566 are not accepted.
6667
6768 Returns
@@ -85,14 +86,14 @@ def process_pattern_list(redims, handler, allow_dict=True, allow_list=True):
8586
8687 from xarray_einstats.einops import process_pattern_list, DimHandler
8788 handler = DimHandler()
88- process_pattern_list(["a", {"b": ( "c", "d") }, ( "e", "f", "g") ], handler)
89+ process_pattern_list(["a", {"b": [ "c", "d"] }, [ "e", "f", "g"] ], handler)
8990
9091 """
9192 out = []
9293 out_names = []
9394 txt = []
9495 for subitem in redims :
95- if isinstance (subitem , str ):
96+ if isinstance (subitem , Hashable ):
9697 out .append (subitem )
9798 out_names .append (subitem )
9899 txt .append (handler .get_name (subitem ))
@@ -103,8 +104,10 @@ def process_pattern_list(redims, handler, allow_dict=True, allow_list=True):
103104 f"found { len (subitem )} : { subitem .keys ()} "
104105 )
105106 key , values = list (subitem .items ())[0 ]
106- if isinstance (values , str ):
107- raise ValueError ("Found values of type str in a pattern dict, use xarray.rename" )
107+ if isinstance (values , Hashable ):
108+ raise ValueError (
109+ "Found values of hashable type in a pattern dict, use xarray.rename"
110+ )
108111 out .extend (values )
109112 out_names .append (key )
110113 txt .append (f"( { handler .get_names (values )} )" )
@@ -182,7 +185,7 @@ def translate_pattern(pattern):
182185 return dims
183186
184187
185- def _rearrange (da , out_dims , in_dims = None , ** kwargs ):
188+ def _rearrange (da , out_dims , in_dims = None , dim_lengths = None ):
186189 """Wrap `einops.rearrange <https://einops.rocks/api/rearrange/>`_.
187190
188191 This is the function that actually interfaces with ``einops``.
@@ -198,11 +201,14 @@ def _rearrange(da, out_dims, in_dims=None, **kwargs):
198201 See docstring of :func:`~xarray_einstats.einops.rearrange`
199202 in_dims : list of str or dict, optional
200203 See docstring of :func:`~xarray_einstats.einops.rearrange`
201- kwargs : dict, optional
204+ dim_lengths : dict, optional
202205 kwargs with key equal to dimension names in ``out_dims``
203206 (that is, strings or dict keys) are passed to einops.rearrange
204207 the rest of keys are passed to :func:`xarray.apply_ufunc`
205208 """
209+ if dim_lengths is None :
210+ dim_lengths = {}
211+
206212 da_dims = da .dims
207213
208214 handler = DimHandler ()
@@ -231,9 +237,9 @@ def _rearrange(da, out_dims, in_dims=None, **kwargs):
231237 { non_core_pattern } { handler .get_names (missing_out_dims )} { out_pattern } "
232238
233239 axes_lengths = {
234- handler .rename_kwarg (k ): v for k , v in kwargs .items () if k in out_names + out_dims
240+ handler .rename_kwarg (k ): v for k , v in dim_lengths .items () if k in out_names + out_dims
235241 }
236- kwargs = {k : v for k , v in kwargs .items () if k not in out_names + out_dims }
242+ kwargs = {k : v for k , v in dim_lengths .items () if k not in out_names + out_dims }
237243 return xr .apply_ufunc (
238244 einops .rearrange ,
239245 da ,
@@ -245,7 +251,7 @@ def _rearrange(da, out_dims, in_dims=None, **kwargs):
245251 )
246252
247253
248- def rearrange (da , pattern , pattern_in = None , ** kwargs ):
254+ def rearrange (da , pattern , pattern_in = None , dim_lengths = None , ** dim_lengths_kwargs ):
249255 """Expose `einops.rearrange <https://einops.rocks/api/rearrange/>`_ with an xarray-like API.
250256
251257 It has two possible syntaxes which are independent and somewhat complementary.
@@ -268,12 +274,12 @@ def rearrange(da, pattern, pattern_in=None, **kwargs):
268274 a default name.
269275
270276 If `pattern` is not a string, then it must be a list where each of its elements
271- is one of: ``str`` , ``list`` (to stack those dimensions and give them an
272- arbitrary name) or ``dict of {str: list} `` (to stack the dimensions indicated
277+ is one of: :term:`python:hashable` , ``list`` (to stack those dimensions and
278+ give them an arbitrary name) or ``dict`` (to stack the dimensions indicated
273279 as values of the dictionary and name the resulting dimensions with the key).
274280
275- `pattern` is then interpreted as the output side of the einops pattern. See
276- TODO for more details.
281+ `pattern` is then interpreted as the output side of the einops pattern.
282+ See :ref:`about_einops` for more details.
277283 pattern_in : list of [str or dict], optional
278284 The input pattern for the dimensions. It can only be provided if `pattern`
279285 is a ``list``. Also, note this is only necessary if you want to split some dimensions.
@@ -282,28 +288,22 @@ def rearrange(da, pattern, pattern_in=None, **kwargs):
282288 with the only difference that ``list`` elements are not allowed, the same way
283289 that ``(dim1 dim2)=dim`` is required on the left hand side when using string
284290 patterns.
285- kwargs : dict, optional
286- Passed to :func:`xarray_einstats.einops.rearrange`
291+ dim_lengths, **dim_lengths_kwargs : dict, optional
292+ If the keys are dimensions present in `pattern` they will be passed to
293+ `einops.rearrange <https://einops.rocks/api/rearrange/>`_, otherwise,
294+ they are passed to :func:`xarray.apply_ufunc`.
287295
288296 Returns
289297 -------
290298 xarray.DataArray
291299
292- Notes
293- -----
294- Unlike for general xarray objects, where dimension
295- names can be :term:`hashable <xarray:name>` here
296- dimension names are not recommended but required to be
297- strings for both cases. Future releases however might
298- support this when using lists as `pattern`, comment
299- on :issue:`50` if you are interested in the feature
300- or could help implement it.
301-
302-
303300 See Also
304301 --------
305302 xarray_einstats.einops.reduce
306303 """
304+ if dim_lengths is None :
305+ dim_lengths = {}
306+ dim_lengths = {** dim_lengths , ** dim_lengths_kwargs }
307307 if isinstance (pattern , str ):
308308 if "->" in pattern :
309309 in_pattern , out_pattern = pattern .split ("->" )
@@ -312,11 +312,11 @@ def rearrange(da, pattern, pattern_in=None, **kwargs):
312312 out_pattern = pattern
313313 in_dims = None
314314 out_dims = translate_pattern (out_pattern )
315- return _rearrange (da , out_dims = out_dims , in_dims = in_dims , ** kwargs )
316- return _rearrange (da , out_dims = pattern , in_dims = pattern_in , ** kwargs )
315+ return _rearrange (da , out_dims = out_dims , in_dims = in_dims , dim_lengths = dim_lengths )
316+ return _rearrange (da , out_dims = pattern , in_dims = pattern_in , dim_lengths = dim_lengths )
317317
318318
319- def _reduce (da , reduction , out_dims , in_dims = None , ** kwargs ):
319+ def _reduce (da , reduction , out_dims , in_dims = None , dim_lengths = None ):
320320 """Wrap `einops.reduce <https://einops.rocks/api/reduce/>`_.
321321
322322 This is the function that actually interfaces with ``einops``.
@@ -338,11 +338,14 @@ def _reduce(da, reduction, out_dims, in_dims=None, **kwargs):
338338 in_dims : list of str or dict, optional
339339 The input pattern for the dimensions.
340340 This is only necessary if you want to split some dimensions.
341- kwargs : dict, optional
341+ dim_lengths : dict, optional
342342 kwargs with key equal to dimension names in ``out_dims``
343343 (that is, strings or dict keys) are passed to einops.rearrange
344344 the rest of keys are passed to :func:`xarray.apply_ufunc`
345345 """
346+ if dim_lengths is None :
347+ dim_lengths = {}
348+
346349 da_dims = da .dims
347350
348351 handler = DimHandler ()
@@ -361,8 +364,8 @@ def _reduce(da, reduction, out_dims, in_dims=None, **kwargs):
361364 pattern = f"{ handler .get_names (missing_in_dims )} { in_pattern } -> { out_pattern } "
362365
363366 all_dims = set (out_dims + out_names + in_names + in_dims )
364- axes_lengths = {handler .rename_kwarg (k ): v for k , v in kwargs .items () if k in all_dims }
365- kwargs = {k : v for k , v in kwargs .items () if k not in all_dims }
367+ axes_lengths = {handler .rename_kwarg (k ): v for k , v in dim_lengths .items () if k in all_dims }
368+ kwargs = {k : v for k , v in dim_lengths .items () if k not in all_dims }
366369 return xr .apply_ufunc (
367370 einops .reduce ,
368371 da ,
@@ -375,7 +378,7 @@ def _reduce(da, reduction, out_dims, in_dims=None, **kwargs):
375378 )
376379
377380
378- def reduce (da , pattern , reduction , pattern_in = None , ** kwargs ):
381+ def reduce (da , pattern , reduction , pattern_in = None , dim_lengths = None , ** dim_lengths_kwargs ):
379382 """Expose `einops.reduce <https://einops.rocks/api/reduce/>`_ with an xarray-like API.
380383
381384 It has two possible syntaxes which are independent and somewhat complementary.
@@ -412,27 +415,22 @@ def reduce(da, pattern, reduction, pattern_in=None, **kwargs):
412415 The syntax and interpretation is the same as the case when `pattern` is a list,
413416 with the only difference that ``list`` elements are not allowed, the same way
414417 that ``(dim1 dim2)=dim`` is required on the left hand side when using string
415- kwargs : dict, optional
416- Passed to :func:`xarray_einstats.einops.reduce`
418+ dim_lengths, **dim_lengths_kwargs : dict, optional
419+ If the keys are dimensions present in `pattern` they will be passed to
420+ `einops.reduce <https://einops.rocks/api/reduce/>`_, otherwise,
421+ they are passed to :func:`xarray.apply_ufunc`.
417422
418423 Returns
419424 -------
420425 xarray.DataArray
421426
422- Notes
423- -----
424- Unlike for general xarray objects, where dimension
425- names can be :term:`hashable <xarray:name>` here
426- dimension names are not recommended but required to be
427- strings for both cases. Future releases however might
428- support this when using lists as `pattern`, comment
429- on :issue:`50` if you are interested in the feature
430- or could help implement it.
431-
432427 See Also
433428 --------
434429 xarray_einstats.einops.rearrange
435430 """
431+ if dim_lengths is None :
432+ dim_lengths = {}
433+ dim_lengths = {** dim_lengths , ** dim_lengths_kwargs }
436434 if isinstance (pattern , str ):
437435 if "->" in pattern :
438436 in_pattern , out_pattern = pattern .split ("->" )
@@ -441,8 +439,8 @@ def reduce(da, pattern, reduction, pattern_in=None, **kwargs):
441439 out_pattern = pattern
442440 in_dims = None
443441 out_dims = translate_pattern (out_pattern )
444- return _reduce (da , reduction , out_dims = out_dims , in_dims = in_dims , ** kwargs )
445- return _reduce (da , reduction , out_dims = pattern , in_dims = pattern_in , ** kwargs )
442+ return _reduce (da , reduction , out_dims = out_dims , in_dims = in_dims , dim_lengths = dim_lengths )
443+ return _reduce (da , reduction , out_dims = pattern , in_dims = pattern_in , dim_lengths = dim_lengths )
446444
447445
448446def raw_reduce (* args , ** kwargs ):
0 commit comments