[JS/WebGPU] Support If on WebGPU (#17478)

This commit is contained in:
Hariharan Seshadri 2023-09-19 12:20:18 -07:00 committed by GitHub
parent 152e61da37
commit 460f17fbb8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 108 additions and 4 deletions

View file

@ -46,6 +46,7 @@ Do not modify directly.*
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| Greater | ai.onnx(7-8,9-12,13+) | |
| GreaterOrEqual | ai.onnx(12-15,16+) | |
| If | ai.onnx(1-10,11-12,13-18,19+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
| LayerNormalization | ai.onnx(17+) | |
| LeakyRelu | ai.onnx(6-15,16+) | |

View file

@ -602,6 +602,11 @@
// // "test_hardsigmoid",
// // "test_hardswish_expanded",
// // "test_hardswish",
"test_if",
// TODO: Uncomment 'test_if_seq' and 'test_if_opt' once the test infra
// supports Sequence and Optional types
// "test_if_seq",
// "test_if_opt",
"test_instancenorm_epsilon",
"test_instancenorm_example",
// "test_isinf_negative",

View file

@ -1030,7 +1030,10 @@ Status SessionState::CreateSubgraphSessionState() {
for (auto& node : graph_.Nodes()) {
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
const auto& ep = node.GetExecutionProviderType();
if (!ep.empty() && ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && ep != kRocmExecutionProvider && ep != kDmlExecutionProvider) {
if (!ep.empty() &&
ep != kCpuExecutionProvider && ep != kCudaExecutionProvider &&
ep != kRocmExecutionProvider && ep != kDmlExecutionProvider &&
ep != kJsExecutionProvider) {
// SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow
// node containing the subgraph it will create whatever state it needs internally.
continue;

View file

@ -318,7 +318,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Til
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad);
@ -327,6 +326,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, If);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, If);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If);
std::unique_ptr<KernelRegistry> RegisterKernels() {
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();
@ -580,15 +584,17 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If)>,
};
for (auto& function_table_entry : function_table) {

View file

@ -0,0 +1,65 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "if.h"
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
namespace js {
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
kOnnxDomain,
1, 10,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
If);
// output shape rules requiring the output shapes of the 'THEN' and 'ELSE'
// branches to be the same were relaxed in opset-11
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
kOnnxDomain,
11, 12,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
If);
// opset-13 supports sequence type for If's subgraph outputs
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
kOnnxDomain,
13, 18,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
// Support sequence/optional tensors when all JSEP infra
// (including tests runner) supports it
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
If);
// opset-19 supports float8
ONNX_OPERATOR_KERNEL_EX(If,
kOnnxDomain,
19,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
// Support sequence/optional tensors when all JSEP infra
// (including tests runner) supports it
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
If);
Status If::Compute(OpKernelContext* ctx) const {
// call the base CPU version.
return onnxruntime::If::Compute(ctx);
}
} // namespace js
} // namespace onnxruntime

View file

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <functional>
#include "core/providers/js/js_kernel.h"
#include "core/common/common.h"
#include "core/providers/cpu/controlflow/if.h"
namespace onnxruntime {
class SessionState;
namespace js {
// Use the CPU implementation for the logic
class If final : public onnxruntime::If {
public:
If(const OpKernelInfo& info) : onnxruntime::If(info) {}
Status Compute(OpKernelContext* ctx) const override;
};
} // namespace js
} // namespace onnxruntime