mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
**Description**: LayerNormalization is now part of the ONNX spec as of opset 17. We had a LayerNormalization contrib op, which (incorrectly) was registered in the ONNX domain. Use that implementation for the ONNX operator. Update skip_layer_norm_fusion.cc. There are other optimizers that use LayerNormalization that need updates as well. **Motivation and Context** #12916
30 lines
1.8 KiB
C++
30 lines
1.8 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
// LayerNorm was a contrib op but is now part of the ONNX spec
|
|
#include "layer_norm.h"
|
|
|
|
#include "core/providers/common.h"
|
|
|
|
namespace onnxruntime {
|
|
namespace contrib {
|
|
// original LayerNormalization contrib op (incorrectly using onnx domain though)
|
|
#define REGISTER_CONTRIB_KERNELS(T) \
|
|
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(LayerNormalization, kOnnxDomain, 1, 16, T, kCpuExecutionProvider, \
|
|
KernelDefBuilder() \
|
|
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
|
.TypeConstraint("U", DataTypeImpl::GetTensorType<T>()) \
|
|
.TypeConstraint("V", DataTypeImpl::GetTensorType<T>()), \
|
|
LayerNorm<false>); \
|
|
ONNX_OPERATOR_TYPED_KERNEL_EX(SimplifiedLayerNormalization, kOnnxDomain, 1, T, kCpuExecutionProvider, \
|
|
KernelDefBuilder() \
|
|
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
|
.TypeConstraint("U", DataTypeImpl::GetTensorType<T>()) \
|
|
.TypeConstraint("V", DataTypeImpl::GetTensorType<T>()), \
|
|
LayerNorm<true>);
|
|
|
|
REGISTER_CONTRIB_KERNELS(float)
|
|
REGISTER_CONTRIB_KERNELS(double)
|
|
|
|
} // namespace contrib
|
|
} // namespace onnxruntime
|