mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
[js/webgpu] support Greater and Less operators (#17296)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
5a83a67f32
commit
5e8d94cec8
6 changed files with 110 additions and 51 deletions
|
|
@ -41,9 +41,11 @@ Do not modify directly.*
|
|||
| Gemm | ai.onnx(7-8,9-10,11-12,13+) | |
|
||||
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
|
||||
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
|
||||
| Greater | ai.onnx(7-8,9-12,13+) | |
|
||||
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
|
||||
| LayerNormalization | ai.onnx(17+) | |
|
||||
| LeakyRelu | ai.onnx(6-15,16+) | |
|
||||
| Less | ai.onnx(7-8,9-12,13+) | |
|
||||
| Log | ai.onnx(6-12,13+) | |
|
||||
| MatMul | ai.onnx(1-12,13+) | |
|
||||
| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(11,12+) | need perf optimization; need implementing activation |
|
||||
|
|
|
|||
|
|
@ -61,9 +61,11 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Gemm', [gemm, parseGemmAttributes]],
|
||||
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
|
||||
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
|
||||
['Greater', [binaryOps.greater]],
|
||||
['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]],
|
||||
['LayerNormalization', [layerNorm, parseLayerNormAttributes]],
|
||||
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
|
||||
['Less', [binaryOps.less]],
|
||||
['Log', [unaryOps.log]],
|
||||
['MatMul', [matMul]],
|
||||
// TODO: support new attributes for MaxPool-8 and MaxPool-10
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor';
|
||||
import {BroadcastUtil, ShapeUtil} from '../../util';
|
||||
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
|
||||
|
|
@ -50,26 +51,27 @@ const createBinaryOpProgramShader =
|
|||
};
|
||||
|
||||
broadcastImpl = `
|
||||
fn calcOffsetA(outputIndices: ${output.type.indices}) -> u32 {
|
||||
return ${calcOffsetImpl(dimsA)};
|
||||
}
|
||||
fn calcOffsetA(outputIndices: ${output.type.indices}) -> u32 {
|
||||
return ${calcOffsetImpl(dimsA)};
|
||||
}
|
||||
|
||||
fn calcOffsetB(outputIndices: ${output.type.indices}) -> u32 {
|
||||
return ${calcOffsetImpl(dimsB)};
|
||||
}
|
||||
`;
|
||||
fn calcOffsetB(outputIndices: ${output.type.indices}) -> u32 {
|
||||
return ${calcOffsetImpl(dimsB)};
|
||||
}
|
||||
`;
|
||||
}
|
||||
|
||||
let assignment: string;
|
||||
if (vectorize) {
|
||||
if (doBroadcast) {
|
||||
assignment = `
|
||||
let outputIndices = ${output.offsetToIndices('global_idx * 4u')};
|
||||
let offsetA = calcOffsetA(outputIndices);
|
||||
let offsetB = calcOffsetB(outputIndices);
|
||||
${
|
||||
let outputIndices = ${output.offsetToIndices('global_idx * 4u')};
|
||||
let offsetA = calcOffsetA(outputIndices);
|
||||
let offsetB = calcOffsetB(outputIndices);
|
||||
${
|
||||
output.setByOffset(
|
||||
'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))}`;
|
||||
'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))}
|
||||
`;
|
||||
} else {
|
||||
assignment = output.setByOffset(
|
||||
'global_idx', expressionVector(a.getByOffset('global_idx'), b.getByOffset('global_idx')));
|
||||
|
|
@ -78,37 +80,49 @@ const createBinaryOpProgramShader =
|
|||
if (!doBroadcast) {
|
||||
throw new Error('no necessary to use scalar implementation for element-wise binary op implementation.');
|
||||
}
|
||||
const singleAssignment = (x: number) => {
|
||||
|
||||
const singleAssignment = (resStr: string, x: number, typeCast = '') => {
|
||||
const expressionA = `aData[indexA${x}][componentA${x}]`;
|
||||
const expressionB = `bData[indexB${x}][componentB${x}]`;
|
||||
return `
|
||||
let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
|
||||
let offsetA${x} = calcOffsetA(outputIndices${x});
|
||||
let offsetB${x} = calcOffsetB(outputIndices${x});
|
||||
let indexA${x} = offsetA${x} / 4u;
|
||||
let indexB${x} = offsetB${x} / 4u;
|
||||
let componentA${x} = offsetA${x} % 4u;
|
||||
let componentB${x} = offsetB${x} % 4u;
|
||||
outputData[global_idx][${x}] = ${expressionScalar(expressionA, expressionB)};`;
|
||||
let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
|
||||
let offsetA${x} = calcOffsetA(outputIndices${x});
|
||||
let offsetB${x} = calcOffsetB(outputIndices${x});
|
||||
let indexA${x} = offsetA${x} / 4u;
|
||||
let indexB${x} = offsetB${x} / 4u;
|
||||
let componentA${x} = offsetA${x} % 4u;
|
||||
let componentB${x} = offsetB${x} % 4u;
|
||||
${resStr}[${x}] = ${typeCast}(${expressionScalar(expressionA, expressionB)});
|
||||
`;
|
||||
};
|
||||
|
||||
assignment = `
|
||||
${singleAssignment(0)}
|
||||
${singleAssignment(1)}
|
||||
${singleAssignment(2)}
|
||||
${singleAssignment(3)}`;
|
||||
if (typeOutput === DataType.bool) {
|
||||
assignment = `
|
||||
var data = vec4<u32>(0);
|
||||
${singleAssignment('data', 0, 'u32')}
|
||||
${singleAssignment('data', 1, 'u32')}
|
||||
${singleAssignment('data', 2, 'u32')}
|
||||
${singleAssignment('data', 3, 'u32')}
|
||||
outputData[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`;
|
||||
} else {
|
||||
assignment = `
|
||||
${singleAssignment('outputData[global_idx]', 0)}
|
||||
${singleAssignment('outputData[global_idx]', 1)}
|
||||
${singleAssignment('outputData[global_idx]', 2)}
|
||||
${singleAssignment('outputData[global_idx]', 3)}
|
||||
`;
|
||||
}
|
||||
}
|
||||
|
||||
return `
|
||||
${shaderHelper.declareVariables(a, b, output)}
|
||||
${shaderHelper.declareVariables(a, b, output)}
|
||||
|
||||
${additionalImplementation ?? ''}
|
||||
${broadcastImpl}
|
||||
${additionalImplementation ?? ''}
|
||||
${broadcastImpl}
|
||||
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)}
|
||||
${assignment}
|
||||
}`;
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)}
|
||||
${assignment}
|
||||
}`;
|
||||
};
|
||||
|
||||
const createBinaryOpProgramInfo =
|
||||
|
|
@ -132,7 +146,7 @@ const createBinaryOpProgramInfo =
|
|||
|
||||
// check whether vectorize can be enabled
|
||||
let sharedDimension = 1;
|
||||
for (let i = 0; i < outputShape.length; i++) {
|
||||
for (let i = 1; i < outputShape.length; i++) {
|
||||
const dimA = a.dims[a.dims.length - i] ?? 1;
|
||||
const dimB = b.dims[b.dims.length - i] ?? 1;
|
||||
if (dimA === dimB) {
|
||||
|
|
@ -162,12 +176,13 @@ const createBinaryOpProgramInfo =
|
|||
|
||||
const createBinaryOpProgramInfoLoader =
|
||||
(inputs: readonly TensorView[], name: string, funcCall: BinaryFunctionCall, additionalImplementation?: string,
|
||||
cacheKey?: string): ProgramInfoLoader => {
|
||||
cacheKey?: string, outputDataType?: number): ProgramInfoLoader => {
|
||||
const metadata:
|
||||
ProgramMetadata = {name, inputTypes: [GpuDataType.default, GpuDataType.default], cacheHint: cacheKey};
|
||||
return {
|
||||
...metadata,
|
||||
get: () => createBinaryOpProgramInfo(metadata, inputs[0], inputs[1], funcCall, additionalImplementation)
|
||||
get: () => createBinaryOpProgramInfo(
|
||||
metadata, inputs[0], inputs[1], funcCall, additionalImplementation, outputDataType)
|
||||
};
|
||||
};
|
||||
|
||||
|
|
@ -209,3 +224,21 @@ export const pow = (context: ComputeContext): void => {
|
|||
export const sub = (context: ComputeContext): void => {
|
||||
context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Sub', (a, b) => `${a}-${b}`));
|
||||
};
|
||||
|
||||
export const greater = (context: ComputeContext): void => {
|
||||
context.compute(createBinaryOpProgramInfoLoader(
|
||||
context.inputs, 'Greater', ({
|
||||
scalar: (a, b) => `select(0, 1, ${a}>${b})`,
|
||||
vector: (a, b) => `select(vec4<u32>(0), vec4<u32>(1), ${a}>${b})`
|
||||
}),
|
||||
undefined, undefined, DataType.bool));
|
||||
};
|
||||
|
||||
export const less = (context: ComputeContext): void => {
|
||||
context.compute(createBinaryOpProgramInfoLoader(
|
||||
context.inputs, 'Less', ({
|
||||
scalar: (a, b) => `select(0, 1, ${a}<${b})`,
|
||||
vector: (a, b) => `select(vec4<u32>(0), vec4<u32>(1), ${a}<${b})`
|
||||
}),
|
||||
undefined, undefined, DataType.bool));
|
||||
};
|
||||
|
|
|
|||
|
|
@ -563,12 +563,12 @@
|
|||
"test_globalaveragepool",
|
||||
"test_globalmaxpool_precomputed",
|
||||
"test_globalmaxpool",
|
||||
// "test_greater_bcast",
|
||||
// "test_greater_equal_bcast_expanded",
|
||||
// "test_greater_equal_bcast",
|
||||
// "test_greater_equal_expanded",
|
||||
// "test_greater_equal",
|
||||
// "test_greater",
|
||||
"test_greater_bcast",
|
||||
"test_greater_equal_bcast_expanded",
|
||||
"test_greater_equal_bcast",
|
||||
"test_greater_equal_expanded",
|
||||
"test_greater_equal",
|
||||
"test_greater",
|
||||
// // "test_gridsample_aligncorners_true",
|
||||
// // "test_gridsample_bicubic",
|
||||
// // "test_gridsample_bilinear",
|
||||
|
|
@ -648,12 +648,12 @@
|
|||
"test_leakyrelu_default",
|
||||
"test_leakyrelu_example",
|
||||
"test_leakyrelu",
|
||||
// "test_less_bcast",
|
||||
// "test_less_equal_bcast_expanded",
|
||||
// "test_less_equal_bcast",
|
||||
// "test_less_equal_expanded",
|
||||
// "test_less_equal",
|
||||
// "test_less",
|
||||
"test_less_bcast",
|
||||
"test_less_equal_bcast_expanded",
|
||||
"test_less_equal_bcast",
|
||||
"test_less_equal_expanded",
|
||||
"test_less_equal",
|
||||
"test_less",
|
||||
"test_log_example",
|
||||
"test_log",
|
||||
// // "test_logsoftmax_axis_0_expanded",
|
||||
|
|
@ -1341,8 +1341,8 @@
|
|||
"floor.jsonc",
|
||||
"gemm.jsonc",
|
||||
"global-average-pool.jsonc",
|
||||
//"greater.jsonc",
|
||||
//"less.jsonc",
|
||||
"greater.jsonc",
|
||||
"less.jsonc",
|
||||
"log.jsonc",
|
||||
//"matmul.jsonc", // <--- some tests fail (when input is 3D/4D/5D)
|
||||
"mul.jsonc",
|
||||
|
|
|
|||
|
|
@ -197,6 +197,12 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, Pow);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Pow);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 15, Pow);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Greater);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 12, Greater);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Greater);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Less);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 12, Less);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Less);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Shape);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Shape);
|
||||
|
|
@ -381,6 +387,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
KERNEL_CREATE_INFO_VERSIONED(12, 12, Pow),
|
||||
KERNEL_CREATE_INFO_VERSIONED(13, 14, Pow),
|
||||
KERNEL_CREATE_INFO(15, Pow),
|
||||
KERNEL_CREATE_INFO_VERSIONED(7, 8, Greater),
|
||||
KERNEL_CREATE_INFO_VERSIONED(9, 12, Greater),
|
||||
KERNEL_CREATE_INFO(13, Greater),
|
||||
KERNEL_CREATE_INFO_VERSIONED(7, 8, Less),
|
||||
KERNEL_CREATE_INFO_VERSIONED(9, 12, Less),
|
||||
KERNEL_CREATE_INFO(13, Less),
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Shape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
|
||||
|
|
|
|||
|
|
@ -52,5 +52,15 @@ REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 12, 12, Pow);
|
|||
REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 13, 14, Pow);
|
||||
REG_ELEMENTWISE_KERNEL(Pow, 15, Pow);
|
||||
|
||||
JSEP_KERNEL_IMPL(Greater, Greater)
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Greater, 7, 8, Greater);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Greater, 9, 12, Greater);
|
||||
REG_ELEMENTWISE_KERNEL(Greater, 13, Greater);
|
||||
|
||||
JSEP_KERNEL_IMPL(Less, Less)
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Less, 7, 8, Less);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Less, 9, 12, Less);
|
||||
REG_ELEMENTWISE_KERNEL(Less, 13, Less);
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue