From 6b196ae314795e358fb77232f9c290838522cc45 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 21 Nov 2025 15:33:33 -0800 Subject: [PATCH] Add Dataset.subset() method for type-stable variable selection Addresses issue #3894 by providing a public API for selecting multiple variables that always returns a Dataset (unlike __getitem__) and accepts sequence types including tuples. This eliminates the need to convert tuples to lists when subsetting variables and provides better type stability for downstream code. Unlike using __getitem__ with a list, an explicit method is more discoverable through IDE autocomplete and documentation. Co-authored-by: Claude --- CLAUDE.md | 1 + doc/api/dataset.rst | 1 + doc/whats-new.rst | 3 ++ xarray/core/dataset.py | 96 ++++++++++++++++++++++++++++++++++++ xarray/tests/test_dataset.py | 51 +++++++++++++++++++ 5 files changed, 152 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index b4c0061bb43..ee1d9a7978d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -42,3 +42,4 @@ uv run dmypy run # Type checking with mypy GitHub issues/PRs unless specifically instructed - When creating commits, always include a co-authorship trailer: `Co-authored-by: Claude ` +- Submit upstream PRs against `main` diff --git a/doc/api/dataset.rst b/doc/api/dataset.rst index 733c9768d2f..efd9a092235 100644 --- a/doc/api/dataset.rst +++ b/doc/api/dataset.rst @@ -62,6 +62,7 @@ Dataset contents Dataset.rename_dims Dataset.swap_dims Dataset.expand_dims + Dataset.subset Dataset.drop_vars Dataset.drop_indexes Dataset.drop_duplicates diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 255f88d241e..7238c023968 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -17,6 +17,9 @@ New Features - :py:func:`combine_nested` now support :py:class:`DataTree` objects (:pull:`10849`). By `Stephan Hoyer `_. +- Added :py:meth:`Dataset.subset` method for type-stable selection of multiple + variables. Unlike indexing with ``__getitem__``, this method always returns + a Dataset and accepts sequence types (lists and tuples) (:issue:`3894`). Breaking Changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 9c2c2f60db1..248bb2d3d39 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1183,6 +1183,8 @@ def as_numpy(self) -> Self: def _copy_listed(self, names: Iterable[Hashable]) -> Self: """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. + + For public API, use Dataset.subset() instead. """ variables: dict[Hashable, Variable] = {} coord_names = set() @@ -1223,6 +1225,100 @@ def _copy_listed(self, names: Iterable[Hashable]) -> Self: return self._replace(variables, coord_names, dims, indexes=indexes) + def subset(self, names: Sequence[Hashable]) -> Self: + """Return a new Dataset with only the specified variables. + + This is a type-stable method for selecting multiple variables from a + Dataset. Unlike indexing with ``__getitem__``, this method always + returns a Dataset and accepts sequence types (lists, tuples) of variable + names. + + All coordinates needed for the selected variables are automatically + included in the returned Dataset. + + Parameters + ---------- + names : sequence of hashable + A sequence (list or tuple) of variable names to include in the + returned Dataset. The names must exist in the Dataset. + + Returns + ------- + Dataset + A new Dataset containing only the specified variables and their + required coordinates. + + Raises + ------ + TypeError + If ``names`` is not a sequence (e.g., if it's a set, generator, or string). + KeyError + If any of the variable names do not exist in the Dataset. + + See Also + -------- + Dataset.__getitem__ : Select variables using indexing notation. + Dataset.drop_vars : Remove variables from a Dataset. + + Examples + -------- + >>> ds = xr.Dataset( + ... { + ... "temperature": (["x", "y"], [[1, 2], [3, 4]]), + ... "pressure": (["x", "y"], [[5, 6], [7, 8]]), + ... "humidity": (["x"], [0.5, 0.6]), + ... }, + ... coords={"x": [10, 20], "y": [30, 40]}, + ... ) + >>> ds + Size: 112B + Dimensions: (x: 2, y: 2) + Coordinates: + * x (x) int64 16B 10 20 + * y (y) int64 16B 30 40 + Data variables: + temperature (x, y) int64 32B 1 2 3 4 + pressure (x, y) int64 32B 5 6 7 8 + humidity (x) float64 16B 0.5 0.6 + + Select a subset of variables using a list: + + >>> ds.subset(["temperature", "humidity"]) + Size: 80B + Dimensions: (x: 2, y: 2) + Coordinates: + * x (x) int64 16B 10 20 + * y (y) int64 16B 30 40 + Data variables: + temperature (x, y) int64 32B 1 2 3 4 + humidity (x) float64 16B 0.5 0.6 + + Unlike ``__getitem__``, this method accepts tuples: + + >>> vars_tuple = ("temperature", "pressure") + >>> ds.subset(vars_tuple) # Works with tuples + Size: 96B + Dimensions: (x: 2, y: 2) + Coordinates: + * x (x) int64 16B 10 20 + * y (y) int64 16B 30 40 + Data variables: + temperature (x, y) int64 32B 1 2 3 4 + pressure (x, y) int64 32B 5 6 7 8 + + The method always returns a Dataset, providing type stability: + + >>> result = ds.subset(["temperature"]) + >>> isinstance(result, xr.Dataset) + True + """ + # Validate that names is a sequence (but not a string) + if isinstance(names, str) or not isinstance(names, Sequence): + raise TypeError( + f"names must be a sequence (list or tuple), got {type(names).__name__}" + ) + return self._copy_listed(names) + def _construct_dataarray(self, name: Hashable) -> DataArray: """Construct a DataArray by indexing this dataset""" from xarray.core.dataarray import DataArray diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index e677430dfbf..cfb9f767e4d 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4451,6 +4451,57 @@ def test_getitem_multiple_dtype(self) -> None: dataset = Dataset({key: ("dim0", range(1)) for key in keys}) assert_identical(dataset, dataset[keys]) + def test_subset(self) -> None: + data = create_test_data() + + # Test with list of variables + result = data.subset(["var1", "var2"]) + expected = Dataset({"var1": data["var1"], "var2": data["var2"]}) + assert_identical(result, expected) + + # Test with tuple (the original issue from #3894) + vars_tuple = ("var1", "var2") + result_tuple = data.subset(vars_tuple) + assert_identical(result_tuple, expected) + + # Test type stability - always returns a Dataset + result_single = data.subset(["var1"]) + expected_single = Dataset({"var1": data["var1"]}) + assert_identical(result_single, expected_single) + + # Test that coordinates are preserved + ds = Dataset( + { + "temperature": (["x", "y"], [[1, 2], [3, 4]]), + "pressure": (["x", "y"], [[5, 6], [7, 8]]), + "humidity": (["x"], [0.5, 0.6]), + }, + coords={"x": [10, 20], "y": [30, 40]}, + ) + result_coords = ds.subset(["temperature", "humidity"]) + expected_coords = Dataset( + { + "temperature": (["x", "y"], [[1, 2], [3, 4]]), + "humidity": (["x"], [0.5, 0.6]), + }, + coords={"x": [10, 20], "y": [30, 40]}, + ) + assert_identical(result_coords, expected_coords) + + # Test error handling for non-existent variable + with pytest.raises(KeyError): + data.subset(["var1", "notfound"]) + + # Test that non-sequence types raise TypeError + with pytest.raises(TypeError, match="names must be a sequence"): + data.subset({"var1", "var2"}) # type: ignore[arg-type] # set + + with pytest.raises(TypeError, match="names must be a sequence"): + data.subset(v for v in ["var1", "var2"]) # type: ignore[arg-type] # generator + + with pytest.raises(TypeError, match="names must be a sequence"): + data.subset("var1") # string (valid Sequence type, but explicitly rejected) + def test_getitem_extra_dim_index_coord(self) -> None: class AnyIndex(Index): def should_add_coord_to_array(self, name, var, dims):