Skip to content

Commit 52cb7a6

Browse files
authored
Fix evaluation of distribution methods on scalars (#39)
* fix dist method eval on scalars * update changelog * add helpful error for ndarray with ndim>=2
1 parent aea6c6d commit 52cb7a6

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

docs/source/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
### Maintenance and fixes
99
* Fix issue in `linalg.svd` for non-square matrices {pull}`37`
10+
* Fix evaluation of distribution methods (e.g. `.pdf`) on scalars {pull}`38` and {pull}`39`
1011

1112
### Documentation
1213
* Ported NumPy tutorial on linear algebra with multidimensional arrays {pull}`37`

src/xarray_einstats/stats.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,15 @@ def _asdataarray(x_or_q, dim_name):
9191
"""
9292
if isinstance(x_or_q, xr.DataArray):
9393
return x_or_q
94-
return xr.DataArray(np.asarray(x_or_q), dims=[dim_name], coords={dim_name: np.asarray(x_or_q)})
94+
x_or_q_ary = np.asarray(x_or_q)
95+
if x_or_q_ary.ndim == 0:
96+
return xr.DataArray(x_or_q_ary)
97+
if x_or_q_ary.ndim == 1:
98+
return xr.DataArray(x_or_q_ary, dims=[dim_name], coords={dim_name: np.asarray(x_or_q)})
99+
raise ValueError(
100+
"To evaluate distribution methods on data with >=2 dims,"
101+
" the input needs to be a xarray.DataArray"
102+
)
95103

96104

97105
def _wrap_method(method):

tests/test_stats.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,23 @@ def data():
3232

3333
@pytest.mark.parametrize("wrapper", ("continuous", "discrete"))
3434
class TestRvWrappers:
35+
@pytest.mark.parametrize(
36+
"method", ("pxf", "logpxf", "cdf", "logcdf", "sf", "logsf", "ppf", "isf")
37+
)
38+
def test_eval_methods_scalar(self, data, wrapper, method):
39+
if wrapper == "continuous":
40+
dist = XrContinuousRV(stats.norm, data["mu"], data["sigma"])
41+
if "pxf" in method:
42+
method = method.replace("x", "d")
43+
else:
44+
dist = XrDiscreteRV(stats.poisson, data["mu"], data["sigma"])
45+
if "pxf" in method:
46+
method = method.replace("x", "m")
47+
meth = getattr(dist, method)
48+
out = meth(0.9)
49+
assert out.ndim == 3
50+
assert_dims_not_in_da(out, ["quantile", "point"])
51+
3552
@pytest.mark.parametrize(
3653
"method", ("pxf", "logpxf", "cdf", "logcdf", "sf", "logsf", "ppf", "isf")
3754
)

0 commit comments

Comments
 (0)