diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 01c0902022..8cedb7ecc0 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -931,7 +931,7 @@ Do not modify directly.* |LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(float), tensor(float16)| |||7+|**T** = tensor(float), tensor(float16)| |LayerNormalization|*in* X:**T**
*in* Scale:**T**
*in* B:**T**
*out* Y:**T**
*out* Mean:**U**
*out* InvStdDev:**U**

or

*in* X:**T**
*in* Scale:**V**
*in* B:**V**
*out* Y:**V**
*out* Mean:**U**
*out* InvStdDev:**U**|17+|**T** = tensor(float), tensor(float16)
**U** = tensor(float)| -|||1+|**T** = tensor(float), tensor(float16)
**U** = tensor(float)| +|||1+|**T** = tensor(float), tensor(float16)
**V** = tensor(float), tensor(float16)| |LeakyRelu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)| |Less|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h index d398233e4c..493cc2e445 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h @@ -64,8 +64,8 @@ namespace Dml ); // This method only works with DML_GRAPH. - // To make it work without DML_GRAPH, we need to add new functionality - // in DMLX i.e. DMLX should also give access to DML_OPERATOR_DESC + // To make it work without DML_GRAPH, we need to add new functionality + // in DMLX i.e. DMLX should also give access to DML_OPERATOR_DESC // rather than IDMLOperator. void SetDmlOperatorGraphDesc( const MLOperatorGraphDesc&& operatorGraphDesc, @@ -103,7 +103,6 @@ namespace Dml // It returns nullptr if there is no work to do (0 bytes). // ComPtr InitializeZeroInt64Tensor(uint64_t tensorSizeInBytes); - void ExecuteZeroInt64Tensor(IDMLCompiledOperator* compiledOperator, IMLOperatorTensor* tensor); TensorDesc CreateTensorDescFromInput( @@ -140,7 +139,7 @@ namespace Dml _Inout_ std::vector& dmlInputEdges, _Inout_ std::vector& dmlOutputEdges, _Inout_ std::vector& dmlIntermediateEdges); - + }; } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp index 8858529056..6477e0c36b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp @@ -18,7 +18,7 @@ public: // because DML MVN1 has a validation which requires all 3 needs to have same dimension count // due to historical artifact. DmlOperator::Initialize( - kernelCreationContext, + kernelCreationContext, kernelInputIndices, std::nullopt, std::nullopt, @@ -33,22 +33,199 @@ public: std::vector onnxAxes(inputDimCount - onnxAxis); std::iota(onnxAxes.begin(), onnxAxes.end(), onnxAxis); - std::vector inputDescs = GetDmlInputDescs(); - std::vector outputDescs = GetDmlOutputDescs(); + assert(m_inputTensorDescs.size() == 3); + assert(m_outputTensorDescs.size() == 1); + + auto inputDataType = m_inputTensorDescs[0].GetDmlDataType(); + ORT_THROW_HR_IF(E_INVALIDARG, inputDataType != DML_TENSOR_DATA_TYPE_FLOAT16 && inputDataType != DML_TENSOR_DATA_TYPE_FLOAT32); + + auto scaleDataType = m_inputTensorDescs[1].GetDmlDataType(); + ORT_THROW_HR_IF(E_INVALIDARG, scaleDataType != DML_TENSOR_DATA_TYPE_FLOAT16 && scaleDataType != DML_TENSOR_DATA_TYPE_FLOAT32); + + // Scale, Bias and Output always have the same data type + ORT_THROW_HR_IF(E_INVALIDARG, m_inputTensorDescs[2].GetDmlDataType() != DML_TENSOR_TYPE_INVALID && m_inputTensorDescs[2].GetDmlDataType() != scaleDataType); + ORT_THROW_HR_IF(E_INVALIDARG, m_outputTensorDescs[0].GetDmlDataType() != scaleDataType); + + auto inputDesc = m_inputTensorDescs[0].GetDmlDesc(); + auto scaleDesc = m_inputTensorDescs[1].GetDmlDesc(); + auto biasDesc = m_inputTensorDescs[2].GetDmlDesc(); + auto outputDesc = m_outputTensorDescs[0].GetDmlDesc(); + + DML_CAST_OPERATOR_DESC inputCastDesc = {}; + DML_OPERATOR_DESC inputCastOpDesc = { DML_OPERATOR_CAST, nullptr }; + + DML_CAST_OPERATOR_DESC scaleCastDesc = {}; + DML_OPERATOR_DESC scaleCastOpDesc = { DML_OPERATOR_CAST, nullptr }; + + DML_CAST_OPERATOR_DESC biasCastDesc = {}; + DML_OPERATOR_DESC biasCastOpDesc = { DML_OPERATOR_CAST, nullptr }; + + // When data types mismatch, we cast to the highest precision to respect DML's requirement that all datatypes must match + TensorDesc inputCastOutputTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[0].GetSizes()); + DML_TENSOR_DESC inputCastOutputDmlTensorDesc = inputCastOutputTensorDesc.GetDmlDesc(); + + TensorDesc scaleCastOutputTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[1].GetSizes()); + DML_TENSOR_DESC scaleCastOutputDmlTensorDesc = scaleCastOutputTensorDesc.GetDmlDesc(); + + TensorDesc biasCastOutputTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[2].GetSizes()); + DML_TENSOR_DESC biasCastOutputDmlTensorDesc = biasCastOutputTensorDesc.GetDmlDesc(); + + // Cast all tensors to the highest common precision + if (inputDataType == DML_TENSOR_DATA_TYPE_FLOAT16 && scaleDataType == DML_TENSOR_DATA_TYPE_FLOAT32) + { + inputCastDesc.InputTensor = &inputDesc; + inputCastDesc.OutputTensor = &inputCastOutputDmlTensorDesc; + inputCastOpDesc.Desc = &inputCastDesc; + } + else if (inputDataType == DML_TENSOR_DATA_TYPE_FLOAT32 && scaleDataType == DML_TENSOR_DATA_TYPE_FLOAT16) + { + scaleCastDesc.InputTensor = &scaleDesc; + scaleCastDesc.OutputTensor = &scaleCastOutputDmlTensorDesc; + scaleCastOpDesc.Desc = &scaleCastDesc; + + if (m_inputTensorDescs[2].GetDmlDataType() != DML_TENSOR_TYPE_INVALID) + { + biasCastDesc.InputTensor = &biasDesc; + biasCastDesc.OutputTensor = &biasCastOutputDmlTensorDesc; + biasCastOpDesc.Desc = &biasCastDesc; + } + } + + // Make sure that the output is the same type as the input + DML_CAST_OPERATOR_DESC outputCastDesc = {}; + DML_OPERATOR_DESC outputCastOpDesc = { DML_OPERATOR_CAST, nullptr }; + + auto realInputDataType = inputCastOpDesc.Desc ? inputCastOutputTensorDesc.GetDmlDataType() : m_inputTensorDescs[0].GetDmlDataType(); + TensorDesc outputCastOutputTensorDesc(realInputDataType, m_outputTensorDescs[0].GetSizes()); + DML_TENSOR_DESC outputCastOutputDmlTensorDesc = outputCastOutputTensorDesc.GetDmlDesc(); + + if (realInputDataType != m_outputTensorDescs[0].GetDmlDataType()) + { + // After the operator has been executed, we need to cast the "casted" output tensor to the original output tensor that TF expects + outputCastDesc.InputTensor = &outputCastOutputDmlTensorDesc; + outputCastDesc.OutputTensor = &outputDesc; + outputCastOpDesc.Desc = &outputCastDesc; + } DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC operatorDesc = {}; - operatorDesc.InputTensor = &inputDescs[0]; - operatorDesc.ScaleTensor = &inputDescs[1]; - operatorDesc.BiasTensor = inputDescs[2].Desc != nullptr ? &inputDescs[2] : nullptr; - operatorDesc.OutputTensor = outputDescs.data(); + operatorDesc.InputTensor = inputCastOpDesc.Desc ? &inputCastOutputDmlTensorDesc : &inputDesc; + operatorDesc.ScaleTensor = scaleCastOpDesc.Desc ? &scaleCastOutputDmlTensorDesc : &scaleDesc; + operatorDesc.BiasTensor = biasCastOpDesc.Desc ? &biasCastOutputDmlTensorDesc : (biasDesc.Desc ? &biasDesc : nullptr); + operatorDesc.OutputTensor = outputCastOpDesc.Desc ? &outputCastOutputDmlTensorDesc : &outputDesc; operatorDesc.Axes = onnxAxes.data(); operatorDesc.AxisCount = gsl::narrow_cast(onnxAxes.size()); operatorDesc.NormalizeVariance = true; operatorDesc.Epsilon = epsilon; operatorDesc.FusedActivation = nullptr; - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &operatorDesc }; - SetDmlOperatorDesc(opDesc, kernelCreationContext); + + // Construct the graph + std::vector opDescs; + opDescs.reserve(5); + + std::vector inputEdges; + inputEdges.reserve(3); + + std::vector intermediateEdges; + intermediateEdges.reserve(4); + + std::vector outputEdges; + outputEdges.reserve(1); + + opDescs.push_back(&opDesc); + uint32_t currentNodeIndex = 1; + + DML_INPUT_GRAPH_EDGE_DESC dataInputEdge = {}; + dataInputEdge.GraphInputIndex = 0; + dataInputEdge.ToNodeIndex = inputCastOpDesc.Desc ? currentNodeIndex : 0; + dataInputEdge.ToNodeInputIndex = 0; + inputEdges.push_back(std::move(dataInputEdge)); + + if (inputCastOpDesc.Desc) + { + opDescs.push_back(&inputCastOpDesc); + + // Link the cast op to the MVN op + DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {}; + intermediateEdge.FromNodeIndex = currentNodeIndex; + intermediateEdge.FromNodeOutputIndex = 0; + intermediateEdge.ToNodeIndex = 0; + intermediateEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(std::move(intermediateEdge)); + ++currentNodeIndex; + } + + DML_INPUT_GRAPH_EDGE_DESC scaleInputEdge = {}; + scaleInputEdge.GraphInputIndex = 1; + scaleInputEdge.ToNodeIndex = scaleCastOpDesc.Desc ? currentNodeIndex : 0; + scaleInputEdge.ToNodeInputIndex = scaleCastOpDesc.Desc ? 0 : 1; + inputEdges.push_back(std::move(scaleInputEdge)); + + if (scaleCastOpDesc.Desc) + { + opDescs.push_back(&scaleCastOpDesc); + + // Link the cast op to the MVN op + DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {}; + intermediateEdge.FromNodeIndex = currentNodeIndex; + intermediateEdge.FromNodeOutputIndex = 0; + intermediateEdge.ToNodeIndex = 0; + intermediateEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(std::move(intermediateEdge)); + ++currentNodeIndex; + } + + DML_INPUT_GRAPH_EDGE_DESC biasInputEdge = {}; + biasInputEdge.GraphInputIndex = 2; + biasInputEdge.ToNodeIndex = biasCastOpDesc.Desc ? currentNodeIndex : 0; + biasInputEdge.ToNodeInputIndex = biasCastOpDesc.Desc ? 0 : 2; + inputEdges.push_back(std::move(biasInputEdge)); + + if (biasCastOpDesc.Desc) + { + opDescs.push_back(&biasCastOpDesc); + + // Link the cast op to the MVN op + DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {}; + intermediateEdge.FromNodeIndex = currentNodeIndex; + intermediateEdge.FromNodeOutputIndex = 0; + intermediateEdge.ToNodeIndex = 0; + intermediateEdge.ToNodeInputIndex = 2; + intermediateEdges.push_back(std::move(intermediateEdge)); + ++currentNodeIndex; + } + + DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {}; + outputEdge.GraphOutputIndex = 0; + outputEdge.FromNodeIndex = outputCastOpDesc.Desc ? currentNodeIndex : 0; + outputEdge.FromNodeOutputIndex = 0; + outputEdges.push_back(std::move(outputEdge)); + + if (outputCastOpDesc.Desc) + { + opDescs.push_back(&outputCastOpDesc); + + // Link the MVN op to the cast op + DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {}; + intermediateEdge.FromNodeIndex = 0; + intermediateEdge.FromNodeOutputIndex = 0; + intermediateEdge.ToNodeIndex = currentNodeIndex; + intermediateEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(std::move(intermediateEdge)); + ++currentNodeIndex; + } + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; @@ -57,9 +234,9 @@ void CALLBACK QueryLayerNormalization(IMLOperatorSupportQueryContextPrivate* con *isSupported = false; // Mean and InvStdDev are not supported outputs. - // If only Scale tensor is present then fall back to CPU. This is temporary until + // If only Scale tensor is present then fall back to CPU. This is temporary until // DML1.9.2 or DML1.10 gets released. - if (context->GetInputCount() < 3 || context->GetOutputCount() > 1) + if (context->GetInputCount() < 3 || context->GetOutputCount() > 1) { return; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index d32ad0cb40..0c0d08628c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -282,6 +282,7 @@ constexpr static std::array typeNameListDefault = {"T"}; constexpr static std::array typeNameListAttention = {"T", "M"}; constexpr static std::array typeNameListTwo = { "T1", "T2" }; constexpr static std::array typeNameListLayerNorm = { "T", "U" }; +constexpr static std::array typeNameListLayerNormContrib = { "T", "V" }; constexpr static std::array typeNameListThree = { "T1", "T2", "T3" }; constexpr static std::array typeNameListFour = { "T1", "T2", "T3", "T4" }; constexpr static std::array typeNameListTopK = { "T", "I" }; @@ -339,6 +340,7 @@ constexpr static std::array supportedTypeListIntege constexpr static std::array supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 }; constexpr static std::array supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; constexpr static std::array supportedTypeListArgMinMax = {SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64}; +constexpr static std::array supportedTypeListLayerNormalizationContrib = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32}; constexpr static std::array supportedTypeListLayerNormalization = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float32}; constexpr static std::array supportedTypeListShape = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64}; constexpr static std::array supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64}; @@ -423,7 +425,6 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute. {REG_INFO( 14, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryBatchNormalization)}, // v14 adds training_mode attribute {REG_INFO( 15, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryBatchNormalization)}, // v15 adds differing types for scale and bias vs input. - {REG_INFO( 7, LayerNormalization, typeNameListLayerNorm, supportedTypeListLayerNormalization, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)}, {REG_INFO_VER( 17, LayerNormalization, typeNameListLayerNorm, supportedTypeListLayerNormalization, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)}, {REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 13, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, @@ -739,6 +740,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, {REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, {REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)}, + {REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)}, }; template diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index fb007b5e41..5caad57e9c 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -94,6 +94,57 @@ TEST(LayerNormTest, LayerNorm_Scale) { test.Run(); } +TEST(LayerNormTest, LayerNorm_Scale_Float16Input) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet"; + } + + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{2, 2, 2}; + test.AddInput("x", dims, ToFloat16({-10.264f, 8.6453f, 43.1561f, -0.641239f, -8.2164f, 0.11412f, 41.3156f, 3.0458f})); + test.AddInput("gamma", {2}, {-0.6953f, 5.1824f}); + test.AddOutput("output", dims, {0.6953f, 5.1824f, -0.6953f, -5.1824f, 0.6953f, 5.1824f, -0.6953f, -5.1824f}); + // TRT, DNNL and OpenVINO don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet"; + } + + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{2, 2, 2}; + test.AddInput("x", dims, {-10.264f, 8.6453f, 43.1561f, -0.641239f, -8.2164f, 0.11412f, 41.3156f, 3.0458f}); + test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f})); + test.AddOutput("output", dims, ToFloat16({0.6953f, 5.1824f, -0.6953f, -5.1824f, 0.6953f, 5.1824f, -0.6953f, -5.1824f})); + // TRT, DNNL and OpenVINO don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet"; + } + + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{2, 2, 2}; + test.AddInput("x", dims, ToFloat16({-10.264f, 8.6453f, 43.1561f, -0.641239f, -8.2164f, 0.11412f, 41.3156f, 3.0458f})); + test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f})); + test.AddOutput("output", dims, ToFloat16({0.6953f, 5.1824f, -0.6953f, -5.1824f, 0.6953f, 5.1824f, -0.6953f, -5.1824f})); + // TRT, DNNL and OpenVINO don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider}); +} + TEST(LayerNormTest, LayerNorm_Scale_Bias) { OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); @@ -106,6 +157,45 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias) { test.Run(); } +TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16Input) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 3, 2}; + test.AddInput("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f})); + test.AddInput("gamma", {2}, {-0.6953f, 5.1824f}); + test.AddInput("bias", {2}, {0.6435f, -0.3964f}); + test.AddOutput("output", dims, {-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f}); + // TRT, DNNL and OpenVINO don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16ScaleBiasOutput) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 3, 2}; + test.AddInput("x", dims, {1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f}); + test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f})); + test.AddInput("bias", {2}, ToFloat16({0.6435f, -0.3964f})); + test.AddOutput("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f})); + // TRT, DNNL and OpenVINO don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 3, 2}; + test.AddInput("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f})); + test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f})); + test.AddInput("bias", {2}, ToFloat16({0.6435f, -0.3964f})); + test.AddOutput("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f})); + // TRT, DNNL and OpenVINO don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider}); +} + // LayerNormalization became an ONNX operator in opset 17. It uses the same implementation so this is a sanity check. TEST(LayerNormTest, LayerNorm17_float) { OpTester test("LayerNormalization", 17); @@ -151,7 +241,7 @@ TEST(LayerNormTest, LayerNorm_InvalidScaleBias) { #if defined(USE_DNNL) TEST(LayerNormTest, LayerNorm17_Scale_Bias_bfloat16) { #ifdef USE_DNNL - if (!DnnlHasBF16Support()) { + if (!DnnlHasBF16Support()) { LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16"; return; }