|
| 1 | +# This file includes tests from dask.fft module: |
| 2 | +# https://github.com/dask/dask/blob/main/dask/array/tests/test_fft.py |
| 3 | + |
| 4 | +import contextlib |
| 5 | +from itertools import combinations_with_replacement |
| 6 | + |
| 7 | +import dask |
| 8 | +import dask.array as da |
| 9 | +import numpy as np |
| 10 | +import pytest |
| 11 | +from dask.array.numpy_compat import NUMPY_GE_200 |
| 12 | +from dask.array.utils import assert_eq, same_keys |
| 13 | + |
| 14 | +import mkl_fft.interfaces.dask_fft as dask_fft |
| 15 | + |
| 16 | +requires_dask_2024_8_2 = pytest.mark.skipif( |
| 17 | + dask.__version__ < "2024.8.2", |
| 18 | + reason="norm kwarg requires Dask >= 2024.8.2", |
| 19 | +) |
| 20 | + |
| 21 | +all_1d_funcnames = ["fft", "ifft", "rfft", "irfft", "hfft", "ihfft"] |
| 22 | + |
| 23 | +all_nd_funcnames = [ |
| 24 | + "fft2", |
| 25 | + "ifft2", |
| 26 | + "fftn", |
| 27 | + "ifftn", |
| 28 | + "rfft2", |
| 29 | + "irfft2", |
| 30 | + "rfftn", |
| 31 | + "irfftn", |
| 32 | +] |
| 33 | + |
| 34 | +if not da._array_expr_enabled(): |
| 35 | + |
| 36 | + nparr = np.arange(100).reshape(10, 10) |
| 37 | + darr = da.from_array(nparr, chunks=(1, 10)) |
| 38 | + darr2 = da.from_array(nparr, chunks=(10, 1)) |
| 39 | + darr3 = da.from_array(nparr, chunks=(10, 10)) |
| 40 | + |
| 41 | + |
| 42 | +@pytest.mark.parametrize("funcname", all_1d_funcnames) |
| 43 | +def test_cant_fft_chunked_axis(funcname): |
| 44 | + da_fft = getattr(dask_fft, funcname) |
| 45 | + |
| 46 | + bad_darr = da.from_array(nparr, chunks=(5, 5)) |
| 47 | + for i in range(bad_darr.ndim): |
| 48 | + with pytest.raises(ValueError): |
| 49 | + da_fft(bad_darr, axis=i) |
| 50 | + |
| 51 | + |
| 52 | +@pytest.mark.parametrize("funcname", all_1d_funcnames) |
| 53 | +def test_fft(funcname): |
| 54 | + da_fft = getattr(dask_fft, funcname) |
| 55 | + np_fft = getattr(np.fft, funcname) |
| 56 | + |
| 57 | + # pylint: disable=possibly-used-before-assignment |
| 58 | + assert_eq(da_fft(darr), np_fft(nparr)) |
| 59 | + |
| 60 | + |
| 61 | +@pytest.mark.parametrize("funcname", all_nd_funcnames) |
| 62 | +def test_fft2n_shapes(funcname): |
| 63 | + da_fft = getattr(dask_fft, funcname) |
| 64 | + np_fft = getattr(np.fft, funcname) |
| 65 | + |
| 66 | + # pylint: disable=possibly-used-before-assignment |
| 67 | + assert_eq(da_fft(darr3), np_fft(nparr)) |
| 68 | + assert_eq( |
| 69 | + da_fft(darr3, (8, 9), axes=(1, 0)), np_fft(nparr, (8, 9), axes=(1, 0)) |
| 70 | + ) |
| 71 | + assert_eq( |
| 72 | + da_fft(darr3, (12, 11), axes=(1, 0)), |
| 73 | + np_fft(nparr, (12, 11), axes=(1, 0)), |
| 74 | + ) |
| 75 | + |
| 76 | + if NUMPY_GE_200 and funcname.endswith("fftn"): |
| 77 | + ctx = pytest.warns( |
| 78 | + DeprecationWarning, |
| 79 | + match="`axes` should not be `None` if `s` is not `None`", |
| 80 | + ) |
| 81 | + else: |
| 82 | + ctx = contextlib.nullcontext() |
| 83 | + with ctx: |
| 84 | + expect = np_fft(nparr, (8, 9)) |
| 85 | + with ctx: |
| 86 | + actual = da_fft(darr3, (8, 9)) |
| 87 | + assert_eq(expect, actual) |
| 88 | + |
| 89 | + |
| 90 | +@requires_dask_2024_8_2 |
| 91 | +@pytest.mark.parametrize("funcname", all_1d_funcnames) |
| 92 | +def test_fft_n_kwarg(funcname): |
| 93 | + da_fft = getattr(dask_fft, funcname) |
| 94 | + np_fft = getattr(np.fft, funcname) |
| 95 | + |
| 96 | + assert_eq(da_fft(darr, 5), np_fft(nparr, 5)) |
| 97 | + assert_eq(da_fft(darr, 13), np_fft(nparr, 13)) |
| 98 | + assert_eq( |
| 99 | + da_fft(darr, 13, norm="backward"), np_fft(nparr, 13, norm="backward") |
| 100 | + ) |
| 101 | + assert_eq(da_fft(darr, 13, norm="ortho"), np_fft(nparr, 13, norm="ortho")) |
| 102 | + assert_eq( |
| 103 | + da_fft(darr, 13, norm="forward"), np_fft(nparr, 13, norm="forward") |
| 104 | + ) |
| 105 | + # pylint: disable=possibly-used-before-assignment |
| 106 | + assert_eq(da_fft(darr2, axis=0), np_fft(nparr, axis=0)) |
| 107 | + assert_eq(da_fft(darr2, 5, axis=0), np_fft(nparr, 5, axis=0)) |
| 108 | + assert_eq( |
| 109 | + da_fft(darr2, 13, axis=0, norm="backward"), |
| 110 | + np_fft(nparr, 13, axis=0, norm="backward"), |
| 111 | + ) |
| 112 | + assert_eq( |
| 113 | + da_fft(darr2, 12, axis=0, norm="ortho"), |
| 114 | + np_fft(nparr, 12, axis=0, norm="ortho"), |
| 115 | + ) |
| 116 | + assert_eq( |
| 117 | + da_fft(darr2, 12, axis=0, norm="forward"), |
| 118 | + np_fft(nparr, 12, axis=0, norm="forward"), |
| 119 | + ) |
| 120 | + |
| 121 | + |
| 122 | +@pytest.mark.parametrize("funcname", all_1d_funcnames) |
| 123 | +def test_fft_consistent_names(funcname): |
| 124 | + da_fft = getattr(dask_fft, funcname) |
| 125 | + |
| 126 | + assert same_keys(da_fft(darr, 5), da_fft(darr, 5)) |
| 127 | + assert same_keys(da_fft(darr2, 5, axis=0), da_fft(darr2, 5, axis=0)) |
| 128 | + assert not same_keys(da_fft(darr, 5), da_fft(darr, 13)) |
| 129 | + |
| 130 | + |
| 131 | +@pytest.mark.parametrize("funcname", all_nd_funcnames) |
| 132 | +@pytest.mark.parametrize("dtype", ["float32", "float64"]) |
| 133 | +def test_nd_ffts_axes(funcname, dtype): |
| 134 | + np_fft = getattr(np.fft, funcname) |
| 135 | + da_fft = getattr(dask_fft, funcname) |
| 136 | + |
| 137 | + shape = (7, 8, 9) |
| 138 | + chunk_size = (3, 3, 3) |
| 139 | + a = np.arange(np.prod(shape), dtype=dtype).reshape(shape) |
| 140 | + d = da.from_array(a, chunks=chunk_size) |
| 141 | + |
| 142 | + for num_axes in range(1, d.ndim): |
| 143 | + for axes in combinations_with_replacement(range(d.ndim), num_axes): |
| 144 | + cs = list(chunk_size) |
| 145 | + for i in axes: |
| 146 | + cs[i] = shape[i] |
| 147 | + d2 = d.rechunk(cs) |
| 148 | + if len(set(axes)) < len(axes): |
| 149 | + with pytest.raises(ValueError): |
| 150 | + da_fft(d2, axes=axes) |
| 151 | + else: |
| 152 | + r = da_fft(d2, axes=axes) |
| 153 | + er = np_fft(a, axes=axes) |
| 154 | + if np.lib.NumpyVersion(np.__version__) >= "2.0.0": |
| 155 | + check_dtype = True |
| 156 | + assert r.dtype == er.dtype |
| 157 | + else: |
| 158 | + check_dtype = False |
| 159 | + assert r.shape == er.shape |
| 160 | + |
| 161 | + assert_eq(r, er, check_dtype=check_dtype, rtol=1e-6, atol=1e-4) |
0 commit comments