Skip to content

Commit 3249a73

Browse files
authored
Merge pull request #408 from ev-br/fftfreq_dtype
ENH: test_{r,}fftfreq dtype argument
2 parents 4dae9dc + f24e548 commit 3249a73

File tree

1 file changed

+34
-9
lines changed

1 file changed

+34
-9
lines changed

array_api_tests/test_fft.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -279,18 +279,43 @@ def test_ihfft(x, data):
279279
ph.assert_shape("ihfft", out_shape=out.shape, expected=expected_shape)
280280

281281

282-
@given(n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5)))
282+
@given(
283+
n=st.integers(1, 100),
284+
kw=hh.kwargs(d=st.floats(0.1, 5), dtype=hh.real_floating_dtypes),
285+
)
283286
def test_fftfreq(n, kw):
284-
out = xp.fft.fftfreq(n, **kw)
285-
ph.assert_shape("fftfreq", out_shape=out.shape, expected=(n,), kw={"n": n})
286-
287+
repro_snippet = ph.format_snippet(f"xp.fft.fftfreq({n!r}, **kw) with {kw = }")
288+
try:
289+
out = xp.fft.fftfreq(n, **kw)
290+
ph.assert_shape("fftfreq", out_shape=out.shape, expected=(n,), kw={"n": n})
291+
292+
dt = kw.get("dtype", None)
293+
if dt is None:
294+
dt = xp.__array_namespace_info__().default_dtypes()["real floating"]
295+
assert out.dtype == dt
296+
except Exception as exc:
297+
ph.add_note(exc, repro_snippet)
298+
raise
287299

288-
@given(n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5)))
300+
@given(
301+
n=st.integers(1, 100),
302+
kw=hh.kwargs(d=st.floats(0.1, 5), dtype=hh.real_floating_dtypes)
303+
)
289304
def test_rfftfreq(n, kw):
290-
out = xp.fft.rfftfreq(n, **kw)
291-
ph.assert_shape(
292-
"rfftfreq", out_shape=out.shape, expected=(n // 2 + 1,), kw={"n": n}
293-
)
305+
repro_snippet = ph.format_snippet(f"xp.fft.rfftfreq({n!r}, **kw) with {kw = }")
306+
try:
307+
out = xp.fft.rfftfreq(n, **kw)
308+
ph.assert_shape(
309+
"rfftfreq", out_shape=out.shape, expected=(n // 2 + 1,), kw={"n": n}
310+
)
311+
312+
dt = kw.get("dtype", None)
313+
if dt is None:
314+
dt = xp.__array_namespace_info__().default_dtypes()["real floating"]
315+
assert out.dtype == dt
316+
except Exception as exc:
317+
ph.add_note(exc, repro_snippet)
318+
raise
294319

295320

296321
@pytest.mark.parametrize("func_name", ["fftshift", "ifftshift"])

0 commit comments

Comments
 (0)