@@ -29,16 +29,16 @@ testing: this would most commonly be a constructor such as `Array` or `CuArray`.
2929function test_fft_backend (array_constructor; test_real= true , test_inplace= true )
3030 @testset " fft correctness" begin
3131 # DFT along last dimension, results computed using FFTW
32- for (_x, _fftw_fft) in (
33- (collect (1 : 7 ),
32+ for (_x, dims, real_input, _fftw_fft) in (
33+ (collect (1 : 7 ), 1 , true ,
3434 [28.0 + 0.0im ,
3535 - 3.5 + 7.267824888003178im ,
3636 - 3.5 + 2.7911568610884143im ,
3737 - 3.5 + 0.7988521603655248im ,
3838 - 3.5 - 0.7988521603655248im ,
3939 - 3.5 - 2.7911568610884143im ,
4040 - 3.5 - 7.267824888003178im ]),
41- (collect (1 : 8 ),
41+ (collect (1 : 8 ), 1 , true ,
4242 [36.0 + 0.0im ,
4343 - 4.0 + 9.65685424949238im ,
4444 - 4.0 + 4.0im ,
@@ -47,21 +47,32 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
4747 - 4.0 - 1.6568542494923806im ,
4848 - 4.0 - 4.0im ,
4949 - 4.0 - 9.65685424949238im ]),
50- (collect (reshape (1 : 8 , 2 , 4 )),
50+ (collect (reshape (1 : 8 , 2 , 4 )), 2 , true ,
5151 [16.0 + 0.0im - 4.0 + 4.0im - 4.0 + 0.0im - 4.0 - 4.0im ;
5252 20.0 + 0.0im - 4.0 + 4.0im - 4.0 + 0.0im - 4.0 - 4.0im ]),
53- (collect (reshape (1 : 9 , 3 , 3 )),
53+ (collect (reshape (1 : 9 , 3 , 3 )), 2 , true ,
5454 [12.0 + 0.0im - 4.5 + 2.598076211353316im - 4.5 - 2.598076211353316im ;
5555 15.0 + 0.0im - 4.5 + 2.598076211353316im - 4.5 - 2.598076211353316im ;
5656 18.0 + 0.0im - 4.5 + 2.598076211353316im - 4.5 - 2.598076211353316im ]),
57+ (collect (reshape (1 : 8 , 2 , 2 , 2 )), 1 : 2 , true ,
58+ [10.0 + 0.0im - 4.0 + 0.0im ; - 2.0 + 0.0im 0.0 + 0.0im ;;;
59+ 26.0 + 0.0im - 4.0 + 0.0im ; - 2.0 + 0.0im 0.0 + 0.0im ]),
60+ (collect (1 : 7 ) + im * collect (8 : 14 ), 1 , false ,
61+ [28.0 + 77.0im ,
62+ - 10.76782488800318 + 3.767824888003175im ,
63+ - 6.291156861088416 - 0.7088431389115883im ,
64+ - 4.298852160365525 - 2.7011478396344746im ,
65+ - 2.7011478396344764 - 4.298852160365524im ,
66+ - 0.7088431389115866 - 6.291156861088417im ,
67+ 3.767824888003177 - 10.76782488800318im ]),
68+ (collect (reshape (1 : 8 , 2 , 2 , 2 )) + im * reshape (9 : 16 , 2 , 2 , 2 ), 1 : 2 , false ,
69+ [10.0 + 42.0im - 4.0 - 4.0im ; - 2.0 - 2.0im 0.0 + 0.0im ;;;
70+ 26.0 + 58.0im - 4.0 - 4.0im ; - 2.0 - 2.0im 0.0 + 0.0im ]),
5771 )
5872 x = array_constructor (_x) # dummy array that will be passed to plans
59- x_real = float .(x) # for testing real FFTs
60- x_complex = complex .(x_real) # for testing complex FFTs
73+ x_complex = complex .(float .(x)) # for testing complex FFTs
6174 fftw_fft = array_constructor (_fftw_fft)
6275
63- dims = ndims (x) # TODO : this is a single dimension, should check multidimensional FFTs too
64-
6576 # FFT
6677 y = AbstractFFTs. fft (x_complex, dims)
6778 @test y ≈ fftw_fft
@@ -82,7 +93,7 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
8293 end
8394
8495 # BFFT
85- fftw_bfft = size (x_complex, dims) .* x_complex
96+ fftw_bfft = prod ( size (x_complex, d) for d in dims) .* x_complex
8697 @test AbstractFFTs. bfft (y, dims) ≈ fftw_bfft
8798 test_inplace && (@test AbstractFFTs. bfft! (copy (y), dims) ≈ fftw_bfft)
8899 plans_to_test = [plan_bfft (similar (y), dims)]
@@ -114,16 +125,18 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
114125 @test fftdims (P) == dims
115126 end
116127
117- if test_real
128+ if test_real && real_input
129+ x_real = float .(x) # for testing real FFTs
118130 # RFFT
119131 fftw_rfft = fftw_fft[
120- (Colon () for _ in 1 : (ndims (fftw_fft) - 1 )). .. ,
121- 1 : (size (fftw_fft, ndims (fftw_fft)) ÷ 2 + 1 )
132+ (Colon () for _ in 1 : (first (dims) - 1 )). .. ,
133+ 1 : (size (fftw_fft, first (dims)) ÷ 2 + 1 ),
134+ (Colon () for _ in (first (dims) + 1 ): ndims (fftw_fft)). ..
122135 ]
123136 ry = AbstractFFTs. rfft (x_real, dims)
124137 @test ry ≈ fftw_rfft
125- for P in [plan_rfft (x_real, dims), inv (plan_irfft (ry, size (x, dims), dims)),
126- AbstractFFTs. plan_inv (plan_irfft (ry, size (x, dims), dims))]
138+ for P in [plan_rfft (x_real, dims), inv (plan_irfft (ry, size (x, first ( dims) ), dims)),
139+ AbstractFFTs. plan_inv (plan_irfft (ry, size (x, first ( dims) ), dims))]
127140 @test eltype (P) <: Real
128141 @test P * x_real ≈ fftw_rfft
129142 @test mul! (similar (ry), P, x_real) ≈ fftw_rfft
@@ -132,18 +145,18 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
132145 end
133146
134147 # BRFFT
135- fftw_brfft = complex . (size (x, dims) .* x_real)
136- @test AbstractFFTs. brfft (ry, size (x_real, dims), dims) ≈ fftw_brfft
137- P = plan_brfft (ry, size (x_real, dims), dims)
148+ fftw_brfft = prod (size (x_real, d) for d in dims) .* x_real
149+ @test AbstractFFTs. brfft (ry, size (x_real, first ( dims) ), dims) ≈ fftw_brfft
150+ P = plan_brfft (ry, size (x_real, first ( dims) ), dims)
138151 @test P * ry ≈ fftw_brfft
139152 @test mul! (similar (x_real), P, ry) ≈ fftw_brfft
140153 @test P \ (P * ry) ≈ ry
141154 @test fftdims (P) == dims
142155
143156 # IRFFT
144157 fftw_irfft = x_complex
145- @test AbstractFFTs. irfft (ry, size (x, dims), dims) ≈ fftw_irfft
146- for P in [plan_irfft (ry, size (x, dims), dims), inv (plan_rfft (x_real, dims)),
158+ @test AbstractFFTs. irfft (ry, size (x, first ( dims) ), dims) ≈ fftw_irfft
159+ for P in [plan_irfft (ry, size (x, first ( dims) ), dims), inv (plan_rfft (x_real, dims)),
147160 AbstractFFTs. plan_inv (plan_rfft (x_real, dims))]
148161 @test P * ry ≈ fftw_irfft
149162 @test mul! (similar (x_real), P, ry) ≈ fftw_irfft
0 commit comments