From 5d06d2359e8ad50f7a61ecd1787a0ad558329964 Mon Sep 17 00:00:00 2001 From: Sam Nystrom Date: Fri, 8 Mar 2024 07:55:23 +0000 Subject: Finish vectorized Math node implementation --- assembly/index.ts | 94 ++++++++++++++++++++++++++++++++++++++++++++++++++---- src/nodes/Math.tsx | 55 +++++--------------------------- 2 files changed, 95 insertions(+), 54 deletions(-) diff --git a/assembly/index.ts b/assembly/index.ts index 1a9ca94..fd4fdd5 100644 --- a/assembly/index.ts +++ b/assembly/index.ts @@ -1,3 +1,9 @@ +const zero = f32x4(0,0,0,0); +const one = f32x4(1,1,1,1); +const pi = f32x4(Mathf.PI, Mathf.PI, Mathf.PI, Mathf.PI); +const deg_to_rad = v128.div(pi, f32x4(180,180,180,180)); +const rad_to_deg = v128.div(f32x4(180,180,180,180), pi); + export enum MathOp { Add, Sub, @@ -42,22 +48,65 @@ function unaryMathS(op: MathOp, x: f32): f32 { switch (op) { case MathOp.Sqrt: return Mathf.sqrt(x); case MathOp.Exp: return Mathf.exp(x); + case MathOp.Sign: return Mathf.sign(x); + + case MathOp.Round: return Mathf.round(x); + case MathOp.Floor: return Mathf.floor(x); + case MathOp.Ceil: return Mathf.ceil(x); + case MathOp.Trunc: return Mathf.trunc(x); + case MathOp.Frac: return x - Mathf.trunc(x); + case MathOp.Clamp: return Mathf.max(0, Mathf.min(x, 1)); + + case MathOp.Sin: return Mathf.sin(x); + case MathOp.Cos: return Mathf.cos(x); + case MathOp.Tan: return Mathf.tan(x); + case MathOp.Asin: return Mathf.asin(x); + case MathOp.Acos: return Mathf.acos(x); + case MathOp.Atan: return Mathf.atan(x); + case MathOp.Sinh: return Mathf.sinh(x); + case MathOp.Cosh: return Mathf.cosh(x); + case MathOp.Tanh: return Mathf.tanh(x); + + case MathOp.ToRad: return x / 180 * Mathf.PI; + case MathOp.ToDeg: return x * 180 / Mathf.PI; default: return 0; } } function unaryMathV(op: MathOp, x: v128): v128 { - const zero = f32x4(0,0,0,0); switch (op) { case MathOp.Sqrt: return v128.sqrt(x); - case MathOp.Exp: return f32x4( - Mathf.exp(v128.extract_lane(x, 0)), - Mathf.exp(v128.extract_lane(x, 1)), - Mathf.exp(v128.extract_lane(x, 2)), - Mathf.exp(v128.extract_lane(x, 3)), - ); + case MathOp.Sign: return v128.sub(v128.gt(x, zero), v128.lt(x, zero)); + + case MathOp.Round: return v128.nearest(x); + case MathOp.Floor: return v128.floor(x); + case MathOp.Ceil: return v128.ceil(x); + case MathOp.Trunc: return v128.trunc(x); + case MathOp.Frac: return v128.sub(x, v128.trunc(x)); + case MathOp.Clamp: return v128.max(zero, v128.min(x, one)); + + case MathOp.ToRad: return v128.mul(x, deg_to_rad); + case MathOp.ToDeg: return v128.mul(x, rad_to_deg); + + // fallthrough + case MathOp.Exp: + case MathOp.Sin: + case MathOp.Cos: + case MathOp.Tan: + case MathOp.Asin: + case MathOp.Acos: + case MathOp.Atan: + case MathOp.Sinh: + case MathOp.Cosh: + case MathOp.Tanh: + return f32x4( + unaryMathS(op, v128.extract_lane(x, 0)), + unaryMathS(op, v128.extract_lane(x, 1)), + unaryMathS(op, v128.extract_lane(x, 2)), + unaryMathS(op, v128.extract_lane(x, 3)), + ); default: return zero; } } @@ -68,6 +117,18 @@ function binaryMathS(op: MathOp, a: f32, b: f32): f32 { case MathOp.Sub: return a - b; case MathOp.Mul: return a * b; case MathOp.Div: return a / b; + case MathOp.Pow: return Mathf.pow(a, b); + case MathOp.Log: return Mathf.log(b) / Mathf.log(a); + + case MathOp.Max: return Mathf.max(a, b); + case MathOp.Min: return Mathf.min(a, b); + case MathOp.Lt: return a < b ? 1 : 0; + case MathOp.Gt: return a > b ? 1 : 0; + + case MathOp.Mod: return a % b; + case MathOp.Snap: return Mathf.round(a / b) * b; + + case MathOp.Atan2: return Mathf.atan2(a, b); default: return 0; } }; @@ -78,6 +139,25 @@ function binaryMathV(op: MathOp, a: v128, b: v128): v128 { case MathOp.Sub: return v128.sub(a, b); case MathOp.Mul: return v128.mul(a, b); case MathOp.Div: return v128.div(a, b); + + case MathOp.Max: return v128.max(a, b); + case MathOp.Min: return v128.min(a, b); + case MathOp.Lt: return v128.lt(a, b); + case MathOp.Gt: return v128.gt(a, b); + + case MathOp.Mod: return v128.sub(a, v128.mul(b, v128.trunc(v128.div(a, b)))); + case MathOp.Snap: return v128.mul(v128.nearest(v128.div(a, b)), b); + + // fallthrough + case MathOp.Pow: + case MathOp.Log: + case MathOp.Atan2: + return f32x4( + binaryMathS(op, v128.extract_lane(a, 0), v128.extract_lane(b, 0)), + binaryMathS(op, v128.extract_lane(a, 1), v128.extract_lane(b, 1)), + binaryMathS(op, v128.extract_lane(a, 2), v128.extract_lane(b, 2)), + binaryMathS(op, v128.extract_lane(a, 3), v128.extract_lane(b, 3)), + ); default: return f32x4(0,0,0,0); } }; diff --git a/src/nodes/Math.tsx b/src/nodes/Math.tsx index a10aeaa..fe46fb4 100644 --- a/src/nodes/Math.tsx +++ b/src/nodes/Math.tsx @@ -49,51 +49,12 @@ export enum MathOpConv { export const MathOp = { ...MathOpFunc, ...MathOpCmp, ...MathOpRound, ...MathOpTrig, ...MathOpConv }; export type MathOp = typeof MathOp; -const binaryOps = { - [MathOp.Add]: (a, b) => a + b, - [MathOp.Sub]: (a, b) => a - b, - [MathOp.Mul]: (a, b) => a * b, - [MathOp.Div]: (a, b) => a / b, - [MathOp.Pow]: (a, b) => a ** b, - [MathOp.Log]: (a, b) => Math.log(b) / Math.log(a), - - [MathOp.Max]: Math.max, - [MathOp.Min]: Math.min, - [MathOp.Lt]: (a, b) => a < b, - [MathOp.Gt]: (a, b) => a > b, - - [MathOp.Mod]: (a, b) => a % b, - [MathOp.Snap]: (a, b) => Math.round(a / b) * b, - - [MathOp.Atan2]: Math.atan2, -}; - -const unaryOps = { - [MathOp.Sqrt]: Math.sqrt, - [MathOp.Exp]: Math.exp, - - [MathOp.Sign]: Math.sign, - - [MathOp.Round]: Math.round, - [MathOp.Floor]: Math.floor, - [MathOp.Ceil]: Math.ceil, - [MathOp.Trunc]: Math.trunc, - [MathOp.Frac]: x => x - Math.trunc(x), - [MathOp.Clamp]: x => Math.max(0, Math.min(x, 1)), - - [MathOp.Sin]: Math.sin, - [MathOp.Cos]: Math.cos, - [MathOp.Tan]: Math.tan, - [MathOp.Asin]: Math.asin, - [MathOp.Acos]: Math.acos, - [MathOp.Atan]: Math.atan, - [MathOp.Sinh]: Math.sinh, - [MathOp.Cosh]: Math.cosh, - [MathOp.Tanh]: Math.tanh, - - [MathOp.ToRad]: x => x / 180 * Math.PI, - [MathOp.ToDeg]: x => x * 180 / Math.PI, -}; +const binaryOps = [ + MathOp.Add, MathOp.Sub, MathOp.Mul, MathOp.Div, MathOp.Pow, MathOp.Log, + MathOp.Max, MathOp.Min, MathOp.Lt, MathOp.Gt, + MathOp.Mod, MathOp.Snap, + MathOp.Atan2, +]; export interface MathInputs { op: MathOp, @@ -113,7 +74,7 @@ export const MathC = ({ id, x, y, inputs }: NodeComponentProps) => { 'Trigonometric': Object.values(MathOpTrig), 'Conversion': Object.values(MathOpConv), }; - const isBinary = Object.keys(binaryOps).includes(inputs.op.value); + const isBinary = binaryOps.includes(inputs.op.value); return ( @@ -128,7 +89,7 @@ const mathFunc = ({ op, a, b }: MathInputs): MathOutputs => { const ta = typeof a === 'number'; const tb = typeof b === 'number'; const opNum = Object.values(MathOp).indexOf(op); - if (Object.keys(unaryOps).includes(op)) { + if (!binaryOps.includes(op)) { return { out: ta ? mathS(opNum, a) : mathV(opNum, a) as Float32Array }; } if (ta && tb) { -- cgit v1.2.3