|
3 | 3 | using AbstractFFTs |
4 | 4 | using AbstractFFTs: Plan |
5 | 5 | using ChainRulesTestUtils |
| 6 | +using ChainRulesCore: NoTangent |
6 | 7 |
|
7 | 8 | using LinearAlgebra |
8 | 9 | using Random |
|
197 | 198 | @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10 |
198 | 199 | end |
199 | 200 |
|
| 201 | +@testset "adjoint" begin |
| 202 | + @testset "complex fft adjoint" begin |
| 203 | + for x in (randn(3), randn(3, 4), randn(3, 4, 5)) |
| 204 | + N = ndims(x) |
| 205 | + y = randn(size(x)) |
| 206 | + for dims in unique((1, 1:N, N)) |
| 207 | + P = plan_fft(x, dims) |
| 208 | + @test dot(y, P * x) ≈ dot(P' * y, x) |
| 209 | + @test_broken dot(y, P \ x) ≈ dot(P' \ y, x) |
| 210 | + Pinv = plan_ifft(x) |
| 211 | + @test dot(x, Pinv * y) ≈ dot(Pinv' * x, y) |
| 212 | + @test_broken dot(x, Pinv \ y) ≈ dot(Pinv' \ x, y) |
| 213 | + end |
| 214 | + end |
| 215 | + end |
| 216 | + @testset "real fft adjoint" begin |
| 217 | + for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths |
| 218 | + N = ndims(x) |
| 219 | + for dims in unique((1, 1:N, N)) |
| 220 | + P = plan_rfft(similar(x), dims) |
| 221 | + y_real = randn(size(P * x)) |
| 222 | + y_imag = randn(size(P * x)) |
| 223 | + y = y_real .+ y_imag .* im |
| 224 | + @test dot(y_real, real.(P * x)) + dot(y_imag, imag.(P * x)) ≈ dot(P' * y, x) |
| 225 | + @test_broken dot(y_real, real.(P \ x)) + dot(y_imag, imag.(P \ x)) ≈ dot(P' * y, x) |
| 226 | + Pinv = plan_irfft(similar(y), size(x)[first(dims)], dims) |
| 227 | + @test dot(x, Pinv * y) ≈ dot(y_real, real.(Pinv' * x)) + dot(y_imag, imag.(Pinv' * x)) |
| 228 | + @test_broken dot(x, Pinv \ y) ≈ dot(y_real, real.(Pinv' \ x)) + dot(y_imag, imag.(Pinv' \ x)) |
| 229 | + end |
| 230 | + end |
| 231 | + end |
| 232 | +end |
| 233 | + |
200 | 234 | @testset "ChainRules" begin |
201 | 235 | @testset "shift functions" begin |
202 | 236 | for x in (randn(3), randn(3, 4), randn(3, 4, 5)) |
|
218 | 252 | end |
219 | 253 |
|
220 | 254 | @testset "fft" begin |
221 | | - for x in (randn(3), randn(3, 4), randn(3, 4, 5)) |
| 255 | + for x in (randn(2), randn(2, 3), randn(3, 4, 5)) |
222 | 256 | N = ndims(x) |
223 | 257 | complex_x = complex.(x) |
224 | 258 | for dims in unique((1, 1:N, N)) |
|
229 | 263 | test_rrule(f, complex_x, dims) |
230 | 264 | end |
231 | 265 |
|
232 | | - test_frule(rfft, x, dims) |
233 | | - test_rrule(rfft, x, dims) |
| 266 | + for pf in (plan_fft, plan_ifft, plan_bfft) |
| 267 | + test_frule(*, pf(x, dims) ⊢ NoTangent(), x) |
| 268 | + test_rrule(*, pf(x, dims) ⊢ NoTangent(), x) |
| 269 | + test_frule(*, pf(complex_x, dims) ⊢ NoTangent(), complex_x) |
| 270 | + test_rrule(*, pf(complex_x, dims) ⊢ NoTangent(), complex_x) |
| 271 | + end |
234 | 272 |
|
235 | 273 | for f in (irfft, brfft) |
236 | 274 | for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2) |
|
240 | 278 | test_rrule(f, complex_x, d, dims) |
241 | 279 | end |
242 | 280 | end |
| 281 | + |
| 282 | + for pf in (plan_irfft, plan_brfft) |
| 283 | + for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2) |
| 284 | + test_frule(*, pf(complex_x, d, dims) ⊢ NoTangent(), complex_x) |
| 285 | + test_rrule(*, pf(complex_x, d, dims) ⊢ NoTangent(), complex_x) |
| 286 | + end |
| 287 | + end |
| 288 | + |
243 | 289 | end |
244 | 290 | end |
245 | 291 | end |
|
0 commit comments