[js/webgpu] support Greater and Less operators (#17296)

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
xhcao 2023-08-26 03:11:25 +08:00 committed by GitHub
parent 5a83a67f32
commit 5e8d94cec8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 110 additions and 51 deletions

View file

@ -41,9 +41,11 @@ Do not modify directly.*
| Gemm | ai.onnx(7-8,9-10,11-12,13+) | |
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| Greater | ai.onnx(7-8,9-12,13+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
| LayerNormalization | ai.onnx(17+) | |
| LeakyRelu | ai.onnx(6-15,16+) | |
| Less | ai.onnx(7-8,9-12,13+) | |
| Log | ai.onnx(6-12,13+) | |
| MatMul | ai.onnx(1-12,13+) | |
| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(11,12+) | need perf optimization; need implementing activation |

View file

@ -61,9 +61,11 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Gemm', [gemm, parseGemmAttributes]],
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['Greater', [binaryOps.greater]],
['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]],
['LayerNormalization', [layerNorm, parseLayerNormAttributes]],
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
['Less', [binaryOps.less]],
['Log', [unaryOps.log]],
['MatMul', [matMul]],
// TODO: support new attributes for MaxPool-8 and MaxPool-10

View file

@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {BroadcastUtil, ShapeUtil} from '../../util';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
@ -50,26 +51,27 @@ const createBinaryOpProgramShader =
};
broadcastImpl = `
fn calcOffsetA(outputIndices: ${output.type.indices}) -> u32 {
return ${calcOffsetImpl(dimsA)};
}
fn calcOffsetA(outputIndices: ${output.type.indices}) -> u32 {
return ${calcOffsetImpl(dimsA)};
}
fn calcOffsetB(outputIndices: ${output.type.indices}) -> u32 {
return ${calcOffsetImpl(dimsB)};
}
`;
fn calcOffsetB(outputIndices: ${output.type.indices}) -> u32 {
return ${calcOffsetImpl(dimsB)};
}
`;
}
let assignment: string;
if (vectorize) {
if (doBroadcast) {
assignment = `
let outputIndices = ${output.offsetToIndices('global_idx * 4u')};
let offsetA = calcOffsetA(outputIndices);
let offsetB = calcOffsetB(outputIndices);
${
let outputIndices = ${output.offsetToIndices('global_idx * 4u')};
let offsetA = calcOffsetA(outputIndices);
let offsetB = calcOffsetB(outputIndices);
${
output.setByOffset(
'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))}`;
'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))}
`;
} else {
assignment = output.setByOffset(
'global_idx', expressionVector(a.getByOffset('global_idx'), b.getByOffset('global_idx')));
@ -78,37 +80,49 @@ const createBinaryOpProgramShader =
if (!doBroadcast) {
throw new Error('no necessary to use scalar implementation for element-wise binary op implementation.');
}
const singleAssignment = (x: number) => {
const singleAssignment = (resStr: string, x: number, typeCast = '') => {
const expressionA = `aData[indexA${x}][componentA${x}]`;
const expressionB = `bData[indexB${x}][componentB${x}]`;
return `
let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
let offsetA${x} = calcOffsetA(outputIndices${x});
let offsetB${x} = calcOffsetB(outputIndices${x});
let indexA${x} = offsetA${x} / 4u;
let indexB${x} = offsetB${x} / 4u;
let componentA${x} = offsetA${x} % 4u;
let componentB${x} = offsetB${x} % 4u;
outputData[global_idx][${x}] = ${expressionScalar(expressionA, expressionB)};`;
let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
let offsetA${x} = calcOffsetA(outputIndices${x});
let offsetB${x} = calcOffsetB(outputIndices${x});
let indexA${x} = offsetA${x} / 4u;
let indexB${x} = offsetB${x} / 4u;
let componentA${x} = offsetA${x} % 4u;
let componentB${x} = offsetB${x} % 4u;
${resStr}[${x}] = ${typeCast}(${expressionScalar(expressionA, expressionB)});
`;
};
assignment = `
${singleAssignment(0)}
${singleAssignment(1)}
${singleAssignment(2)}
${singleAssignment(3)}`;
if (typeOutput === DataType.bool) {
assignment = `
var data = vec4<u32>(0);
${singleAssignment('data', 0, 'u32')}
${singleAssignment('data', 1, 'u32')}
${singleAssignment('data', 2, 'u32')}
${singleAssignment('data', 3, 'u32')}
outputData[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`;
} else {
assignment = `
${singleAssignment('outputData[global_idx]', 0)}
${singleAssignment('outputData[global_idx]', 1)}
${singleAssignment('outputData[global_idx]', 2)}
${singleAssignment('outputData[global_idx]', 3)}
`;
}
}
return `
${shaderHelper.declareVariables(a, b, output)}
${shaderHelper.declareVariables(a, b, output)}
${additionalImplementation ?? ''}
${broadcastImpl}
${additionalImplementation ?? ''}
${broadcastImpl}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)}
${assignment}
}`;
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)}
${assignment}
}`;
};
const createBinaryOpProgramInfo =
@ -132,7 +146,7 @@ const createBinaryOpProgramInfo =
// check whether vectorize can be enabled
let sharedDimension = 1;
for (let i = 0; i < outputShape.length; i++) {
for (let i = 1; i < outputShape.length; i++) {
const dimA = a.dims[a.dims.length - i] ?? 1;
const dimB = b.dims[b.dims.length - i] ?? 1;
if (dimA === dimB) {
@ -162,12 +176,13 @@ const createBinaryOpProgramInfo =
const createBinaryOpProgramInfoLoader =
(inputs: readonly TensorView[], name: string, funcCall: BinaryFunctionCall, additionalImplementation?: string,
cacheKey?: string): ProgramInfoLoader => {
cacheKey?: string, outputDataType?: number): ProgramInfoLoader => {
const metadata:
ProgramMetadata = {name, inputTypes: [GpuDataType.default, GpuDataType.default], cacheHint: cacheKey};
return {
...metadata,
get: () => createBinaryOpProgramInfo(metadata, inputs[0], inputs[1], funcCall, additionalImplementation)
get: () => createBinaryOpProgramInfo(
metadata, inputs[0], inputs[1], funcCall, additionalImplementation, outputDataType)
};
};
@ -209,3 +224,21 @@ export const pow = (context: ComputeContext): void => {
export const sub = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Sub', (a, b) => `${a}-${b}`));
};
export const greater = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(
context.inputs, 'Greater', ({
scalar: (a, b) => `select(0, 1, ${a}>${b})`,
vector: (a, b) => `select(vec4<u32>(0), vec4<u32>(1), ${a}>${b})`
}),
undefined, undefined, DataType.bool));
};
export const less = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(
context.inputs, 'Less', ({
scalar: (a, b) => `select(0, 1, ${a}<${b})`,
vector: (a, b) => `select(vec4<u32>(0), vec4<u32>(1), ${a}<${b})`
}),
undefined, undefined, DataType.bool));
};

View file

@ -563,12 +563,12 @@
"test_globalaveragepool",
"test_globalmaxpool_precomputed",
"test_globalmaxpool",
// "test_greater_bcast",
// "test_greater_equal_bcast_expanded",
// "test_greater_equal_bcast",
// "test_greater_equal_expanded",
// "test_greater_equal",
// "test_greater",
"test_greater_bcast",
"test_greater_equal_bcast_expanded",
"test_greater_equal_bcast",
"test_greater_equal_expanded",
"test_greater_equal",
"test_greater",
// // "test_gridsample_aligncorners_true",
// // "test_gridsample_bicubic",
// // "test_gridsample_bilinear",
@ -648,12 +648,12 @@
"test_leakyrelu_default",
"test_leakyrelu_example",
"test_leakyrelu",
// "test_less_bcast",
// "test_less_equal_bcast_expanded",
// "test_less_equal_bcast",
// "test_less_equal_expanded",
// "test_less_equal",
// "test_less",
"test_less_bcast",
"test_less_equal_bcast_expanded",
"test_less_equal_bcast",
"test_less_equal_expanded",
"test_less_equal",
"test_less",
"test_log_example",
"test_log",
// // "test_logsoftmax_axis_0_expanded",
@ -1341,8 +1341,8 @@
"floor.jsonc",
"gemm.jsonc",
"global-average-pool.jsonc",
//"greater.jsonc",
//"less.jsonc",
"greater.jsonc",
"less.jsonc",
"log.jsonc",
//"matmul.jsonc", // <--- some tests fail (when input is 3D/4D/5D)
"mul.jsonc",

View file

@ -197,6 +197,12 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, Pow);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Pow);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 15, Pow);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Greater);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 12, Greater);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Greater);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Less);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 12, Less);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Less);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Shape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Shape);
@ -381,6 +387,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
KERNEL_CREATE_INFO_VERSIONED(12, 12, Pow),
KERNEL_CREATE_INFO_VERSIONED(13, 14, Pow),
KERNEL_CREATE_INFO(15, Pow),
KERNEL_CREATE_INFO_VERSIONED(7, 8, Greater),
KERNEL_CREATE_INFO_VERSIONED(9, 12, Greater),
KERNEL_CREATE_INFO(13, Greater),
KERNEL_CREATE_INFO_VERSIONED(7, 8, Less),
KERNEL_CREATE_INFO_VERSIONED(9, 12, Less),
KERNEL_CREATE_INFO(13, Less),
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Shape)>,

View file

@ -52,5 +52,15 @@ REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 12, 12, Pow);
REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 13, 14, Pow);
REG_ELEMENTWISE_KERNEL(Pow, 15, Pow);
JSEP_KERNEL_IMPL(Greater, Greater)
REG_ELEMENTWISE_VERSIONED_KERNEL(Greater, 7, 8, Greater);
REG_ELEMENTWISE_VERSIONED_KERNEL(Greater, 9, 12, Greater);
REG_ELEMENTWISE_KERNEL(Greater, 13, Greater);
JSEP_KERNEL_IMPL(Less, Less)
REG_ELEMENTWISE_VERSIONED_KERNEL(Less, 7, 8, Less);
REG_ELEMENTWISE_VERSIONED_KERNEL(Less, 9, 12, Less);
REG_ELEMENTWISE_KERNEL(Less, 13, Less);
} // namespace js
} // namespace onnxruntime