|
1 | | -# ffts |
2 | | -function ChainRulesCore.frule((_, Δx, _), ::typeof(fft), x::AbstractArray, dims) |
3 | | - y = fft(x, dims) |
4 | | - Δy = fft(Δx, dims) |
5 | | - return y, Δy |
6 | | -end |
7 | | -function ChainRulesCore.rrule(::typeof(fft), x::AbstractArray, dims) |
8 | | - y = fft(x, dims) |
9 | | - project_x = ChainRulesCore.ProjectTo(x) |
10 | | - function fft_pullback(ȳ) |
11 | | - x̄ = project_x(bfft(ChainRulesCore.unthunk(ȳ), dims)) |
12 | | - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
13 | | - end |
14 | | - return y, fft_pullback |
15 | | -end |
16 | | - |
17 | | -function ChainRulesCore.frule((_, Δx, _), ::typeof(rfft), x::AbstractArray{<:Real}, dims) |
18 | | - y = rfft(x, dims) |
19 | | - Δy = rfft(Δx, dims) |
20 | | - return y, Δy |
21 | | -end |
22 | | -function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) |
23 | | - y = rfft(x, dims) |
24 | | - |
25 | | - # compute scaling factors |
26 | | - halfdim = first(dims) |
27 | | - d = size(x, halfdim) |
28 | | - n = size(y, halfdim) |
29 | | - scale = reshape( |
30 | | - [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], |
31 | | - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), |
32 | | - ) |
33 | | - |
34 | | - project_x = ChainRulesCore.ProjectTo(x) |
35 | | - function rfft_pullback(ȳ) |
36 | | - x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims)) |
37 | | - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
38 | | - end |
39 | | - return y, rfft_pullback |
40 | | -end |
41 | | - |
42 | | -function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dims) |
43 | | - y = ifft(x, dims) |
44 | | - Δy = ifft(Δx, dims) |
45 | | - return y, Δy |
46 | | -end |
47 | | -function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims) |
48 | | - y = ifft(x, dims) |
49 | | - invN = normalization(y, dims) |
50 | | - project_x = ChainRulesCore.ProjectTo(x) |
51 | | - function ifft_pullback(ȳ) |
52 | | - x̄ = project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims)) |
53 | | - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
54 | | - end |
55 | | - return y, ifft_pullback |
56 | | -end |
57 | | - |
58 | | -function ChainRulesCore.frule((_, Δx, _, _), ::typeof(irfft), x::AbstractArray, d::Int, dims) |
59 | | - y = irfft(x, d, dims) |
60 | | - Δy = irfft(Δx, d, dims) |
61 | | - return y, Δy |
62 | | -end |
63 | | -function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) |
64 | | - y = irfft(x, d, dims) |
65 | | - |
66 | | - # compute scaling factors |
67 | | - halfdim = first(dims) |
68 | | - n = size(x, halfdim) |
69 | | - invN = normalization(y, dims) |
70 | | - twoinvN = 2 * invN |
71 | | - scale = reshape( |
72 | | - [i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n], |
73 | | - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), |
74 | | - ) |
75 | | - |
76 | | - project_x = ChainRulesCore.ProjectTo(x) |
77 | | - function irfft_pullback(ȳ) |
78 | | - x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) |
79 | | - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() |
| 1 | +for f in (:fft, :bfft, :ifft, :rfft) |
| 2 | + pf = Symbol("plan_", f) |
| 3 | + @eval begin |
| 4 | + function ChainRulesCore.frule((_, Δx, _), ::typeof($f), x::AbstractArray, dims) |
| 5 | + y = $f(x, dims) |
| 6 | + Δy = $f(Δx, dims) |
| 7 | + return y, Δy |
| 8 | + end |
| 9 | + function ChainRulesCore.rrule(::typeof($f), x::T, dims) where {T<:AbstractArray} |
| 10 | + y = $f(x, dims) |
| 11 | + project_x = ChainRulesCore.ProjectTo(x) |
| 12 | + ax = axes(x) |
| 13 | + function fft_pullback(ȳ) |
| 14 | + x̄ = project_x($pf(similar(T, ax), dims)' * ChainRulesCore.unthunk(ȳ)) |
| 15 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
| 16 | + end |
| 17 | + return y, fft_pullback |
| 18 | + end |
80 | 19 | end |
81 | | - return y, irfft_pullback |
82 | 20 | end |
83 | 21 |
|
84 | | -function ChainRulesCore.frule((_, Δx, _), ::typeof(bfft), x::AbstractArray, dims) |
85 | | - y = bfft(x, dims) |
86 | | - Δy = bfft(Δx, dims) |
87 | | - return y, Δy |
88 | | -end |
89 | | -function ChainRulesCore.rrule(::typeof(bfft), x::AbstractArray, dims) |
90 | | - y = bfft(x, dims) |
91 | | - project_x = ChainRulesCore.ProjectTo(x) |
92 | | - function bfft_pullback(ȳ) |
93 | | - x̄ = project_x(fft(ChainRulesCore.unthunk(ȳ), dims)) |
94 | | - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() |
| 22 | +for f in (:brfft, :irfft) |
| 23 | + pf = Symbol("plan_", f) |
| 24 | + @eval begin |
| 25 | + function ChainRulesCore.frule((_, Δx, _), ::typeof($f), x::AbstractArray, d::Int, dims) |
| 26 | + y = $f(x, d::Int, dims) |
| 27 | + Δy = $f(Δx, d::Int, dims) |
| 28 | + return y, Δy |
| 29 | + end |
| 30 | + function ChainRulesCore.rrule(::typeof($f), x::T, d::Int, dims) where {T<:AbstractArray} |
| 31 | + y = $f(x, d, dims) |
| 32 | + project_x = ChainRulesCore.ProjectTo(x) |
| 33 | + ax = axes(x) |
| 34 | + function fft_pullback(ȳ) |
| 35 | + x̄ = project_x($pf(similar(T, ax), d, dims)' * ChainRulesCore.unthunk(ȳ)) |
| 36 | + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() |
| 37 | + end |
| 38 | + return y, fft_pullback |
| 39 | + end |
95 | 40 | end |
96 | | - return y, bfft_pullback |
97 | | -end |
98 | | - |
99 | | -function ChainRulesCore.frule((_, Δx, _, _), ::typeof(brfft), x::AbstractArray, d::Int, dims) |
100 | | - y = brfft(x, d, dims) |
101 | | - Δy = brfft(Δx, d, dims) |
102 | | - return y, Δy |
103 | | -end |
104 | | -function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims) |
105 | | - y = brfft(x, d, dims) |
106 | | - |
107 | | - # compute scaling factors |
108 | | - halfdim = first(dims) |
109 | | - n = size(x, halfdim) |
110 | | - scale = reshape( |
111 | | - [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], |
112 | | - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), |
113 | | - ) |
114 | | - |
115 | | - project_x = ChainRulesCore.ProjectTo(x) |
116 | | - function brfft_pullback(ȳ) |
117 | | - x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) |
118 | | - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() |
119 | | - end |
120 | | - return y, brfft_pullback |
121 | 41 | end |
122 | 42 |
|
123 | 43 | # shift functions |
@@ -150,3 +70,19 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims) |
150 | 70 | end |
151 | 71 | return y, ifftshift_pullback |
152 | 72 | end |
| 73 | + |
| 74 | +# plans |
| 75 | +function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::Plan, x::AbstractArray) |
| 76 | + y = P * x |
| 77 | + Δy = P * Δx |
| 78 | + return y, Δy |
| 79 | +end |
| 80 | +function ChainRulesCore.rrule(::typeof(*), P::Plan, x::AbstractArray) |
| 81 | + y = P * x |
| 82 | + project_x = ChainRulesCore.ProjectTo(x) |
| 83 | + function fft_pullback(ȳ) |
| 84 | + x̄ = project_x(P' * ȳ) |
| 85 | + return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), x̄ |
| 86 | + end |
| 87 | + return y, fft_pullback |
| 88 | +end |
0 commit comments