[js/webgpu] support GreaterOrEqual and LessOrEqual operators (#17310)

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
xhcao 2023-09-08 08:41:16 +08:00 committed by GitHub
parent eaef485461
commit 9017ea131b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 33 additions and 0 deletions

View file

@ -44,10 +44,12 @@ Do not modify directly.*
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| Greater | ai.onnx(7-8,9-12,13+) | |
| GreaterOrEqual | ai.onnx(12-15,16+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
| LayerNormalization | ai.onnx(17+) | |
| LeakyRelu | ai.onnx(6-15,16+) | |
| Less | ai.onnx(7-8,9-12,13+) | |
| LessOrEqual | ai.onnx(12-15,16+) | |
| Log | ai.onnx(6-12,13+) | |
| MatMul | ai.onnx(1-12,13+) | |
| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(11,12+) | need perf optimization; need implementing activation |

View file

@ -65,10 +65,12 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['Greater', [binaryOps.greater]],
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]],
['LayerNormalization', [layerNorm, parseLayerNormAttributes]],
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
['Less', [binaryOps.less]],
['LessOrEqual', [binaryOps.lessOrEqual]],
['Log', [unaryOps.log]],
['MatMul', [matMul]],
// TODO: support new attributes for MaxPool-8 and MaxPool-10

View file

@ -240,3 +240,16 @@ export const less = (context: ComputeContext): void => {
context.inputs, 'Less', ({scalar: (a, b) => `u32(${a}<${b})`, vector: (a, b) => `vec4<u32>(${a}<${b})`}),
undefined, undefined, DataType.bool));
};
export const greaterOrEqual = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(
context.inputs, 'GreaterOrEqual',
({scalar: (a, b) => `u32(${a}>=${b})`, vector: (a, b) => `vec4<u32>(${a}>=${b})`}), undefined, undefined,
DataType.bool));
};
export const lessOrEqual = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(
context.inputs, 'LessOrEqual', ({scalar: (a, b) => `u32(${a}<=${b})`, vector: (a, b) => `vec4<u32>(${a}<=${b})`}),
undefined, undefined, DataType.bool));
};

View file

@ -205,9 +205,13 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Equ
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Greater);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 12, Greater);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Greater);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 15, GreaterOrEqual);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, GreaterOrEqual);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Less);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 12, Less);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Less);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 15, LessOrEqual);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, LessOrEqual);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Shape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Shape);
@ -403,9 +407,13 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
KERNEL_CREATE_INFO_VERSIONED(7, 8, Greater),
KERNEL_CREATE_INFO_VERSIONED(9, 12, Greater),
KERNEL_CREATE_INFO(13, Greater),
KERNEL_CREATE_INFO_VERSIONED(12, 15, GreaterOrEqual),
KERNEL_CREATE_INFO(16, GreaterOrEqual),
KERNEL_CREATE_INFO_VERSIONED(7, 8, Less),
KERNEL_CREATE_INFO_VERSIONED(9, 12, Less),
KERNEL_CREATE_INFO(13, Less),
KERNEL_CREATE_INFO_VERSIONED(12, 15, LessOrEqual),
KERNEL_CREATE_INFO(16, LessOrEqual),
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Shape)>,

View file

@ -63,10 +63,18 @@ REG_ELEMENTWISE_VERSIONED_KERNEL(Greater, 7, 8, Greater);
REG_ELEMENTWISE_VERSIONED_KERNEL(Greater, 9, 12, Greater);
REG_ELEMENTWISE_KERNEL(Greater, 13, Greater);
JSEP_KERNEL_IMPL(GreaterOrEqual, GreaterOrEqual)
REG_ELEMENTWISE_VERSIONED_KERNEL(GreaterOrEqual, 12, 15, GreaterOrEqual);
REG_ELEMENTWISE_KERNEL(GreaterOrEqual, 16, GreaterOrEqual);
JSEP_KERNEL_IMPL(Less, Less)
REG_ELEMENTWISE_VERSIONED_KERNEL(Less, 7, 8, Less);
REG_ELEMENTWISE_VERSIONED_KERNEL(Less, 9, 12, Less);
REG_ELEMENTWISE_KERNEL(Less, 13, Less);
JSEP_KERNEL_IMPL(LessOrEqual, LessOrEqual)
REG_ELEMENTWISE_VERSIONED_KERNEL(LessOrEqual, 12, 15, LessOrEqual);
REG_ELEMENTWISE_KERNEL(LessOrEqual, 16, LessOrEqual);
} // namespace js
} // namespace onnxruntime