diff options
| -rw-r--r-- | assembly/index.ts | 123 | ||||
| -rw-r--r-- | src/wasm.ts | 1 |
2 files changed, 90 insertions, 34 deletions
diff --git a/assembly/index.ts b/assembly/index.ts index 7af2896..e94b0a8 100644 --- a/assembly/index.ts +++ b/assembly/index.ts @@ -1,9 +1,34 @@ +@inline +function vf32(x: f32): v128 { + return f32x4(x,x,x,x); +} + const zero = f32x4(0,0,0,0); const one = f32x4(1,1,1,1); const pi = f32x4(Mathf.PI, Mathf.PI, Mathf.PI, Mathf.PI); const deg_to_rad = v128.div<f32>(pi, f32x4(180,180,180,180)); const rad_to_deg = v128.div<f32>(f32x4(180,180,180,180), pi); +@inline +function vsin(x: v128): v128 { + return f32x4( + Mathf.sin(v128.extract_lane<f32>(x, 0)), + Mathf.sin(v128.extract_lane<f32>(x, 1)), + Mathf.sin(v128.extract_lane<f32>(x, 2)), + Mathf.sin(v128.extract_lane<f32>(x, 3)), + ); +} + +@inline +function vcos(x: v128): v128 { + return f32x4( + Mathf.cos(v128.extract_lane<f32>(x, 0)), + Mathf.cos(v128.extract_lane<f32>(x, 1)), + Mathf.cos(v128.extract_lane<f32>(x, 2)), + Mathf.cos(v128.extract_lane<f32>(x, 3)), + ); +} + export enum MathOp { Add, Sub, @@ -78,17 +103,17 @@ function unaryMathV(op: MathOp, x: v128): v128 { switch (op) { case MathOp.Sqrt: return v128.sqrt<f32>(x); - case MathOp.Sign: return v128.sub<f32>(v128.gt<f32>(x, zero), v128.lt<f32>(x, zero)); + case MathOp.Sign: return v128.sub<f32>(v128.gt<f32>(x, vf32(0)), v128.lt<f32>(x, vf32(0))); case MathOp.Round: return v128.nearest<f32>(x); case MathOp.Floor: return v128.floor<f32>(x); case MathOp.Ceil: return v128.ceil<f32>(x); case MathOp.Trunc: return v128.trunc<f32>(x); case MathOp.Frac: return v128.sub<f32>(x, v128.trunc<f32>(x)); - case MathOp.Clamp: return v128.max<f32>(zero, v128.min<f32>(x, one)); + case MathOp.Clamp: return v128.max<f32>(zero, v128.min<f32>(x, vf32(1))); - case MathOp.ToRad: return v128.mul<f32>(x, deg_to_rad); - case MathOp.ToDeg: return v128.mul<f32>(x, rad_to_deg); + case MathOp.ToRad: return v128.mul<f32>(x, vf32(Mathf.PI / 180)); + case MathOp.ToDeg: return v128.mul<f32>(x, vf32(180 / Mathf.PI)); // fallthrough case MathOp.Exp: @@ -107,7 +132,7 @@ function unaryMathV(op: MathOp, x: v128): v128 { unaryMathS(op, v128.extract_lane<f32>(x, 2)), unaryMathS(op, v128.extract_lane<f32>(x, 3)), ); - default: return zero; + default: return vf32(0); } } @@ -158,7 +183,7 @@ function binaryMathV(op: MathOp, a: v128, b: v128): v128 { binaryMathS(op, v128.extract_lane<f32>(a, 2), v128.extract_lane<f32>(b, 2)), binaryMathS(op, v128.extract_lane<f32>(a, 3), v128.extract_lane<f32>(b, 3)), ); - default: return f32x4(0,0,0,0); + default: return vf32(0); } }; @@ -167,8 +192,8 @@ export { unaryMathS as mathS }; export function mathV(op: MathOp, x: StaticArray<f32>): StaticArray<f32> { const len = x.length; const out = new StaticArray<f32>(len); - const outptr = changetype<i32>(out); - const xptr = changetype<i32>(x); + const outptr = changetype<usize>(out); + const xptr = changetype<usize>(x); for (let i = 0; i < len - len % 4; i += 4) { v128.store(outptr + i*4, unaryMathV(op, v128.load(xptr + i*4))); } @@ -186,9 +211,9 @@ export { binaryMathS as mathSS }; export function mathVS(op: MathOp, a: StaticArray<f32>, b: f32): StaticArray<f32> { const len = a.length; const out = new StaticArray<f32>(len); - const outptr = changetype<i32>(out); - const aptr = changetype<i32>(a); - const vb = f32x4(b,b,b,b); + const outptr = changetype<usize>(out); + const aptr = changetype<usize>(a); + const vb = vf32(b); for (let i = 0; i < len - len % 4; i += 4) { v128.store(outptr + i*4, binaryMathV(op, v128.load(aptr + i*4), vb)); } @@ -204,9 +229,9 @@ export function mathVS(op: MathOp, a: StaticArray<f32>, b: f32): StaticArray<f32 export function mathSV(op: MathOp, a: f32, b: StaticArray<f32>): StaticArray<f32> { const len = b.length; const out = new StaticArray<f32>(len); - const outptr = changetype<i32>(out); - const bptr = changetype<i32>(b); - const va = f32x4(a,a,a,a); + const outptr = changetype<usize>(out); + const bptr = changetype<usize>(b); + const va = vf32(a); for (let i = 0; i < len - len % 4; i += 4) { v128.store(outptr + i*4, binaryMathV(op, va, v128.load(bptr + i*4))); } @@ -222,9 +247,9 @@ export function mathSV(op: MathOp, a: f32, b: StaticArray<f32>): StaticArray<f32 export function mathVV(op: MathOp, a: StaticArray<f32>, b: StaticArray<f32>): StaticArray<f32> { const len = a.length < b.length ? a.length : b.length; const out = new StaticArray<f32>(len); - const outptr = changetype<i32>(out); - const aptr = changetype<i32>(a); - const bptr = changetype<i32>(b); + const outptr = changetype<usize>(out); + const aptr = changetype<usize>(a); + const bptr = changetype<usize>(b); for (let i = 0; i < len - len % 4; i += 4) { v128.store(outptr + i*4, binaryMathV(op, v128.load(aptr + i*4), v128.load(bptr + i*4))); } @@ -239,11 +264,11 @@ export function mathVV(op: MathOp, a: StaticArray<f32>, b: StaticArray<f32>): St export function linspace(start: f32, stop: f32, n: i32): StaticArray<f32> { const out = new StaticArray<f32>(n); - const ptr = changetype<i32>(out); + const ptr = changetype<usize>(out); const fac = (stop - start) / f32(n-1); let value = f32x4(start, start + fac, start + fac*2, start + fac*3); - const vinc = f32x4(fac*4, fac*4, fac*4, fac*4); + const vinc = vf32(fac*4); for (let i = 0; i < n - n % 4; i += 4) { v128.store(ptr + i * 4, value); value = v128.add<f32>(value, vinc); @@ -258,11 +283,11 @@ export function linspace(start: f32, stop: f32, n: i32): StaticArray<f32> { } export function intersperse(a: StaticArray<f32>, b: StaticArray<f32>): StaticArray<f32> { - const aptr = changetype<i32>(a); - const bptr = changetype<i32>(b); + const aptr = changetype<usize>(a); + const bptr = changetype<usize>(b); let len = a.length < b.length ? a.length : b.length; const out = new StaticArray<f32>(len * 2); - const outptr = changetype<i32>(out); + const outptr = changetype<usize>(out); for (let i = 0; i < len - len % 4; i += 4) { const va = v128.load(aptr + i * 4); @@ -294,20 +319,52 @@ export function unzip(x: StaticArray<f32>): StaticArray<f32> { return out; } -export function dft(x: StaticArray<f32>): StaticArray<f32> { - const out = new StaticArray<f32>(x.length); - for (let k = 0; k < out.length - out.length % 2; k += 2) { - for (let n = 0; n < x.length - x.length % 2; n += 2) { - const y = -<f32>2.0 * Mathf.PI * <f32>k / <f32>x.length * <f32>n; - const u = Mathf.cos(y); - const v = Mathf.sin(y); - out[k] += x[n] * u - x[n+1] * v; - out[k+1] += x[n] * v + x[n+1] * u; +function ditfft2(x: StaticArray<f32>, n: i32, s: i32, out: StaticArray<f32>): void { + if (n == 1) { + out[0] = x[0]; + out[1] = x[1]; + return; + } else if (n == 2) { + for (let k = 0; k < out.length - out.length % 2; k += 2) { + for (let n = 0; n < x.length - x.length % 2; n += 2) { + const y = -<f32>2 * Mathf.PI * <f32>k / <f32>x.length * <f32>n; + const u = Mathf.cos(y); + const v = Mathf.sin(y); + out[k] += x[n] * u - x[n+1] * v; + out[k+1] += x[n] * v + x[n+1] * u; + } } + return; + } + ditfft2(x.slice<StaticArray<f32>>(0, n), n/2, 2*s, out.slice<StaticArray<f32>>(0, n)); + ditfft2(x.slice<StaticArray<f32>>(s * 2), n/2, 2*s, out.slice<StaticArray<f32>>(n)); + const outptr = changetype<usize>(out); + const twiddle = vf32(-<f32>2 * Mathf.PI / <f32>n); + let vk = f32x4(0,1,2,3); + for (let k = 0; k < n/2 - n/2 % 2; k += 4) { + const p = v128.load(outptr + k*2 * 4); + const y = v128.mul<f32>(twiddle, vk); + const tw = vsin(v128.neg<f32>(v128.sub<f32>(y, f32x4(Mathf.PI/2, 0, Mathf.PI/2, 0)))); + const q = v128.mul<f32>(tw, v128.load(outptr + (k*2 + k*2) * 4)); + v128.store(outptr + k*2 * 4, v128.add<f32>(p, q)); + v128.store(outptr + (k*2 + n/2) * 4, v128.sub<f32>(p, q)); + vk = v128.add<f32>(vk, vf32(1)); + } + for (let k = n/2 - n/2 % 2; k < n/2; k++) { + const pr = unchecked(out[k*2]); + const pi = unchecked(out[k*2+1]); + const y = -<f32>2 * Mathf.PI * <f32>k / <f32>n; + const qr = Mathf.cos(y) * out[k*2+n/2] + const qi = Mathf.sin(y) * out[k*2+n/2+1]; + unchecked(out[k*2] = pr + qr); + unchecked(out[k*2+1] = pi + qi); + unchecked(out[k*2+n/2] = pr - qr); + unchecked(out[k*2+n/2+1] = pi - qi); } - return out; } export function fft(x: StaticArray<f32>): StaticArray<f32> { - return dft(x); + const out = new StaticArray<f32>(x.length); + ditfft2(x, x.length/2, 1, out); + return out; }
\ No newline at end of file diff --git a/src/wasm.ts b/src/wasm.ts index a62a8fc..cf7b552 100644 --- a/src/wasm.ts +++ b/src/wasm.ts @@ -6,6 +6,5 @@ export const { linspace, intersperse, unzip, - dft, fft, } = await instantiate(await WebAssembly.compileStreaming(fetch(url)), {});
\ No newline at end of file |
