@@ -72,16 +72,13 @@ def _check_norm(norm):
7272
7373
7474def frwd_sc_1d (n , s ):
75- nn = n if n else s
75+ nn = n if n is not None else s
7676 return 1 / nn if nn != 0 else 1
7777
7878
79- def frwd_sc_nd (s , axes , x_shape ):
79+ def frwd_sc_nd (s , x_shape ):
8080 ss = s if s is not None else x_shape
81- if axes is not None :
82- nn = prod ([ss [ai ] for ai in axes ])
83- else :
84- nn = prod (ss )
81+ nn = prod (ss )
8582 return 1 / nn if nn != 0 else 1
8683
8784
@@ -837,14 +834,14 @@ def fftn(a, s=None, axes=None, norm=None):
837834 if norm in (None , "backward" ):
838835 fsc = 1.0
839836 elif norm == "forward" :
840- fsc = frwd_sc_nd (s , axes , x .shape )
837+ fsc = frwd_sc_nd (s , x .shape )
841838 else :
842- fsc = sqrt (frwd_sc_nd (s , axes , x .shape ))
839+ fsc = sqrt (frwd_sc_nd (s , x .shape ))
843840
844841 return trycall (
845842 mkl_fft .fftn ,
846843 (x ,),
847- {'shape ' : s , 'axes' : axes ,
844+ {'s ' : s , 'axes' : axes ,
848845 'fwd_scale' : fsc })
849846
850847
@@ -954,14 +951,14 @@ def ifftn(a, s=None, axes=None, norm=None):
954951 if norm in (None , "backward" ):
955952 fsc = 1.0
956953 elif norm == "forward" :
957- fsc = frwd_sc_nd (s , axes , x .shape )
954+ fsc = frwd_sc_nd (s , x .shape )
958955 else :
959- fsc = sqrt (frwd_sc_nd (s , axes , x .shape ))
956+ fsc = sqrt (frwd_sc_nd (s , x .shape ))
960957
961958 return trycall (
962959 mkl_fft .ifftn ,
963960 (x ,),
964- {'shape ' : s , 'axes' : axes ,
961+ {'s ' : s , 'axes' : axes ,
965962 'fwd_scale' : fsc })
966963
967964
@@ -1253,10 +1250,10 @@ def rfftn(a, s=None, axes=None, norm=None):
12531250 fsc = 1.0
12541251 elif norm == "forward" :
12551252 x = asanyarray (x )
1256- fsc = frwd_sc_nd (s , axes , x .shape )
1253+ fsc = frwd_sc_nd (s , x .shape )
12571254 else :
12581255 x = asanyarray (x )
1259- fsc = sqrt (frwd_sc_nd (s , axes , x .shape ))
1256+ fsc = sqrt (frwd_sc_nd (s , x .shape ))
12601257
12611258 return trycall (
12621259 mkl_fft .rfftn ,
@@ -1408,10 +1405,10 @@ def irfftn(a, s=None, axes=None, norm=None):
14081405 fsc = 1.0
14091406 elif norm == "forward" :
14101407 x = asanyarray (x )
1411- fsc = frwd_sc_nd (s , axes , x .shape )
1408+ fsc = frwd_sc_nd (s , x .shape )
14121409 else :
14131410 x = asanyarray (x )
1414- fsc = sqrt (frwd_sc_nd (s , axes , x .shape ))
1411+ fsc = sqrt (frwd_sc_nd (s , x .shape ))
14151412
14161413 return trycall (
14171414 mkl_fft .irfftn ,
0 commit comments