From 395f3700281da999226a10c1b24da52252f49c95 Mon Sep 17 00:00:00 2001 From: Sam Nystrom Date: Fri, 8 Mar 2024 20:27:56 +0000 Subject: Implement vectorized FFT --- assembly/index.ts | 123 +++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 90 insertions(+), 33 deletions(-) (limited to 'assembly') diff --git a/assembly/index.ts b/assembly/index.ts index 7af2896..e94b0a8 100644 --- a/assembly/index.ts +++ b/assembly/index.ts @@ -1,9 +1,34 @@ +@inline +function vf32(x: f32): v128 { + return f32x4(x,x,x,x); +} + const zero = f32x4(0,0,0,0); const one = f32x4(1,1,1,1); const pi = f32x4(Mathf.PI, Mathf.PI, Mathf.PI, Mathf.PI); const deg_to_rad = v128.div(pi, f32x4(180,180,180,180)); const rad_to_deg = v128.div(f32x4(180,180,180,180), pi); +@inline +function vsin(x: v128): v128 { + return f32x4( + Mathf.sin(v128.extract_lane(x, 0)), + Mathf.sin(v128.extract_lane(x, 1)), + Mathf.sin(v128.extract_lane(x, 2)), + Mathf.sin(v128.extract_lane(x, 3)), + ); +} + +@inline +function vcos(x: v128): v128 { + return f32x4( + Mathf.cos(v128.extract_lane(x, 0)), + Mathf.cos(v128.extract_lane(x, 1)), + Mathf.cos(v128.extract_lane(x, 2)), + Mathf.cos(v128.extract_lane(x, 3)), + ); +} + export enum MathOp { Add, Sub, @@ -78,17 +103,17 @@ function unaryMathV(op: MathOp, x: v128): v128 { switch (op) { case MathOp.Sqrt: return v128.sqrt(x); - case MathOp.Sign: return v128.sub(v128.gt(x, zero), v128.lt(x, zero)); + case MathOp.Sign: return v128.sub(v128.gt(x, vf32(0)), v128.lt(x, vf32(0))); case MathOp.Round: return v128.nearest(x); case MathOp.Floor: return v128.floor(x); case MathOp.Ceil: return v128.ceil(x); case MathOp.Trunc: return v128.trunc(x); case MathOp.Frac: return v128.sub(x, v128.trunc(x)); - case MathOp.Clamp: return v128.max(zero, v128.min(x, one)); + case MathOp.Clamp: return v128.max(zero, v128.min(x, vf32(1))); - case MathOp.ToRad: return v128.mul(x, deg_to_rad); - case MathOp.ToDeg: return v128.mul(x, rad_to_deg); + case MathOp.ToRad: return v128.mul(x, vf32(Mathf.PI / 180)); + case MathOp.ToDeg: return v128.mul(x, vf32(180 / Mathf.PI)); // fallthrough case MathOp.Exp: @@ -107,7 +132,7 @@ function unaryMathV(op: MathOp, x: v128): v128 { unaryMathS(op, v128.extract_lane(x, 2)), unaryMathS(op, v128.extract_lane(x, 3)), ); - default: return zero; + default: return vf32(0); } } @@ -158,7 +183,7 @@ function binaryMathV(op: MathOp, a: v128, b: v128): v128 { binaryMathS(op, v128.extract_lane(a, 2), v128.extract_lane(b, 2)), binaryMathS(op, v128.extract_lane(a, 3), v128.extract_lane(b, 3)), ); - default: return f32x4(0,0,0,0); + default: return vf32(0); } }; @@ -167,8 +192,8 @@ export { unaryMathS as mathS }; export function mathV(op: MathOp, x: StaticArray): StaticArray { const len = x.length; const out = new StaticArray(len); - const outptr = changetype(out); - const xptr = changetype(x); + const outptr = changetype(out); + const xptr = changetype(x); for (let i = 0; i < len - len % 4; i += 4) { v128.store(outptr + i*4, unaryMathV(op, v128.load(xptr + i*4))); } @@ -186,9 +211,9 @@ export { binaryMathS as mathSS }; export function mathVS(op: MathOp, a: StaticArray, b: f32): StaticArray { const len = a.length; const out = new StaticArray(len); - const outptr = changetype(out); - const aptr = changetype(a); - const vb = f32x4(b,b,b,b); + const outptr = changetype(out); + const aptr = changetype(a); + const vb = vf32(b); for (let i = 0; i < len - len % 4; i += 4) { v128.store(outptr + i*4, binaryMathV(op, v128.load(aptr + i*4), vb)); } @@ -204,9 +229,9 @@ export function mathVS(op: MathOp, a: StaticArray, b: f32): StaticArray): StaticArray { const len = b.length; const out = new StaticArray(len); - const outptr = changetype(out); - const bptr = changetype(b); - const va = f32x4(a,a,a,a); + const outptr = changetype(out); + const bptr = changetype(b); + const va = vf32(a); for (let i = 0; i < len - len % 4; i += 4) { v128.store(outptr + i*4, binaryMathV(op, va, v128.load(bptr + i*4))); } @@ -222,9 +247,9 @@ export function mathSV(op: MathOp, a: f32, b: StaticArray): StaticArray, b: StaticArray): StaticArray { const len = a.length < b.length ? a.length : b.length; const out = new StaticArray(len); - const outptr = changetype(out); - const aptr = changetype(a); - const bptr = changetype(b); + const outptr = changetype(out); + const aptr = changetype(a); + const bptr = changetype(b); for (let i = 0; i < len - len % 4; i += 4) { v128.store(outptr + i*4, binaryMathV(op, v128.load(aptr + i*4), v128.load(bptr + i*4))); } @@ -239,11 +264,11 @@ export function mathVV(op: MathOp, a: StaticArray, b: StaticArray): St export function linspace(start: f32, stop: f32, n: i32): StaticArray { const out = new StaticArray(n); - const ptr = changetype(out); + const ptr = changetype(out); const fac = (stop - start) / f32(n-1); let value = f32x4(start, start + fac, start + fac*2, start + fac*3); - const vinc = f32x4(fac*4, fac*4, fac*4, fac*4); + const vinc = vf32(fac*4); for (let i = 0; i < n - n % 4; i += 4) { v128.store(ptr + i * 4, value); value = v128.add(value, vinc); @@ -258,11 +283,11 @@ export function linspace(start: f32, stop: f32, n: i32): StaticArray { } export function intersperse(a: StaticArray, b: StaticArray): StaticArray { - const aptr = changetype(a); - const bptr = changetype(b); + const aptr = changetype(a); + const bptr = changetype(b); let len = a.length < b.length ? a.length : b.length; const out = new StaticArray(len * 2); - const outptr = changetype(out); + const outptr = changetype(out); for (let i = 0; i < len - len % 4; i += 4) { const va = v128.load(aptr + i * 4); @@ -294,20 +319,52 @@ export function unzip(x: StaticArray): StaticArray { return out; } -export function dft(x: StaticArray): StaticArray { - const out = new StaticArray(x.length); - for (let k = 0; k < out.length - out.length % 2; k += 2) { - for (let n = 0; n < x.length - x.length % 2; n += 2) { - const y = -2.0 * Mathf.PI * k / x.length * n; - const u = Mathf.cos(y); - const v = Mathf.sin(y); - out[k] += x[n] * u - x[n+1] * v; - out[k+1] += x[n] * v + x[n+1] * u; +function ditfft2(x: StaticArray, n: i32, s: i32, out: StaticArray): void { + if (n == 1) { + out[0] = x[0]; + out[1] = x[1]; + return; + } else if (n == 2) { + for (let k = 0; k < out.length - out.length % 2; k += 2) { + for (let n = 0; n < x.length - x.length % 2; n += 2) { + const y = -2 * Mathf.PI * k / x.length * n; + const u = Mathf.cos(y); + const v = Mathf.sin(y); + out[k] += x[n] * u - x[n+1] * v; + out[k+1] += x[n] * v + x[n+1] * u; + } } + return; + } + ditfft2(x.slice>(0, n), n/2, 2*s, out.slice>(0, n)); + ditfft2(x.slice>(s * 2), n/2, 2*s, out.slice>(n)); + const outptr = changetype(out); + const twiddle = vf32(-2 * Mathf.PI / n); + let vk = f32x4(0,1,2,3); + for (let k = 0; k < n/2 - n/2 % 2; k += 4) { + const p = v128.load(outptr + k*2 * 4); + const y = v128.mul(twiddle, vk); + const tw = vsin(v128.neg(v128.sub(y, f32x4(Mathf.PI/2, 0, Mathf.PI/2, 0)))); + const q = v128.mul(tw, v128.load(outptr + (k*2 + k*2) * 4)); + v128.store(outptr + k*2 * 4, v128.add(p, q)); + v128.store(outptr + (k*2 + n/2) * 4, v128.sub(p, q)); + vk = v128.add(vk, vf32(1)); + } + for (let k = n/2 - n/2 % 2; k < n/2; k++) { + const pr = unchecked(out[k*2]); + const pi = unchecked(out[k*2+1]); + const y = -2 * Mathf.PI * k / n; + const qr = Mathf.cos(y) * out[k*2+n/2] + const qi = Mathf.sin(y) * out[k*2+n/2+1]; + unchecked(out[k*2] = pr + qr); + unchecked(out[k*2+1] = pi + qi); + unchecked(out[k*2+n/2] = pr - qr); + unchecked(out[k*2+n/2+1] = pi - qi); } - return out; } export function fft(x: StaticArray): StaticArray { - return dft(x); + const out = new StaticArray(x.length); + ditfft2(x, x.length/2, 1, out); + return out; } \ No newline at end of file -- cgit v1.2.3