onnxruntime/onnxruntime/contrib_ops/cpu/layer_norm.cc
Scott McKay 394c249c7c
Add ONNX LayerNormalization(17) (#12978)
**Description**: LayerNormalization is now part of the ONNX spec as of
opset 17.
We had a LayerNormalization contrib op, which (incorrectly) was
registered in the ONNX domain. Use that implementation for the ONNX
operator.

Update skip_layer_norm_fusion.cc. There are other optimizers that use
LayerNormalization that need updates as well.

**Motivation and Context**
#12916
2022-09-23 09:49:27 +10:00

30 lines
1.8 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// LayerNorm was a contrib op but is now part of the ONNX spec
#include "layer_norm.h"
#include "core/providers/common.h"
namespace onnxruntime {
namespace contrib {
// original LayerNormalization contrib op (incorrectly using onnx domain though)
#define REGISTER_CONTRIB_KERNELS(T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(LayerNormalization, kOnnxDomain, 1, 16, T, kCpuExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("U", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("V", DataTypeImpl::GetTensorType<T>()), \
LayerNorm<false>); \
ONNX_OPERATOR_TYPED_KERNEL_EX(SimplifiedLayerNormalization, kOnnxDomain, 1, T, kCpuExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("U", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("V", DataTypeImpl::GetTensorType<T>()), \
LayerNorm<true>);
REGISTER_CONTRIB_KERNELS(float)
REGISTER_CONTRIB_KERNELS(double)
} // namespace contrib
} // namespace onnxruntime