mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
[js/webgpu] support GreaterOrEqual and LessOrEqual operators (#17310)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
eaef485461
commit
9017ea131b
5 changed files with 33 additions and 0 deletions
|
|
@ -44,10 +44,12 @@ Do not modify directly.*
|
|||
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
|
||||
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
|
||||
| Greater | ai.onnx(7-8,9-12,13+) | |
|
||||
| GreaterOrEqual | ai.onnx(12-15,16+) | |
|
||||
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
|
||||
| LayerNormalization | ai.onnx(17+) | |
|
||||
| LeakyRelu | ai.onnx(6-15,16+) | |
|
||||
| Less | ai.onnx(7-8,9-12,13+) | |
|
||||
| LessOrEqual | ai.onnx(12-15,16+) | |
|
||||
| Log | ai.onnx(6-12,13+) | |
|
||||
| MatMul | ai.onnx(1-12,13+) | |
|
||||
| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(11,12+) | need perf optimization; need implementing activation |
|
||||
|
|
|
|||
|
|
@ -65,10 +65,12 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
|
||||
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
|
||||
['Greater', [binaryOps.greater]],
|
||||
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
|
||||
['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]],
|
||||
['LayerNormalization', [layerNorm, parseLayerNormAttributes]],
|
||||
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
|
||||
['Less', [binaryOps.less]],
|
||||
['LessOrEqual', [binaryOps.lessOrEqual]],
|
||||
['Log', [unaryOps.log]],
|
||||
['MatMul', [matMul]],
|
||||
// TODO: support new attributes for MaxPool-8 and MaxPool-10
|
||||
|
|
|
|||
|
|
@ -240,3 +240,16 @@ export const less = (context: ComputeContext): void => {
|
|||
context.inputs, 'Less', ({scalar: (a, b) => `u32(${a}<${b})`, vector: (a, b) => `vec4<u32>(${a}<${b})`}),
|
||||
undefined, undefined, DataType.bool));
|
||||
};
|
||||
|
||||
export const greaterOrEqual = (context: ComputeContext): void => {
|
||||
context.compute(createBinaryOpProgramInfoLoader(
|
||||
context.inputs, 'GreaterOrEqual',
|
||||
({scalar: (a, b) => `u32(${a}>=${b})`, vector: (a, b) => `vec4<u32>(${a}>=${b})`}), undefined, undefined,
|
||||
DataType.bool));
|
||||
};
|
||||
|
||||
export const lessOrEqual = (context: ComputeContext): void => {
|
||||
context.compute(createBinaryOpProgramInfoLoader(
|
||||
context.inputs, 'LessOrEqual', ({scalar: (a, b) => `u32(${a}<=${b})`, vector: (a, b) => `vec4<u32>(${a}<=${b})`}),
|
||||
undefined, undefined, DataType.bool));
|
||||
};
|
||||
|
|
|
|||
|
|
@ -205,9 +205,13 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Equ
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Greater);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 12, Greater);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Greater);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 15, GreaterOrEqual);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, GreaterOrEqual);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Less);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 12, Less);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Less);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 15, LessOrEqual);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, LessOrEqual);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Shape);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Shape);
|
||||
|
|
@ -403,9 +407,13 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
KERNEL_CREATE_INFO_VERSIONED(7, 8, Greater),
|
||||
KERNEL_CREATE_INFO_VERSIONED(9, 12, Greater),
|
||||
KERNEL_CREATE_INFO(13, Greater),
|
||||
KERNEL_CREATE_INFO_VERSIONED(12, 15, GreaterOrEqual),
|
||||
KERNEL_CREATE_INFO(16, GreaterOrEqual),
|
||||
KERNEL_CREATE_INFO_VERSIONED(7, 8, Less),
|
||||
KERNEL_CREATE_INFO_VERSIONED(9, 12, Less),
|
||||
KERNEL_CREATE_INFO(13, Less),
|
||||
KERNEL_CREATE_INFO_VERSIONED(12, 15, LessOrEqual),
|
||||
KERNEL_CREATE_INFO(16, LessOrEqual),
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Shape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
|
||||
|
|
|
|||
|
|
@ -63,10 +63,18 @@ REG_ELEMENTWISE_VERSIONED_KERNEL(Greater, 7, 8, Greater);
|
|||
REG_ELEMENTWISE_VERSIONED_KERNEL(Greater, 9, 12, Greater);
|
||||
REG_ELEMENTWISE_KERNEL(Greater, 13, Greater);
|
||||
|
||||
JSEP_KERNEL_IMPL(GreaterOrEqual, GreaterOrEqual)
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(GreaterOrEqual, 12, 15, GreaterOrEqual);
|
||||
REG_ELEMENTWISE_KERNEL(GreaterOrEqual, 16, GreaterOrEqual);
|
||||
|
||||
JSEP_KERNEL_IMPL(Less, Less)
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Less, 7, 8, Less);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Less, 9, 12, Less);
|
||||
REG_ELEMENTWISE_KERNEL(Less, 13, Less);
|
||||
|
||||
JSEP_KERNEL_IMPL(LessOrEqual, LessOrEqual)
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(LessOrEqual, 12, 15, LessOrEqual);
|
||||
REG_ELEMENTWISE_KERNEL(LessOrEqual, 16, LessOrEqual);
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue