@@ -12,6 +12,7 @@ eltype(::Type{<:Plan{T}}) where {T} = T
1212
1313# size(p) should return the size of the input array for p
1414size (p:: Plan , d) = size (p)[d]
15+ output_size (p:: Plan , d) = output_size (p)[d]
1516ndims (p:: Plan ) = length (size (p))
1617length (p:: Plan ) = prod (size (p)):: Int
1718
@@ -254,6 +255,7 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale)
254255ScaledPlan (p:: ScaledPlan , α:: Number ) = ScaledPlan (p. p, p. scale * α)
255256
256257size (p:: ScaledPlan ) = size (p. p)
258+ output_size (p:: ScaledPlan ) = size (p)
257259
258260region (p:: ScaledPlan ) = region (p. p)
259261
@@ -301,9 +303,12 @@ for f in (:brfft, :irfft)
301303end
302304
303305for f in (:brfft , :irfft )
306+ pf = Symbol (" plan_" , f)
304307 @eval begin
305308 $ f (x:: AbstractArray{<:Real} , d:: Integer , region= 1 : ndims (x)) = $ f (complexfloat (x), d, region)
309+ $ pf (x:: AbstractArray{<:Real} , d:: Integer , region; kws... ) = $ pf (complexfloat (x), d, region; kws... )
306310 $ f (x:: AbstractArray{<:Complex{<:Union{Integer,Rational}}} , d:: Integer , region= 1 : ndims (x)) = $ f (complexfloat (x), d, region)
311+ $ pf (x:: AbstractArray{<:Complex{<:Union{Integer,Rational}}} , d:: Integer , region; kws... ) = $ pf (complexfloat (x), d, region; kws... )
307312 end
308313end
309314
@@ -343,6 +348,16 @@ function brfft_output_size(sz::Dims{N}, d::Integer, region) where {N}
343348 return ntuple (i -> i == d1 ? d : sz[i], Val (N))
344349end
345350
351+ function output_size (p:: Plan )
352+ if projection_style (p) == :none
353+ return size (p)
354+ elseif projection_style (p) == :real
355+ return rfft_output_size (size (p), region (p))
356+ elseif projection_style (p) == :real_inv
357+ return brfft_output_size (size (p), irfft_dim (p), region (p))
358+ end
359+ end
360+
346361plan_irfft (x:: AbstractArray{Complex{T}} , d:: Integer , region; kws... ) where {T} =
347362 ScaledPlan (plan_brfft (x, d, region; kws... ),
348363 normalization (T, brfft_output_size (x, d, region), region))
@@ -575,3 +590,58 @@ Pre-plan an optimized real-input unnormalized transform, similar to
575590the same as for [`brfft`](@ref).
576591"""
577592plan_brfft
593+
594+ # #############################################################################
595+
596+ region (p:: Plan ) = p. region
597+ region (p:: ScaledPlan ) = region (p. p)
598+
599+ # Projection style (:none, :real, or :real_inv) to handle real FFTs
600+ function projection_style end
601+ # Length of halved dimension, needed only for irfft
602+ function irfft_dim end
603+
604+ mutable struct AdjointPlan{T,P} <: Plan{T}
605+ p:: P
606+ pinv:: Plan
607+ AdjointPlan {T,P} (p) where {T,P} = new (p)
608+ # always have adjoint inside scaled
609+ AdjointPlan {T,P} (p:: P ) where {T,P<: ScaledPlan{T} } = ScaledPlan {T} (AdjointPlan {T} (p. p), p. scale)
610+ AdjointPlan {T,P} (p:: AdjointPlan{T} ) where {T,P} = new (p. p)
611+ end
612+
613+ AdjointPlan {T} (p:: P ) where {T,P} = AdjointPlan {T,P} (p)
614+ AdjointPlan (p:: Plan{T} ) where {T} = AdjointPlan {T} (p)
615+ Base. adjoint (p:: Plan{T} ) where {T} = AdjointPlan {T} (p)
616+
617+ size (p:: AdjointPlan ) = output_size (p)
618+ output_size (p:: AdjointPlan ) = size (p)
619+
620+ function Base.:* (p:: AdjointPlan{T} , x:: AbstractArray ) where {T}
621+ dims = region (p. p)
622+ halfdim = first (dims)
623+ d = size (p. p, halfdim)
624+ n = output_size (p. p, halfdim)
625+ if projection_style (p. p) == :none
626+ N = normalization (T, size (p. p), dims)
627+ return 1 / N * (p. p \ x)
628+ elseif projection_style (p. p) == :real
629+ N = normalization (T, size (p. p), dims)
630+ scale = reshape (
631+ [(i == 1 || (i == n && 2 * (i - 1 )) == d) ? 1 : 2 for i in 1 : n],
632+ ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x)))
633+ )
634+ return 1 / N * (p. p \ (x ./ scale))
635+ elseif projection_style (p. p) == :real_inv
636+ N = normalization (real (T), output_size (p. p), dims)
637+ scale = reshape (
638+ [(i == 1 || (i == d && 2 * (i - 1 )) == n) ? 1 : 2 for i in 1 : d],
639+ ntuple (i -> i == first (dims) ? d : 1 , Val (ndims (x)))
640+ )
641+ return 1 / N * scale .* (p. p \ x)
642+ else
643+ error (" plan must define a valid projection style" )
644+ end
645+ end
646+
647+ plan_inv (p:: AdjointPlan ) = AdjointPlan (plan_inv (p. p))
0 commit comments