Skip to content

Commit 4cb7b1f

Browse files
committed
Adding fast_fourier_transform.py
Adding fast_fourier_transform.py
1 parent e2a78d4 commit 4cb7b1f

File tree

1 file changed

+319
-0
lines changed

1 file changed

+319
-0
lines changed
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
"""
2+
Fast Fourier Transform (FFT) using Divide and Conquer
3+
4+
The Fast Fourier Transform is a divide-and-conquer algorithm that computes the
5+
Discrete Fourier Transform (DFT) of a sequence in O(n log n) time, compared to
6+
O(n²) for the naive DFT computation.
7+
8+
The algorithm works by:
9+
1. Recursively dividing the DFT computation into smaller subproblems
10+
2. Using the symmetry and periodicity properties of complex exponentials
11+
3. Combining results using the "butterfly" operation
12+
13+
Key mathematical insight:
14+
- DFT of even-indexed elements and odd-indexed elements can be computed separately
15+
- Results are combined using complex exponentials (twiddle factors)
16+
17+
Time complexity: O(n log n)
18+
Space complexity: O(n log n) due to recursion
19+
20+
References:
21+
- https://en.wikipedia.org/wiki/Fast_Fourier_transform
22+
- Cooley-Tukey FFT algorithm (1965)
23+
"""
24+
25+
from __future__ import annotations
26+
27+
import cmath
28+
from collections.abc import Sequence
29+
30+
31+
def fft(x: Sequence[float | complex]) -> list[complex]:
32+
"""
33+
Compute the Fast Fourier Transform of a sequence using divide and conquer.
34+
35+
This implementation uses the Cooley-Tukey algorithm, which recursively
36+
divides the DFT computation into smaller subproblems.
37+
38+
Args:
39+
x: Input sequence (list of real or complex numbers)
40+
41+
Returns:
42+
List of complex numbers representing the DFT of the input sequence
43+
44+
Raises:
45+
ValueError: If input length is not a power of 2
46+
47+
Examples:
48+
>>> import math
49+
>>> # Test with delta function [1, 0, 0, 0] -> constant spectrum [1, 1, 1, 1]
50+
>>> result = fft([1, 0, 0, 0])
51+
>>> all(abs(abs(x) - 1) < 1e-10 for x in result) # All should have magnitude 1
52+
True
53+
54+
>>> # Test with impulse at second position
55+
>>> result = fft([0, 1, 0, 0])
56+
>>> all(abs(abs(x) - 1) < 1e-10 for x in result) # All should have magnitude 1
57+
True
58+
59+
>>> # Test with real sine wave
60+
>>> n = 8
61+
>>> signal = [math.sin(2 * math.pi * k / n) for k in range(n)]
62+
>>> result = fft(signal)
63+
>>> len(result) == n
64+
True
65+
"""
66+
n = len(x)
67+
68+
# Check if length is power of 2
69+
if n <= 0 or (n & (n - 1)) != 0:
70+
raise ValueError("Input length must be a power of 2")
71+
72+
# Base case
73+
if n == 1:
74+
return [complex(x[0])]
75+
76+
# Divide: separate even and odd indexed elements
77+
even = [x[i] for i in range(0, n, 2)]
78+
odd = [x[i] for i in range(1, n, 2)]
79+
80+
# Conquer: recursively compute FFT of even and odd parts
81+
fft_even = fft(even)
82+
fft_odd = fft(odd)
83+
84+
# Combine: merge the results using butterfly operation
85+
result = [complex(0)] * n
86+
for k in range(n // 2):
87+
# Twiddle factor: e^(-2πik/n)
88+
twiddle = cmath.exp(-2j * cmath.pi * k / n)
89+
90+
# Butterfly operation
91+
butterfly = twiddle * fft_odd[k]
92+
result[k] = fft_even[k] + butterfly
93+
result[k + n // 2] = fft_even[k] - butterfly
94+
95+
return result
96+
97+
98+
def ifft(x: Sequence[complex]) -> list[complex]:
99+
"""
100+
Compute the Inverse Fast Fourier Transform using divide and conquer.
101+
102+
The IFFT is computed by taking the conjugate of the input, applying FFT,
103+
taking conjugate again, and scaling by 1/n.
104+
105+
Args:
106+
x: Input sequence (list of complex numbers)
107+
108+
Returns:
109+
List of complex numbers representing the IFFT of the input sequence
110+
111+
Examples:
112+
>>> # Test round-trip: FFT followed by IFFT should give original signal
113+
>>> original = [1, 2, 3, 4]
114+
>>> recovered = ifft(fft(original))
115+
>>> all(abs(recovered[i] - original[i]) < 1e-10 for i in range(len(original)))
116+
True
117+
118+
>>> # Test with complex input
119+
>>> original = [1+2j, 3-1j, 0+0j, 2+3j]
120+
>>> recovered = ifft(fft(original))
121+
>>> all(abs(recovered[i] - original[i]) < 1e-10 for i in range(len(original)))
122+
True
123+
"""
124+
n = len(x)
125+
126+
# Conjugate input
127+
x_conj = [complex(val.real, -val.imag) for val in x]
128+
129+
# Apply FFT
130+
result = fft(x_conj)
131+
132+
# Conjugate result and scale by 1/n
133+
return [complex(val.real / n, -val.imag / n) for val in result]
134+
135+
136+
def dft_naive(x: Sequence[float | complex]) -> list[complex]:
137+
"""
138+
Compute the Discrete Fourier Transform using the naive O(n²) algorithm.
139+
140+
This is provided for comparison and testing purposes.
141+
142+
Args:
143+
x: Input sequence (list of real or complex numbers)
144+
145+
Returns:
146+
List of complex numbers representing the DFT of the input sequence
147+
148+
Examples:
149+
>>> # Compare with FFT result
150+
>>> signal = [1, 2, 3, 4]
151+
>>> fft_result = fft(signal)
152+
>>> dft_result = dft_naive(signal)
153+
>>> all(abs(fft_result[i] - dft_result[i]) < 1e-10 for i in range(len(signal)))
154+
True
155+
"""
156+
n = len(x)
157+
result = []
158+
159+
for k in range(n):
160+
sum_val = complex(0)
161+
for j in range(n):
162+
# Compute e^(-2πijk/n)
163+
angle = -2 * cmath.pi * j * k / n
164+
sum_val += x[j] * cmath.exp(1j * angle)
165+
result.append(sum_val)
166+
167+
return result
168+
169+
170+
def pad_to_power_of_2(x: Sequence[float | complex]) -> list[float | complex]:
171+
"""
172+
Pad input sequence with zeros to make its length a power of 2.
173+
174+
Args:
175+
x: Input sequence
176+
177+
Returns:
178+
Padded sequence with length as power of 2
179+
180+
Examples:
181+
>>> pad_to_power_of_2([1, 2, 3])
182+
[1, 2, 3, 0]
183+
>>> pad_to_power_of_2([1, 2, 3, 4, 5])
184+
[1, 2, 3, 4, 5, 0, 0, 0]
185+
"""
186+
n = len(x)
187+
if n <= 0:
188+
return list(x)
189+
190+
# Find next power of 2
191+
next_power = 1
192+
while next_power < n:
193+
next_power *= 2
194+
195+
# Pad with zeros
196+
return list(x) + [0] * (next_power - n)
197+
198+
199+
def fft_magnitude_spectrum(x: Sequence[float | complex]) -> list[float]:
200+
"""
201+
Compute the magnitude spectrum of a signal using FFT.
202+
203+
Args:
204+
x: Input signal
205+
206+
Returns:
207+
List of magnitudes of the FFT coefficients
208+
209+
Examples:
210+
>>> # Test with a simple signal
211+
>>> signal = [1, 0, 1, 0]
212+
>>> spectrum = fft_magnitude_spectrum(signal)
213+
>>> len(spectrum) == len(signal)
214+
True
215+
>>> all(mag >= 0 for mag in spectrum) # All magnitudes should be non-negative
216+
True
217+
"""
218+
# Pad to power of 2 if necessary
219+
if len(x) & (len(x) - 1) != 0:
220+
x = pad_to_power_of_2(x)
221+
222+
# Compute FFT
223+
fft_result = fft(x)
224+
225+
# Return magnitudes
226+
return [abs(val) for val in fft_result]
227+
228+
229+
def convolution_fft(a: Sequence[float], b: Sequence[float]) -> list[float]:
230+
"""
231+
Compute convolution of two sequences using FFT.
232+
233+
Convolution in time domain equals pointwise multiplication in frequency domain.
234+
This provides an O(n log n) alternative to the naive O(n²) convolution.
235+
236+
Args:
237+
a: First sequence
238+
b: Second sequence
239+
240+
Returns:
241+
Convolution of a and b
242+
243+
Examples:
244+
>>> # Test convolution property
245+
>>> a = [1, 2, 3]
246+
>>> b = [1, 1]
247+
>>> result = convolution_fft(a, b)
248+
>>> len(result) >= len(a) + len(b) - 1
249+
True
250+
"""
251+
if not a or not b:
252+
return []
253+
254+
# Result length should be len(a) + len(b) - 1
255+
result_len = len(a) + len(b) - 1
256+
257+
# Pad both sequences to the same power of 2 length
258+
padded_len = 1
259+
while padded_len < result_len:
260+
padded_len *= 2
261+
262+
a_padded = list(a) + [0] * (padded_len - len(a))
263+
b_padded = list(b) + [0] * (padded_len - len(b))
264+
265+
# Compute FFT of both sequences
266+
fft_a = fft(a_padded)
267+
fft_b = fft(b_padded)
268+
269+
# Pointwise multiplication in frequency domain
270+
fft_product = [fft_a[i] * fft_b[i] for i in range(len(fft_a))]
271+
272+
# Inverse FFT to get convolution result
273+
conv_result = ifft(fft_product)
274+
275+
# Return only the valid part (real parts, since convolution of real signals is real)
276+
return [val.real for val in conv_result[:result_len]]
277+
278+
279+
if __name__ == "__main__":
280+
import doctest
281+
282+
doctest.testmod()
283+
284+
# Example usage and demonstration
285+
print("Fast Fourier Transform Demonstration")
286+
print("=" * 40)
287+
288+
# Example 1: Simple signal
289+
print("\n1. Simple 4-point signal:")
290+
signal = [1, 2, 3, 4]
291+
print(f"Input: {signal}")
292+
293+
fft_result = fft(signal)
294+
print("FFT result:")
295+
for i, val in enumerate(fft_result):
296+
print(f" X[{i}] = {val:.3f}")
297+
298+
# Verify with naive DFT
299+
dft_result = dft_naive(signal)
300+
matches_dft = all(
301+
abs(fft_result[i] - dft_result[i]) < 1e-10 for i in range(len(signal))
302+
)
303+
print(f"\nVerification - FFT matches DFT: {matches_dft}")
304+
305+
# Test round-trip
306+
recovered = ifft(fft_result)
307+
print(f"Round-trip test (IFFT of FFT): {[f'{val.real:.3f}' for val in recovered]}")
308+
309+
# Example 2: Magnitude spectrum
310+
print("\n2. Magnitude spectrum:")
311+
spectrum = fft_magnitude_spectrum(signal)
312+
print(f"Magnitudes: {[f'{mag:.3f}' for mag in spectrum]}")
313+
314+
# Example 3: Convolution using FFT
315+
print("\n3. Convolution using FFT:")
316+
a = [1, 2, 3]
317+
b = [1, 1, 1]
318+
conv_result = convolution_fft(a, b)
319+
print(f"Convolution of {a} and {b}: {[f'{val:.3f}' for val in conv_result]}")

0 commit comments

Comments
 (0)