@@ -262,23 +262,43 @@ def _iter_fftnd(
262262 axes = None ,
263263 out = None ,
264264 direction = + 1 ,
265- overwrite_x = False ,
266- scale_function = lambda n , ind : 1.0 ,
265+ scale_function = lambda ind : 1.0 ,
267266):
268267 a = np .asarray (a )
269268 s , axes = _init_nd_shape_and_axes (a , s , axes )
270- ovwr = overwrite_x
271- for ii in reversed (range (len (axes ))):
269+
270+ # Combine the two, but in reverse, to end with the first axis given.
271+ axes_and_s = list (zip (axes , s ))[::- 1 ]
272+ # We try to use in-place calculations where possible, which is
273+ # everywhere except when the size changes after the first FFT.
274+ size_changes = [axis for axis , n in axes_and_s [1 :] if a .shape [axis ] != n ]
275+
276+ # If there are any size changes, we cannot use out
277+ res = None if size_changes else out
278+ for ind , (axis , n ) in enumerate (axes_and_s ):
279+ if axis in size_changes :
280+ if axis == size_changes [- 1 ]:
281+ # Last size change, so any output should now be OK
282+ # (an error will be raised if not), and if no output is
283+ # required, we want a freshly allocated array of the right size.
284+ res = out
285+ elif res is not None and n < res .shape [axis ]:
286+ # For an intermediate step where we return fewer elements, we
287+ # can use a smaller view of the previous array.
288+ res = res [(slice (None ),) * axis + (slice (n ),)]
289+ else :
290+ # If we need more elements, we cannot use res.
291+ res = None
272292 a = _c2c_fft1d_impl (
273293 a ,
274- n = s [ii ],
275- axis = axes [ii ],
276- overwrite_x = ovwr ,
294+ n = n ,
295+ axis = axis ,
277296 direction = direction ,
278- fsc = scale_function (s [ ii ], ii ),
279- out = out ,
297+ fsc = scale_function (ind ),
298+ out = res ,
280299 )
281- ovwr = True
300+ # Default output for next iteration.
301+ res = a
282302 return a
283303
284304
@@ -360,7 +380,6 @@ def _c2c_fftnd_impl(
360380 x ,
361381 s = None ,
362382 axes = None ,
363- overwrite_x = False ,
364383 direction = + 1 ,
365384 fsc = 1.0 ,
366385 out = None ,
@@ -385,7 +404,6 @@ def _c2c_fftnd_impl(
385404 if _direct :
386405 return _direct_fftnd (
387406 x ,
388- overwrite_x = overwrite_x ,
389407 direction = direction ,
390408 fsc = fsc ,
391409 out = out ,
@@ -403,11 +421,7 @@ def _c2c_fftnd_impl(
403421 x ,
404422 axes ,
405423 _direct_fftnd ,
406- {
407- "overwrite_x" : overwrite_x ,
408- "direction" : direction ,
409- "fsc" : fsc ,
410- },
424+ {"direction" : direction , "fsc" : fsc },
411425 res ,
412426 )
413427 else :
@@ -418,97 +432,122 @@ def _c2c_fftnd_impl(
418432 axes = axes ,
419433 out = out ,
420434 direction = direction ,
421- overwrite_x = overwrite_x ,
422- scale_function = lambda n , i : fsc if i == 0 else 1.0 ,
435+ scale_function = lambda i : fsc if i == 0 else 1.0 ,
423436 )
424437
425438
426439def _r2c_fftnd_impl (x , s = None , axes = None , fsc = 1.0 , out = None ):
427440 a = np .asarray (x )
428441 no_trim = (s is None ) and (axes is None )
429442 s , axes = _cook_nd_args (a , s , axes )
443+ axes = [ax + a .ndim if ax < 0 else ax for ax in axes ]
430444 la = axes [- 1 ]
445+
431446 # trim array, so that rfft avoids doing unnecessary computations
432447 if not no_trim :
433448 a = _trim_array (a , s , axes )
449+
450+ # last axis is not included since we calculate r2c FFT separately
451+ # and not in the loop
452+ axes_and_s = list (zip (axes , s ))[- 2 ::- 1 ]
453+ size_changes = [axis for axis , n in axes_and_s if a .shape [axis ] != n ]
454+ res = None if size_changes else out
455+
434456 # r2c along last axis
435- a = _r2c_fft1d_impl (a , n = s [- 1 ], axis = la , fsc = fsc , out = out )
457+ a = _r2c_fft1d_impl (a , n = s [- 1 ], axis = la , fsc = fsc , out = res )
458+ res = a
436459 if len (s ) > 1 :
437- if not no_trim :
438- ss = list (s )
439- ss [- 1 ] = a .shape [la ]
440- a = _pad_array (a , tuple (ss ), axes )
460+
441461 len_axes = len (axes )
442462 if len (set (axes )) == len_axes and len_axes == a .ndim and len_axes > 2 :
463+ if not no_trim :
464+ ss = list (s )
465+ ss [- 1 ] = a .shape [la ]
466+ a = _pad_array (a , tuple (ss ), axes )
443467 # a series of ND c2c FFTs along last axis
444468 ss , aa = _remove_axis (s , axes , - 1 )
445- ind = [
446- slice (None , None , 1 ),
447- ] * len (s )
469+ ind = [slice (None , None , 1 )] * len (s )
448470 for ii in range (a .shape [la ]):
449471 ind [la ] = ii
450472 tind = tuple (ind )
451473 a_inp = a [tind ]
452- res = out [tind ] if out is not None else None
453- a_res = _c2c_fftnd_impl (
454- a_inp , s = ss , axes = aa , overwrite_x = True , direction = 1 , out = res
455- )
456- if a_res is not a_inp :
457- a [tind ] = a_res # copy in place
474+ res = out [tind ] if out is not None else a_inp
475+ _ = _c2c_fftnd_impl (a_inp , s = ss , axes = aa , direction = 1 , out = res )
476+ if out is not None :
477+ a = out
458478 else :
479+ # another size_changes check is needed if there are repeated axes
480+ # of last axis, since since FFT changes the shape along last axis
481+ size_changes = [
482+ axis for axis , n in axes_and_s if a .shape [axis ] != n
483+ ]
484+
459485 # a series of 1D c2c FFTs along all axes except last
460- for ii in range (len (axes ) - 2 , - 1 , - 1 ):
461- a = _c2c_fft1d_impl (a , s [ii ], axes [ii ], overwrite_x = True )
486+ for axis , n in axes_and_s :
487+ if axis in size_changes :
488+ if axis == size_changes [- 1 ]:
489+ res = out
490+ elif res is not None and n < res .shape [axis ]:
491+ res = res [(slice (None ),) * axis + (slice (n ),)]
492+ else :
493+ res = None
494+ a = _c2c_fft1d_impl (a , n , axis , out = res )
495+ res = a
462496 return a
463497
464498
465499def _c2r_fftnd_impl (x , s = None , axes = None , fsc = 1.0 , out = None ):
466500 a = np .asarray (x )
467501 no_trim = (s is None ) and (axes is None )
468502 s , axes = _cook_nd_args (a , s , axes , invreal = True )
503+ axes = [ax + a .ndim if ax < 0 else ax for ax in axes ]
469504 la = axes [- 1 ]
470505 if not no_trim :
471506 a = _trim_array (a , s , axes )
472507 if len (s ) > 1 :
473- if not no_trim :
474- a = _pad_array (a , s , axes )
475- ovr_x = True if _datacopied (a , x ) else False
476508 len_axes = len (axes )
477509 if len (set (axes )) == len_axes and len_axes == a .ndim and len_axes > 2 :
510+ if not no_trim :
511+ a = _pad_array (a , s , axes )
478512 # a series of ND c2c FFTs along last axis
479513 # due to need to write into a, we must copy
480- if not ovr_x :
481- a = a .copy ()
482- ovr_x = True
514+ a = a if _datacopied (a , x ) else a .copy ()
483515 if not np .issubdtype (a .dtype , np .complexfloating ):
484516 # complex output will be copied to input, copy is needed
485517 if a .dtype == np .float32 :
486518 a = a .astype (np .complex64 )
487519 else :
488520 a = a .astype (np .complex128 )
489- ovr_x = True
490521 ss , aa = _remove_axis (s , axes , - 1 )
491- ind = [
492- slice (None , None , 1 ),
493- ] * len (s )
522+ ind = [slice (None , None , 1 )] * len (s )
494523 for ii in range (a .shape [la ]):
495524 ind [la ] = ii
496525 tind = tuple (ind )
497526 a_inp = a [tind ]
498527 # out has real dtype and cannot be used in intermediate steps
499- a_res = _c2c_fftnd_impl (
500- a_inp , s = ss , axes = aa , overwrite_x = True , direction = - 1
528+ # ss and aa are reversed since np.irfftn uses forward order but
529+ # np.ifftn uses reverse order see numpy-gh-28950
530+ _ = _c2c_fftnd_impl (
531+ a_inp , s = ss [::- 1 ], axes = aa [::- 1 ], out = a_inp , direction = - 1
501532 )
502- if a_res is not a_inp :
503- a [tind ] = a_res # copy in place
504533 else :
505534 # a series of 1D c2c FFTs along all axes except last
506- for ii in range (len (axes ) - 1 ):
507- # out has real dtype and cannot be used in intermediate steps
508- a = _c2c_fft1d_impl (
509- a , s [ii ], axes [ii ], overwrite_x = ovr_x , direction = - 1
510- )
511- ovr_x = True
535+ # forward order, see numpy-gh-28950
536+ axes_and_s = list (zip (axes , s ))[:- 1 ]
537+ size_changes = [
538+ axis for axis , n in axes_and_s [1 :] if a .shape [axis ] != n
539+ ]
540+ # out has real dtype cannot be used for intermediate steps
541+ res = None
542+ for axis , n in axes_and_s :
543+ if axis in size_changes :
544+ if res is not None and n < res .shape [axis ]:
545+ # pylint: disable=unsubscriptable-object
546+ res = res [(slice (None ),) * axis + (slice (n ),)]
547+ else :
548+ res = None
549+ a = _c2c_fft1d_impl (a , n , axis , out = res , direction = - 1 )
550+ res = a
512551 # c2r along last axis
513552 a = _c2r_fft1d_impl (a , n = s [- 1 ], axis = la , fsc = fsc , out = out )
514553 return a
0 commit comments