Skip to content

Commit 3425d96

Browse files
jsignellIllviljanbenbovy
authored
Coerce IndexVariable to Variable when assigning to data variables or coordinates (#10909)
Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> Co-authored-by: Benoit Bovy <benbovy@gmail.com>
1 parent d42c464 commit 3425d96

File tree

6 files changed

+73
-29
lines changed

6 files changed

+73
-29
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ Deprecations
3434
Bug Fixes
3535
~~~~~~~~~
3636

37+
- When assigning an indexed coordinate to a data variable or coordinate, coerce it from
38+
``IndexVariable`` to ``Variable`` (:issue:`9859`, :issue:`10829`, :pull:`10909`)
39+
By `Julia Signell <https://github.com/jsignell>`_
3740
- The NetCDF4 backend will now claim to be able to read any URL except for one that contains
3841
the substring zarr. This restores backward compatibility after
3942
:pull:`10804` broke workflows that relied on ``xr.open_dataset("http://...")``

xarray/core/coordinates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1265,7 +1265,7 @@ def create_coords_with_default_indexes(
12651265
variables.update(idx_vars)
12661266
all_variables.update(idx_vars)
12671267
else:
1268-
variables[name] = variable
1268+
variables[name] = variable.to_base_variable()
12691269

12701270
new_coords = Coordinates._construct_direct(coords=variables, indexes=indexes)
12711271

xarray/structure/merge.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
emit_user_level_warning,
2323
equivalent,
2424
)
25-
from xarray.core.variable import Variable, as_variable, calculate_dimensions
25+
from xarray.core.variable import (
26+
IndexVariable,
27+
Variable,
28+
as_variable,
29+
calculate_dimensions,
30+
)
2631
from xarray.structure.alignment import deep_align
2732
from xarray.util.deprecation_helpers import (
2833
_COMPAT_DEFAULT,
@@ -1206,7 +1211,11 @@ def dataset_update_method(dataset: Dataset, other: CoercibleMapping) -> _MergeRe
12061211
if c not in value.dims and c in dataset.coords
12071212
]
12081213
if coord_names:
1209-
other[key] = value.drop_vars(coord_names)
1214+
value = value.drop_vars(coord_names)
1215+
if isinstance(value.variable, IndexVariable):
1216+
variable = value.variable.to_base_variable()
1217+
value = value._replace(variable=variable)
1218+
other[key] = value
12101219

12111220
return merge_core(
12121221
[dataset, other],

xarray/testing/assertions.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import functools
44
import warnings
55
from collections.abc import Hashable
6+
from typing import Any
67

78
import numpy as np
89
import pandas as pd
@@ -362,6 +363,17 @@ def _assert_indexes_invariants_checks(
362363
if isinstance(v, IndexVariable)
363364
}
364365
assert indexes.keys() <= index_vars, (set(indexes), index_vars)
366+
assert all(
367+
k in index_vars
368+
for k, v in possible_coord_variables.items()
369+
if v.dims == (k,)
370+
), {k: type(v) for k, v in possible_coord_variables.items()}
371+
372+
assert not any(
373+
isinstance(v, IndexVariable)
374+
for k, v in possible_coord_variables.items()
375+
if k not in indexes.keys()
376+
), {k: type(v) for k, v in possible_coord_variables.items()}
365377

366378
# check pandas index wrappers vs. coordinate data adapters
367379
for k, index in indexes.items():
@@ -401,11 +413,17 @@ def _assert_indexes_invariants_checks(
401413
)
402414

403415

404-
def _assert_variable_invariants(var: Variable, name: Hashable = None):
416+
def _assert_variable_invariants(
417+
var: Variable | Any,
418+
name: Hashable = None,
419+
) -> None:
405420
if name is None:
406421
name_or_empty: tuple = ()
407422
else:
408423
name_or_empty = (name,)
424+
425+
assert isinstance(var, Variable), {name: type(var)}
426+
409427
assert isinstance(var._dims, tuple), name_or_empty + (var._dims,)
410428
assert len(var._dims) == len(var._data.shape), name_or_empty + (
411429
var._dims,
@@ -418,35 +436,28 @@ def _assert_variable_invariants(var: Variable, name: Hashable = None):
418436

419437

420438
def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool):
421-
assert isinstance(da._variable, Variable), da._variable
422439
_assert_variable_invariants(da._variable)
423440

424441
assert isinstance(da._coords, dict), da._coords
425-
assert all(isinstance(v, Variable) for v in da._coords.values()), da._coords
426442

427443
if check_default_indexes:
428444
assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), (
429445
da.dims,
430446
{k: v.dims for k, v in da._coords.items()},
431447
)
432-
assert all(
433-
isinstance(v, IndexVariable)
434-
for (k, v) in da._coords.items()
435-
if v.dims == (k,)
436-
), {k: type(v) for k, v in da._coords.items()}
437448

438449
for k, v in da._coords.items():
439450
_assert_variable_invariants(v, k)
440451

441-
if da._indexes is not None:
442-
_assert_indexes_invariants_checks(
443-
da._indexes, da._coords, da.dims, check_default=check_default_indexes
444-
)
452+
assert da._indexes is not None
453+
_assert_indexes_invariants_checks(
454+
da._indexes, da._coords, da.dims, check_default=check_default_indexes
455+
)
445456

446457

447458
def _assert_dataset_invariants(ds: Dataset, check_default_indexes: bool):
448459
assert isinstance(ds._variables, dict), type(ds._variables)
449-
assert all(isinstance(v, Variable) for v in ds._variables.values()), ds._variables
460+
450461
for k, v in ds._variables.items():
451462
_assert_variable_invariants(v, k)
452463

@@ -466,17 +477,10 @@ def _assert_dataset_invariants(ds: Dataset, check_default_indexes: bool):
466477
ds._dims[k] == v.sizes[k] for v in ds._variables.values() for k in v.sizes
467478
), (ds._dims, {k: v.sizes for k, v in ds._variables.items()})
468479

