From a0b470bc35cb2df26d3dc8c75c5261f92e8cf586 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Thu, 1 Dec 2022 14:08:18 -0800 Subject: [PATCH] [DML EP] Add mixed datatype support for DML's LayerNorm contrib op (#13734) ### Description Add mixed datatype support for DML's LayerNorm contrib op. ### Motivation and Context The fusion logic removes casts around LayerNorm in the graph because the contrib version of the op supports mixed datatypes. Scale, Bias and Output's datatypes must match, but input's datatype can be different. --- docs/OperatorKernels.md | 2 +- .../src/Operators/DmlOperator.h | 7 +- .../DmlOperatorLayerNormalization.cpp | 199 +++++++++++++++++- .../src/Operators/OperatorRegistration.cpp | 4 +- .../test/contrib_ops/layer_norm_op_test.cc | 92 +++++++- 5 files changed, 286 insertions(+), 18 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 01c0902022..8cedb7ecc0 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -931,7 +931,7 @@ Do not modify directly.* |LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(float), tensor(float16)| |||7+|**T** = tensor(float), tensor(float16)| |LayerNormalization|*in* X:**T**
*in* Scale:**T**
*in* B:**T**
*out* Y:**T**
*out* Mean:**U**
*out* InvStdDev:**U**

or

*in* X:**T**
*in* Scale:**V**
*in* B:**V**
*out* Y:**V**
*out* Mean:**U**
*out* InvStdDev:**U**|17+|**T** = tensor(float), tensor(float16)
**U** = tensor(float)| -|||1+|**T** = tensor(float), tensor(float16)
**U** = tensor(float)| +|||1+|**T** = tensor(float), tensor(float16)
**V** = tensor(float), tensor(float16)| |LeakyRelu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)| |Less|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h index d398233e4c..493cc2e445 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h @@ -64,8 +64,8 @@ namespace Dml ); // This method only works with DML_GRAPH. - // To make it work without DML_GRAPH, we need to add new functionality - // in DMLX i.e. DMLX should also give access to DML_OPERATOR_DESC + // To make it work without DML_GRAPH, we need to add new functionality + // in DMLX i.e. DMLX should also give access to DML_OPERATOR_DESC // rather than IDMLOperator. void SetDmlOperatorGraphDesc( const MLOperatorGraphDesc&& operatorGraphDesc, @@ -103,7 +103,6 @@ namespace Dml // It returns nullptr if there is no work to do (0 bytes). // ComPtr InitializeZeroInt64Tensor(uint64_t tensorSizeInBytes); - void ExecuteZeroInt64Tensor(IDMLCompiledOperator* compiledOperator, IMLOperatorTensor* tensor); TensorDesc CreateTensorDescFromInput( @@ -140,7 +139,7 @@ namespace Dml _Inout_ std::vector& dmlInputEdges, _Inout_ std::vector& dmlOutputEdges, _Inout_ std::vector& dmlIntermediateEdges); - + }; } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp index 8858529056..6477e0c36b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp @@ -18,7 +18,7 @@ public: // because DML MVN1 has a validation which requires all 3 needs to have same dimension count // due to historical artifact. DmlOperator::Initialize( - kernelCreationContext, + kernelCreationContext, kernelInputIndices, std::nullopt, std::nullopt, @@ -33,22 +33,199 @@ public: std::vector onnxAxes(inputDimCount - onnxAxis); std::iota(onnxAxes.begin(), onnxAxes.end(), onnxAxis); - std::vector inputDescs = GetDmlInputDescs(); - std::vector outputDescs = GetDmlOutputDescs(); + assert(m_inputTensorDescs.size() == 3); + assert(m_outputTensorDescs.size() == 1); + + auto inputDataType = m_inputTensorDescs[0].GetDmlDataType(); + ORT_THROW_HR_IF(E_INVALIDARG, inputDataType != DML_TENSOR_DATA_TYPE_FLOAT16 && inputDataType != DML_TENSOR_DATA_TYPE_FLOAT32); + + auto scaleDataType = m_inputTensorDescs[1].GetDmlDataType(); + ORT_THROW_HR_IF(E_INVALIDARG, scaleDataType != DML_TENSOR_DATA_TYPE_FLOAT16 && scaleDataType != DML_TENSOR_DATA_TYPE_FLOAT32); + + // Scale, Bias and Output always have the same data type + ORT_THROW_HR_IF(E_INVALIDARG, m_inputTensorDescs[2].GetDmlDataType() != DML_TENSOR_TYPE_INVALID && m_inputTensorDescs[2].GetDmlDataType() != scaleDataType); + ORT_THROW_HR_IF(E_INVALIDARG, m_outputTensorDescs[0].GetDmlDataType() != scaleDataType); + + auto inputDesc = m_inputTensorDescs[0].GetDmlDesc(); + auto scaleDesc = m_inputTensorDescs[1].GetDmlDesc(); + auto biasDesc = m_inputTensorDescs[2].GetDmlDesc(); + auto outputDesc = m_outputTensorDescs[0].GetDmlDesc(); + + DML_CAST_OPERATOR_DESC inputCastDesc = {}; + DML_OPERATOR_DESC inputCastOpDesc = { DML_OPERATOR_CAST, nullptr }; + + DML_CAST_OPERATOR_DESC scaleCastDesc = {}; + DML_OPERATOR_DESC scaleCastOpDesc = { DML_OPERATOR_CAST, nullptr }; + + DML_CAST_OPERATOR_DESC biasCastDesc = {}; + DML_OPERATOR_DESC biasCastOpDesc = { DML_OPERATOR_CAST, nullptr }; + + // When data types mismatch, we cast to the highest precision to respect DML's requirement that all datatypes must match + TensorDesc inputCastOutputTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[0].GetSizes()); + DML_TENSOR_DESC inputCastOutputDmlTensorDesc = inputCastOutputTensorDesc.GetDmlDesc(); + + TensorDesc scaleCastOutputTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[1].GetSizes()); + DML_TENSOR_DESC scaleCastOutputDmlTensorDesc = scaleCastOutputTensorDesc.GetDmlDesc(); + + TensorDesc biasCastOutputTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[2].GetSizes()); + DML_TENSOR_DESC biasCastOutputDmlTensorDesc = biasCastOutputTensorDesc.GetDmlDesc(); + + // Cast all tensors to the highest common precision + if (inputDataType == DML_TENSOR_DATA_TYPE_FLOAT16 && scaleDataType == DML_TENSOR_DATA_TYPE_FLOAT32) + { + inputCastDesc.InputTensor = &inputDesc; + inputCastDesc.OutputTensor = &inputCastOutputDmlTensorDesc; + inputCastOpDesc.Desc = &inputCastDesc; + } + else if (inputDataType == DML_TENSOR_DATA_TYPE_FLOAT32 && scaleDataType == DML_TENSOR_DATA_TYPE_FLOAT16) + { + scaleCastDesc.InputTensor = &scaleDesc; + scaleCastDesc.OutputTensor = &scaleCastOutputDmlTensorDesc; + scaleCastOpDesc.Desc = &scaleCastDesc; + + if (m_inputTensorDescs[2].GetDmlDataType() != DML_TENSOR_TYPE_INVALID) + { + biasCastDesc.InputTensor = &biasDesc; + biasCastDesc.OutputTensor = &biasCastOutputDmlTensorDesc; + biasCastOpDesc.Desc = &biasCastDesc; + } + } + + // Make sure that the output is the same type as the input + DML_CAST_OPERATOR_DESC outputCastDesc = {}; + DML_OPERATOR_DESC outputCastOpDesc = { DML_OPERATOR_CAST, nullptr }; + + auto realInputDataType = inputCastOpDesc.Desc ? inputCastOutputTensorDesc.GetDmlDataType() : m_inputTensorDescs[0].GetDmlDataType(); + TensorDesc outputCastOutputTensorDesc(realInputDataType, m_outputTensorDescs[0].GetSizes()); + DML_TENSOR_DESC outputCastOutputDmlTensorDesc = outputCastOutputTensorDesc.GetDmlDesc(); + + if (realInputDataType != m_outputTensorDescs[0].GetDmlDataType()) + { + // After the operator has been executed, we need to cast the "casted" output tensor to the original output tensor that TF expects + outputCastDesc.InputTensor = &outputCastOutputDmlTensorDesc; + outputCastDesc.OutputTensor = &outputDesc; + outputCastOpDesc.Desc = &outputCastDesc; + } DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC operatorDesc = {}; - operatorDesc.InputTensor = &inputDescs[0]; - operatorDesc.ScaleTensor = &inputDescs[1]; - operatorDesc.BiasTensor = inputDescs[2].Desc != nullptr ? &inputDescs[2] : nullptr; - operatorDesc.OutputTensor = outputDescs.data(); + operatorDesc.InputTensor = inputCastOpDesc.Desc ? &inputCastOutputDmlTensorDesc : &inputDesc; + operatorDesc.ScaleTensor = scaleCastOpDesc.Desc ? &scaleCastOutputDmlTensorDesc : &scaleDesc; + operatorDesc.BiasTensor = biasCastOpDesc.Desc ? &biasCastOutputDmlTensorDesc : (biasDesc.Desc ? &biasDesc : nullptr); + operatorDesc.OutputTensor = outputCastOpDesc.Desc ? &outputCastOutputDmlTensorDesc : &outputDesc; operatorDesc.Axes = onnxAxes.data(); operatorDesc.AxisCount = gsl::narrow_cast(onnxAxes.size()); operatorDesc.NormalizeVariance = true; operatorDesc.Epsilon = epsilon; operatorDesc.FusedActivation = nullptr; - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &operatorDesc }; - SetDmlOperatorDesc(opDesc, kernelCreationContext); + + // Construct the graph + std::vector opDescs; + opDescs.reserve(5); + + std::vector inputEdges; + inputEdges.reserve(3); + + std::vector intermediateEdges; + intermediateEdges.reserve(4); + + std::vector outputEdges; + outputEdges.reserve(1); + + opDescs.push_back(&opDesc); + uint32_t currentNodeIndex = 1; + + DML_INPUT_GRAPH_EDGE_DESC dataInputEdge = {}; + dataInputEdge.GraphInputIndex = 0; + dataInputEdge.ToNodeIndex = inputCastOpDesc.Desc ? currentNodeIndex : 0; + dataInputEdge.ToNodeInputIndex = 0; + inputEdges.push_back(std::move(dataInputEdge)); + + if (inputCastOpDesc.Desc) + { + opDescs.push_back(&inputCastOpDesc); + + // Link the cast op to the MVN op + DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {}; + intermediateEdge.FromNodeIndex = currentNodeIndex; + intermediateEdge.FromNodeOutputIndex = 0; + intermediateEdge.ToNodeIndex = 0; + intermediateEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(std::move(intermediateEdge)); + ++currentNodeIndex; + } + + DML_INPUT_GRAPH_EDGE_DESC scaleInputEdge = {}; + scaleInputEdge.GraphInputIndex = 1; + scaleInputEdge.ToNodeIndex = scaleCastOpDesc.Desc ? currentNodeIndex : 0; + scaleInputEdge.ToNodeInputIndex = scaleCastOpDesc.Desc ? 0 : 1; + inputEdges.push_back(std::move(scaleInputEdge)); + + if (scaleCastOpDesc.Desc) + { + opDescs.push_back(&scaleCastOpDesc); + + // Link the cast op to the MVN op + DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {}; + intermediateEdge.FromNodeIndex = currentNodeIndex; + intermediateEdge.FromNodeOutputIndex = 0; + intermediateEdge.ToNodeIndex = 0; + intermediateEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(std::move(intermediateEdge)); + ++currentNodeIndex; + } + + DML_INPUT_GRAPH_EDGE_DESC biasInputEdge = {}; + biasInputEdge.GraphInputIndex = 2; + biasInputEdge.ToNodeIndex = biasCastOpDesc.Desc ? currentNodeIndex : 0; + biasInputEdge.ToNodeInputIndex = biasCastOpDesc.Desc ? 0 : 2; + inputEdges.push_back(std::move(biasInputEdge)); + + if (biasCastOpDesc.Desc) + { + opDescs.push_back(&biasCastOpDesc); + + // Link the cast op to the MVN op + DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {}; + intermediateEdge.FromNodeIndex = currentNodeIndex; + intermediateEdge.FromNodeOutputIndex = 0; + intermediateEdge.ToNodeIndex = 0; + intermediateEdge.ToNodeInputIndex = 2; + intermediateEdges.push_back(std::move(intermediateEdge)); + ++currentNodeIndex; + } + + DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {}; + outputEdge.GraphOutputIndex = 0; + outputEdge.FromNodeIndex = outputCastOpDesc.Desc ? currentNodeIndex : 0; + outputEdge.FromNodeOutputIndex = 0; + outputEdges.push_back(std::move(outputEdge)); + + if (outputCastOpDesc.Desc) + { + opDescs.push_back(&outputCastOpDesc); + + // Link the MVN op to the cast op + DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {}; + intermediateEdge.FromNodeIndex = 0; + intermediateEdge.FromNodeOutputIndex = 0; + intermediateEdge.ToNodeIndex = currentNodeIndex; + intermediateEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(std::move(intermediateEdge)); + ++currentNodeIndex; + } + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; @@ -57,9 +234,9 @@ void CALLBACK QueryLayerNormalization(IMLOperatorSupportQueryContextPrivate* con *isSupported = false; // Mean and InvStdDev are not supported outputs. - // If only Scale tensor is present then fall back to CPU. This is temporary until + // If only Scale tensor is present then fall back to CPU. This is temporary until // DML1.9.2 or DML1.10 gets released. - if (context->GetInputCount() < 3 || context->GetOutputCount() > 1) + if (context->GetInputCount() < 3 || context->GetOutputCount() > 1) { return; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index d32ad0cb40..0c0d08628c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -282,6 +282,7 @@ constexpr static std::array typeNameListDefault = {"T"}; constexpr static std::array typeNameListAttention = {"T", "M"}; constexpr static std::array typeNameListTwo = { "T1", "T2" }; constexpr static std::array typeNameListLayerNorm = { "T", "U" }; +constexpr static std::array typeNameListLayerNormContrib = { "T", "V" }; constexpr static std::array typeNameListThree = { "T1", "T2", "T3" }; constexpr static std::array typeNameListFour = { "T1", "T2", "T3", "T4" }; constexpr static std::array typeNameListTopK = { "T", "I" }; @@ -339,6 +340,7 @@ constexpr static std::array supportedTypeListIntege constexpr static std::array supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 }; constexpr static std::array supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; constexpr static std::array supportedTypeListArgMinMax = {SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64}; +constexpr static std::array supportedTypeListLayerNormalizationContrib = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32}; constexpr static std::array supportedTypeListLayerNormalization = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float32}; constexpr static std::array supportedTypeListShape = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64}; constexpr static std::array supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64}; @@ -423,7 +425,6 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute. {REG_INFO( 14, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryBatchNormalization)}, // v14 adds training_mode attribute {REG_INFO( 15, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryBatchNormalization)}, // v15 adds differing types for scale and bias vs input. - {REG_INFO( 7, LayerNormalization, typeNameListLayerNorm, supportedTypeListLayerNormalization, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)}, {REG_INFO_VER( 17, LayerNormalization, typeNameListLayerNorm, supportedTypeListLayerNormalization, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)}, {REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 13, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, @@ -739,6 +740,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, {REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, {REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)}, + {REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)}, }; template diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index fb007b5e41..5caad57e9c 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -94,6 +94,57 @@ TEST(LayerNormTest, LayerNorm_Scale) { test.Run(); } +TEST(LayerNormTest, LayerNorm_Scale_Float16Input) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet"; + } + + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{2, 2, 2}; + test.AddInput("x", dims, ToFloat16({-10.264f, 8.6453f, 43.1561f, -0.641239f, -8.2164f, 0.11412f, 41.3156f, 3.0458f})); + test.AddInput("gamma", {2}, {-0.6953f, 5.1824f}); + test.AddOutput("output", dims, {0.6953f, 5.1824f, -0.6953f, -5.1824f, 0.6953f, 5.1824f, -0.6953f, -5.1824f}); + // TRT, DNNL and OpenVINO don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet"; + } + + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{2, 2, 2}; + test.AddInput("x", dims, {-10.264f, 8.6453f, 43.1561f, -0.641239f, -8.2164f, 0.11412f, 41.3156f, 3.0458f}); + test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f})); + test.AddOutput("output", dims, ToFloat16({0.6953f, 5.1824f, -0.6953f, -5.1824f, 0.6953f, 5.1824f, -0.6953f, -5.1824f})); + // TRT, DNNL and OpenVINO don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet"; + } + + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{2, 2, 2}; + test.AddInput("x", dims, ToFloat16({-10.264f, 8.6453f, 43.1561f, -0.641239f, -8.2164f, 0.11412f, 41.3156f, 3.0458f})); + test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f})); + test.AddOutput("output", dims, ToFloat16({0.6953f, 5.1824f, -0.6953f, -5.1824f, 0.6953f, 5.1824f, -0.6953f, -5.1824f})); + // TRT, DNNL and OpenVINO don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider}); +} + TEST(LayerNormTest, LayerNorm_Scale_Bias) { OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); @@ -106,6 +157,45 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias) { test.Run(); } +TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16Input) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 3, 2}; + test.AddInput("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f})); + test.AddInput("gamma", {2}, {-0.6953f, 5.1824f}); + test.AddInput("bias", {2}, {0.6435f, -0.3964f}); + test.AddOutput("output", dims, {-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f}); + // TRT, DNNL and OpenVINO don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16ScaleBiasOutput) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 3, 2}; + test.AddInput("x", dims, {1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f}); + test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f})); + test.AddInput("bias", {2}, ToFloat16({0.6435f, -0.3964f})); + test.AddOutput("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f})); + // TRT, DNNL and OpenVINO don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 3, 2}; + test.AddInput("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f})); + test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f})); + test.AddInput("bias", {2}, ToFloat16({0.6435f, -0.3964f})); + test.AddOutput("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f})); + // TRT, DNNL and OpenVINO don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider}); +} + // LayerNormalization became an ONNX operator in opset 17. It uses the same implementation so this is a sanity check. TEST(LayerNormTest, LayerNorm17_float) { OpTester test("LayerNormalization", 17); @@ -151,7 +241,7 @@ TEST(LayerNormTest, LayerNorm_InvalidScaleBias) { #if defined(USE_DNNL) TEST(LayerNormTest, LayerNorm17_Scale_Bias_bfloat16) { #ifdef USE_DNNL - if (!DnnlHasBF16Support()) { + if (!DnnlHasBF16Support()) { LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16"; return; }