Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
776bc5a
use cumsum from flox
Illviljan Dec 6, 2025
ae27632
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
a5f9326
Update groupby.py
Illviljan Dec 6, 2025
50ccca4
Update groupby.py
Illviljan Dec 6, 2025
f55531e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
06ac372
Update groupby.py
Illviljan Dec 6, 2025
31244e6
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
dd47536
Update groupby.py
Illviljan Dec 6, 2025
e867f12
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
88e0ebc
Update groupby.py
Illviljan Dec 6, 2025
181d4a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
a82ec39
use apply_ufunc for dataset and dataarray handling
Illviljan Dec 6, 2025
6c6abed
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
24c3f1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
d8d0eaa
Update groupby.py
Illviljan Dec 6, 2025
55ff46a
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
33d1360
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
c97ae98
sync protocols with each other
Illviljan Dec 6, 2025
06b52ae
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
84f9b44
typing
Illviljan Dec 6, 2025
2978877
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
0a9adee
add dataset and version requirement
Illviljan Dec 6, 2025
ae9a3d8
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
c056d1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
d4873b9
Update _aggregations.py
Illviljan Dec 6, 2025
21cbde2
Update xarray/core/groupby.py
Illviljan Dec 6, 2025
4aebc47
Update groupby.py
Illviljan Dec 6, 2025
f4cab24
Update groupby.py
Illviljan Dec 6, 2025
23d9d50
Update groupby.py
Illviljan Dec 6, 2025
9b64db2
Update generate_aggregations.py
Illviljan Dec 6, 2025
928b158
Renove workaround in test
Illviljan Dec 7, 2025
130f98e
Update _aggregations.py
Illviljan Dec 7, 2025
5a3e754
Update _aggregations.py
Illviljan Dec 7, 2025
d912cda
Update test_groupby.py
Illviljan Dec 7, 2025
3bc8dc7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2025
ec8ffd6
clean ups
Illviljan Dec 7, 2025
b0cf8c4
Merge branch 'main' into cumsum_flox
Illviljan Dec 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 74 additions & 19 deletions xarray/core/_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3655,6 +3655,17 @@ def _flox_reduce(
) -> Dataset:
raise NotImplementedError()

def _flox_scan(
self,
dim: Dims,
*,
func: str,
skipna: bool | None = None,
keep_attrs: bool | None = None,
**kwargs: Any,
) -> Dataset:
raise NotImplementedError()
Copy link
Contributor Author

@Illviljan Illviljan Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made these changes manually now.
I'm not getting pytest-accept to correctly fix the docstrings in _aggregations.py, it's for example not indenting correctly. I'm not sure if this is just a Windows 10 thing.


