mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
### Description Support wasm64 ### Motivation and Context Overcome memory limitations --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
341 lines
11 KiB
TypeScript
341 lines
11 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import { DataType } from '../../../wasm-common';
|
|
import { TensorView } from '../../tensor-view';
|
|
import { BroadcastUtil, ShapeUtil } from '../../util';
|
|
import { ComputeContext, ProgramInfo } from '../types';
|
|
|
|
import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper } from './common';
|
|
|
|
type BuiltinFunctionName = string;
|
|
type BinaryCustomExpression = (expressionA: string, expressionB: string) => string;
|
|
type BinaryFunctionCall =
|
|
| BuiltinFunctionName
|
|
| BinaryCustomExpression
|
|
| {
|
|
scalar: BinaryCustomExpression;
|
|
vector: BinaryCustomExpression;
|
|
};
|
|
|
|
const createBinaryOpProgramShader = (
|
|
shaderHelper: ShaderHelper,
|
|
dimsA: readonly number[],
|
|
dimsB: readonly number[],
|
|
dimsOutput: readonly number[],
|
|
vectorize: boolean,
|
|
doBroadcast: boolean,
|
|
sharedDimensionDivisibleBy4: boolean,
|
|
funcCall: BinaryFunctionCall,
|
|
typeA: number,
|
|
typeB: number,
|
|
typeOutput: number,
|
|
additionalImplementation?: string,
|
|
) => {
|
|
let expressionScalar: BinaryCustomExpression;
|
|
let expressionVector: BinaryCustomExpression;
|
|
if (typeof funcCall === 'string') {
|
|
expressionScalar = expressionVector = (a, b) => `${funcCall}((${a}),(${b}))`;
|
|
} else if (typeof funcCall === 'function') {
|
|
expressionScalar = expressionVector = funcCall;
|
|
} else {
|
|
expressionScalar = funcCall.scalar;
|
|
expressionVector = funcCall.vector;
|
|
}
|
|
|
|
const output = outputVariable('outputData', typeOutput, dimsOutput.length, 4);
|
|
const a = inputVariable('aData', typeA, dimsA.length, 4);
|
|
const b = inputVariable('bData', typeB, dimsB.length, 4);
|
|
|
|
let assignment: string;
|
|
if (vectorize) {
|
|
if (doBroadcast) {
|
|
const isAOneElement = ShapeUtil.size(dimsA) === 1;
|
|
const isBOneElement = ShapeUtil.size(dimsB) === 1;
|
|
const aLastDimDivisibleBy4 = dimsA.length > 0 && dimsA[dimsA.length - 1] % 4 === 0;
|
|
const bLastDimDivisibleBy4 = dimsB.length > 0 && dimsB[dimsB.length - 1] % 4 === 0;
|
|
if (isAOneElement || isBOneElement) {
|
|
assignment = output.setByOffset(
|
|
'global_idx',
|
|
expressionVector(
|
|
isAOneElement ? `${a.type.value}(${a.getByOffset('0')}.x)` : a.getByOffset('global_idx'),
|
|
isBOneElement ? `${b.type.value}(${b.getByOffset('0')}.x)` : b.getByOffset('global_idx'),
|
|
),
|
|
);
|
|
} else {
|
|
assignment = `
|
|
let outputIndices = ${output.offsetToIndices('global_idx * 4u')};
|
|
let offsetA = ${a.broadcastedIndicesToOffset('outputIndices', output)};
|
|
let offsetB = ${b.broadcastedIndicesToOffset('outputIndices', output)};
|
|
${output.setByOffset(
|
|
'global_idx',
|
|
expressionVector(
|
|
sharedDimensionDivisibleBy4 || aLastDimDivisibleBy4
|
|
? a.getByOffset('offsetA / 4u')
|
|
: `${a.type.value}(${a.getByOffset('offsetA / 4u')}[offsetA % 4u])`,
|
|
sharedDimensionDivisibleBy4 || bLastDimDivisibleBy4
|
|
? b.getByOffset('offsetB / 4u')
|
|
: `${b.type.value}(${b.getByOffset('offsetB / 4u')}[offsetB % 4u])`,
|
|
),
|
|
)}
|
|
`;
|
|
}
|
|
} else {
|
|
assignment = output.setByOffset(
|
|
'global_idx',
|
|
expressionVector(a.getByOffset('global_idx'), b.getByOffset('global_idx')),
|
|
);
|
|
}
|
|
} else {
|
|
if (!doBroadcast) {
|
|
throw new Error('no necessary to use scalar implementation for element-wise binary op implementation.');
|
|
}
|
|
|
|
const singleAssignment = (resStr: string, x: number, typeCast = '') => {
|
|
const expressionA = `aData[indexA${x}][componentA${x}]`;
|
|
const expressionB = `bData[indexB${x}][componentB${x}]`;
|
|
return `
|
|
let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
|
|
let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
|
|
let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
|
|
let indexA${x} = offsetA${x} / 4u;
|
|
let indexB${x} = offsetB${x} / 4u;
|
|
let componentA${x} = offsetA${x} % 4u;
|
|
let componentB${x} = offsetB${x} % 4u;
|
|
${resStr}[${x}] = ${typeCast}(${expressionScalar(expressionA, expressionB)});
|
|
`;
|
|
};
|
|
if (typeOutput === DataType.bool) {
|
|
assignment = `
|
|
var data = vec4<u32>(0);
|
|
${singleAssignment('data', 0, 'u32')}
|
|
${singleAssignment('data', 1, 'u32')}
|
|
${singleAssignment('data', 2, 'u32')}
|
|
${singleAssignment('data', 3, 'u32')}
|
|
outputData[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`;
|
|
} else {
|
|
assignment = `
|
|
${singleAssignment('outputData[global_idx]', 0)}
|
|
${singleAssignment('outputData[global_idx]', 1)}
|
|
${singleAssignment('outputData[global_idx]', 2)}
|
|
${singleAssignment('outputData[global_idx]', 3)}
|
|
`;
|
|
}
|
|
}
|
|
|
|
return `
|
|
${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(a, b, output)}
|
|
|
|
${additionalImplementation ?? ''}
|
|
|
|
${shaderHelper.mainStart()}
|
|
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')}
|
|
${assignment}
|
|
}`;
|
|
};
|
|
|
|
const createBinaryOpProgramInfo = (
|
|
name: string,
|
|
cacheKey: string,
|
|
a: TensorView,
|
|
b: TensorView,
|
|
funcCall: BinaryFunctionCall,
|
|
additionalImplementation?: string,
|
|
outputDataType: number = a.dataType,
|
|
): ProgramInfo => {
|
|
const aDims = a.dims.map((x) => Number(x) ?? 1);
|
|
const bDims = b.dims.map((x) => Number(x) ?? 1);
|
|
const isBroadcast = !ShapeUtil.areEqual(aDims, bDims);
|
|
let outputShape = aDims;
|
|
let outputSize = ShapeUtil.size(aDims);
|
|
|
|
let vectorize = false;
|
|
let sharedDimensionDivisibleBy4 = false;
|
|
|
|
// TODO: deal with zero-sized tensors (eg. dims=[1,0])
|
|
const cacheKeyAux = [isBroadcast];
|
|
if (isBroadcast) {
|
|
const calculatedShape = BroadcastUtil.calcShape(aDims, bDims, false);
|
|
if (!calculatedShape) {
|
|
throw new Error("Can't perform binary op on the given tensors");
|
|
}
|
|
outputShape = calculatedShape.slice();
|
|
outputSize = ShapeUtil.size(outputShape);
|
|
const isAOneElement = ShapeUtil.size(aDims) === 1;
|
|
const isBOneElement = ShapeUtil.size(bDims) === 1;
|
|
const aLastDimDivisibleBy4 = aDims.length > 0 && aDims[aDims.length - 1] % 4 === 0;
|
|
const bLastDimDivisibleBy4 = bDims.length > 0 && bDims[bDims.length - 1] % 4 === 0;
|
|
cacheKeyAux.push(isAOneElement);
|
|
cacheKeyAux.push(isBOneElement);
|
|
cacheKeyAux.push(aLastDimDivisibleBy4);
|
|
cacheKeyAux.push(bLastDimDivisibleBy4);
|
|
// check whether vectorize can be enabled
|
|
let sharedDimension = 1;
|
|
for (let i = 1; i < outputShape.length; i++) {
|
|
const dimA = aDims[aDims.length - i];
|
|
const dimB = bDims[bDims.length - i];
|
|
if (dimA === dimB) {
|
|
sharedDimension *= dimA;
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
if (sharedDimension % 4 === 0) {
|
|
sharedDimensionDivisibleBy4 = true;
|
|
vectorize = true;
|
|
} else if (isAOneElement || isBOneElement || aLastDimDivisibleBy4 || bLastDimDivisibleBy4) {
|
|
vectorize = true;
|
|
}
|
|
} else {
|
|
// element-wise
|
|
vectorize = true;
|
|
}
|
|
cacheKeyAux.push(vectorize);
|
|
|
|
return {
|
|
name,
|
|
shaderCache: {
|
|
hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'),
|
|
inputDependencies: ['rank', 'rank'],
|
|
},
|
|
getShaderSource: (shaderHelper) =>
|
|
createBinaryOpProgramShader(
|
|
shaderHelper,
|
|
aDims,
|
|
bDims,
|
|
outputShape,
|
|
vectorize,
|
|
isBroadcast,
|
|
sharedDimensionDivisibleBy4,
|
|
funcCall,
|
|
a.dataType,
|
|
b.dataType,
|
|
outputDataType,
|
|
additionalImplementation,
|
|
),
|
|
getRunData: () => ({
|
|
outputs: [{ dims: outputShape, dataType: outputDataType }],
|
|
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */) },
|
|
programUniforms: [
|
|
{ type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4) },
|
|
...createTensorShapeVariables(aDims, bDims, outputShape),
|
|
],
|
|
}),
|
|
};
|
|
};
|
|
|
|
const runBinaryOp = (
|
|
context: ComputeContext,
|
|
name: string,
|
|
funcCall: BinaryFunctionCall,
|
|
additionalImplementation?: string,
|
|
cacheKey?: string,
|
|
outputDataType?: number,
|
|
): void => {
|
|
context.compute(
|
|
createBinaryOpProgramInfo(
|
|
name,
|
|
cacheKey ?? '',
|
|
context.inputs[0],
|
|
context.inputs[1],
|
|
funcCall,
|
|
additionalImplementation,
|
|
outputDataType,
|
|
),
|
|
);
|
|
};
|
|
|
|
export const add = (context: ComputeContext): void => {
|
|
runBinaryOp(context, 'Add', (a, b) => `${a}+${b}`);
|
|
};
|
|
|
|
export const div = (context: ComputeContext): void => {
|
|
runBinaryOp(context, 'Div', (a, b) => `${a}/${b}`);
|
|
};
|
|
|
|
export const equal = (context: ComputeContext): void => {
|
|
runBinaryOp(
|
|
context,
|
|
'Equal',
|
|
{ scalar: (a, b) => `u32(${a}==${b})`, vector: (a, b) => `vec4<u32>(${a}==${b})` },
|
|
undefined,
|
|
undefined,
|
|
DataType.bool,
|
|
);
|
|
};
|
|
|
|
export const mul = (context: ComputeContext): void => {
|
|
runBinaryOp(context, 'Mul', (a, b) => `${a}*${b}`);
|
|
};
|
|
|
|
export const pow = (context: ComputeContext): void => {
|
|
const type = inputVariable('input', context.inputs[0].dataType, context.inputs[0].dims).type.value;
|
|
const roundStr = type === 'i32' ? 'round' : '';
|
|
runBinaryOp(
|
|
context,
|
|
'Pow',
|
|
{ scalar: (a, b) => `pow_custom(${a},${b})`, vector: (a, b) => `pow_vector_custom(${a},${b})` },
|
|
`
|
|
fn pow_custom(a : ${type}, b : ${type}) -> ${type} {
|
|
if (b == ${type}(0.0)) {
|
|
return ${type}(1.0);
|
|
} else if (a < ${type}(0.0) && f32(b) != floor(f32(b))) {
|
|
return ${type}(pow(f32(a), f32(b))); // NaN
|
|
}
|
|
return select(sign(a), ${type}(1.0), round(f32(abs(b) % ${type}(2.0))) != 1.0) * ${type}(${roundStr}(pow(f32(abs(a)), f32(b))));
|
|
}
|
|
fn pow_vector_custom(a : vec4<${type}>, b : vec4<${type}>) -> vec4<${type}> {
|
|
// TODO: implement vectorized pow
|
|
return vec4<${type}>(pow_custom(a.x, b.x), pow_custom(a.y, b.y), pow_custom(a.z, b.z), pow_custom(a.w, b.w));
|
|
}
|
|
`,
|
|
);
|
|
};
|
|
|
|
export const sub = (context: ComputeContext): void => {
|
|
runBinaryOp(context, 'Sub', (a, b) => `${a}-${b}`);
|
|
};
|
|
|
|
export const greater = (context: ComputeContext): void => {
|
|
runBinaryOp(
|
|
context,
|
|
'Greater',
|
|
{ scalar: (a, b) => `u32(${a}>${b})`, vector: (a, b) => `vec4<u32>(${a}>${b})` },
|
|
undefined,
|
|
undefined,
|
|
DataType.bool,
|
|
);
|
|
};
|
|
|
|
export const less = (context: ComputeContext): void => {
|
|
runBinaryOp(
|
|
context,
|
|
'Less',
|
|
{ scalar: (a, b) => `u32(${a}<${b})`, vector: (a, b) => `vec4<u32>(${a}<${b})` },
|
|
undefined,
|
|
undefined,
|
|
DataType.bool,
|
|
);
|
|
};
|
|
|
|
export const greaterOrEqual = (context: ComputeContext): void => {
|
|
runBinaryOp(
|
|
context,
|
|
'GreaterOrEqual',
|
|
{ scalar: (a, b) => `u32(${a}>=${b})`, vector: (a, b) => `vec4<u32>(${a}>=${b})` },
|
|
undefined,
|
|
undefined,
|
|
DataType.bool,
|
|
);
|
|
};
|
|
|
|
export const lessOrEqual = (context: ComputeContext): void => {
|
|
runBinaryOp(
|
|
context,
|
|
'LessOrEqual',
|
|
{ scalar: (a, b) => `u32(${a}<=${b})`, vector: (a, b) => `vec4<u32>(${a}<=${b})` },
|
|
undefined,
|
|
undefined,
|
|
DataType.bool,
|
|
);
|
|
};
|