diff options
| author | Sam Nystrom <sam@samnystrom.dev> | 2024-03-07 21:47:56 -0500 |
|---|---|---|
| committer | Sam Nystrom <sam@samnystrom.dev> | 2024-03-07 21:47:56 -0500 |
| commit | 1db01eaf7f1a0e06abe689c55f746518598d5f4c (patch) | |
| tree | 7e1056da01130d6c1cbc6d9fbf4ec58910eb9184 /assembly/index.ts | |
| parent | f193348ec5086619a4ffe4c08ebd7ef2dcc901b3 (diff) | |
Begin vectorizing the Math node
Diffstat (limited to 'assembly/index.ts')
| -rw-r--r-- | assembly/index.ts | 159 |
1 files changed, 159 insertions, 0 deletions
diff --git a/assembly/index.ts b/assembly/index.ts index 4bdb03f..1a9ca94 100644 --- a/assembly/index.ts +++ b/assembly/index.ts @@ -1,3 +1,162 @@ +export enum MathOp { + Add, + Sub, + Mul, + Div, + Pow, + Log, + Sqrt, + Exp, + + Min, + Max, + Lt, + Gt, + Sign, + + Round, + Floor, + Ceil, + Trunc, + Frac, + Mod, + Snap, + Clamp, + + Sin, + Cos, + Tan, + Asin, + Acos, + Atan, + Atan2, + Sinh, + Cosh, + Tanh, + + ToRad, + ToDeg, +} + +function unaryMathS(op: MathOp, x: f32): f32 { + switch (op) { + case MathOp.Sqrt: return Mathf.sqrt(x); + case MathOp.Exp: return Mathf.exp(x); + case MathOp.Sign: return Mathf.sign(x); + default: return 0; + } +} + +function unaryMathV(op: MathOp, x: v128): v128 { + const zero = f32x4(0,0,0,0); + switch (op) { + case MathOp.Sqrt: return v128.sqrt<f32>(x); + case MathOp.Exp: return f32x4( + Mathf.exp(v128.extract_lane<f32>(x, 0)), + Mathf.exp(v128.extract_lane<f32>(x, 1)), + Mathf.exp(v128.extract_lane<f32>(x, 2)), + Mathf.exp(v128.extract_lane<f32>(x, 3)), + ); + case MathOp.Sign: return v128.sub<f32>(v128.gt<f32>(x, zero), v128.lt<f32>(x, zero)); + default: return zero; + } +} + +function binaryMathS(op: MathOp, a: f32, b: f32): f32 { + switch (op) { + case MathOp.Add: return a + b; + case MathOp.Sub: return a - b; + case MathOp.Mul: return a * b; + case MathOp.Div: return a / b; + default: return 0; + } +}; + +function binaryMathV(op: MathOp, a: v128, b: v128): v128 { + switch (op) { + case MathOp.Add: return v128.add<f32>(a, b); + case MathOp.Sub: return v128.sub<f32>(a, b); + case MathOp.Mul: return v128.mul<f32>(a, b); + case MathOp.Div: return v128.div<f32>(a, b); + default: return f32x4(0,0,0,0); + } +}; + +export { unaryMathS as mathS }; + +export function mathV(op: MathOp, x: StaticArray<f32>): StaticArray<f32> { + const len = x.length; + const out = new StaticArray<f32>(len); + const outptr = changetype<i32>(out); + const xptr = changetype<i32>(x); + for (let i = 0; i < len - len % 4; i += 4) { + v128.store(outptr + i*4, unaryMathV(op, v128.load(xptr + i*4))); + } + // fallthrough + switch (len % 4) { + case 3: out[len-3] = unaryMathS(op, x[len-3]); + case 2: out[len-2] = unaryMathS(op, x[len-2]); + case 1: out[len-1] = unaryMathS(op, x[len-1]); + } + return out; +} + +export { binaryMathS as mathSS }; + +export function mathVS(op: MathOp, a: StaticArray<f32>, b: f32): StaticArray<f32> { + const len = a.length; + const out = new StaticArray<f32>(len); + const outptr = changetype<i32>(out); + const aptr = changetype<i32>(a); + const vb = f32x4(b,b,b,b); + for (let i = 0; i < len - len % 4; i += 4) { + v128.store(outptr + i*4, binaryMathV(op, v128.load(aptr + i*4), vb)); + } + // fallthrough + switch (len % 4) { + case 3: out[len-3] = binaryMathS(op, a[len-3], b); + case 2: out[len-2] = binaryMathS(op, a[len-2], b); + case 1: out[len-1] = binaryMathS(op, a[len-1], b); + } + return out; +} + +export function mathSV(op: MathOp, a: f32, b: StaticArray<f32>): StaticArray<f32> { + const len = b.length; + const out = new StaticArray<f32>(len); + const outptr = changetype<i32>(out); + const bptr = changetype<i32>(b); + const va = f32x4(a,a,a,a); + for (let i = 0; i < len - len % 4; i += 4) { + v128.store(outptr + i*4, binaryMathV(op, va, v128.load(bptr + i*4))); + } + // fallthrough + switch (len % 4) { + case 3: out[len-3] = binaryMathS(op, a, b[len-3]); + case 2: out[len-2] = binaryMathS(op, a, b[len-2]); + case 1: out[len-1] = binaryMathS(op, a, b[len-1]); + } + return out; +} + +export function mathVV(op: MathOp, a: StaticArray<f32>, b: StaticArray<f32>): StaticArray<f32> { + const len = a.length < b.length ? a.length : b.length; + const out = new StaticArray<f32>(len); + const outptr = changetype<i32>(out); + const aptr = changetype<i32>(a); + const bptr = changetype<i32>(b); + for (let i = 0; i < len - len % 4; i += 4) { + v128.store(outptr + i*4, binaryMathV(op, v128.load(aptr + i*4), v128.load(bptr + i*4))); + } + // fallthrough + switch (len % 4) { + case 3: out[len-3] = binaryMathS(op, a[len-3], b[len-3]); + case 2: out[len-2] = binaryMathS(op, a[len-2], b[len-2]); + case 1: out[len-1] = binaryMathS(op, a[len-1], b[len-1]); + } + return out; +} + export function linspace(start: f32, stop: f32, n: i32): StaticArray<f32> { const out = new StaticArray<f32>(n); const ptr = changetype<i32>(out); |
