summaryrefslogtreecommitdiff
path: root/assembly
diff options
context:
space:
mode:
authorSam Nystrom <sam@samnystrom.dev>2024-03-07 07:07:42 +0000
committerSam Nystrom <15555332-SamNystrom1@users.noreply.replit.com>2024-03-07 07:07:42 +0000
commitf193348ec5086619a4ffe4c08ebd7ef2dcc901b3 (patch)
treebace31cf98a9f73b7ea85f4e8c681ab1b1c1f46d /assembly
parent18c4c7e82e309e868e4e1e47460c5b3a0600847e (diff)
Vectorize linspace and intersperse implementations
linspace is 4 times faster and intersperse is 12-13 times faster.
Diffstat (limited to 'assembly')
-rw-r--r--assembly/index.ts68
1 files changed, 46 insertions, 22 deletions
diff --git a/assembly/index.ts b/assembly/index.ts
index cba8926..4bdb03f 100644
--- a/assembly/index.ts
+++ b/assembly/index.ts
@@ -1,29 +1,58 @@
+export function linspace(start: f32, stop: f32, n: i32): StaticArray<f32> {
+ const out = new StaticArray<f32>(n);
+ const ptr = changetype<i32>(out);
-export function linspace(start: f64, stop: f64, n: i32): Float64Array {
- const out = new Float64Array(n);
- for (let i = 0; i < n; i++) {
- out[i] = start + (stop - start) * <f64>i / <f64>(n-1);
+ const fac = (stop - start) / f32(n-1);
+ let value = f32x4(start, start + fac, start + fac*2, start + fac*3);
+ const vinc = f32x4(fac*4, fac*4, fac*4, fac*4);
+ for (let i = 0; i < n - n % 4; i += 4) {
+ v128.store(ptr + i * 4, value);
+ value = v128.add<f32>(value, vinc);
+ }
+ // fallthrough
+ switch (n % 4) {
+ case 3: out[n-3] = start + (stop - start) * f32(n-3) / f32(n-1);
+ case 2: out[n-2] = start + (stop - start) * f32(n-2) / f32(n-1);
+ case 1: out[n-1] = stop;
}
return out;
}
-export function intersperse(a: Float64Array, b: Float64Array): Float64Array {
- const len = a.length < b.length ? a.length : b.length;
- const out = new Float64Array(len * 2);
- for (let i = 0; i < out.length / 2; i++) {
- out[i*2] = a[i];
- out[i*2+1] = b[i];
+export function intersperse(a: StaticArray<f32>, b: StaticArray<f32>): StaticArray<f32> {
+ const aptr = changetype<i32>(a);
+ const bptr = changetype<i32>(b);
+ let len = a.length < b.length ? a.length : b.length;
+ const out = new StaticArray<f32>(len * 2);
+ const outptr = changetype<i32>(out);
+
+ for (let i = 0; i < len - len % 4; i += 4) {
+ const va = v128.load(aptr + i * 4);
+ const vb = v128.load(bptr + i * 4);
+ v128.store(outptr + (i*2) * 4, v128.shuffle<f32>(va, vb, 0, 4, 1, 5));
+ v128.store(outptr + (i*2+4) * 4, v128.shuffle<f32>(va, vb, 2, 6, 3, 7));
+ }
+ // fallthrough
+ switch (len % 4) {
+ case 3:
+ out[out.length-6] = a[len-3];
+ out[out.length-5] = b[len-3];
+ case 2:
+ out[out.length-4] = a[len-2];
+ out[out.length-3] = b[len-2];
+ case 1:
+ out[out.length-2] = a[len-1];
+ out[out.length-1] = b[len-1];
}
return out;
}
-export function dft(x: Float64Array): Float64Array {
- const out = new Float64Array(x.length);
+export function dft(x: StaticArray<f32>): StaticArray<f32> {
+ const out = new StaticArray<f32>(x.length);
for (let k = 0; k < out.length - out.length % 2; k += 2) {
for (let n = 0; n < x.length - x.length % 2; n += 2) {
- const y = -2.0 * Math.PI * <f64>k / <f64>x.length * <f64>n;
- const u = Math.cos(y);
- const v = Math.sin(y);
+ const y = -<f32>2.0 * Mathf.PI * <f32>k / <f32>x.length * <f32>n;
+ const u = Mathf.cos(y);
+ const v = Mathf.sin(y);
out[k] = x[n] * u - x[n+1] * v;
out[k+1] = x[n] * v + x[n+1] * u;
}
@@ -31,11 +60,6 @@ export function dft(x: Float64Array): Float64Array {
return out;
}
-export function fft(x: Float64Array): Float64Array {
- //const out = new Float64Array(x.length);
+export function fft(x: StaticArray<f32>): StaticArray<f32> {
return dft(x);
-}
-
-export function add(a: i32, b: i32): i32 {
- return a + b;
-}
+} \ No newline at end of file