diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 01c0902022..8cedb7ecc0 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -931,7 +931,7 @@ Do not modify directly.*
|LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(float), tensor(float16)|
|||7+|**T** = tensor(float), tensor(float16)|
|LayerNormalization|*in* X:**T**
*in* Scale:**T**
*in* B:**T**
*out* Y:**T**
*out* Mean:**U**
*out* InvStdDev:**U**
or
*in* X:**T**
*in* Scale:**V**
*in* B:**V**
*out* Y:**V**
*out* Mean:**U**
*out* InvStdDev:**U**|17+|**T** = tensor(float), tensor(float16)
**U** = tensor(float)|
-|||1+|**T** = tensor(float), tensor(float16)
**U** = tensor(float)|
+|||1+|**T** = tensor(float), tensor(float16)
**V** = tensor(float), tensor(float16)|
|LeakyRelu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)|
|Less|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h
index d398233e4c..493cc2e445 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h
@@ -64,8 +64,8 @@ namespace Dml
);
// This method only works with DML_GRAPH.
- // To make it work without DML_GRAPH, we need to add new functionality
- // in DMLX i.e. DMLX should also give access to DML_OPERATOR_DESC
+ // To make it work without DML_GRAPH, we need to add new functionality
+ // in DMLX i.e. DMLX should also give access to DML_OPERATOR_DESC
// rather than IDMLOperator.
void SetDmlOperatorGraphDesc(
const MLOperatorGraphDesc&& operatorGraphDesc,
@@ -103,7 +103,6 @@ namespace Dml
// It returns nullptr if there is no work to do (0 bytes).
//
ComPtr InitializeZeroInt64Tensor(uint64_t tensorSizeInBytes);
-
void ExecuteZeroInt64Tensor(IDMLCompiledOperator* compiledOperator, IMLOperatorTensor* tensor);
TensorDesc CreateTensorDescFromInput(
@@ -140,7 +139,7 @@ namespace Dml
_Inout_ std::vector& dmlInputEdges,
_Inout_ std::vector& dmlOutputEdges,
_Inout_ std::vector& dmlIntermediateEdges);
-
+
};
} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp
index 8858529056..6477e0c36b 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp
@@ -18,7 +18,7 @@ public:
// because DML MVN1 has a validation which requires all 3 needs to have same dimension count
// due to historical artifact.
DmlOperator::Initialize(
- kernelCreationContext,
+ kernelCreationContext,
kernelInputIndices,
std::nullopt,
std::nullopt,
@@ -33,22 +33,199 @@ public:
std::vector onnxAxes(inputDimCount - onnxAxis);
std::iota(onnxAxes.begin(), onnxAxes.end(), onnxAxis);
- std::vector inputDescs = GetDmlInputDescs();
- std::vector outputDescs = GetDmlOutputDescs();
+ assert(m_inputTensorDescs.size() == 3);
+ assert(m_outputTensorDescs.size() == 1);
+
+ auto inputDataType = m_inputTensorDescs[0].GetDmlDataType();
+ ORT_THROW_HR_IF(E_INVALIDARG, inputDataType != DML_TENSOR_DATA_TYPE_FLOAT16 && inputDataType != DML_TENSOR_DATA_TYPE_FLOAT32);
+
+ auto scaleDataType = m_inputTensorDescs[1].GetDmlDataType();
+ ORT_THROW_HR_IF(E_INVALIDARG, scaleDataType != DML_TENSOR_DATA_TYPE_FLOAT16 && scaleDataType != DML_TENSOR_DATA_TYPE_FLOAT32);
+
+ // Scale, Bias and Output always have the same data type
+ ORT_THROW_HR_IF(E_INVALIDARG, m_inputTensorDescs[2].GetDmlDataType() != DML_TENSOR_TYPE_INVALID && m_inputTensorDescs[2].GetDmlDataType() != scaleDataType);
+ ORT_THROW_HR_IF(E_INVALIDARG, m_outputTensorDescs[0].GetDmlDataType() != scaleDataType);
+
+ auto inputDesc = m_inputTensorDescs[0].GetDmlDesc();
+ auto scaleDesc = m_inputTensorDescs[1].GetDmlDesc();
+ auto biasDesc = m_inputTensorDescs[2].GetDmlDesc();
+ auto outputDesc = m_outputTensorDescs[0].GetDmlDesc();
+
+ DML_CAST_OPERATOR_DESC inputCastDesc = {};
+ DML_OPERATOR_DESC inputCastOpDesc = { DML_OPERATOR_CAST, nullptr };
+
+ DML_CAST_OPERATOR_DESC scaleCastDesc = {};
+ DML_OPERATOR_DESC scaleCastOpDesc = { DML_OPERATOR_CAST, nullptr };
+
+ DML_CAST_OPERATOR_DESC biasCastDesc = {};
+ DML_OPERATOR_DESC biasCastOpDesc = { DML_OPERATOR_CAST, nullptr };
+
+ // When data types mismatch, we cast to the highest precision to respect DML's requirement that all datatypes must match
+ TensorDesc inputCastOutputTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[0].GetSizes());
+ DML_TENSOR_DESC inputCastOutputDmlTensorDesc = inputCastOutputTensorDesc.GetDmlDesc();
+
+ TensorDesc scaleCastOutputTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[1].GetSizes());
+ DML_TENSOR_DESC scaleCastOutputDmlTensorDesc = scaleCastOutputTensorDesc.GetDmlDesc();
+
+ TensorDesc biasCastOutputTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[2].GetSizes());
+ DML_TENSOR_DESC biasCastOutputDmlTensorDesc = biasCastOutputTensorDesc.GetDmlDesc();
+
+ // Cast all tensors to the highest common precision
+ if (inputDataType == DML_TENSOR_DATA_TYPE_FLOAT16 && scaleDataType == DML_TENSOR_DATA_TYPE_FLOAT32)
+ {
+ inputCastDesc.InputTensor = &inputDesc;
+ inputCastDesc.OutputTensor = &inputCastOutputDmlTensorDesc;
+ inputCastOpDesc.Desc = &inputCastDesc;
+ }
+ else if (inputDataType == DML_TENSOR_DATA_TYPE_FLOAT32 && scaleDataType == DML_TENSOR_DATA_TYPE_FLOAT16)
+ {
+ scaleCastDesc.InputTensor = &scaleDesc;
+ scaleCastDesc.OutputTensor = &scaleCastOutputDmlTensorDesc;
+ scaleCastOpDesc.Desc = &scaleCastDesc;
+
+ if (m_inputTensorDescs[2].GetDmlDataType() != DML_TENSOR_TYPE_INVALID)
+ {
+ biasCastDesc.InputTensor = &biasDesc;
+ biasCastDesc.OutputTensor = &biasCastOutputDmlTensorDesc;
+ biasCastOpDesc.Desc = &biasCastDesc;
+ }
+ }
+
+ // Make sure that the output is the same type as the input
+ DML_CAST_OPERATOR_DESC outputCastDesc = {};
+ DML_OPERATOR_DESC outputCastOpDesc = { DML_OPERATOR_CAST, nullptr };
+
+ auto realInputDataType = inputCastOpDesc.Desc ? inputCastOutputTensorDesc.GetDmlDataType() : m_inputTensorDescs[0].GetDmlDataType();
+ TensorDesc outputCastOutputTensorDesc(realInputDataType, m_outputTensorDescs[0].GetSizes());
+ DML_TENSOR_DESC outputCastOutputDmlTensorDesc = outputCastOutputTensorDesc.GetDmlDesc();
+
+ if (realInputDataType != m_outputTensorDescs[0].GetDmlDataType())
+ {
+ // After the operator has been executed, we need to cast the "casted" output tensor to the original output tensor that TF expects
+ outputCastDesc.InputTensor = &outputCastOutputDmlTensorDesc;
+ outputCastDesc.OutputTensor = &outputDesc;
+ outputCastOpDesc.Desc = &outputCastDesc;
+ }
DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC operatorDesc = {};
- operatorDesc.InputTensor = &inputDescs[0];
- operatorDesc.ScaleTensor = &inputDescs[1];
- operatorDesc.BiasTensor = inputDescs[2].Desc != nullptr ? &inputDescs[2] : nullptr;
- operatorDesc.OutputTensor = outputDescs.data();
+ operatorDesc.InputTensor = inputCastOpDesc.Desc ? &inputCastOutputDmlTensorDesc : &inputDesc;
+ operatorDesc.ScaleTensor = scaleCastOpDesc.Desc ? &scaleCastOutputDmlTensorDesc : &scaleDesc;
+ operatorDesc.BiasTensor = biasCastOpDesc.Desc ? &biasCastOutputDmlTensorDesc : (biasDesc.Desc ? &biasDesc : nullptr);
+ operatorDesc.OutputTensor = outputCastOpDesc.Desc ? &outputCastOutputDmlTensorDesc : &outputDesc;
operatorDesc.Axes = onnxAxes.data();
operatorDesc.AxisCount = gsl::narrow_cast(onnxAxes.size());
operatorDesc.NormalizeVariance = true;
operatorDesc.Epsilon = epsilon;
operatorDesc.FusedActivation = nullptr;
-
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &operatorDesc };
- SetDmlOperatorDesc(opDesc, kernelCreationContext);
+
+ // Construct the graph
+ std::vector opDescs;
+ opDescs.reserve(5);
+
+ std::vector inputEdges;
+ inputEdges.reserve(3);
+
+ std::vector intermediateEdges;
+ intermediateEdges.reserve(4);
+
+ std::vector outputEdges;
+ outputEdges.reserve(1);
+
+ opDescs.push_back(&opDesc);
+ uint32_t currentNodeIndex = 1;
+
+ DML_INPUT_GRAPH_EDGE_DESC dataInputEdge = {};
+ dataInputEdge.GraphInputIndex = 0;
+ dataInputEdge.ToNodeIndex = inputCastOpDesc.Desc ? currentNodeIndex : 0;
+ dataInputEdge.ToNodeInputIndex = 0;
+ inputEdges.push_back(std::move(dataInputEdge));
+
+ if (inputCastOpDesc.Desc)
+ {
+ opDescs.push_back(&inputCastOpDesc);
+
+ // Link the cast op to the MVN op
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
+ intermediateEdge.FromNodeIndex = currentNodeIndex;
+ intermediateEdge.FromNodeOutputIndex = 0;
+ intermediateEdge.ToNodeIndex = 0;
+ intermediateEdge.ToNodeInputIndex = 0;
+ intermediateEdges.push_back(std::move(intermediateEdge));
+ ++currentNodeIndex;
+ }
+
+ DML_INPUT_GRAPH_EDGE_DESC scaleInputEdge = {};
+ scaleInputEdge.GraphInputIndex = 1;
+ scaleInputEdge.ToNodeIndex = scaleCastOpDesc.Desc ? currentNodeIndex : 0;
+ scaleInputEdge.ToNodeInputIndex = scaleCastOpDesc.Desc ? 0 : 1;
+ inputEdges.push_back(std::move(scaleInputEdge));
+
+ if (scaleCastOpDesc.Desc)
+ {
+ opDescs.push_back(&scaleCastOpDesc);
+
+ // Link the cast op to the MVN op
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
+ intermediateEdge.FromNodeIndex = currentNodeIndex;
+ intermediateEdge.FromNodeOutputIndex = 0;
+ intermediateEdge.ToNodeIndex = 0;
+ intermediateEdge.ToNodeInputIndex = 1;
+ intermediateEdges.push_back(std::move(intermediateEdge));
+ ++currentNodeIndex;
+ }
+
+ DML_INPUT_GRAPH_EDGE_DESC biasInputEdge = {};
+ biasInputEdge.GraphInputIndex = 2;
+ biasInputEdge.ToNodeIndex = biasCastOpDesc.Desc ? currentNodeIndex : 0;
+ biasInputEdge.ToNodeInputIndex = biasCastOpDesc.Desc ? 0 : 2;
+ inputEdges.push_back(std::move(biasInputEdge));
+
+ if (biasCastOpDesc.Desc)
+ {
+ opDescs.push_back(&biasCastOpDesc);
+
+ // Link the cast op to the MVN op
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
+ intermediateEdge.FromNodeIndex = currentNodeIndex;
+ intermediateEdge.FromNodeOutputIndex = 0;
+ intermediateEdge.ToNodeIndex = 0;
+ intermediateEdge.ToNodeInputIndex = 2;
+ intermediateEdges.push_back(std::move(intermediateEdge));
+ ++currentNodeIndex;
+ }
+
+ DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
+ outputEdge.GraphOutputIndex = 0;
+ outputEdge.FromNodeIndex = outputCastOpDesc.Desc ? currentNodeIndex : 0;
+ outputEdge.FromNodeOutputIndex = 0;
+ outputEdges.push_back(std::move(outputEdge));
+
+ if (outputCastOpDesc.Desc)
+ {
+ opDescs.push_back(&outputCastOpDesc);
+
+ // Link the MVN op to the cast op
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
+ intermediateEdge.FromNodeIndex = 0;
+ intermediateEdge.FromNodeOutputIndex = 0;
+ intermediateEdge.ToNodeIndex = currentNodeIndex;
+ intermediateEdge.ToNodeInputIndex = 0;
+ intermediateEdges.push_back(std::move(intermediateEdge));
+ ++currentNodeIndex;
+ }
+
+ MLOperatorGraphDesc operatorGraphDesc = {};
+ operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size());
+ operatorGraphDesc.inputEdges = inputEdges.data();
+ operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size());
+ operatorGraphDesc.intermediateEdges = intermediateEdges.data();
+ operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size());
+ operatorGraphDesc.outputEdges = outputEdges.data();
+ operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size());
+ operatorGraphDesc.nodesAsOpDesc = opDescs.data();
+
+ SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
}
};
@@ -57,9 +234,9 @@ void CALLBACK QueryLayerNormalization(IMLOperatorSupportQueryContextPrivate* con
*isSupported = false;
// Mean and InvStdDev are not supported outputs.
- // If only Scale tensor is present then fall back to CPU. This is temporary until
+ // If only Scale tensor is present then fall back to CPU. This is temporary until
// DML1.9.2 or DML1.10 gets released.
- if (context->GetInputCount() < 3 || context->GetOutputCount() > 1)
+ if (context->GetInputCount() < 3 || context->GetOutputCount() > 1)
{
return;
}
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index d32ad0cb40..0c0d08628c 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -282,6 +282,7 @@ constexpr static std::array typeNameListDefault = {"T"};
constexpr static std::array typeNameListAttention = {"T", "M"};
constexpr static std::array typeNameListTwo = { "T1", "T2" };
constexpr static std::array typeNameListLayerNorm = { "T", "U" };
+constexpr static std::array typeNameListLayerNormContrib = { "T", "V" };
constexpr static std::array typeNameListThree = { "T1", "T2", "T3" };
constexpr static std::array typeNameListFour = { "T1", "T2", "T3", "T4" };
constexpr static std::array typeNameListTopK = { "T", "I" };
@@ -339,6 +340,7 @@ constexpr static std::array supportedTypeListIntege
constexpr static std::array supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 };
constexpr static std::array supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
constexpr static std::array supportedTypeListArgMinMax = {SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64};
+constexpr static std::array supportedTypeListLayerNormalizationContrib = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32};
constexpr static std::array supportedTypeListLayerNormalization = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float32};
constexpr static std::array supportedTypeListShape = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
constexpr static std::array supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
@@ -423,7 +425,6 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute.
{REG_INFO( 14, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryBatchNormalization)}, // v14 adds training_mode attribute
{REG_INFO( 15, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryBatchNormalization)}, // v15 adds differing types for scale and bias vs input.
- {REG_INFO( 7, LayerNormalization, typeNameListLayerNorm, supportedTypeListLayerNormalization, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
{REG_INFO_VER( 17, LayerNormalization, typeNameListLayerNorm, supportedTypeListLayerNormalization, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
{REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
@@ -739,6 +740,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
{REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
+ {REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
};
template
diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc
index fb007b5e41..5caad57e9c 100644
--- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc
+++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc
@@ -94,6 +94,57 @@ TEST(LayerNormTest, LayerNorm_Scale) {
test.Run();
}
+TEST(LayerNormTest, LayerNorm_Scale_Float16Input) {
+ // TODO: Unskip when fixed #41968513
+ if (DefaultDmlExecutionProvider().get() != nullptr) {
+ GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet";
+ }
+
+ OpTester test("LayerNormalization");
+ test.AddAttribute("epsilon", 1e-05f);
+
+ std::vector dims{2, 2, 2};
+ test.AddInput("x", dims, ToFloat16({-10.264f, 8.6453f, 43.1561f, -0.641239f, -8.2164f, 0.11412f, 41.3156f, 3.0458f}));
+ test.AddInput("gamma", {2}, {-0.6953f, 5.1824f});
+ test.AddOutput("output", dims, {0.6953f, 5.1824f, -0.6953f, -5.1824f, 0.6953f, 5.1824f, -0.6953f, -5.1824f});
+ // TRT, DNNL and OpenVINO don't support this combination of datatypes
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) {
+ // TODO: Unskip when fixed #41968513
+ if (DefaultDmlExecutionProvider().get() != nullptr) {
+ GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet";
+ }
+
+ OpTester test("LayerNormalization");
+ test.AddAttribute("epsilon", 1e-05f);
+
+ std::vector dims{2, 2, 2};
+ test.AddInput("x", dims, {-10.264f, 8.6453f, 43.1561f, -0.641239f, -8.2164f, 0.11412f, 41.3156f, 3.0458f});
+ test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}));
+ test.AddOutput("output", dims, ToFloat16({0.6953f, 5.1824f, -0.6953f, -5.1824f, 0.6953f, 5.1824f, -0.6953f, -5.1824f}));
+ // TRT, DNNL and OpenVINO don't support this combination of datatypes
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput) {
+ // TODO: Unskip when fixed #41968513
+ if (DefaultDmlExecutionProvider().get() != nullptr) {
+ GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet";
+ }
+
+ OpTester test("LayerNormalization");
+ test.AddAttribute("epsilon", 1e-05f);
+
+ std::vector dims{2, 2, 2};
+ test.AddInput("x", dims, ToFloat16({-10.264f, 8.6453f, 43.1561f, -0.641239f, -8.2164f, 0.11412f, 41.3156f, 3.0458f}));
+ test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}));
+ test.AddOutput("output", dims, ToFloat16({0.6953f, 5.1824f, -0.6953f, -5.1824f, 0.6953f, 5.1824f, -0.6953f, -5.1824f}));
+ // TRT, DNNL and OpenVINO don't support this combination of datatypes
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider});
+}
+
TEST(LayerNormTest, LayerNorm_Scale_Bias) {
OpTester test("LayerNormalization");
test.AddAttribute("epsilon", 1e-05f);
@@ -106,6 +157,45 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias) {
test.Run();
}
+TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16Input) {
+ OpTester test("LayerNormalization");
+ test.AddAttribute("epsilon", 1e-05f);
+
+ std::vector dims{1, 3, 2};
+ test.AddInput("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f}));
+ test.AddInput("gamma", {2}, {-0.6953f, 5.1824f});
+ test.AddInput("bias", {2}, {0.6435f, -0.3964f});
+ test.AddOutput("output", dims, {-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f});
+ // TRT, DNNL and OpenVINO don't support this combination of datatypes
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16ScaleBiasOutput) {
+ OpTester test("LayerNormalization");
+ test.AddAttribute("epsilon", 1e-05f);
+
+ std::vector dims{1, 3, 2};
+ test.AddInput("x", dims, {1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f});
+ test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}));
+ test.AddInput("bias", {2}, ToFloat16({0.6435f, -0.3964f}));
+ test.AddOutput("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f}));
+ // TRT, DNNL and OpenVINO don't support this combination of datatypes
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput) {
+ OpTester test("LayerNormalization");
+ test.AddAttribute("epsilon", 1e-05f);
+
+ std::vector dims{1, 3, 2};
+ test.AddInput("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f}));
+ test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}));
+ test.AddInput("bias", {2}, ToFloat16({0.6435f, -0.3964f}));
+ test.AddOutput("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f}));
+ // TRT, DNNL and OpenVINO don't support this combination of datatypes
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider});
+}
+
// LayerNormalization became an ONNX operator in opset 17. It uses the same implementation so this is a sanity check.
TEST(LayerNormTest, LayerNorm17_float) {
OpTester test("LayerNormalization", 17);
@@ -151,7 +241,7 @@ TEST(LayerNormTest, LayerNorm_InvalidScaleBias) {
#if defined(USE_DNNL)
TEST(LayerNormTest, LayerNorm17_Scale_Bias_bfloat16) {
#ifdef USE_DNNL
- if (!DnnlHasBF16Support()) {
+ if (!DnnlHasBF16Support()) {
LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16";
return;
}