support for layernorm in webgpu pre opset-17 (#21121)

handled the same way cpu does
This commit is contained in:
Guenther Schmuelling 2024-06-27 10:20:48 -07:00 committed by GitHub
parent 8f738d8e9f
commit 9eb1c2a7a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 19 additions and 1 deletions

View file

@ -58,7 +58,7 @@ Do not modify directly.*
| HardSigmoid | ai.onnx(6+) | |
| If | ai.onnx(1-10,11-12,13-18,19+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
| LayerNormalization | ai.onnx(17+) | |
| LayerNormalization | ai.onnx(1-16,17+) | |
| LeakyRelu | ai.onnx(6-15,16+) | |
| Less | ai.onnx(7-8,9-12,13+) | |
| LessOrEqual | ai.onnx(12-15,16+) | |

View file

@ -14,6 +14,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGe
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, GroupQueryAttention);
// LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, QuickGelu);
@ -23,6 +25,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, Simp
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization);
template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
KernelCreateInfo info;
return info;
@ -37,6 +40,8 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
// LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, QuickGelu)>,

View file

@ -8,6 +8,19 @@ namespace onnxruntime {
namespace contrib {
namespace js {
// LayerNormalization used to be a contrib op
// that (incorrectly) used kOnnxDomain so we need to version it
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
LayerNormalization,
kOnnxDomain,
1,
16,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", onnxruntime::js::JsepSupportedFloatTypes())
.TypeConstraint("U", onnxruntime::js::JsepSupportedFloatTypes()),
onnxruntime::js::LayerNorm<false>);
ONNX_OPERATOR_KERNEL_EX(
SimplifiedLayerNormalization,
kOnnxDomain,