summaryrefslogtreecommitdiff
path: root/assembly/index.ts
diff options
context:
space:
mode:
authorSam Nystrom <sam@samnystrom.dev>2024-03-07 21:47:56 -0500
committerSam Nystrom <sam@samnystrom.dev>2024-03-07 21:47:56 -0500
commit1db01eaf7f1a0e06abe689c55f746518598d5f4c (patch)
tree7e1056da01130d6c1cbc6d9fbf4ec58910eb9184 /assembly/index.ts
parentf193348ec5086619a4ffe4c08ebd7ef2dcc901b3 (diff)
Begin vectorizing the Math node
Diffstat (limited to 'assembly/index.ts')
-rw-r--r--assembly/index.ts159
1 files changed, 159 insertions, 0 deletions
diff --git a/assembly/index.ts b/assembly/index.ts
index 4bdb03f..1a9ca94 100644
--- a/assembly/index.ts
+++ b/assembly/index.ts
@@ -1,3 +1,162 @@
+export enum MathOp {
+ Add,
+ Sub,
+ Mul,
+ Div,
+ Pow,
+ Log,
+ Sqrt,
+ Exp,
+
+ Min,
+ Max,
+ Lt,
+ Gt,
+ Sign,
+
+ Round,
+ Floor,
+ Ceil,
+ Trunc,
+ Frac,
+ Mod,
+ Snap,
+ Clamp,
+
+ Sin,
+ Cos,
+ Tan,
+ Asin,
+ Acos,
+ Atan,
+ Atan2,
+ Sinh,
+ Cosh,
+ Tanh,
+
+ ToRad,
+ ToDeg,
+}
+
+function unaryMathS(op: MathOp, x: f32): f32 {
+ switch (op) {
+ case MathOp.Sqrt: return Mathf.sqrt(x);
+ case MathOp.Exp: return Mathf.exp(x);
+ case MathOp.Sign: return Mathf.sign(x);
+ default: return 0;
+ }
+}
+
+function unaryMathV(op: MathOp, x: v128): v128 {
+ const zero = f32x4(0,0,0,0);
+ switch (op) {
+ case MathOp.Sqrt: return v128.sqrt<f32>(x);
+ case MathOp.Exp: return f32x4(
+ Mathf.exp(v128.extract_lane<f32>(x, 0)),
+ Mathf.exp(v128.extract_lane<f32>(x, 1)),
+ Mathf.exp(v128.extract_lane<f32>(x, 2)),
+ Mathf.exp(v128.extract_lane<f32>(x, 3)),
+ );
+ case MathOp.Sign: return v128.sub<f32>(v128.gt<f32>(x, zero), v128.lt<f32>(x, zero));
+ default: return zero;
+ }
+}
+
+function binaryMathS(op: MathOp, a: f32, b: f32): f32 {
+ switch (op) {
+ case MathOp.Add: return a + b;
+ case MathOp.Sub: return a - b;
+ case MathOp.Mul: return a * b;
+ case MathOp.Div: return a / b;
+ default: return 0;
+ }
+};
+
+function binaryMathV(op: MathOp, a: v128, b: v128): v128 {
+ switch (op) {
+ case MathOp.Add: return v128.add<f32>(a, b);
+ case MathOp.Sub: return v128.sub<f32>(a, b);
+ case MathOp.Mul: return v128.mul<f32>(a, b);
+ case MathOp.Div: return v128.div<f32>(a, b);
+ default: return f32x4(0,0,0,0);
+ }
+};
+
+export { unaryMathS as mathS };
+
+export function mathV(op: MathOp, x: StaticArray<f32>): StaticArray<f32> {
+ const len = x.length;
+ const out = new StaticArray<f32>(len);
+ const outptr = changetype<i32>(out);
+ const xptr = changetype<i32>(x);
+ for (let i = 0; i < len - len % 4; i += 4) {
+ v128.store(outptr + i*4, unaryMathV(op, v128.load(xptr + i*4)));
+ }
+ // fallthrough
+ switch (len % 4) {
+ case 3: out[len-3] = unaryMathS(op, x[len-3]);
+ case 2: out[len-2] = unaryMathS(op, x[len-2]);
+ case 1: out[len-1] = unaryMathS(op, x[len-1]);
+ }
+ return out;
+}
+
+export { binaryMathS as mathSS };
+
+export function mathVS(op: MathOp, a: StaticArray<f32>, b: f32): StaticArray<f32> {
+ const len = a.length;
+ const out = new StaticArray<f32>(len);
+ const outptr = changetype<i32>(out);
+ const aptr = changetype<i32>(a);
+ const vb = f32x4(b,b,b,b);
+ for (let i = 0; i < len - len % 4; i += 4) {
+ v128.store(outptr + i*4, binaryMathV(op, v128.load(aptr + i*4), vb));
+ }
+ // fallthrough
+ switch (len % 4) {
+ case 3: out[len-3] = binaryMathS(op, a[len-3], b);
+ case 2: out[len-2] = binaryMathS(op, a[len-2], b);
+ case 1: out[len-1] = binaryMathS(op, a[len-1], b);
+ }
+ return out;
+}
+
+export function mathSV(op: MathOp, a: f32, b: StaticArray<f32>): StaticArray<f32> {
+ const len = b.length;
+ const out = new StaticArray<f32>(len);
+ const outptr = changetype<i32>(out);
+ const bptr = changetype<i32>(b);
+ const va = f32x4(a,a,a,a);
+ for (let i = 0; i < len - len % 4; i += 4) {
+ v128.store(outptr + i*4, binaryMathV(op, va, v128.load(bptr + i*4)));
+ }
+ // fallthrough
+ switch (len % 4) {
+ case 3: out[len-3] = binaryMathS(op, a, b[len-3]);
+ case 2: out[len-2] = binaryMathS(op, a, b[len-2]);
+ case 1: out[len-1] = binaryMathS(op, a, b[len-1]);
+ }
+ return out;
+}
+
+export function mathVV(op: MathOp, a: StaticArray<f32>, b: StaticArray<f32>): StaticArray<f32> {
+ const len = a.length < b.length ? a.length : b.length;
+ const out = new StaticArray<f32>(len);
+ const outptr = changetype<i32>(out);
+ const aptr = changetype<i32>(a);
+ const bptr = changetype<i32>(b);
+ for (let i = 0; i < len - len % 4; i += 4) {
+ v128.store(outptr + i*4, binaryMathV(op, v128.load(aptr + i*4), v128.load(bptr + i*4)));
+ }
+ // fallthrough
+ switch (len % 4) {
+ case 3: out[len-3] = binaryMathS(op, a[len-3], b[len-3]);
+ case 2: out[len-2] = binaryMathS(op, a[len-2], b[len-2]);
+ case 1: out[len-1] = binaryMathS(op, a[len-1], b[len-1]);
+ }
+ return out;
+}
+
export function linspace(start: f32, stop: f32, n: i32): StaticArray<f32> {
const out = new StaticArray<f32>(n);
const ptr = changetype<i32>(out);