Skip to content

Commit 15f9b9b

Browse files
ilan-goldrichard-bergpre-commit-ci[bot]dcherian
authored
(fix): handle internal type promotion and nas for extension arrays properly (#10423)
* Improve support for pandas Extension Arrays (#10301) * (chore): remove non-reindex fixes * merge * (fix): minimize more api, mostly working * (fix): allow through scalars with extension arrays in `result_type` * (refactor): clean reindexing test * (chore): remove redundant test * (fix): some types * (chore): remove commented out tests * (fix): more typing * (fix): more typing! * (fix): `Scalar` as a type * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: bring back import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update xarray/core/extension_array.py Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com> * refactor: `args = ...` location * fix: remove comment * fix: do scalar checking inside `__extension_duck_array__where` + remove banned API * chore: add better comment * fix: mypy * fix: mypy? * fix: don't use `dtype` where it wasn't before * fix: test fixture creation * fix: apply suggestion --------- Co-authored-by: Richard Berg <rberg@jumptrading.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
1 parent 3425d96 commit 15f9b9b

File tree

7 files changed

+478
-111
lines changed

7 files changed

+478
-111
lines changed

properties/test_pandas_roundtrip.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
"""
44

55
from functools import partial
6+
from typing import cast
67

78
import numpy as np
89
import pandas as pd
910
import pytest
1011

1112
import xarray as xr
13+
from xarray.core.dataset import Dataset
1214

1315
pytest.importorskip("hypothesis")
1416
import hypothesis.extra.numpy as npst # isort:skip
@@ -88,10 +90,10 @@ def test_roundtrip_dataarray(data, arr) -> None:
8890

8991

9092
@given(datasets_1d_vars())
91-
def test_roundtrip_dataset(dataset) -> None:
93+
def test_roundtrip_dataset(dataset: Dataset) -> None:
9294
df = dataset.to_dataframe()
9395
assert isinstance(df, pd.DataFrame)
94-
roundtripped = xr.Dataset(df)
96+
roundtripped = xr.Dataset.from_dataframe(df)
9597
xr.testing.assert_identical(dataset, roundtripped)
9698

9799

@@ -101,7 +103,7 @@ def test_roundtrip_pandas_series(ser, ix_name) -> None:
101103
ser.index.name = ix_name
102104
arr = xr.DataArray(ser)
103105
roundtripped = arr.to_pandas()
104-
pd.testing.assert_series_equal(ser, roundtripped)
106+
pd.testing.assert_series_equal(ser, roundtripped) # type: ignore[arg-type]
105107
xr.testing.assert_identical(arr, roundtripped.to_xarray())
106108

107109

@@ -119,7 +121,7 @@ def test_roundtrip_pandas_dataframe(df) -> None:
119121
df.columns.name = "cols"
120122
arr = xr.DataArray(df)
121123
roundtripped = arr.to_pandas()
122-
pd.testing.assert_frame_equal(df, roundtripped)
124+
pd.testing.assert_frame_equal(df, cast(pd.DataFrame, roundtripped))
123125
xr.testing.assert_identical(arr, roundtripped.to_xarray())
124126

125127

@@ -143,8 +145,8 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None:
143145
pd.arrays.IntervalArray(
144146
[pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)]
145147
),
146-
pd.arrays.TimedeltaArray._from_sequence(pd.TimedeltaIndex(["1h", "2h", "3h"])),
147-
pd.arrays.DatetimeArray._from_sequence(
148+
pd.arrays.TimedeltaArray._from_sequence(pd.TimedeltaIndex(["1h", "2h", "3h"])), # type: ignore[attr-defined]
149+
pd.arrays.DatetimeArray._from_sequence( # type: ignore[attr-defined]
148150
pd.DatetimeIndex(["2023-01-01", "2023-01-02", "2023-01-03"], freq="D")
149151
),
150152
np.array([1, 2, 3], dtype="int64"),

xarray/core/dtypes.py

Lines changed: 79 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
from __future__ import annotations
22

33
import functools
4-
from typing import Any
4+
from collections.abc import Iterable
5+
from typing import TYPE_CHECKING, TypeVar, cast
56

67
import numpy as np
7-
import pandas as pd
8+
from pandas.api.extensions import ExtensionDtype
89

910
from xarray.compat import array_api_compat, npcompat
1011
from xarray.compat.npcompat import HAS_STRING_DTYPE
1112
from xarray.core import utils
1213

14+
if TYPE_CHECKING:
15+
from typing import Any
16+
17+
1318
# Use as a sentinel value to indicate a dtype appropriate NA value.
1419
NA = utils.ReprObject("<NA>")
1520

@@ -47,8 +52,10 @@ def __eq__(self, other):
4752
(np.bytes_, np.str_), # numpy promotes to unicode
4853
)
4954

55+
T_dtype = TypeVar("T_dtype", np.dtype, ExtensionDtype)
5056

51-
def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:
57+
58+
def maybe_promote(dtype: T_dtype) -> tuple[T_dtype, Any]:
5259
"""Simpler equivalent of pandas.core.common._maybe_promote
5360
5461
Parameters
@@ -63,7 +70,13 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:
6370
# N.B. these casting rules should match pandas
6471
dtype_: np.typing.DTypeLike
6572
fill_value: Any
66-
if HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()):
73+
if utils.is_allowed_extension_array_dtype(dtype):
74+
return dtype, cast(ExtensionDtype, dtype).na_value # type: ignore[redundant-cast]
75+
if not isinstance(dtype, np.dtype):
76+
raise TypeError(
77+
f"dtype {dtype} must be one of an extension array dtype or numpy dtype"
78+
)
79+
elif HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()):
6780
# for now, we always promote string dtypes to object for consistency with existing behavior
6881
# TODO: refactor this once we have a better way to handle numpy vlen-string dtypes
6982
dtype_ = object
@@ -213,7 +226,7 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:
213226

