grafos_dsp/
fft.rs

1//! Cooley-Tukey radix-2 FFT and inverse FFT.
2
3extern crate alloc;
4use alloc::vec::Vec;
5
6use crate::types::Complex;
7
8/// Compute the forward FFT of a complex input using Cooley-Tukey radix-2.
9///
10/// `input` length must be a power of 2.
11pub fn fft(input: &[Complex]) -> Vec<Complex> {
12    let n = input.len();
13    assert!(n.is_power_of_two(), "FFT input length must be a power of 2");
14
15    if n == 1 {
16        return input.to_vec();
17    }
18
19    let mut output = input.to_vec();
20    bit_reverse_permutation(&mut output);
21    butterfly(&mut output, false);
22    output
23}
24
25/// Compute the inverse FFT of a complex input.
26///
27/// `input` length must be a power of 2.
28pub fn ifft(input: &[Complex]) -> Vec<Complex> {
29    let n = input.len();
30    assert!(
31        n.is_power_of_two(),
32        "IFFT input length must be a power of 2"
33    );
34
35    if n == 1 {
36        return input.to_vec();
37    }
38
39    let mut output = input.to_vec();
40    bit_reverse_permutation(&mut output);
41    butterfly(&mut output, true);
42
43    let scale = 1.0 / n as f32;
44    for c in output.iter_mut() {
45        c.re *= scale;
46        c.im *= scale;
47    }
48    output
49}
50
51/// Forward FFT of real-valued input. Returns N/2+1 complex bins.
52pub fn fft_real(input: &[f32]) -> Vec<Complex> {
53    let complex_input: Vec<Complex> = input.iter().map(|&x| Complex::new(x, 0.0)).collect();
54    fft(&complex_input)
55}
56
57/// Inverse FFT returning real-valued output. Takes N complex bins
58/// (full spectrum) and returns N real samples.
59pub fn ifft_real(input: &[Complex]) -> Vec<f32> {
60    ifft(input).iter().map(|c| c.re).collect()
61}
62
63fn bit_reverse_permutation(data: &mut [Complex]) {
64    let n = data.len();
65    let bits = n.trailing_zeros();
66    for i in 0..n {
67        let j = i.reverse_bits() >> (usize::BITS - bits);
68        if i < j {
69            data.swap(i, j);
70        }
71    }
72}
73
74fn butterfly(data: &mut [Complex], inverse: bool) {
75    let n = data.len();
76    let mut len = 2;
77    while len <= n {
78        let half = len / 2;
79        let angle_sign = if inverse { 1.0 } else { -1.0 };
80        let angle = angle_sign * 2.0 * core::f32::consts::PI / len as f32;
81
82        for start in (0..n).step_by(len) {
83            let mut w = Complex::new(1.0, 0.0);
84            let w_step = Complex::new(angle.cos(), angle.sin());
85            for k in 0..half {
86                let even = data[start + k];
87                let odd = data[start + k + half] * w;
88                data[start + k] = even + odd;
89                data[start + k + half] = even - odd;
90                w = w * w_step;
91            }
92        }
93        len *= 2;
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use alloc::vec;
101
102    fn approx_eq(a: f32, b: f32, tol: f32) -> bool {
103        (a - b).abs() < tol
104    }
105
106    #[test]
107    fn fft_single_element() {
108        let input = vec![Complex::new(42.0, 0.0)];
109        let output = fft(&input);
110        assert_eq!(output.len(), 1);
111        assert!(approx_eq(output[0].re, 42.0, 1e-6));
112    }
113
114    #[test]
115    fn fft_ifft_roundtrip() {
116        let input: Vec<Complex> = (0..8).map(|i| Complex::new(i as f32, 0.0)).collect();
117        let freq = fft(&input);
118        let recovered = ifft(&freq);
119        for (i, c) in recovered.iter().enumerate() {
120            assert!(
121                approx_eq(c.re, i as f32, 1e-4),
122                "sample {i}: expected {}, got {}",
123                i as f32,
124                c.re
125            );
126            assert!(
127                approx_eq(c.im, 0.0, 1e-4),
128                "sample {i}: imaginary part {} should be ~0",
129                c.im
130            );
131        }
132    }
133
134    #[test]
135    fn fft_known_sine_wave() {
136        // 8-sample sine wave at bin 1 (frequency = sample_rate/8)
137        let n = 8;
138        let input: Vec<Complex> = (0..n)
139            .map(|i| {
140                let angle = 2.0 * core::f32::consts::PI * (i as f32) / (n as f32);
141                Complex::new(angle.sin(), 0.0)
142            })
143            .collect();
144        let freq = fft(&input);
145
146        // Bin 1 should have the dominant energy
147        let magnitudes: Vec<f32> = freq.iter().map(|c| c.magnitude()).collect();
148        // Bin 0 (DC) should be ~0
149        assert!(approx_eq(magnitudes[0], 0.0, 1e-4));
150        // Bin 1 should have magnitude ~4 (N/2)
151        assert!(
152            approx_eq(magnitudes[1], 4.0, 1e-3),
153            "bin 1 magnitude: {}",
154            magnitudes[1]
155        );
156        // Bin 7 (mirror of bin 1) should also have magnitude ~4
157        assert!(
158            approx_eq(magnitudes[7], 4.0, 1e-3),
159            "bin 7 magnitude: {}",
160            magnitudes[7]
161        );
162        // Other bins should be ~0
163        #[allow(clippy::needless_range_loop)]
164        for k in 2..7 {
165            assert!(
166                approx_eq(magnitudes[k], 0.0, 1e-3),
167                "bin {k} magnitude: {}",
168                magnitudes[k]
169            );
170        }
171    }
172
173    #[test]
174    fn fft_real_roundtrip() {
175        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
176        let freq = fft_real(&input);
177        let recovered = ifft_real(&freq);
178        for (i, &val) in recovered.iter().enumerate() {
179            assert!(
180                approx_eq(val, input[i], 1e-4),
181                "sample {i}: expected {}, got {val}",
182                input[i]
183            );
184        }
185    }
186}