summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Nystrom <sam@samnystrom.dev>2024-03-08 20:27:56 +0000
committerSam Nystrom <sam@samnystrom.dev>2024-03-09 02:05:58 -0500
commit395f3700281da999226a10c1b24da52252f49c95 (patch)
tree6434c9cd1992cdfbb174f8e6ffaad8bc4b806fdb
parent41213e45761fc1dd795d462ad7bc719533efd09e (diff)
Implement vectorized FFT
-rw-r--r--assembly/index.ts123
-rw-r--r--src/wasm.ts1
2 files changed, 90 insertions, 34 deletions
diff --git a/assembly/index.ts b/assembly/index.ts
index 7af2896..e94b0a8 100644
--- a/assembly/index.ts
+++ b/assembly/index.ts
@@ -1,9 +1,34 @@
+@inline
+function vf32(x: f32): v128 {
+ return f32x4(x,x,x,x);
+}
+
const zero = f32x4(0,0,0,0);
const one = f32x4(1,1,1,1);
const pi = f32x4(Mathf.PI, Mathf.PI, Mathf.PI, Mathf.PI);
const deg_to_rad = v128.div<f32>(pi, f32x4(180,180,180,180));
const rad_to_deg = v128.div<f32>(f32x4(180,180,180,180), pi);
+@inline
+function vsin(x: v128): v128 {
+ return f32x4(
+ Mathf.sin(v128.extract_lane<f32>(x, 0)),
+ Mathf.sin(v128.extract_lane<f32>(x, 1)),
+ Mathf.sin(v128.extract_lane<f32>(x, 2)),
+ Mathf.sin(v128.extract_lane<f32>(x, 3)),
+ );
+}
+
+@inline
+function vcos(x: v128): v128 {
+ return f32x4(
+ Mathf.cos(v128.extract_lane<f32>(x, 0)),
+ Mathf.cos(v128.extract_lane<f32>(x, 1)),
+ Mathf.cos(v128.extract_lane<f32>(x, 2)),
+ Mathf.cos(v128.extract_lane<f32>(x, 3)),
+ );
+}
+
export enum MathOp {
Add,
Sub,
@@ -78,17 +103,17 @@ function unaryMathV(op: MathOp, x: v128): v128 {
switch (op) {
case MathOp.Sqrt: return v128.sqrt<f32>(x);
- case MathOp.Sign: return v128.sub<f32>(v128.gt<f32>(x, zero), v128.lt<f32>(x, zero));
+ case MathOp.Sign: return v128.sub<f32>(v128.gt<f32>(x, vf32(0)), v128.lt<f32>(x, vf32(0)));
case MathOp.Round: return v128.nearest<f32>(x);
case MathOp.Floor: return v128.floor<f32>(x);
case MathOp.Ceil: return v128.ceil<f32>(x);
case MathOp.Trunc: return v128.trunc<f32>(x);
case MathOp.Frac: return v128.sub<f32>(x, v128.trunc<f32>(x));
- case MathOp.Clamp: return v128.max<f32>(zero, v128.min<f32>(x, one));
+ case MathOp.Clamp: return v128.max<f32>(zero, v128.min<f32>(x, vf32(1)));
- case MathOp.ToRad: return v128.mul<f32>(x, deg_to_rad);
- case MathOp.ToDeg: return v128.mul<f32>(x, rad_to_deg);
+ case MathOp.ToRad: return v128.mul<f32>(x, vf32(Mathf.PI / 180));
+ case MathOp.ToDeg: return v128.mul<f32>(x, vf32(180 / Mathf.PI));
// fallthrough
case MathOp.Exp:
@@ -107,7 +132,7 @@ function unaryMathV(op: MathOp, x: v128): v128 {
unaryMathS(op, v128.extract_lane<f32>(x, 2)),
unaryMathS(op, v128.extract_lane<f32>(x, 3)),
);
- default: return zero;
+ default: return vf32(0);
}
}
@@ -158,7 +183,7 @@ function binaryMathV(op: MathOp, a: v128, b: v128): v128 {
binaryMathS(op, v128.extract_lane<f32>(a, 2), v128.extract_lane<f32>(b, 2)),
binaryMathS(op, v128.extract_lane<f32>(a, 3), v128.extract_lane<f32>(b, 3)),
);
- default: return f32x4(0,0,0,0);
+ default: return vf32(0);
}
};
@@ -167,8 +192,8 @@ export { unaryMathS as mathS };
export function mathV(op: MathOp, x: StaticArray<f32>): StaticArray<f32> {
const len = x.length;
const out = new StaticArray<f32>(len);
- const outptr = changetype<i32>(out);
- const xptr = changetype<i32>(x);
+ const outptr = changetype<usize>(out);
+ const xptr = changetype<usize>(x);
for (let i = 0; i < len - len % 4; i += 4) {
v128.store(outptr + i*4, unaryMathV(op, v128.load(xptr + i*4)));
}
@@ -186,9 +211,9 @@ export { binaryMathS as mathSS };
export function mathVS(op: MathOp, a: StaticArray<f32>, b: f32): StaticArray<f32> {
const len = a.length;
const out = new StaticArray<f32>(len);
- const outptr = changetype<i32>(out);
- const aptr = changetype<i32>(a);
- const vb = f32x4(b,b,b,b);
+ const outptr = changetype<usize>(out);
+ const aptr = changetype<usize>(a);
+ const vb = vf32(b);
for (let i = 0; i < len - len % 4; i += 4) {
v128.store(outptr + i*4, binaryMathV(op, v128.load(aptr + i*4), vb));
}
@@ -204,9 +229,9 @@ export function mathVS(op: MathOp, a: StaticArray<f32>, b: f32): StaticArray<f32
export function mathSV(op: MathOp, a: f32, b: StaticArray<f32>): StaticArray<f32> {
const len = b.length;
const out = new StaticArray<f32>(len);
- const outptr = changetype<i32>(out);
- const bptr = changetype<i32>(b);
- const va = f32x4(a,a,a,a);
+ const outptr = changetype<usize>(out);
+ const bptr = changetype<usize>(b);
+ const va = vf32(a);
for (let i = 0; i < len - len % 4; i += 4) {
v128.store(outptr + i*4, binaryMathV(op, va, v128.load(bptr + i*4)));
}
@@ -222,9 +247,9 @@ export function mathSV(op: MathOp, a: f32, b: StaticArray<f32>): StaticArray<f32
export function mathVV(op: MathOp, a: StaticArray<f32>, b: StaticArray<f32>): StaticArray<f32> {
const len = a.length < b.length ? a.length : b.length;
const out = new StaticArray<f32>(len);
- const outptr = changetype<i32>(out);
- const aptr = changetype<i32>(a);
- const bptr = changetype<i32>(b);
+ const outptr = changetype<usize>(out);
+ const aptr = changetype<usize>(a);
+ const bptr = changetype<usize>(b);
for (let i = 0; i < len - len % 4; i += 4) {
v128.store(outptr + i*4, binaryMathV(op, v128.load(aptr + i*4), v128.load(bptr + i*4)));
}
@@ -239,11 +264,11 @@ export function mathVV(op: MathOp, a: StaticArray<f32>, b: StaticArray<f32>): St
export function linspace(start: f32, stop: f32, n: i32): StaticArray<f32> {
const out = new StaticArray<f32>(n);
- const ptr = changetype<i32>(out);
+ const ptr = changetype<usize>(out);
const fac = (stop - start) / f32(n-1);
let value = f32x4(start, start + fac, start + fac*2, start + fac*3);
- const vinc = f32x4(fac*4, fac*4, fac*4, fac*4);
+ const vinc = vf32(fac*4);
for (let i = 0; i < n - n % 4; i += 4) {
v128.store(ptr + i * 4, value);
value = v128.add<f32>(value, vinc);
@@ -258,11 +283,11 @@ export function linspace(start: f32, stop: f32, n: i32): StaticArray<f32> {
}
export function intersperse(a: StaticArray<f32>, b: StaticArray<f32>): StaticArray<f32> {
- const aptr = changetype<i32>(a);
- const bptr = changetype<i32>(b);
+ const aptr = changetype<usize>(a);
+ const bptr = changetype<usize>(b);
let len = a.length < b.length ? a.length : b.length;
const out = new StaticArray<f32>(len * 2);
- const outptr = changetype<i32>(out);
+ const outptr = changetype<usize>(out);
for (let i = 0; i < len - len % 4; i += 4) {
const va = v128.load(aptr + i * 4);
@@ -294,20 +319,52 @@ export function unzip(x: StaticArray<f32>): StaticArray<f32> {
return out;
}
-export function dft(x: StaticArray<f32>): StaticArray<f32> {
- const out = new StaticArray<f32>(x.length);
- for (let k = 0; k < out.length - out.length % 2; k += 2) {
- for (let n = 0; n < x.length - x.length % 2; n += 2) {
- const y = -<f32>2.0 * Mathf.PI * <f32>k / <f32>x.length * <f32>n;
- const u = Mathf.cos(y);
- const v = Mathf.sin(y);
- out[k] += x[n] * u - x[n+1] * v;
- out[k+1] += x[n] * v + x[n+1] * u;
+function ditfft2(x: StaticArray<f32>, n: i32, s: i32, out: StaticArray<f32>): void {
+ if (n == 1) {
+ out[0] = x[0];
+ out[1] = x[1];
+ return;
+ } else if (n == 2) {
+ for (let k = 0; k < out.length - out.length % 2; k += 2) {
+ for (let n = 0; n < x.length - x.length % 2; n += 2) {
+ const y = -<f32>2 * Mathf.PI * <f32>k / <f32>x.length * <f32>n;
+ const u = Mathf.cos(y);
+ const v = Mathf.sin(y);
+ out[k] += x[n] * u - x[n+1] * v;
+ out[k+1] += x[n] * v + x[n+1] * u;
+ }
}
+ return;
+ }
+ ditfft2(x.slice<StaticArray<f32>>(0, n), n/2, 2*s, out.slice<StaticArray<f32>>(0, n));
+ ditfft2(x.slice<StaticArray<f32>>(s * 2), n/2, 2*s, out.slice<StaticArray<f32>>(n));
+ const outptr = changetype<usize>(out);
+ const twiddle = vf32(-<f32>2 * Mathf.PI / <f32>n);
+ let vk = f32x4(0,1,2,3);
+ for (let k = 0; k < n/2 - n/2 % 2; k += 4) {
+ const p = v128.load(outptr + k*2 * 4);
+ const y = v128.mul<f32>(twiddle, vk);
+ const tw = vsin(v128.neg<f32>(v128.sub<f32>(y, f32x4(Mathf.PI/2, 0, Mathf.PI/2, 0))));
+ const q = v128.mul<f32>(tw, v128.load(outptr + (k*2 + k*2) * 4));
+ v128.store(outptr + k*2 * 4, v128.add<f32>(p, q));
+ v128.store(outptr + (k*2 + n/2) * 4, v128.sub<f32>(p, q));
+ vk = v128.add<f32>(vk, vf32(1));
+ }
+ for (let k = n/2 - n/2 % 2; k < n/2; k++) {
+ const pr = unchecked(out[k*2]);
+ const pi = unchecked(out[k*2+1]);
+ const y = -<f32>2 * Mathf.PI * <f32>k / <f32>n;
+ const qr = Mathf.cos(y) * out[k*2+n/2]
+ const qi = Mathf.sin(y) * out[k*2+n/2+1];
+ unchecked(out[k*2] = pr + qr);
+ unchecked(out[k*2+1] = pi + qi);
+ unchecked(out[k*2+n/2] = pr - qr);
+ unchecked(out[k*2+n/2+1] = pi - qi);
}
- return out;
}
export function fft(x: StaticArray<f32>): StaticArray<f32> {
- return dft(x);
+ const out = new StaticArray<f32>(x.length);
+ ditfft2(x, x.length/2, 1, out);
+ return out;
} \ No newline at end of file
diff --git a/src/wasm.ts b/src/wasm.ts
index a62a8fc..cf7b552 100644
--- a/src/wasm.ts
+++ b/src/wasm.ts
@@ -6,6 +6,5 @@ export const {
linspace,
intersperse,
unzip,
- dft,
fft,
} = await instantiate(await WebAssembly.compileStreaming(fetch(url)), {}); \ No newline at end of file