Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
776bc5a
use cumsum from flox
Illviljan Dec 6, 2025
ae27632
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
a5f9326
Update groupby.py
Illviljan Dec 6, 2025
50ccca4
Update groupby.py
Illviljan Dec 6, 2025
f55531e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
06ac372
Update groupby.py
Illviljan Dec 6, 2025
31244e6
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
dd47536
Update groupby.py
Illviljan Dec 6, 2025
e867f12
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
88e0ebc
Update groupby.py
Illviljan Dec 6, 2025
181d4a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
a82ec39
use apply_ufunc for dataset and dataarray handling
Illviljan Dec 6, 2025
6c6abed
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
24c3f1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
d8d0eaa
Update groupby.py
Illviljan Dec 6, 2025
55ff46a
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
33d1360
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
c97ae98
sync protocols with each other
Illviljan Dec 6, 2025
06b52ae
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
84f9b44
typing
Illviljan Dec 6, 2025
2978877
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
0a9adee
add dataset and version requirement
Illviljan Dec 6, 2025
ae9a3d8
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
c056d1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
d4873b9
Update _aggregations.py
Illviljan Dec 6, 2025
21cbde2
Update xarray/core/groupby.py
Illviljan Dec 6, 2025
4aebc47
Update groupby.py
Illviljan Dec 6, 2025
f4cab24
Update groupby.py
Illviljan Dec 6, 2025
23d9d50
Update groupby.py
Illviljan Dec 6, 2025
9b64db2
Update generate_aggregations.py
Illviljan Dec 6, 2025
928b158
Renove workaround in test
Illviljan Dec 7, 2025
130f98e
Update _aggregations.py
Illviljan Dec 7, 2025
5a3e754
Update _aggregations.py
Illviljan Dec 7, 2025
d912cda
Update test_groupby.py
Illviljan Dec 7, 2025
3bc8dc7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2025
ec8ffd6
clean ups
Illviljan Dec 7, 2025
b0cf8c4
Merge branch 'main' into cumsum_flox
Illviljan Dec 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions xarray/core/_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6647,6 +6647,13 @@ def _flox_reduce(
) -> DataArray:
raise NotImplementedError()

def _flox_scan(
self,
dim: Dims,
**kwargs: Any,
) -> DataArray:
raise NotImplementedError()

def count(
self,
dim: Dims = None,
Expand Down Expand Up @@ -7904,13 +7911,27 @@ def cumsum(
* time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30
labels (time) <U1 24B 'a' 'b' 'c' 'c' 'b' 'a'
"""
return self.reduce(
duck_array_ops.cumsum,
dim=dim,
skipna=skipna,
keep_attrs=keep_attrs,
**kwargs,
)
if (
flox_available
and OPTIONS["use_flox"]
and contains_only_chunked_or_numpy(self._obj)
):
return self._flox_scan(
func="cumsum",
dim=dim,
skipna=skipna,
# fill_value=fill_value,
keep_attrs=keep_attrs,
**kwargs,
)
else:
return self.reduce(
duck_array_ops.cumsum,
dim=dim,
skipna=skipna,
keep_attrs=keep_attrs,
**kwargs,
)

def cumprod(
self,
Expand Down
117 changes: 101 additions & 16 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from packaging.version import Version

from xarray.computation import ops
from xarray.computation.apply_ufunc import apply_ufunc
from xarray.computation.arithmetic import (
DataArrayGroupbyArithmetic,
DatasetGroupbyArithmetic,
Expand Down Expand Up @@ -1028,6 +1029,26 @@ def _maybe_unstack(self, obj):

return obj

def _parse_dim(self, dim: Dims) -> tuple[Hashable, ...]:
parsed_dim: tuple[Hashable, ...]
if isinstance(dim, str):
parsed_dim = (dim,)
elif dim is None:
parsed_dim_list = list()
# preserve order
for dim_ in itertools.chain(
*(grouper.codes.dims for grouper in self.groupers)
):
if dim_ not in parsed_dim_list:
parsed_dim_list.append(dim_)
parsed_dim = tuple(parsed_dim_list)
elif dim is ...:
parsed_dim = tuple(self._original_obj.dims)
else:
parsed_dim = tuple(dim)

return parsed_dim

def _flox_reduce(
self,
dim: Dims,
Expand Down Expand Up @@ -1088,22 +1109,7 @@ def _flox_reduce(
# set explicitly to avoid unnecessarily accumulating count
kwargs["min_count"] = 0

parsed_dim: tuple[Hashable, ...]
if isinstance(dim, str):
parsed_dim = (dim,)
elif dim is None:
parsed_dim_list = list()
# preserve order
for dim_ in itertools.chain(
*(grouper.codes.dims for grouper in self.groupers)
):
if dim_ not in parsed_dim_list:
parsed_dim_list.append(dim_)
parsed_dim = tuple(parsed_dim_list)
elif dim is ...:
parsed_dim = tuple(obj.dims)
else:
parsed_dim = tuple(dim)
parsed_dim = self._parse_dim(dim)

# Do this so we raise the same error message whether flox is present or not.
# Better to control it here than in flox.
Expand Down Expand Up @@ -1202,6 +1208,85 @@ def _flox_reduce(

return result

def _flox_scan(
self,
dim: Dims,
*,
func: str,
keep_attrs: bool | None = None,
skipna: bool | None = None,
**kwargs: Any,
) -> DataArray:
from flox import groupby_scan

obj = self._original_obj

if skipna or (
skipna is None and isinstance(func, str) and obj.dtype.kind in "cfO"
):
if "nan" not in func and func not in ["all", "any", "count"]:
func = f"nan{func}"

# if keep_attrs is None:
# keep_attrs = _get_keep_attrs(default=True)

parsed_dim = self._parse_dim(dim)

axis_ = obj.get_axis_num(parsed_dim)
axis = (axis_,) if isinstance(axis_, int) else axis_
codes = tuple(g.codes for g in self.groupers)
# g = groupby_scan(
# obj.data,
# *codes,
# func=func,
# expected_groups=None,
# axis=axis,
# dtype=None,
# method=None,
# engine=None,
# )
# result = obj.copy(data=g)

# return result

actual = apply_ufunc(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is the way. eventually I'd like the apply_ufunc for reductions to live in Xarray too. So feel free to move that over if it helps. We could put it in flox_compat.py

groupby_scan,
obj,
*codes,
# input_core_dims=input_core_dims,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need this because we just want the full array forwarded

# for xarray's test_groupby_duplicate_coordinate_labels
# exclude_dims=set(dim_tuple),
# output_core_dims=[output_core_dims],
Comment on lines +1239 to +1241
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please delete

dask="allowed",
# dask_gufunc_kwargs=dict(
Copy link
Contributor

@dcherian dcherian Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please delete.

# output_sizes=output_sizes,
# output_dtypes=[dtype] if dtype is not None else None,
# ),
keep_attrs=(
_get_keep_attrs(default=True) if keep_attrs is None else keep_attrs
),
kwargs=dict(
func=func,
expected_groups=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be the same as _flox_reduce. This is an important optimization.

axis=axis,
dtype=None,
method=None,
engine=None,
Comment on lines +1255 to +1257
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These we should grab from kwargs and forward just like _flox_reduce

),
)

return actual

# xarray_reduce(
# obj.drop_vars(non_numeric.keys()),
# *codes,
# dim=parsed_dim,
# expected_groups=expected_groups,
# isbin=False,
# keep_attrs=keep_attrs,
# **kwargs,
# )

def fillna(self, value: Any) -> T_Xarray:
"""Fill missing values in this object by group.

Expand Down
Loading