214227
if isinstance(dtype, np.dtype):
215228
return npcompat.isdtype(dtype, kind)
216-
elif pd.api.types.is_extension_array_dtype(dtype): # noqa: TID251
229+
elif utils.is_allowed_extension_array_dtype(dtype):
217230
# we never want to match pandas extension array dtypes
218231
return False
219232
else:
@@ -222,23 +235,67 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:
222235
return xp.isdtype(dtype, kind)
223236

224237

225-
def preprocess_types(t):
226-
if isinstance(t, str | bytes):
227-
return type(t)
228-
elif isinstance(dtype := getattr(t, "dtype", t), np.dtype) and (
229-
np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_)
230-
):
238+
def maybe_promote_to_variable_width(
239+
array_or_dtype: np.typing.ArrayLike
240+
| np.typing.DTypeLike
241+
| ExtensionDtype
242+
| str
243+
| bytes,
244+
*,
245+
should_return_str_or_bytes: bool = False,
246+
) -> np.typing.ArrayLike | np.typing.DTypeLike | ExtensionDtype:
247+
if isinstance(array_or_dtype, str | bytes):
248+
if should_return_str_or_bytes:
249+
return array_or_dtype
250+
return type(array_or_dtype)
251+
elif isinstance(
252+
dtype := getattr(array_or_dtype, "dtype", array_or_dtype), np.dtype
253+
) and (np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_)):
231254
# drop the length from numpy's fixed-width string dtypes, it is better to
232255
# recalculate
233256
# TODO(keewis): remove once the minimum version of `numpy.result_type` does this
234257
# for us
235258
return dtype.type
236259
else:
237-
return t
260+
return array_or_dtype
261+
262+
263+
def should_promote_to_object(
264+
arrays_and_dtypes: Iterable[
265+
np.typing.ArrayLike | np.typing.DTypeLike | ExtensionDtype
266+
],
267+
xp,
268+
) -> bool:
269+
"""
270+
Test whether the given arrays_and_dtypes, when evaluated individually, match the
271+
type promotion rules found in PROMOTE_TO_OBJECT.
272+
"""
273+
np_result_types = set()
274+
for arr_or_dtype in arrays_and_dtypes:
275+
try:
276+
result_type = array_api_compat.result_type(
277+
maybe_promote_to_variable_width(arr_or_dtype), xp=xp
278+
)
279+
if isinstance(result_type, np.dtype):
280+
np_result_types.add(result_type)
281+
except TypeError:
282+
# passing individual objects to xp.result_type (i.e., what `array_api_compat.result_type` calls) means NEP-18 implementations won't have
283+
# a chance to intercept special values (such as NA) that numpy core cannot handle.
284+
# Thus they are considered as types that don't need promotion i.e., the `arr_or_dtype` that rose the `TypeError` will not contribute to `np_result_types`.
285+
pass
286+
287+
if np_result_types:
288+
for left, right in PROMOTE_TO_OBJECT:
289+
if any(np.issubdtype(t, left) for t in np_result_types) and any(
290+
np.issubdtype(t, right) for t in np_result_types
291+
):
292+
return True
293+
294+
return False
238295

239296

