mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
[JS/WebGPU] Support If on WebGPU (#17478)
This commit is contained in:
parent
152e61da37
commit
460f17fbb8
6 changed files with 108 additions and 4 deletions
|
|
@ -46,6 +46,7 @@ Do not modify directly.*
|
|||
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
|
||||
| Greater | ai.onnx(7-8,9-12,13+) | |
|
||||
| GreaterOrEqual | ai.onnx(12-15,16+) | |
|
||||
| If | ai.onnx(1-10,11-12,13-18,19+) | |
|
||||
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
|
||||
| LayerNormalization | ai.onnx(17+) | |
|
||||
| LeakyRelu | ai.onnx(6-15,16+) | |
|
||||
|
|
|
|||
|
|
@ -602,6 +602,11 @@
|
|||
// // "test_hardsigmoid",
|
||||
// // "test_hardswish_expanded",
|
||||
// // "test_hardswish",
|
||||
"test_if",
|
||||
// TODO: Uncomment 'test_if_seq' and 'test_if_opt' once the test infra
|
||||
// supports Sequence and Optional types
|
||||
// "test_if_seq",
|
||||
// "test_if_opt",
|
||||
"test_instancenorm_epsilon",
|
||||
"test_instancenorm_example",
|
||||
// "test_isinf_negative",
|
||||
|
|
|
|||
|
|
@ -1030,7 +1030,10 @@ Status SessionState::CreateSubgraphSessionState() {
|
|||
for (auto& node : graph_.Nodes()) {
|
||||
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
|
||||
const auto& ep = node.GetExecutionProviderType();
|
||||
if (!ep.empty() && ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && ep != kRocmExecutionProvider && ep != kDmlExecutionProvider) {
|
||||
if (!ep.empty() &&
|
||||
ep != kCpuExecutionProvider && ep != kCudaExecutionProvider &&
|
||||
ep != kRocmExecutionProvider && ep != kDmlExecutionProvider &&
|
||||
ep != kJsExecutionProvider) {
|
||||
// SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow
|
||||
// node containing the subgraph it will create whatever state it needs internally.
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -318,7 +318,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Til
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad);
|
||||
|
|
@ -327,6 +326,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, If);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, If);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If);
|
||||
|
||||
std::unique_ptr<KernelRegistry> RegisterKernels() {
|
||||
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();
|
||||
|
||||
|
|
@ -580,15 +584,17 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, If)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, If)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
65
onnxruntime/core/providers/js/operators/if.cc
Normal file
65
onnxruntime/core/providers/js/operators/if.cc
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "if.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
|
||||
kOnnxDomain,
|
||||
1, 10,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
|
||||
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
|
||||
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
|
||||
If);
|
||||
// output shape rules requiring the output shapes of the 'THEN' and 'ELSE'
|
||||
// branches to be the same were relaxed in opset-11
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
|
||||
kOnnxDomain,
|
||||
11, 12,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
|
||||
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
|
||||
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
|
||||
If);
|
||||
|
||||
// opset-13 supports sequence type for If's subgraph outputs
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
|
||||
kOnnxDomain,
|
||||
13, 18,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
|
||||
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
|
||||
// Support sequence/optional tensors when all JSEP infra
|
||||
// (including tests runner) supports it
|
||||
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
|
||||
If);
|
||||
|
||||
// opset-19 supports float8
|
||||
ONNX_OPERATOR_KERNEL_EX(If,
|
||||
kOnnxDomain,
|
||||
19,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
|
||||
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
|
||||
// Support sequence/optional tensors when all JSEP infra
|
||||
// (including tests runner) supports it
|
||||
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
|
||||
If);
|
||||
|
||||
Status If::Compute(OpKernelContext* ctx) const {
|
||||
// call the base CPU version.
|
||||
return onnxruntime::If::Compute(ctx);
|
||||
}
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
24
onnxruntime/core/providers/js/operators/if.h
Normal file
24
onnxruntime/core/providers/js/operators/if.h
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <functional>
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cpu/controlflow/if.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
class SessionState;
|
||||
|
||||
namespace js {
|
||||
|
||||
// Use the CPU implementation for the logic
|
||||
class If final : public onnxruntime::If {
|
||||
public:
|
||||
If(const OpKernelInfo& info) : onnxruntime::If(info) {}
|
||||
|
||||
Status Compute(OpKernelContext* ctx) const override;
|
||||
};
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue