Skip to content
Draft
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
776bc5a
use cumsum from flox
Illviljan Dec 6, 2025
ae27632
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
a5f9326
Update groupby.py
Illviljan Dec 6, 2025
50ccca4
Update groupby.py
Illviljan Dec 6, 2025
f55531e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
06ac372
Update groupby.py
Illviljan Dec 6, 2025
31244e6
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
dd47536
Update groupby.py
Illviljan Dec 6, 2025
e867f12
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
88e0ebc
Update groupby.py
Illviljan Dec 6, 2025
181d4a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
a82ec39
use apply_ufunc for dataset and dataarray handling
Illviljan Dec 6, 2025
6c6abed
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
24c3f1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
d8d0eaa
Update groupby.py
Illviljan Dec 6, 2025
55ff46a
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
33d1360
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
c97ae98
sync protocols with each other
Illviljan Dec 6, 2025
06b52ae
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
84f9b44
typing
Illviljan Dec 6, 2025
2978877
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
0a9adee
add dataset and version requirement
Illviljan Dec 6, 2025
ae9a3d8
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
c056d1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
d4873b9
Update _aggregations.py
Illviljan Dec 6, 2025
21cbde2
Update xarray/core/groupby.py
Illviljan Dec 6, 2025
4aebc47
Update groupby.py
Illviljan Dec 6, 2025
f4cab24
Update groupby.py
Illviljan Dec 6, 2025
23d9d50
Update groupby.py
Illviljan Dec 6, 2025
9b64db2
Update generate_aggregations.py
Illviljan Dec 6, 2025
928b158
Renove workaround in test
Illviljan Dec 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 66 additions & 15 deletions xarray/core/_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3655,6 +3655,17 @@ def _flox_reduce(
) -> Dataset:
raise NotImplementedError()

def _flox_scan(
self,
dim: Dims,
*,
func: str,
skipna: bool | None = None,
keep_attrs: bool | None = None,
**kwargs: Any,
) -> Dataset:
raise NotImplementedError()

def count(
self,
dim: Dims = None,
Expand Down Expand Up @@ -5015,14 +5026,28 @@ def cumsum(
Data variables:
da (time) float64 48B 1.0 2.0 3.0 3.0 4.0 nan
"""
return self.reduce(
duck_array_ops.cumsum,
dim=dim,
skipna=skipna,
numeric_only=True,
keep_attrs=keep_attrs,
**kwargs,
)
if (
flox_available
and OPTIONS["use_flox"]
and module_available("flox", minversion="0.10.5")
and contains_only_chunked_or_numpy(self._obj)
):
return self._flox_scan(
func="cumsum",
dim=dim,
skipna=skipna,
# fill_value=fill_value,
keep_attrs=keep_attrs,
**kwargs,
)
else:
return self.reduce(
duck_array_ops.cumsum,
dim=dim,
skipna=skipna,
keep_attrs=keep_attrs,
**kwargs,
)

def cumprod(
self,
Expand Down Expand Up @@ -6647,6 +6672,17 @@ def _flox_reduce(
) -> DataArray:
raise NotImplementedError()

def _flox_scan(
self,
dim: Dims,
*,
func: str,
skipna: bool | None = None,
keep_attrs: bool | None = None,
**kwargs: Any,
) -> DataArray:
raise NotImplementedError()

def count(
self,
dim: Dims = None,
Expand Down Expand Up @@ -7904,13 +7940,28 @@ def cumsum(
* time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30
labels (time) <U1 24B 'a' 'b' 'c' 'c' 'b' 'a'
"""
return self.reduce(
duck_array_ops.cumsum,
dim=dim,
skipna=skipna,
keep_attrs=keep_attrs,
**kwargs,
)
if (
flox_available
and OPTIONS["use_flox"]
and module_available("flox", minversion="0.10.5")
and contains_only_chunked_or_numpy(self._obj)
):
return self._flox_scan(
func="cumsum",
dim=dim,
skipna=skipna,
# fill_value=fill_value,
keep_attrs=keep_attrs,
**kwargs,
)
else:
return self.reduce(
duck_array_ops.cumsum,
dim=dim,
skipna=skipna,
keep_attrs=keep_attrs,
**kwargs,
)

def cumprod(
self,
Expand Down
93 changes: 77 additions & 16 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from packaging.version import Version

from xarray.computation import ops
from xarray.computation.apply_ufunc import apply_ufunc
from xarray.computation.arithmetic import (
DataArrayGroupbyArithmetic,
DatasetGroupbyArithmetic,
Expand Down Expand Up @@ -1028,6 +1029,26 @@ def _maybe_unstack(self, obj):

return obj

def _parse_dim(self, dim: Dims) -> tuple[Hashable, ...]:
parsed_dim: tuple[Hashable, ...]
if isinstance(dim, str):
parsed_dim = (dim,)
elif dim is None:
parsed_dim_list = list()
# preserve order
for dim_ in itertools.chain(
*(grouper.codes.dims for grouper in self.groupers)
):
if dim_ not in parsed_dim_list:
parsed_dim_list.append(dim_)
parsed_dim = tuple(parsed_dim_list)
elif dim is ...:
parsed_dim = tuple(self._original_obj.dims)
else:
parsed_dim = tuple(dim)

return parsed_dim

def _flox_reduce(
self,
dim: Dims,
Expand Down Expand Up @@ -1088,22 +1109,7 @@ def _flox_reduce(
# set explicitly to avoid unnecessarily accumulating count
kwargs["min_count"] = 0

parsed_dim: tuple[Hashable, ...]
if isinstance(dim, str):
parsed_dim = (dim,)
elif dim is None:
parsed_dim_list = list()
# preserve order
for dim_ in itertools.chain(
*(grouper.codes.dims for grouper in self.groupers)
):
if dim_ not in parsed_dim_list:
parsed_dim_list.append(dim_)
parsed_dim = tuple(parsed_dim_list)
elif dim is ...:
parsed_dim = tuple(obj.dims)
else:
parsed_dim = tuple(dim)
parsed_dim = self._parse_dim(dim)

# Do this so we raise the same error message whether flox is present or not.
# Better to control it here than in flox.
Expand Down Expand Up @@ -1202,6 +1208,61 @@ def _flox_reduce(

return result

def _flox_scan(
self,
dim: Dims,
*,
func: str,
skipna: bool | None = None,
keep_attrs: bool | None = None,
**kwargs: Any,
) -> DataArray:
from flox import groupby_scan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably requires a version guard.

Copy link
Contributor Author

@Illviljan Illviljan Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, now I see:
groupby_scan was added in: https://github.com/xarray-contrib/flox/releases/tag/v0.9.9 - OK
cumsum was added in: https://github.com/xarray-contrib/flox/releases/tag/v0.10.5 - NOK


obj = self._original_obj

parsed_dim = self._parse_dim(dim)

axis = obj.get_axis_num(parsed_dim)
# axis = (axis_,) if isinstance(axis_, int) else axis_
codes = tuple(g.codes for g in self.groupers)

def wrapper(array, *by, func: str, skipna: bool | None, **kwargs):
if skipna or (skipna is None and obj.dtype.kind in "cfO"):
if "nan" not in func:
func = f"nan{func}"

return groupby_scan(array, *codes, func=func, **kwargs)

actual = apply_ufunc(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is the way. eventually I'd like the apply_ufunc for reductions to live in Xarray too. So feel free to move that over if it helps. We could put it in flox_compat.py

wrapper,
obj,
*codes,
# input_core_dims=input_core_dims,
# for xarray's test_groupby_duplicate_coordinate_labels
# exclude_dims=set(dim_tuple),
# output_core_dims=[output_core_dims],
dask="allowed",
# dask_gufunc_kwargs=dict(
# output_sizes=output_sizes,
# output_dtypes=[dtype] if dtype is not None else None,
# ),
keep_attrs=(
_get_keep_attrs(default=True) if keep_attrs is None else keep_attrs
),
kwargs=dict(
func=func,
skipna=skipna,
expected_groups=None,
axis=axis,
dtype=None,
method=None,
engine=None,
),
)

return actual

def fillna(self, value: Any) -> T_Xarray:
"""Fill missing values in this object by group.

Expand Down
21 changes: 18 additions & 3 deletions xarray/util/generate_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import textwrap
from dataclasses import dataclass, field
from typing import NamedTuple
from typing import Literal, NamedTuple

MODULE_PREAMBLE = '''\
"""Mixin classes with reduction operations."""
Expand Down Expand Up @@ -132,6 +132,17 @@ def _flox_reduce(
dim: Dims,
**kwargs: Any,
) -> {obj}:
raise NotImplementedError()

def _flox_scan(
self,
dim: Dims,
*,
func: str,
skipna: bool | None = None,
keep_attrs: bool | None = None,
**kwargs: Any,
) -> DataArray:
raise NotImplementedError()"""

TEMPLATE_REDUCTION_SIGNATURE = '''
Expand Down Expand Up @@ -284,6 +295,7 @@ def __init__(
see_also_methods=(),
min_flox_version=None,
additional_notes="",
flox_aggregation_type: Literal["reduce", "scan"] = "reduce",
):
self.name = name
self.extra_kwargs = extra_kwargs
Expand All @@ -292,6 +304,7 @@ def __init__(
self.see_also_methods = see_also_methods
self.min_flox_version = min_flox_version
self.additional_notes = additional_notes
self.flox_aggregation_type = flox_aggregation_type
if bool_reduce:
self.array_method = f"array_{name}"
self.np_example_array = (
Expand Down Expand Up @@ -444,7 +457,7 @@ def generate_code(self, method, has_keep_attrs):

# median isn't enabled yet, because it would break if a single group was present in multiple
# chunks. The non-flox code path will just rechunk every group to a single chunk and execute the median
method_is_not_flox_supported = method.name in ("median", "cumsum", "cumprod")
method_is_not_flox_supported = method.name in ("median", "cumprod")
if method_is_not_flox_supported:
indent = 12
else:
Expand Down Expand Up @@ -476,7 +489,7 @@ def generate_code(self, method, has_keep_attrs):
+ f"""
and contains_only_chunked_or_numpy(self._obj)
):
return self._flox_reduce(
return self._flox_{method.flox_aggregation_type}(
func="{method.name}",
dim=dim,{extra_kwargs}
# fill_value=fill_value,
Expand Down Expand Up @@ -537,6 +550,8 @@ def generate_code(self, method, has_keep_attrs):
numeric_only=True,
see_also_methods=("cumulative",),
additional_notes=_CUM_NOTES,
min_flox_version="0.10.5",
flox_aggregation_type="scan",
),
Method(
"cumprod",
Expand Down
Loading