469-
if check_default_indexes:
470-
assert all(
471-
isinstance(v, IndexVariable)
472-
for (k, v) in ds._variables.items()
473-
if v.dims == (k,)
474-
), {k: type(v) for k, v in ds._variables.items() if v.dims == (k,)}
475-
476-
if ds._indexes is not None:
477-
_assert_indexes_invariants_checks(
478-
ds._indexes, ds._variables, ds._dims, check_default=check_default_indexes
479-
)
480+
assert ds._indexes is not None
481+
_assert_indexes_invariants_checks(
482+
ds._indexes, ds._variables, ds._dims, check_default=check_default_indexes
483+
)
480484

481485
assert isinstance(ds._encoding, type(None) | dict)
482486
assert isinstance(ds._attrs, type(None) | dict)

xarray/tests/test_dataarray.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,6 +1702,20 @@ def should_add_coord_to_array(self, name, var, dims):
17021702
assert_identical(actual, expected, check_default_indexes=False)
17031703
assert "x_bnds" not in actual.dims
17041704

1705+
def test_assign_coords_uses_base_variable_class(self) -> None:
1706+
a = DataArray([0, 1, 3], dims=["x"], coords={"x": [0, 1, 2]})
1707+
a = a.assign_coords(foo=a.x)
1708+
1709+
# explicit check
1710+
assert isinstance(a["x"].variable, IndexVariable)
1711+
assert not isinstance(a["foo"].variable, IndexVariable)
1712+
1713+
# test internal invariant checks when comparing the datasets
1714+
expected = DataArray(
1715+
[0, 1, 3], dims=["x"], coords={"x": [0, 1, 2], "foo": ("x", [0, 1, 2])}
1716+
)
1717+
assert_identical(a, expected)
1718+
17051719
def test_coords_alignment(self) -> None:
17061720
lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])])
17071721
rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])])

xarray/tests/test_dataset.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4311,9 +4311,11 @@ def test_to_stacked_array_preserves_dtype(self) -> None:
43114311
# coordinate created from variables names should be of string dtype
43124312
data = np.array(["a", "a", "a", "b"], dtype="<U1")
43134313
expected_stacked_variable = DataArray(name="variable", data=data, dims="z")
4314+
4315+
# coerce from `IndexVariable` to `Variable` before comparing
43144316
assert_identical(
4315-
stacked.coords["variable"].drop_vars(["z", "variable", "y"]),
4316-
expected_stacked_variable,
4317+
stacked["variable"].variable.to_base_variable(),
4318+
expected_stacked_variable.variable,
43174319
)
43184320

43194321
def test_to_stacked_array_transposed(self) -> None:
@@ -4779,6 +4781,18 @@ def test_setitem_using_list_errors(self, var_list, data, error_regex) -> None:
47794781
with pytest.raises(ValueError, match=error_regex):
47804782
actual[var_list] = data
47814783

4784+
def test_setitem_uses_base_variable_class_even_for_index_variables(self) -> None:
4785+
ds = Dataset(coords={"x": [1, 2, 3]})
4786+
ds["y"] = ds["x"]
4787+
4788+
# explicit check
4789+
assert isinstance(ds["x"].variable, IndexVariable)
4790+
assert not isinstance(ds["y"].variable, IndexVariable)
4791+
4792+
# test internal invariant checks when comparing the datasets
4793+
expected = Dataset(data_vars={"y": ("x", [1, 2, 3])}, coords={"x": [1, 2, 3]})
4794+
assert_identical(ds, expected)
4795+
47824796
def test_assign(self) -> None:
47834797
ds = Dataset()
47844798
actual = ds.assign(x=[0, 1, 2], y=2)

0 commit comments

Comments
 (0)