mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
support for layernorm in webgpu pre opset-17 (#21121)
handled the same way cpu does
This commit is contained in:
parent
8f738d8e9f
commit
9eb1c2a7a3
3 changed files with 19 additions and 1 deletions
|
|
@ -58,7 +58,7 @@ Do not modify directly.*
|
|||
| HardSigmoid | ai.onnx(6+) | |
|
||||
| If | ai.onnx(1-10,11-12,13-18,19+) | |
|
||||
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
|
||||
| LayerNormalization | ai.onnx(17+) | |
|
||||
| LayerNormalization | ai.onnx(1-16,17+) | |
|
||||
| LeakyRelu | ai.onnx(6-15,16+) | |
|
||||
| Less | ai.onnx(7-8,9-12,13+) | |
|
||||
| LessOrEqual | ai.onnx(12-15,16+) | |
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGe
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, GroupQueryAttention);
|
||||
// LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, QuickGelu);
|
||||
|
|
@ -23,6 +25,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, Simp
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization);
|
||||
|
||||
template <>
|
||||
|
||||
KernelCreateInfo BuildKernelCreateInfo<void>() {
|
||||
KernelCreateInfo info;
|
||||
return info;
|
||||
|
|
@ -37,6 +40,8 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
|
||||
// LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, QuickGelu)>,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,19 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
// LayerNormalization used to be a contrib op
|
||||
// that (incorrectly) used kOnnxDomain so we need to version it
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
LayerNormalization,
|
||||
kOnnxDomain,
|
||||
1,
|
||||
16,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", onnxruntime::js::JsepSupportedFloatTypes())
|
||||
.TypeConstraint("U", onnxruntime::js::JsepSupportedFloatTypes()),
|
||||
onnxruntime::js::LayerNorm<false>);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
SimplifiedLayerNormalization,
|
||||
kOnnxDomain,
|
||||
|
|
|
|||
Loading…
Reference in a new issue