From f193348ec5086619a4ffe4c08ebd7ef2dcc901b3 Mon Sep 17 00:00:00 2001 From: Sam Nystrom Date: Thu, 7 Mar 2024 07:07:42 +0000 Subject: Vectorize linspace and intersperse implementations linspace is 4 times faster and intersperse is 12-13 times faster. --- assembly/index.ts | 68 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 22 deletions(-) (limited to 'assembly/index.ts') diff --git a/assembly/index.ts b/assembly/index.ts index cba8926..4bdb03f 100644 --- a/assembly/index.ts +++ b/assembly/index.ts @@ -1,29 +1,58 @@ +export function linspace(start: f32, stop: f32, n: i32): StaticArray { + const out = new StaticArray(n); + const ptr = changetype(out); -export function linspace(start: f64, stop: f64, n: i32): Float64Array { - const out = new Float64Array(n); - for (let i = 0; i < n; i++) { - out[i] = start + (stop - start) * i / (n-1); + const fac = (stop - start) / f32(n-1); + let value = f32x4(start, start + fac, start + fac*2, start + fac*3); + const vinc = f32x4(fac*4, fac*4, fac*4, fac*4); + for (let i = 0; i < n - n % 4; i += 4) { + v128.store(ptr + i * 4, value); + value = v128.add(value, vinc); + } + // fallthrough + switch (n % 4) { + case 3: out[n-3] = start + (stop - start) * f32(n-3) / f32(n-1); + case 2: out[n-2] = start + (stop - start) * f32(n-2) / f32(n-1); + case 1: out[n-1] = stop; } return out; } -export function intersperse(a: Float64Array, b: Float64Array): Float64Array { - const len = a.length < b.length ? a.length : b.length; - const out = new Float64Array(len * 2); - for (let i = 0; i < out.length / 2; i++) { - out[i*2] = a[i]; - out[i*2+1] = b[i]; +export function intersperse(a: StaticArray, b: StaticArray): StaticArray { + const aptr = changetype(a); + const bptr = changetype(b); + let len = a.length < b.length ? a.length : b.length; + const out = new StaticArray(len * 2); + const outptr = changetype(out); + + for (let i = 0; i < len - len % 4; i += 4) { + const va = v128.load(aptr + i * 4); + const vb = v128.load(bptr + i * 4); + v128.store(outptr + (i*2) * 4, v128.shuffle(va, vb, 0, 4, 1, 5)); + v128.store(outptr + (i*2+4) * 4, v128.shuffle(va, vb, 2, 6, 3, 7)); + } + // fallthrough + switch (len % 4) { + case 3: + out[out.length-6] = a[len-3]; + out[out.length-5] = b[len-3]; + case 2: + out[out.length-4] = a[len-2]; + out[out.length-3] = b[len-2]; + case 1: + out[out.length-2] = a[len-1]; + out[out.length-1] = b[len-1]; } return out; } -export function dft(x: Float64Array): Float64Array { - const out = new Float64Array(x.length); +export function dft(x: StaticArray): StaticArray { + const out = new StaticArray(x.length); for (let k = 0; k < out.length - out.length % 2; k += 2) { for (let n = 0; n < x.length - x.length % 2; n += 2) { - const y = -2.0 * Math.PI * k / x.length * n; - const u = Math.cos(y); - const v = Math.sin(y); + const y = -2.0 * Mathf.PI * k / x.length * n; + const u = Mathf.cos(y); + const v = Mathf.sin(y); out[k] = x[n] * u - x[n+1] * v; out[k+1] = x[n] * v + x[n+1] * u; } @@ -31,11 +60,6 @@ export function dft(x: Float64Array): Float64Array { return out; } -export function fft(x: Float64Array): Float64Array { - //const out = new Float64Array(x.length); +export function fft(x: StaticArray): StaticArray { return dft(x); -} - -export function add(a: i32, b: i32): i32 { - return a + b; -} +} \ No newline at end of file -- cgit v1.2.3