def count(
self,
dim: Dims = None,
Expand Down Expand Up @@ -5000,29 +5011,47 @@ def cumsum(
da (time) float64 48B 1.0 2.0 3.0 0.0 2.0 nan

>>> ds.groupby("labels").cumsum()
<xarray.Dataset> Size: 48B
<xarray.Dataset> Size: 120B
Dimensions: (time: 6)
Dimensions without coordinates: time
Coordinates:
* time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30
labels (time) <U1 24B 'a' 'b' 'c' 'c' 'b' 'a'
Data variables:
da (time) float64 48B 1.0 2.0 3.0 3.0 4.0 1.0

Use ``skipna`` to control whether NaNs are ignored.

>>> ds.groupby("labels").cumsum(skipna=False)
<xarray.Dataset> Size: 48B
<xarray.Dataset> Size: 120B
Dimensions: (time: 6)
Dimensions without coordinates: time
Coordinates:
* time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30
labels (time) <U1 24B 'a' 'b' 'c' 'c' 'b' 'a'
Data variables:
da (time) float64 48B 1.0 2.0 3.0 3.0 4.0 nan
"""
return self.reduce(
duck_array_ops.cumsum,
dim=dim,
skipna=skipna,
numeric_only=True,
keep_attrs=keep_attrs,
**kwargs,
)
if (
flox_available
and OPTIONS["use_flox"]
and module_available("flox", minversion="0.10.5")
and contains_only_chunked_or_numpy(self._obj)
):
return self._flox_scan(
func="cumsum",
dim=dim,
skipna=skipna,
# fill_value=fill_value,
keep_attrs=keep_attrs,
**kwargs,
)
else:
return self.reduce(
duck_array_ops.cumsum,
dim=dim,
skipna=skipna,
keep_attrs=keep_attrs,
**kwargs,
)

def cumprod(
self,
Expand Down Expand Up @@ -6647,6 +6676,17 @@ def _flox_reduce(
) -> DataArray:
raise NotImplementedError()

def _flox_scan(
self,
dim: Dims,
*,
func: str,
skipna: bool | None = None,
keep_attrs: bool | None = None,
**kwargs: Any,
) -> DataArray:
raise NotImplementedError()

def count(
self,
dim: Dims = None,
Expand Down Expand Up @@ -7904,13 +7944,28 @@ def cumsum(
* time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30
labels (time) <U1 24B 'a' 'b' 'c' 'c' 'b' 'a'
"""
return self.reduce(
duck_array_ops.cumsum,
dim=dim,
skipna=skipna,
keep_attrs=keep_attrs,
**kwargs,
)
if (
flox_available
and OPTIONS["use_flox"]
and module_available("flox", minversion="0.10.5")
and contains_only_chunked_or_numpy(self._obj)
):
return self._flox_scan(
func="cumsum",
dim=dim,
skipna=skipna,
# fill_value=fill_value,
keep_attrs=keep_attrs,
**kwargs,
)
else:
return self.reduce(
duck_array_ops.cumsum,
dim=dim,
skipna=skipna,
keep_attrs=keep_attrs,
**kwargs,
)

def cumprod(
self,
Expand Down
90 changes: 74 additions & 16 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from packaging.version import Version

from xarray.computation import ops
from xarray.computation.apply_ufunc import apply_ufunc
from xarray.computation.arithmetic import (
DataArrayGroupbyArithmetic,
DatasetGroupbyArithmetic,
Expand Down Expand Up @@ -1028,6 +1029,26 @@ def _maybe_unstack(self, obj):

return obj

def _parse_dim(self, dim: Dims) -> tuple[Hashable, ...]:
parsed_dim: tuple[Hashable, ...]
if isinstance(dim, str):
parsed_dim = (dim,)
elif dim is None:
parsed_dim_list = list()
# preserve order
for dim_ in itertools.chain(
*(grouper.codes.dims for grouper in self.groupers)
):
if dim_ not in parsed_dim_list:
parsed_dim_list.append(dim_)
parsed_dim = tuple(parsed_dim_list)
elif dim is ...:
parsed_dim = tuple(self._original_obj.dims)
else:
parsed_dim = tuple(dim)

return parsed_dim

def _flox_reduce(
self,
dim: Dims,
Expand Down Expand Up @@ -1088,22 +1109,7 @@ def _flox_reduce(
# set explicitly to avoid unnecessarily accumulating count
kwargs["min_count"] = 0

parsed_dim: tuple[Hashable, ...]
if isinstance(dim, str):
parsed_dim = (dim,)
elif dim is None:
parsed_dim_list = list()
# preserve order
for dim_ in itertools.chain(
*(grouper.codes.dims for grouper in self.groupers)
):
if dim_ not in parsed_dim_list:
parsed_dim_list.append(dim_)
parsed_dim = tuple(parsed_dim_list)
elif dim is ...:
parsed_dim = tuple(obj.dims)
else:
parsed_dim = tuple(dim)
parsed_dim = self._parse_dim(dim)

# Do this so we raise the same error message whether flox is present or not.
# Better to control it here than in flox.
Expand Down Expand Up @@ -1202,6 +1208,58 @@ def _flox_reduce(

return result

def _flox_scan(
self,
dim: Dims,
*,
func: str,
skipna: bool | None = None,
keep_attrs: bool | None = None,
**kwargs: Any,
) -> T_Xarray:
from flox import groupby_scan

parsed_dim = self._parse_dim(dim)
obj = self._original_obj.transpose(..., *parsed_dim)
axis = range(-len(parsed_dim), 0)
codes = tuple(g.codes for g in self.groupers)

def wrapper(array, *by, func: str, skipna: bool | None, **kwargs):
if skipna or (skipna is None and array.dtype.kind in "cfO"):
if "nan" not in func:
func = f"nan{func}"

return groupby_scan(array, *codes, func=func, **kwargs)

actual = apply_ufunc(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is the way. eventually I'd like the apply_ufunc for reductions to live in Xarray too. So feel free to move that over if it helps. We could put it in flox_compat.py

wrapper,
obj,
*codes,
# input_core_dims=input_core_dims,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need this because we just want the full array forwarded

# for xarray's test_groupby_duplicate_coordinate_labels
# exclude_dims=set(dim_tuple),
# output_core_dims=[output_core_dims],
Comment on lines +1239 to +1241
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please delete

dask="allowed",
# dask_gufunc_kwargs=dict(
Copy link
Contributor

@dcherian dcherian Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please delete.

# output_sizes=output_sizes,
# output_dtypes=[dtype] if dtype is not None else None,
# ),
keep_attrs=(
_get_keep_attrs(default=True) if keep_attrs is None else keep_attrs
),
kwargs=dict(
func=func,
skipna=skipna,
expected_groups=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be the same as _flox_reduce. This is an important optimization.

axis=axis,
dtype=None,
method=None,
engine=None,
Comment on lines +1255 to +1257
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These we should grab from kwargs and forward just like _flox_reduce

),
)

return actual

def fillna(self, value: Any) -> T_Xarray:
"""Fill missing values in this object by group.

Expand Down
11 changes: 8 additions & 3 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from xarray import DataArray, Dataset, Variable, date_range
from xarray.core.groupby import _consolidate_slices
from xarray.core.types import InterpOptions, ResampleCompatible
from xarray.core.utils import module_available
from xarray.groupers import (
BinGrouper,
EncodedGroups,
Expand Down Expand Up @@ -2566,9 +2567,13 @@ def test_groupby_cumsum() -> None:
"group_id": ds.group_id,
},
)
# TODO: Remove drop_vars when GH6528 is fixed
# when Dataset.cumsum propagates indexes, and the group variable?
assert_identical(expected.drop_vars(["x", "group_id"]), actual)

if xr.get_options()["use_flox"] and module_available("flox", minversion="0.10.5"):
assert_identical(expected, actual)
else:
# TODO: Remove drop_vars when GH6528 is fixed
# when Dataset.cumsum propagates indexes, and the group variable?
assert_identical(expected.drop_vars(["x", "group_id"]), actual)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping this until min_version of flox is 0.10.5 at least.
Coordinates and docstrings might differ between using flox or not now though. Is that ok?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not ok imo. I think it might be fixed by simply propagating coordinates in the non-flox branch of the templated code. Might be easy


actual = ds.foo.groupby("group_id").cumsum(dim="x")
expected.coords["group_id"] = ds.group_id
Expand Down
21 changes: 18 additions & 3 deletions xarray/util/generate_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import textwrap
from dataclasses import dataclass, field
from typing import NamedTuple
from typing import Literal, NamedTuple

MODULE_PREAMBLE = '''\
"""Mixin classes with reduction operations."""
Expand Down Expand Up @@ -131,6 +131,17 @@ def _flox_reduce(
self,
dim: Dims,
**kwargs: Any,
) -> {obj}:
raise NotImplementedError()

def _flox_scan(
self,
dim: Dims,
*,
func: str,
skipna: bool | None = None,
keep_attrs: bool | None = None,
**kwargs: Any,
) -> {obj}:
raise NotImplementedError()"""

Expand Down Expand Up @@ -284,6 +295,7 @@ def __init__(
see_also_methods=(),
min_flox_version=None,
additional_notes="",
flox_aggregation_type: Literal["reduce", "scan"] = "reduce",
):
self.name = name
self.extra_kwargs = extra_kwargs
Expand All @@ -292,6 +304,7 @@ def __init__(
self.see_also_methods = see_also_methods
self.min_flox_version = min_flox_version
self.additional_notes = additional_notes
self.flox_aggregation_type = flox_aggregation_type
if bool_reduce:
self.array_method = f"array_{name}"
self.np_example_array = (
Expand Down Expand Up @@ -444,7 +457,7 @@ def generate_code(self, method, has_keep_attrs):

# median isn't enabled yet, because it would break if a single group was present in multiple
# chunks. The non-flox code path will just rechunk every group to a single chunk and execute the median
method_is_not_flox_supported = method.name in ("median", "cumsum", "cumprod")
method_is_not_flox_supported = method.name in ("median", "cumprod")
if method_is_not_flox_supported:
indent = 12
else:
Expand Down Expand Up @@ -476,7 +489,7 @@ def generate_code(self, method, has_keep_attrs):
+ f"""
and contains_only_chunked_or_numpy(self._obj)
):
return self._flox_reduce(
return self._flox_{method.flox_aggregation_type}(
func="{method.name}",
dim=dim,{extra_kwargs}
# fill_value=fill_value,
Expand Down Expand Up @@ -537,6 +550,8 @@ def generate_code(self, method, has_keep_attrs):
numeric_only=True,
see_also_methods=("cumulative",),
additional_notes=_CUM_NOTES,
min_flox_version="0.10.5",
flox_aggregation_type="scan",
),
Method(
"cumprod",
Expand Down
Loading