diff options
| author | Sam Nystrom <sam@samnystrom.dev> | 2024-03-07 21:47:56 -0500 |
|---|---|---|
| committer | Sam Nystrom <sam@samnystrom.dev> | 2024-03-07 21:47:56 -0500 |
| commit | 1db01eaf7f1a0e06abe689c55f746518598d5f4c (patch) | |
| tree | 7e1056da01130d6c1cbc6d9fbf4ec58910eb9184 | |
| parent | f193348ec5086619a4ffe4c08ebd7ef2dcc901b3 (diff) | |
Begin vectorizing the Math node
| -rw-r--r-- | assembly/index.ts | 159 | ||||
| -rw-r--r-- | src/nodes/Math.tsx | 144 | ||||
| -rw-r--r-- | src/nodes/Plot.tsx | 2 | ||||
| -rw-r--r-- | src/nodes/index.ts | 1 | ||||
| -rw-r--r-- | src/wasm.ts | 3 |
5 files changed, 231 insertions, 78 deletions
diff --git a/assembly/index.ts b/assembly/index.ts index 4bdb03f..1a9ca94 100644 --- a/assembly/index.ts +++ b/assembly/index.ts @@ -1,3 +1,162 @@ +export enum MathOp { + Add, + Sub, + Mul, + Div, + Pow, + Log, + Sqrt, + Exp, + + Min, + Max, + Lt, + Gt, + Sign, + + Round, + Floor, + Ceil, + Trunc, + Frac, + Mod, + Snap, + Clamp, + + Sin, + Cos, + Tan, + Asin, + Acos, + Atan, + Atan2, + Sinh, + Cosh, + Tanh, + + ToRad, + ToDeg, +} + +function unaryMathS(op: MathOp, x: f32): f32 { + switch (op) { + case MathOp.Sqrt: return Mathf.sqrt(x); + case MathOp.Exp: return Mathf.exp(x); + case MathOp.Sign: return Mathf.sign(x); + default: return 0; + } +} + +function unaryMathV(op: MathOp, x: v128): v128 { + const zero = f32x4(0,0,0,0); + switch (op) { + case MathOp.Sqrt: return v128.sqrt<f32>(x); + case MathOp.Exp: return f32x4( + Mathf.exp(v128.extract_lane<f32>(x, 0)), + Mathf.exp(v128.extract_lane<f32>(x, 1)), + Mathf.exp(v128.extract_lane<f32>(x, 2)), + Mathf.exp(v128.extract_lane<f32>(x, 3)), + ); + case MathOp.Sign: return v128.sub<f32>(v128.gt<f32>(x, zero), v128.lt<f32>(x, zero)); + default: return zero; + } +} + +function binaryMathS(op: MathOp, a: f32, b: f32): f32 { + switch (op) { + case MathOp.Add: return a + b; + case MathOp.Sub: return a - b; + case MathOp.Mul: return a * b; + case MathOp.Div: return a / b; + default: return 0; + } +}; + +function binaryMathV(op: MathOp, a: v128, b: v128): v128 { + switch (op) { + case MathOp.Add: return v128.add<f32>(a, b); + case MathOp.Sub: return v128.sub<f32>(a, b); + case MathOp.Mul: return v128.mul<f32>(a, b); + case MathOp.Div: return v128.div<f32>(a, b); + default: return f32x4(0,0,0,0); + } +}; + +export { unaryMathS as mathS }; + +export function mathV(op: MathOp, x: StaticArray<f32>): StaticArray<f32> { + const len = x.length; + const out = new StaticArray<f32>(len); + const outptr = changetype<i32>(out); + const xptr = changetype<i32>(x); + for (let i = 0; i < len - len % 4; i += 4) { + v128.store(outptr + i*4, unaryMathV(op, v128.load(xptr + i*4))); + } + // fallthrough + switch (len % 4) { + case 3: out[len-3] = unaryMathS(op, x[len-3]); + case 2: out[len-2] = unaryMathS(op, x[len-2]); + case 1: out[len-1] = unaryMathS(op, x[len-1]); + } + return out; +} + +export { binaryMathS as mathSS }; + +export function mathVS(op: MathOp, a: StaticArray<f32>, b: f32): StaticArray<f32> { + const len = a.length; + const out = new StaticArray<f32>(len); + const outptr = changetype<i32>(out); + const aptr = changetype<i32>(a); + const vb = f32x4(b,b,b,b); + for (let i = 0; i < len - len % 4; i += 4) { + v128.store(outptr + i*4, binaryMathV(op, v128.load(aptr + i*4), vb)); + } + // fallthrough + switch (len % 4) { + case 3: out[len-3] = binaryMathS(op, a[len-3], b); + case 2: out[len-2] = binaryMathS(op, a[len-2], b); + case 1: out[len-1] = binaryMathS(op, a[len-1], b); + } + return out; +} + +export function mathSV(op: MathOp, a: f32, b: StaticArray<f32>): StaticArray<f32> { + const len = b.length; + const out = new StaticArray<f32>(len); + const outptr = changetype<i32>(out); + const bptr = changetype<i32>(b); + const va = f32x4(a,a,a,a); + for (let i = 0; i < len - len % 4; i += 4) { + v128.store(outptr + i*4, binaryMathV(op, va, v128.load(bptr + i*4))); + } + // fallthrough + switch (len % 4) { + case 3: out[len-3] = binaryMathS(op, a, b[len-3]); + case 2: out[len-2] = binaryMathS(op, a, b[len-2]); + case 1: out[len-1] = binaryMathS(op, a, b[len-1]); + } + return out; +} + +export function mathVV(op: MathOp, a: StaticArray<f32>, b: StaticArray<f32>): StaticArray<f32> { + const len = a.length < b.length ? a.length : b.length; + const out = new StaticArray<f32>(len); + const outptr = changetype<i32>(out); + const aptr = changetype<i32>(a); + const bptr = changetype<i32>(b); + for (let i = 0; i < len - len % 4; i += 4) { + v128.store(outptr + i*4, binaryMathV(op, v128.load(aptr + i*4), v128.load(bptr + i*4))); + } + // fallthrough + switch (len % 4) { + case 3: out[len-3] = binaryMathS(op, a[len-3], b[len-3]); + case 2: out[len-2] = binaryMathS(op, a[len-2], b[len-2]); + case 1: out[len-1] = binaryMathS(op, a[len-1], b[len-1]); + } + return out; +} + export function linspace(start: f32, stop: f32, n: i32): StaticArray<f32> { const out = new StaticArray<f32>(n); const ptr = changetype<i32>(out); diff --git a/src/nodes/Math.tsx b/src/nodes/Math.tsx index ec11277..a10aeaa 100644 --- a/src/nodes/Math.tsx +++ b/src/nodes/Math.tsx @@ -1,3 +1,4 @@ +import { mathS, mathV, mathSS, mathSV, mathVS, mathVV } from '../wasm.ts'; import { NodeShell, InputNumber, InputSelect, OutputNumber, NodeComponentProps, NodeInfo } from '../node.tsx'; export enum MathOpFunc { @@ -5,7 +6,7 @@ export enum MathOpFunc { Sub = 'Subtract', Mul = 'Multiply', Div = 'Divide', - Power = 'Power', + Pow = 'Power', Log = 'Logarithm', Sqrt = 'Square Root', Exp = 'Exponent', @@ -13,8 +14,8 @@ export enum MathOpFunc { export enum MathOpCmp { Min = 'Minimum', Max = 'Maximum', - Less = 'Less Than', - Greater = 'Greater Than', + Lt = 'Less Than', + Gt = 'Greater Than', Sign = 'Sign', } export enum MathOpRound { @@ -48,6 +49,52 @@ export enum MathOpConv { export const MathOp = { ...MathOpFunc, ...MathOpCmp, ...MathOpRound, ...MathOpTrig, ...MathOpConv }; export type MathOp = typeof MathOp; +const binaryOps = { + [MathOp.Add]: (a, b) => a + b, + [MathOp.Sub]: (a, b) => a - b, + [MathOp.Mul]: (a, b) => a * b, + [MathOp.Div]: (a, b) => a / b, + [MathOp.Pow]: (a, b) => a ** b, + [MathOp.Log]: (a, b) => Math.log(b) / Math.log(a), + + [MathOp.Max]: Math.max, + [MathOp.Min]: Math.min, + [MathOp.Lt]: (a, b) => a < b, + [MathOp.Gt]: (a, b) => a > b, + + [MathOp.Mod]: (a, b) => a % b, + [MathOp.Snap]: (a, b) => Math.round(a / b) * b, + + [MathOp.Atan2]: Math.atan2, +}; + +const unaryOps = { + [MathOp.Sqrt]: Math.sqrt, + [MathOp.Exp]: Math.exp, + + [MathOp.Sign]: Math.sign, + + [MathOp.Round]: Math.round, + [MathOp.Floor]: Math.floor, + [MathOp.Ceil]: Math.ceil, + [MathOp.Trunc]: Math.trunc, + [MathOp.Frac]: x => x - Math.trunc(x), + [MathOp.Clamp]: x => Math.max(0, Math.min(x, 1)), + + [MathOp.Sin]: Math.sin, + [MathOp.Cos]: Math.cos, + [MathOp.Tan]: Math.tan, + [MathOp.Asin]: Math.asin, + [MathOp.Acos]: Math.acos, + [MathOp.Atan]: Math.atan, + [MathOp.Sinh]: Math.sinh, + [MathOp.Cosh]: Math.cosh, + [MathOp.Tanh]: Math.tanh, + + [MathOp.ToRad]: x => x / 180 * Math.PI, + [MathOp.ToDeg]: x => x * 180 / Math.PI, +}; + export interface MathInputs { op: MathOp, a: number | Float32Array, @@ -55,7 +102,7 @@ export interface MathInputs { } export interface MathOutputs { - out: boolean | number | Float32Array, + out: number | Float32Array, } export const MathC = ({ id, x, y, inputs }: NodeComponentProps<MathInputs>) => { @@ -66,87 +113,32 @@ export const MathC = ({ id, x, y, inputs }: NodeComponentProps<MathInputs>) => { 'Trigonometric': Object.values(MathOpTrig), 'Conversion': Object.values(MathOpConv), }; + const isBinary = Object.keys(binaryOps).includes(inputs.op.value); return ( <NodeShell name="Math" id={id} x={x} y={y}> <OutputNumber name="out" label="Value" /> <InputSelect name="op" label="Operation" value={inputs.op} options={options} /> - <InputNumber name="a" label="a" value={inputs.a} /> - <InputNumber name="b" label="b" value={inputs.b}/> + <InputNumber name="a" label={isBinary ? 'a' : 'x'} value={inputs.a} /> + {isBinary && <InputNumber name="b" label="b" value={inputs.b} />} </NodeShell> ); }; -const doMathOp = (op: MathOp, a: number, b: number): number => { - switch (op) { - case MathOp.Add: return a + b; - case MathOp.Sub: return a - b; - case MathOp.Mul: return a * b; - case MathOp.Div: return a / b; - case MathOp.Power: return a ** b; - case MathOp.Log: return Math.log(a) / Math.log(b); - case MathOp.Sqrt: return Math.sqrt(a); - case MathOp.Exp: return Math.exp(a); - - case MathOp.Min: return Math.min(a, b); - case MathOp.Max: return Math.max(a, b); - case MathOp.Less: return a < b; - case MathOp.Greater: return a > b; - case MathOp.Sign: return Math.sign(a); - - case MathOp.Round: return Math.round(a); - case MathOp.Floor: return Math.floor(a); - case MathOp.Ceil: return Math.ceil(a); - case MathOp.Trunc: return Math.trunc(a); - case MathOp.Frac: return a - Math.trunc(a); - case MathOp.Mod: return a % b; - case MathOp.Snap: return Math.round(a * b) / b; - case MathOp.Clamp: return Math.max(0, Math.min(a, 1)); - - case MathOp.Sin: return Math.sin(a); - case MathOp.Cos: return Math.cos(a); - case MathOp.Tan: return Math.tan(a); - case MathOp.Asin: return Math.asin(a); - case MathOp.Acos: return Math.acos(a); - case MathOp.Atan: return Math.atan(a); - case MathOp.Atan2: return Math.atan2(a, b); - - case MathOp.Sinh: return Math.sinh(a); - case MathOp.Cosh: return Math.cosh(a); - case MathOp.Tanh: return Math.tanh(a); - - case MathOp.ToRad: return a / 180 * Math.PI; - case MathOp.ToDeg: return a * 180 / Math.PI; - - default: throw new TypeError(); - } -}; - const mathFunc = ({ op, a, b }: MathInputs): MathOutputs => { - const countScalar = (typeof a === 'number' ? 1 : 0) + (typeof b === 'number' ? 1 : 0); - if (typeof a === 'number') { - if (typeof b === 'number') { - return { out: doMathOp(op, a, b) }; - } else { - const out = new Float32Array(b.length); - for (let i = 0; i < out.length; i++) { - out[i] = doMathOp(op, a, b[i]); - } - return { out }; - } + const ta = typeof a === 'number'; + const tb = typeof b === 'number'; + const opNum = Object.values(MathOp).indexOf(op); + if (Object.keys(unaryOps).includes(op)) { + return { out: ta ? mathS(opNum, a) : mathV(opNum, a) as Float32Array }; + } + if (ta && tb) { + return { out: mathSS(opNum, a, b) }; + } else if (ta && !tb) { + return { out: mathSV(opNum, a, b) as Float32Array }; + } else if (!ta && tb) { + return { out: mathVS(opNum, a, b) as Float32Array }; } else { - if (typeof b === 'number') { - const out = new Float32Array(a.length); - for (let i = 0; i < out.length; i++) { - out[i] = doMathOp(op, a[i], b); - } - return { out }; - } else { - const out = new Float32Array(Math.min(a.length, b.length)); - for (let i = 0; i < out.length; i++) { - out[i] = doMathOp(op, a[i], b[i]); - } - return { out }; - } + return { out: mathVV(opNum, a, b) as Float32Array }; } }; @@ -154,4 +146,4 @@ export const MathNode: NodeInfo<MathInputs, MathOutputs> = { component: MathC, func: mathFunc, inputs: { op: MathOp.Add, a: 0, b: 0 }, -}; +};
\ No newline at end of file diff --git a/src/nodes/Plot.tsx b/src/nodes/Plot.tsx index eaba592..d7be295 100644 --- a/src/nodes/Plot.tsx +++ b/src/nodes/Plot.tsx @@ -13,7 +13,7 @@ export const Plot = ({ id, x, y, inputs }: NodeComponentProps<PlotInputs>) => { const dy = 0; let path = ''; if (data !== null && data.length > 3) { - for (let i = 0; i < data.length; i += Math.max(1, Math.floor(data.length / 1000))) { + for (let i = 0; i < data.length; i += Math.max(2, Math.floor(data.length / 1000))) { if (i >= data.length) break; path += (i ? 'L' : 'M') + (data[i] * scale + dx) + ' ' + (data[i+1] * scale + dy); } diff --git a/src/nodes/index.ts b/src/nodes/index.ts index 0e63d3a..a5678fa 100644 --- a/src/nodes/index.ts +++ b/src/nodes/index.ts @@ -1,3 +1,4 @@ +import type { NodeInfo } from '../node.tsx'; import { CombineXYZNode } from './CombineXYZ.tsx'; import { SeparateXYZNode } from './SeparateXYZ.tsx'; import { ViewerNode } from './Viewer.tsx'; diff --git a/src/wasm.ts b/src/wasm.ts index 3a308b2..311a088 100644 --- a/src/wasm.ts +++ b/src/wasm.ts @@ -2,8 +2,9 @@ import { instantiate } from '../build/out.js'; import url from '../build/out.wasm'; export const { memory, + mathS, mathV, mathSS, mathSV, mathVS, mathVV, linspace, intersperse, dft, fft, -} = await instantiate(await WebAssembly.compileStreaming(fetch(url)));
\ No newline at end of file +} = await instantiate(await WebAssembly.compileStreaming(fetch(url)), {});
\ No newline at end of file |
