FFT in Haskell and Futhark
2024-12-24 (updated: 2024-12-24 11:11)
The Fourier transform is one of the fundamental tools in analysis. From the perspective of approximation theory, it gives us one of the orthogonal function bases, the natural basis for periodic functions. To use this numerically, as with any basis, we must sample from the function to be approximated, and periodic functions have a wonderful property that the optimal points to sample are uniformly spaced on the interval. This is very different from polynomial bases like Chebyshev or Legendre polynomials, where choosing the points is somewhat involved. For Fourier, this leads us to the Discrete Fourier Transform (DFT), a cornerstone of signal processing.
Let be our signal sampled at points, i.e. is a sequence of real or complex numbers. Let be the root of unity, the powers of which are sometimes called “twiddle” factors in this context. We can then define the DFT element-wise as follows.
We see that each element is the dot product of with a vector of the twiddle factors, so we can represent this as a matrix multiplication where
We will also denote this as . The complexity of matrix-vector multiplication is trivially , as we must access every element of the matrix. In the 1960s, during the great boom of numerical research, James Cooley and John Tukey published their paper exploiting the structure of the matrix for a log-linear solution, ushering in the age of real time digital signal processing. This family of algorithms is called the Fast Fourier Transform (FFT).
The key insight of the Cooley-Tukey FFT is that one can split the signal in half and recursively compute the FFT. This is one of the early examples of divide and conquer algorithms, along with merge sort, which shares a similar time complexity. For simplicity we assume is a power of two, so that we can divide in half until we get to a single element, though this is not necessary with real FFT implementations. Recursive algorithms are often most naturally expressed in functional languages, so we derive a recursive form to implement in Haskell.
First we identify the base case, which is simply the identity . For it is instructive to do examples out by hand in full using the matrix multiplication . For we get the following.
When drawn out as a data flow diagram, as you would see in more hardware-adjacent expositions, this forms a cross-over, leading to the name butterfly for the combining stage of the FFT.
The trick to the recursion is that splitting into even and odd components. This can be seen by rewriting the case out by hand, I will leave out the derivation here, but the result should look like the following.
We then combine the two sub-problems. This is not entirely intuitive and I encourage you to look in an introductory numerical analysis textbook if you would like to be guided through the derivation. Note that the last equality is just using to simplify, this is very helpful computationally, as the bottom half and top half of the vector are now much more similar. From this we have the motivation for the recursive definition we will implement. Let and be a vector of twiddle factors, with being element-wise “broadcasting” multiplication. Then we can derive the following, abusing matrix notation somewhat.
A minimal Haskell implementation of this recursive form is quite elegant.
split :: [a] -> ([a], [a])
= ([], [])
split [] = error "input size must be power of two"
split [_] :y:xs) =
split (xlet (es, os) = split xs
in (x:es, y:os)
mergeRadix2 :: [Complex Double] -> [Complex Double] -> Int -> [Complex Double]
= (++) (zipWith (+) u q) (zipWith (-) u q)
mergeRadix2 u v n where q = zipWith (*) v w
= length u - 1
n2 = [exp (0 :+ (-2 * pi * fromIntegral k / fromIntegral n )) | k <- [0..n2]]
w
fft :: [Complex Double] -> [Complex Double]
= []
fft [] = [z]
fft [z] = mergeRadix2 (fft evens) (fft odds) (length zs)
fft zs where (evens, odds) = split zs
One might immediately ask about performance, and yes, this implementation is meant only to be instructive, but explicitly recursive implementations can be competitive. The first place to look is FFTW, the state of the art software FFT library, which takes a “bag of algorithms + planner” approach. It is implemented with OCaml for code generation with many passes of optimization to create a portable C library, and many of the variants are recursive.
The obvious suspects in a numerical algorithm optimization such as this are:
- Avoiding memory reallocation and optimizing cache locality.
- Using lookup tables or otherwise avoiding trigonometric calculation.
Implementing the FFT in Futhark
I wanted to try Futhark, the pure functional array based language implemented in Haskell that compiles to C or Cuda/OpenCL, and thought this algorithm would be a good fit. There is a Stockham variant in the Futhark packages for reference, but I implemented Cooley-Tukey Radix-2. Unfortunately Futhark does not support explicit recursion, and it is not clear (to me at least) if it ever will. My understanding is that it may be possible in the future, though there are fundamental difficulties, as the stack cannot be used willy-nilly on a GPU, so any recursion would be limited in nature, and currently you just have to unroll it into a loop manually. This means we cannot implement a recursive FFT, but must do the more complicated iterative approach.
I attempted to use Claude for this, to see how it would do with a relative obscure programming language, surprisingly it mostly worked, though it consistently would get indexing wrong and mostly would not use the array combinators correctly. The main points of the iterative approach are that successive applications of the even/odd splits can be viewed as a rearrangement by “bit reversal permutation” and that we must do much tedious indexing to keep track of the arithmetic combinations, these are the “butterflies” previously mentioned. Not going into depth, here is my implementation.
: i64) (n: i64): complex =
def twiddle (klet angle = -2.0 * f64.pi * f64.i64 k / f64.i64 n
in (f64.cos angle, f64.sin angle)
: [n]t): [n]t =
def bit_reversal [n] 't (inputlet bits = i64.f64 (f64.log2 (f64.i64 n))
let indices = map (\i ->
let rev = loop rev = 0 for j < bits do
<< 1) | ((i >> j) & 1)
(rev in rev
) (iota n)in spread n (input[0]) indices input
-- Type to hold butterfly operation parameters
type butterfly_params = {
: i64, -- Index of upper butterfly input
upper_idx: i64, -- Index of lower butterfly input
lower_idx: complex -- Twiddle factor for this butterfly
twiddle
}
-- Calculate butterfly parameters for a given stage
: i64) (n: i64) (i: i64): butterfly_params =
def get_butterfly_params (stagelet butterfly_size = 1 << (stage + 1) -- Size of entire butterfly
let half_size = butterfly_size >> 1 -- Size of half butterfly
let group = i / butterfly_size -- Which group of butterflies
let k = i % half_size -- Position within half
let group_start = group * butterfly_size -- Start index of this group
let twiddle_idx = k * (n / butterfly_size) -- Index for twiddle factor
in {
= group_start + k,
upper_idx = group_start + k + half_size,
lower_idx = twiddle twiddle_idx n
twiddle
}
-- Perform single butterfly operation
data: []complex) (p: butterfly_params) (is_upper: bool): complex =
def butterfly_op (if is_upper
then complex_add data[p.upper_idx]
data[p.lower_idx] p.twiddle)
(complex_mul else complex_sub data[p.upper_idx]
data[p.lower_idx] p.twiddle)
(complex_mul
-- Main FFT function
: [n]complex): [n]complex =
def fft [n] (inputlet bits = i64.f64 (f64.log2 (f64.i64 n))
-- This method can only handle arrays of length 2^n
in assert (n == 1 << bits) (
-- First apply bit reversal permutation
let reordered = bit_reversal input
-- Perform log2(n) stages of butterfly operations
in loop data = reordered for stage < bits do
-- For each stage, compute butterfly parameters and perform operations
let butterfly_size = 1 << (stage + 1)
let half_size = butterfly_size >> 1
let params = map (get_butterfly_params stage n) (iota n)
in map2 (\p i ->
let is_upper = (i % butterfly_size) < half_size
in butterfly_op data p is_upper
) params (iota n) )
This is not particularly optimized. Futhark allows for fused memory operations and has a semantics for tracking when it is safe to overwrite memory while remaining pure, I did not use this here. I did make sure to use spread
and map2
array combinators when traversing, which theoretically should allow for some automatic parallelism, though I did not test this, as I don’t have CUDA running on my laptop.
Futhark is slowly emerging from being an academic project into a serious tool, and the ecosystem is still in its infancy. I wanted to try implementing some of my research in eigensolvers, but the linear algebra module is at the level of undergraduate research project, and does not appear to support complex matrices at the moment. Personally, I probably will not use it further at the moment, but it is very much the direction I would like numerical algorithms to go, with functional DSLs (or full languages) that compile to highly portable, highly optimized code.