From 1db01eaf7f1a0e06abe689c55f746518598d5f4c Mon Sep 17 00:00:00 2001 From: Sam Nystrom Date: Thu, 7 Mar 2024 21:47:56 -0500 Subject: Begin vectorizing the Math node --- assembly/index.ts | 159 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) (limited to 'assembly/index.ts') diff --git a/assembly/index.ts b/assembly/index.ts index 4bdb03f..1a9ca94 100644 --- a/assembly/index.ts +++ b/assembly/index.ts @@ -1,3 +1,162 @@ +export enum MathOp { + Add, + Sub, + Mul, + Div, + Pow, + Log, + Sqrt, + Exp, + + Min, + Max, + Lt, + Gt, + Sign, + + Round, + Floor, + Ceil, + Trunc, + Frac, + Mod, + Snap, + Clamp, + + Sin, + Cos, + Tan, + Asin, + Acos, + Atan, + Atan2, + Sinh, + Cosh, + Tanh, + + ToRad, + ToDeg, +} + +function unaryMathS(op: MathOp, x: f32): f32 { + switch (op) { + case MathOp.Sqrt: return Mathf.sqrt(x); + case MathOp.Exp: return Mathf.exp(x); + case MathOp.Sign: return Mathf.sign(x); + default: return 0; + } +} + +function unaryMathV(op: MathOp, x: v128): v128 { + const zero = f32x4(0,0,0,0); + switch (op) { + case MathOp.Sqrt: return v128.sqrt(x); + case MathOp.Exp: return f32x4( + Mathf.exp(v128.extract_lane(x, 0)), + Mathf.exp(v128.extract_lane(x, 1)), + Mathf.exp(v128.extract_lane(x, 2)), + Mathf.exp(v128.extract_lane(x, 3)), + ); + case MathOp.Sign: return v128.sub(v128.gt(x, zero), v128.lt(x, zero)); + default: return zero; + } +} + +function binaryMathS(op: MathOp, a: f32, b: f32): f32 { + switch (op) { + case MathOp.Add: return a + b; + case MathOp.Sub: return a - b; + case MathOp.Mul: return a * b; + case MathOp.Div: return a / b; + default: return 0; + } +}; + +function binaryMathV(op: MathOp, a: v128, b: v128): v128 { + switch (op) { + case MathOp.Add: return v128.add(a, b); + case MathOp.Sub: return v128.sub(a, b); + case MathOp.Mul: return v128.mul(a, b); + case MathOp.Div: return v128.div(a, b); + default: return f32x4(0,0,0,0); + } +}; + +export { unaryMathS as mathS }; + +export function mathV(op: MathOp, x: StaticArray): StaticArray { + const len = x.length; + const out = new StaticArray(len); + const outptr = changetype(out); + const xptr = changetype(x); + for (let i = 0; i < len - len % 4; i += 4) { + v128.store(outptr + i*4, unaryMathV(op, v128.load(xptr + i*4))); + } + // fallthrough + switch (len % 4) { + case 3: out[len-3] = unaryMathS(op, x[len-3]); + case 2: out[len-2] = unaryMathS(op, x[len-2]); + case 1: out[len-1] = unaryMathS(op, x[len-1]); + } + return out; +} + +export { binaryMathS as mathSS }; + +export function mathVS(op: MathOp, a: StaticArray, b: f32): StaticArray { + const len = a.length; + const out = new StaticArray(len); + const outptr = changetype(out); + const aptr = changetype(a); + const vb = f32x4(b,b,b,b); + for (let i = 0; i < len - len % 4; i += 4) { + v128.store(outptr + i*4, binaryMathV(op, v128.load(aptr + i*4), vb)); + } + // fallthrough + switch (len % 4) { + case 3: out[len-3] = binaryMathS(op, a[len-3], b); + case 2: out[len-2] = binaryMathS(op, a[len-2], b); + case 1: out[len-1] = binaryMathS(op, a[len-1], b); + } + return out; +} + +export function mathSV(op: MathOp, a: f32, b: StaticArray): StaticArray { + const len = b.length; + const out = new StaticArray(len); + const outptr = changetype(out); + const bptr = changetype(b); + const va = f32x4(a,a,a,a); + for (let i = 0; i < len - len % 4; i += 4) { + v128.store(outptr + i*4, binaryMathV(op, va, v128.load(bptr + i*4))); + } + // fallthrough + switch (len % 4) { + case 3: out[len-3] = binaryMathS(op, a, b[len-3]); + case 2: out[len-2] = binaryMathS(op, a, b[len-2]); + case 1: out[len-1] = binaryMathS(op, a, b[len-1]); + } + return out; +} + +export function mathVV(op: MathOp, a: StaticArray, b: StaticArray): StaticArray { + const len = a.length < b.length ? a.length : b.length; + const out = new StaticArray(len); + const outptr = changetype(out); + const aptr = changetype(a); + const bptr = changetype(b); + for (let i = 0; i < len - len % 4; i += 4) { + v128.store(outptr + i*4, binaryMathV(op, v128.load(aptr + i*4), v128.load(bptr + i*4))); + } + // fallthrough + switch (len % 4) { + case 3: out[len-3] = binaryMathS(op, a[len-3], b[len-3]); + case 2: out[len-2] = binaryMathS(op, a[len-2], b[len-2]); + case 1: out[len-1] = binaryMathS(op, a[len-1], b[len-1]); + } + return out; +} + export function linspace(start: f32, stop: f32, n: i32): StaticArray { const out = new StaticArray(n); const ptr = changetype(out); -- cgit v1.2.3