240297
def result_type(
241-
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike | None,
298+
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike | ExtensionDtype,
242299
xp=None,
243300
) -> np.dtype:
244301
"""Like np.result_type, but with type promotion rules matching pandas.
@@ -263,19 +320,13 @@ def result_type(
263320
if xp is None:
264321
xp = get_array_namespace(arrays_and_dtypes)
265322

266-
types = {
267-
array_api_compat.result_type(preprocess_types(t), xp=xp)
268-
for t in arrays_and_dtypes
269-
}
270-
if any(isinstance(t, np.dtype) for t in types):
271-
# only check if there's numpy dtypes – the array API does not
272-
# define the types we're checking for
273-
for left, right in PROMOTE_TO_OBJECT:
274-
if any(np.issubdtype(t, left) for t in types) and any(
275-
np.issubdtype(t, right) for t in types
276-
):
277-
return np.dtype(object)
278-
279-
return array_api_compat.result_type(
280-
*map(preprocess_types, arrays_and_dtypes), xp=xp
323+
if should_promote_to_object(arrays_and_dtypes, xp):
324+
return np.dtype(object)
325+
maybe_promote = functools.partial(
326+
maybe_promote_to_variable_width,
327+
# let extension arrays handle their own str/bytes
328+
should_return_str_or_bytes=any(
329+
map(utils.is_allowed_extension_array_dtype, arrays_and_dtypes)
330+
),
281331
)
332+
return array_api_compat.result_type(*map(maybe_promote, arrays_and_dtypes), xp=xp)

xarray/core/duck_array_ops.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,17 @@
2727
from xarray.compat import dask_array_compat, dask_array_ops
2828
from xarray.compat.array_api_compat import get_array_namespace
2929
from xarray.core import dtypes, nputils
30-
from xarray.core.extension_array import PandasExtensionArray
30+
from xarray.core.extension_array import (
31+
PandasExtensionArray,
32+
as_extension_array,
33+
)
3134
from xarray.core.options import OPTIONS
32-
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available
35+
from xarray.core.utils import (
36+
is_allowed_extension_array_dtype,
37+
is_duck_array,
38+
is_duck_dask_array,
39+
module_available,
40+
)
3341
from xarray.namedarray.parallelcompat import get_chunked_array_type
3442
from xarray.namedarray.pycompat import array_type, is_chunked_array
3543

@@ -252,7 +260,14 @@ def astype(data, dtype, *, xp=None, **kwargs):
252260

253261

254262
def asarray(data, xp=np, dtype=None):
255-
converted = data if is_duck_array(data) else xp.asarray(data)
263+
if is_duck_array(data):
264+
converted = data
265+
elif is_allowed_extension_array_dtype(dtype):
266+
# data may or may not be an ExtensionArray, so we can't rely on
267+
# np.asarray to call our NEP-18 handler; gotta hook it ourselves
268+
converted = PandasExtensionArray(as_extension_array(data, dtype))
269+
else:
270+
converted = xp.asarray(data)
256271

257272
if dtype is None or converted.dtype == dtype:
258273
return converted
@@ -264,29 +279,7 @@ def asarray(data, xp=np, dtype=None):
264279

265280

266281
def as_shared_dtype(scalars_or_arrays, xp=None):
267-
"""Cast arrays to a shared dtype using xarray's type promotion rules."""
268-
extension_array_types = [
269-
x.dtype
270-
for x in scalars_or_arrays
271-
if pd.api.types.is_extension_array_dtype(x) # noqa: TID251
272-
]
273-
if len(extension_array_types) >= 1:
274-
non_nans = [x for x in scalars_or_arrays if not isna(x)]
275-
if len(extension_array_types) == len(non_nans) and all(
276-
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
277-
):
278-
return [
279-
x
280-
if not isna(x)
281-
else PandasExtensionArray(
282-
type(non_nans[0].array)._from_sequence([x], dtype=non_nans[0].dtype)
283-
)
284-
for x in scalars_or_arrays
285-
]
286-
raise ValueError(
287-
f"Cannot cast values to shared type, found values: {scalars_or_arrays}"
288-
)
289-
282+
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
290283
# Avoid calling array_type("cupy") repeatidely in the any check
291284
array_type_cupy = array_type("cupy")
292285
if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
@@ -295,7 +288,12 @@ def as_shared_dtype(scalars_or_arrays, xp=None):
295288
xp = cp
296289
elif xp is None:
297290
xp = get_array_namespace(scalars_or_arrays)
298-
291+
scalars_or_arrays = [
292+
PandasExtensionArray(s_or_a)
293+
if isinstance(s_or_a, pd.api.extensions.ExtensionArray)
294+
else s_or_a
295+
for s_or_a in scalars_or_arrays
296+
]
299297
# Pass arrays directly instead of dtypes to result_type so scalars
300298
# get handled properly.
301299
# Note that result_type() safely gets the dtype from dask arrays without
@@ -406,7 +404,9 @@ def where(condition, x, y):
406404
else:
407405
condition = astype(condition, dtype=dtype, xp=xp)
408406

409-
return xp.where(condition, *as_shared_dtype([x, y], xp=xp))
407+
promoted_x, promoted_y = as_shared_dtype([x, y], xp=xp)
408+
409+
return xp.where(condition, promoted_x, promoted_y)
410410

411411

412412
def where_method(data, cond, other=dtypes.NA):

0 commit comments

Comments
 (0)