diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index a969e1b86b..ad18302318 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -44,10 +44,12 @@ Do not modify directly.* | GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | Greater | ai.onnx(7-8,9-12,13+) | | +| GreaterOrEqual | ai.onnx(12-15,16+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | | LayerNormalization | ai.onnx(17+) | | | LeakyRelu | ai.onnx(6-15,16+) | | | Less | ai.onnx(7-8,9-12,13+) | | +| LessOrEqual | ai.onnx(12-15,16+) | | | Log | ai.onnx(6-12,13+) | | | MatMul | ai.onnx(1-12,13+) | | | MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(11,12+) | need perf optimization; need implementing activation | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 23aabb6531..a8fbf9c00e 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -65,10 +65,12 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]], ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], ['Greater', [binaryOps.greater]], + ['GreaterOrEqual', [binaryOps.greaterOrEqual]], ['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]], ['LayerNormalization', [layerNorm, parseLayerNormAttributes]], ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]], ['Less', [binaryOps.less]], + ['LessOrEqual', [binaryOps.lessOrEqual]], ['Log', [unaryOps.log]], ['MatMul', [matMul]], // TODO: support new attributes for MaxPool-8 and MaxPool-10 diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 28284554f9..b004ca37a2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -240,3 +240,16 @@ export const less = (context: ComputeContext): void => { context.inputs, 'Less', ({scalar: (a, b) => `u32(${a}<${b})`, vector: (a, b) => `vec4(${a}<${b})`}), undefined, undefined, DataType.bool)); }; + +export const greaterOrEqual = (context: ComputeContext): void => { + context.compute(createBinaryOpProgramInfoLoader( + context.inputs, 'GreaterOrEqual', + ({scalar: (a, b) => `u32(${a}>=${b})`, vector: (a, b) => `vec4(${a}>=${b})`}), undefined, undefined, + DataType.bool)); +}; + +export const lessOrEqual = (context: ComputeContext): void => { + context.compute(createBinaryOpProgramInfoLoader( + context.inputs, 'LessOrEqual', ({scalar: (a, b) => `u32(${a}<=${b})`, vector: (a, b) => `vec4(${a}<=${b})`}), + undefined, undefined, DataType.bool)); +}; diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 829f3e5f4f..c5b3b1933e 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -205,9 +205,13 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Equ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Greater); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 12, Greater); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Greater); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 15, GreaterOrEqual); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, GreaterOrEqual); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Less); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 12, Less); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Less); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 15, LessOrEqual); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, LessOrEqual); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Shape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Shape); @@ -403,9 +407,13 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO_VERSIONED(7, 8, Greater), KERNEL_CREATE_INFO_VERSIONED(9, 12, Greater), KERNEL_CREATE_INFO(13, Greater), + KERNEL_CREATE_INFO_VERSIONED(12, 15, GreaterOrEqual), + KERNEL_CREATE_INFO(16, GreaterOrEqual), KERNEL_CREATE_INFO_VERSIONED(7, 8, Less), KERNEL_CREATE_INFO_VERSIONED(9, 12, Less), KERNEL_CREATE_INFO(13, Less), + KERNEL_CREATE_INFO_VERSIONED(12, 15, LessOrEqual), + KERNEL_CREATE_INFO(16, LessOrEqual), BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/binary.cc b/onnxruntime/core/providers/js/operators/binary.cc index 2a96619c2c..98f7ca6e61 100644 --- a/onnxruntime/core/providers/js/operators/binary.cc +++ b/onnxruntime/core/providers/js/operators/binary.cc @@ -63,10 +63,18 @@ REG_ELEMENTWISE_VERSIONED_KERNEL(Greater, 7, 8, Greater); REG_ELEMENTWISE_VERSIONED_KERNEL(Greater, 9, 12, Greater); REG_ELEMENTWISE_KERNEL(Greater, 13, Greater); +JSEP_KERNEL_IMPL(GreaterOrEqual, GreaterOrEqual) +REG_ELEMENTWISE_VERSIONED_KERNEL(GreaterOrEqual, 12, 15, GreaterOrEqual); +REG_ELEMENTWISE_KERNEL(GreaterOrEqual, 16, GreaterOrEqual); + JSEP_KERNEL_IMPL(Less, Less) REG_ELEMENTWISE_VERSIONED_KERNEL(Less, 7, 8, Less); REG_ELEMENTWISE_VERSIONED_KERNEL(Less, 9, 12, Less); REG_ELEMENTWISE_KERNEL(Less, 13, Less); +JSEP_KERNEL_IMPL(LessOrEqual, LessOrEqual) +REG_ELEMENTWISE_VERSIONED_KERNEL(LessOrEqual, 12, 15, LessOrEqual); +REG_ELEMENTWISE_KERNEL(LessOrEqual, 16, LessOrEqual); + } // namespace js } // namespace onnxruntime