@@ -345,16 +345,6 @@ function brfft_output_size(sz::Dims{N}, d::Integer, region) where {N}
345345 return ntuple (i -> i == d1 ? d : sz[i], Val (N))
346346end
347347
348- function output_size (p:: Plan )
349- if projection_style (p) == :none
350- return size (p)
351- elseif projection_style (p) == :real
352- return rfft_output_size (size (p), region (p))
353- elseif projection_style (p) == :real_inv
354- return brfft_output_size (size (p), irfft_dim (p), region (p))
355- end
356- end
357-
358348plan_irfft (x:: AbstractArray{Complex{T}} , d:: Integer , region; kws... ) where {T} =
359349 ScaledPlan (plan_brfft (x, d, region; kws... ),
360350 normalization (T, brfft_output_size (x, d, region), region))
@@ -590,11 +580,19 @@ plan_brfft
590580
591581# #############################################################################
592582
593- # Projection style (:none, :real, or :real_inv) to handle real FFTs
594- function projection_style end
595- # Length of halved dimension, needed only for irfft
583+ struct NoProjectionStyle end
584+ struct RealProjectionStyle end
585+ struct RealInverseProjectionStyle end
586+ const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle}
587+
596588function irfft_dim end
597589
590+ ProjectionStyle (p:: Plan ) = error (" No projection style defined for plan" )
591+ output_size (p:: Plan ) = _output_size (p, ProjectionStyle (p))
592+ _output_size (p:: Plan , :: NoProjectionStyle ) = size (p)
593+ _output_size (p:: Plan , :: RealProjectionStyle ) = rfft_output_size (size (p), region (p))
594+ _output_size (p:: Plan , :: RealInverseProjectionStyle ) = brfft_output_size (size (p), irfft_dim (p), region (p))
595+
598596mutable struct AdjointPlan{T,P} <: Plan{T}
599597 p:: P
600598 pinv:: Plan
@@ -611,31 +609,38 @@ Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T}(p)
611609size (p:: AdjointPlan ) = output_size (p)
612610output_size (p:: AdjointPlan ) = size (p)
613611
614- function Base.:* (p:: AdjointPlan{T} , x:: AbstractArray ) where {T}
612+ Base.:* (p:: AdjointPlan , x:: AbstractArray ) = _mul (p, x, ProjectionStyle (p. p))
613+
614+ function _mul (p:: AdjointPlan{T} , x:: AbstractArray , :: NoProjectionStyle ) where {T}
615+ dims = region (p. p)
616+ N = normalization (T, size (p. p), dims)
617+ return 1 / N * (p. p \ x)
618+ end
619+
620+ function _mul (p:: AdjointPlan{T} , x:: AbstractArray , :: RealProjectionStyle ) where {T}
615621 dims = region (p. p)
622+ N = normalization (T, size (p. p), dims)
616623 halfdim = first (dims)
617624 d = size (p. p, halfdim)
618625 n = output_size (p. p, halfdim)
619- if projection_style (p. p) == :none
620- N = normalization (T, size (p. p), dims)
621- return 1 / N * (p. p \ x)
622- elseif projection_style (p. p) == :real
623- N = normalization (T, size (p. p), dims)
624- scale = reshape (
625- [(i == 1 || (i == n && 2 * (i - 1 )) == d) ? 1 : 2 for i in 1 : n],
626- ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x)))
627- )
628- return 1 / N * (p. p \ (x ./ scale))
629- elseif projection_style (p. p) == :real_inv
630- N = normalization (real (T), output_size (p. p), dims)
631- scale = reshape (
632- [(i == 1 || (i == d && 2 * (i - 1 )) == n) ? 1 : 2 for i in 1 : d],
633- ntuple (i -> i == first (dims) ? d : 1 , Val (ndims (x)))
634- )
635- return 1 / N * scale .* (p. p \ x)
636- else
637- error (" plan must define a valid projection style" )
638- end
626+ scale = reshape (
627+ [(i == 1 || (i == n && 2 * (i - 1 )) == d) ? 1 : 2 for i in 1 : n],
628+ ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x)))
629+ )
630+ return 1 / N * (p. p \ (x ./ scale))
631+ end
632+
633+ function _mul (p:: AdjointPlan{T} , x:: AbstractArray , :: RealInverseProjectionStyle ) where {T}
634+ dims = region (p. p)
635+ N = normalization (real (T), output_size (p. p), dims)
636+ halfdim = first (dims)
637+ n = size (p. p, halfdim)
638+ d = output_size (p. p, halfdim)
639+ scale = reshape (
640+ [(i == 1 || (i == n && 2 * (i - 1 )) == d) ? 1 : 2 for i in 1 : n],
641+ ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x)))
642+ )
643+ return 1 / N * scale .* (p. p \ x)
639644end
640645
641646plan_inv (p:: AdjointPlan ) = AdjointPlan (plan_inv (p. p))
0 commit comments