The S matrix in the out-of-place fast Walsh Hadamard transform
1. Introduction
In the Fast Walsh–Hadamard Transform (FWHT), computation is organized into a sequence of stages.
Each stage combines values in pairs — computing a sum and a difference — and then rearranges them to prepare for the next stage.
The stage operator in our out-of-place version is:
where:
-
: permutes data so pairs are adjacent (adjacent-pair grouping).
-
Hadamard kernel to each pair:
-
: “unzips” results so all sums are in the lower half of the array, all differences in the upper half.
We normalize by so that each stage is orthonormal, meaning
2. Properties of
-
Sparse: exactly 2 non-zero entries per row and column.
-
Orthogonal (normalized): energy preserving — no growth in norm.
-
Self-inverse up to order:
-
In normalized form
-
The transform cycles every applications.
-
-
Part of the butterfly family: shares structure with FFT stages, wavelet transforms, and random projection matrices.
3. Why it matters
-
The matrix is the building block:
-
It isolates the pairing–mixing–unpacking logic into a clean, fixed operator.
-
Useful for:
-
FWHT
-
Randomized transforms
-
Data mixing layers in neural networks
-
4. The algorithm in plain words
One stage works like this:
-
Take the input array of length (power of two).
-
Pair elements:
-
For each pair :
-
Sum = → goes to lower half of output.
-
Diff = → goes to upper half of output.
-
-
The output array now has sums in
B[0..n/2-1]
, diffs inB[n/2..n-1]
.
Repeat this with the new input for the next stage.
Java Code:
public class FWHTStage {
/**
* Applies one normalized FWHT stage S: sums in lower half, diffs in upper half.
* @param input Input array of length n (must be power of 2).
* @param output Output array of length n (must be allocated by caller).
*/
public static void applyStage(double[] input, double[] output) {
int n = input.length;
if ((n & (n - 1)) != 0) {
throw new IllegalArgumentException("Length must be a power of two");
}
int half = n / 2;
double norm = 1.0 / Math.sqrt(2.0);
for (int i = 0, j = 0; i < n; i += 2, j++) {
double u = input[i];
double v = input[i + 1];
output[j] = (u + v) * norm; // sum → lower half
output[j + half] = (u - v) * norm; // diff → upper half
}
}
// Small test
public static void main(String[] args) {
double[] A = {1, 2, 3, 4, 5, 6, 7, 8};
double[] B = new double[A.length];
applyStage(A, B);
for (double x : B) {
System.out.printf("%6.3f", x);
System.out.println();
}
System.out.println();
}
}
Comments
Post a Comment