mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
[ROCm] enable LayerNorm opset Ver17 for ROCm EP (#15601)
enable LayerNorm opset Ver17 for ROCm EP.
This commit is contained in:
parent
45c82eefb4
commit
9df1a5e605
2 changed files with 8 additions and 13 deletions
|
|
@ -1195,12 +1195,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
|
||||
|
||||
// Opset 17
|
||||
// TODO: Enable LayerNormalization. It uses the same implementation as the old contrib op.
|
||||
// See https://github.com/microsoft/onnxruntime/pull/13066
|
||||
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
|
||||
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, double, LayerNormalization);
|
||||
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, BFloat16, LayerNormalization);
|
||||
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, MLFloat16, LayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, double, LayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, BFloat16, LayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, MLFloat16, LayerNormalization);
|
||||
|
||||
// Opset 18
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split);
|
||||
|
|
@ -2128,10 +2126,10 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual)>,
|
||||
|
||||
// Opset 17
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, double, LayerNormalization)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, BFloat16, LayerNormalization)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, MLFloat16, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, double, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, BFloat16, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, MLFloat16, LayerNormalization)>,
|
||||
|
||||
// Opset 18
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split)>,
|
||||
|
|
|
|||
|
|
@ -150,14 +150,11 @@ TEST(CudaKernelTest, SimplifiedLayerNorm_LargeSizeTensor) {
|
|||
TestLayerNorm(X_dims, SIMPLIFIED_LAYER_NORM_OP, k_epsilon_default);
|
||||
}
|
||||
|
||||
// TODO: Generate the ROCM implementation of ONNX LayerNorm
|
||||
#if defined(USE_CUDA) || defined(USE_DML)
|
||||
// LayerNormalization is an ONNX operator in opset 17. It uses the same implementation so this is just a sanity check.
|
||||
TEST(CudaKernelTest, LayerNorm_SmallSizeTensor_Opset17) {
|
||||
const std::vector<int64_t> X_dims{4, 20, 128};
|
||||
TestLayerNorm(X_dims, LAYER_NORM_OP, k_epsilon_default, -1, 1, false, 17);
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
} // namespace test
|
||||
|
|
|
|||
Loading…
Reference in a new issue