[DML EP] Add mixed datatype support for DML's LayerNorm contrib op (#13734)

### Description
Add mixed datatype support for DML's LayerNorm contrib op.



### Motivation and Context
The fusion logic removes casts around LayerNorm in the graph because the
contrib version of the op supports mixed datatypes. Scale, Bias and
Output's datatypes must match, but input's datatype can be different.
This commit is contained in:
Patrice Vignola 2022-12-01 14:08:18 -08:00 committed by GitHub
parent 82d123b6c9
commit a0b470bc35
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 286 additions and 18 deletions

View file

@ -931,7 +931,7 @@ Do not modify directly.*
|LSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|14+|**T** = tensor(float), tensor(float16)|
|||7+|**T** = tensor(float), tensor(float16)|
|LayerNormalization|*in* X:**T**<br> *in* Scale:**T**<br> *in* B:**T**<br> *out* Y:**T**<br> *out* Mean:**U**<br> *out* InvStdDev:**U**<br><br>or<br><br>*in* X:**T**<br> *in* Scale:**V**<br> *in* B:**V**<br> *out* Y:**V**<br> *out* Mean:**U**<br> *out* InvStdDev:**U**|17+|**T** = tensor(float), tensor(float16)<br/> **U** = tensor(float)|
|||1+|**T** = tensor(float), tensor(float16)<br/> **U** = tensor(float)|
|||1+|**T** = tensor(float), tensor(float16)<br/> **V** = tensor(float), tensor(float16)|
|LeakyRelu|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(float), tensor(float16)|
|Less|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|

View file

@ -64,8 +64,8 @@ namespace Dml
);
// This method only works with DML_GRAPH.
// To make it work without DML_GRAPH, we need to add new functionality
// in DMLX i.e. DMLX should also give access to DML_OPERATOR_DESC
// To make it work without DML_GRAPH, we need to add new functionality
// in DMLX i.e. DMLX should also give access to DML_OPERATOR_DESC
// rather than IDMLOperator.
void SetDmlOperatorGraphDesc(
const MLOperatorGraphDesc&& operatorGraphDesc,
@ -103,7 +103,6 @@ namespace Dml
// It returns nullptr if there is no work to do (0 bytes).
//
ComPtr<IDMLCompiledOperator> InitializeZeroInt64Tensor(uint64_t tensorSizeInBytes);
void ExecuteZeroInt64Tensor(IDMLCompiledOperator* compiledOperator, IMLOperatorTensor* tensor);
TensorDesc CreateTensorDescFromInput(
@ -140,7 +139,7 @@ namespace Dml
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlInputEdges,
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlIntermediateEdges);
};
} // namespace Dml

View file

@ -18,7 +18,7 @@ public:
// because DML MVN1 has a validation which requires all 3 needs to have same dimension count
// due to historical artifact.
DmlOperator::Initialize(
kernelCreationContext,
kernelCreationContext,
kernelInputIndices,
std::nullopt,
std::nullopt,
@ -33,22 +33,199 @@ public:
std::vector<uint32_t> onnxAxes(inputDimCount - onnxAxis);
std::iota(onnxAxes.begin(), onnxAxes.end(), onnxAxis);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
assert(m_inputTensorDescs.size() == 3);
assert(m_outputTensorDescs.size() == 1);
auto inputDataType = m_inputTensorDescs[0].GetDmlDataType();
ORT_THROW_HR_IF(E_INVALIDARG, inputDataType != DML_TENSOR_DATA_TYPE_FLOAT16 && inputDataType != DML_TENSOR_DATA_TYPE_FLOAT32);
auto scaleDataType = m_inputTensorDescs[1].GetDmlDataType();
ORT_THROW_HR_IF(E_INVALIDARG, scaleDataType != DML_TENSOR_DATA_TYPE_FLOAT16 && scaleDataType != DML_TENSOR_DATA_TYPE_FLOAT32);
// Scale, Bias and Output always have the same data type
ORT_THROW_HR_IF(E_INVALIDARG, m_inputTensorDescs[2].GetDmlDataType() != DML_TENSOR_TYPE_INVALID && m_inputTensorDescs[2].GetDmlDataType() != scaleDataType);
ORT_THROW_HR_IF(E_INVALIDARG, m_outputTensorDescs[0].GetDmlDataType() != scaleDataType);
auto inputDesc = m_inputTensorDescs[0].GetDmlDesc();
auto scaleDesc = m_inputTensorDescs[1].GetDmlDesc();
auto biasDesc = m_inputTensorDescs[2].GetDmlDesc();
auto outputDesc = m_outputTensorDescs[0].GetDmlDesc();
DML_CAST_OPERATOR_DESC inputCastDesc = {};
DML_OPERATOR_DESC inputCastOpDesc = { DML_OPERATOR_CAST, nullptr };
DML_CAST_OPERATOR_DESC scaleCastDesc = {};
DML_OPERATOR_DESC scaleCastOpDesc = { DML_OPERATOR_CAST, nullptr };
DML_CAST_OPERATOR_DESC biasCastDesc = {};
DML_OPERATOR_DESC biasCastOpDesc = { DML_OPERATOR_CAST, nullptr };
// When data types mismatch, we cast to the highest precision to respect DML's requirement that all datatypes must match
TensorDesc inputCastOutputTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[0].GetSizes());
DML_TENSOR_DESC inputCastOutputDmlTensorDesc = inputCastOutputTensorDesc.GetDmlDesc();
TensorDesc scaleCastOutputTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[1].GetSizes());
DML_TENSOR_DESC scaleCastOutputDmlTensorDesc = scaleCastOutputTensorDesc.GetDmlDesc();
TensorDesc biasCastOutputTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[2].GetSizes());
DML_TENSOR_DESC biasCastOutputDmlTensorDesc = biasCastOutputTensorDesc.GetDmlDesc();
// Cast all tensors to the highest common precision
if (inputDataType == DML_TENSOR_DATA_TYPE_FLOAT16 && scaleDataType == DML_TENSOR_DATA_TYPE_FLOAT32)
{
inputCastDesc.InputTensor = &inputDesc;
inputCastDesc.OutputTensor = &inputCastOutputDmlTensorDesc;
inputCastOpDesc.Desc = &inputCastDesc;
}
else if (inputDataType == DML_TENSOR_DATA_TYPE_FLOAT32 && scaleDataType == DML_TENSOR_DATA_TYPE_FLOAT16)
{
scaleCastDesc.InputTensor = &scaleDesc;
scaleCastDesc.OutputTensor = &scaleCastOutputDmlTensorDesc;
scaleCastOpDesc.Desc = &scaleCastDesc;
if (m_inputTensorDescs[2].GetDmlDataType() != DML_TENSOR_TYPE_INVALID)
{
biasCastDesc.InputTensor = &biasDesc;
biasCastDesc.OutputTensor = &biasCastOutputDmlTensorDesc;
biasCastOpDesc.Desc = &biasCastDesc;
}
}
// Make sure that the output is the same type as the input
DML_CAST_OPERATOR_DESC outputCastDesc = {};
DML_OPERATOR_DESC outputCastOpDesc = { DML_OPERATOR_CAST, nullptr };
auto realInputDataType = inputCastOpDesc.Desc ? inputCastOutputTensorDesc.GetDmlDataType() : m_inputTensorDescs[0].GetDmlDataType();
TensorDesc outputCastOutputTensorDesc(realInputDataType, m_outputTensorDescs[0].GetSizes());
DML_TENSOR_DESC outputCastOutputDmlTensorDesc = outputCastOutputTensorDesc.GetDmlDesc();
if (realInputDataType != m_outputTensorDescs[0].GetDmlDataType())
{
// After the operator has been executed, we need to cast the "casted" output tensor to the original output tensor that TF expects
outputCastDesc.InputTensor = &outputCastOutputDmlTensorDesc;
outputCastDesc.OutputTensor = &outputDesc;
outputCastOpDesc.Desc = &outputCastDesc;
}
DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.ScaleTensor = &inputDescs[1];
operatorDesc.BiasTensor = inputDescs[2].Desc != nullptr ? &inputDescs[2] : nullptr;
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.InputTensor = inputCastOpDesc.Desc ? &inputCastOutputDmlTensorDesc : &inputDesc;
operatorDesc.ScaleTensor = scaleCastOpDesc.Desc ? &scaleCastOutputDmlTensorDesc : &scaleDesc;
operatorDesc.BiasTensor = biasCastOpDesc.Desc ? &biasCastOutputDmlTensorDesc : (biasDesc.Desc ? &biasDesc : nullptr);
operatorDesc.OutputTensor = outputCastOpDesc.Desc ? &outputCastOutputDmlTensorDesc : &outputDesc;
operatorDesc.Axes = onnxAxes.data();
operatorDesc.AxisCount = gsl::narrow_cast<uint32_t>(onnxAxes.size());
operatorDesc.NormalizeVariance = true;
operatorDesc.Epsilon = epsilon;
operatorDesc.FusedActivation = nullptr;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
// Construct the graph
std::vector<const DML_OPERATOR_DESC*> opDescs;
opDescs.reserve(5);
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
inputEdges.reserve(3);
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
intermediateEdges.reserve(4);
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
outputEdges.reserve(1);
opDescs.push_back(&opDesc);
uint32_t currentNodeIndex = 1;
DML_INPUT_GRAPH_EDGE_DESC dataInputEdge = {};
dataInputEdge.GraphInputIndex = 0;
dataInputEdge.ToNodeIndex = inputCastOpDesc.Desc ? currentNodeIndex : 0;
dataInputEdge.ToNodeInputIndex = 0;
inputEdges.push_back(std::move(dataInputEdge));
if (inputCastOpDesc.Desc)
{
opDescs.push_back(&inputCastOpDesc);
// Link the cast op to the MVN op
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
intermediateEdge.FromNodeIndex = currentNodeIndex;
intermediateEdge.FromNodeOutputIndex = 0;
intermediateEdge.ToNodeIndex = 0;
intermediateEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(std::move(intermediateEdge));
++currentNodeIndex;
}
DML_INPUT_GRAPH_EDGE_DESC scaleInputEdge = {};
scaleInputEdge.GraphInputIndex = 1;
scaleInputEdge.ToNodeIndex = scaleCastOpDesc.Desc ? currentNodeIndex : 0;
scaleInputEdge.ToNodeInputIndex = scaleCastOpDesc.Desc ? 0 : 1;
inputEdges.push_back(std::move(scaleInputEdge));
if (scaleCastOpDesc.Desc)
{
opDescs.push_back(&scaleCastOpDesc);
// Link the cast op to the MVN op
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
intermediateEdge.FromNodeIndex = currentNodeIndex;
intermediateEdge.FromNodeOutputIndex = 0;
intermediateEdge.ToNodeIndex = 0;
intermediateEdge.ToNodeInputIndex = 1;
intermediateEdges.push_back(std::move(intermediateEdge));
++currentNodeIndex;
}
DML_INPUT_GRAPH_EDGE_DESC biasInputEdge = {};
biasInputEdge.GraphInputIndex = 2;
biasInputEdge.ToNodeIndex = biasCastOpDesc.Desc ? currentNodeIndex : 0;
biasInputEdge.ToNodeInputIndex = biasCastOpDesc.Desc ? 0 : 2;
inputEdges.push_back(std::move(biasInputEdge));
if (biasCastOpDesc.Desc)
{
opDescs.push_back(&biasCastOpDesc);
// Link the cast op to the MVN op
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
intermediateEdge.FromNodeIndex = currentNodeIndex;
intermediateEdge.FromNodeOutputIndex = 0;
intermediateEdge.ToNodeIndex = 0;
intermediateEdge.ToNodeInputIndex = 2;
intermediateEdges.push_back(std::move(intermediateEdge));
++currentNodeIndex;
}
DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
outputEdge.GraphOutputIndex = 0;
outputEdge.FromNodeIndex = outputCastOpDesc.Desc ? currentNodeIndex : 0;
outputEdge.FromNodeOutputIndex = 0;
outputEdges.push_back(std::move(outputEdge));
if (outputCastOpDesc.Desc)
{
opDescs.push_back(&outputCastOpDesc);
// Link the MVN op to the cast op
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
intermediateEdge.FromNodeIndex = 0;
intermediateEdge.FromNodeOutputIndex = 0;
intermediateEdge.ToNodeIndex = currentNodeIndex;
intermediateEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(std::move(intermediateEdge));
++currentNodeIndex;
}
MLOperatorGraphDesc operatorGraphDesc = {};
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
operatorGraphDesc.inputEdges = inputEdges.data();
operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
operatorGraphDesc.intermediateEdges = intermediateEdges.data();
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
operatorGraphDesc.outputEdges = outputEdges.data();
operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
operatorGraphDesc.nodesAsOpDesc = opDescs.data();
SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
}
};
@ -57,9 +234,9 @@ void CALLBACK QueryLayerNormalization(IMLOperatorSupportQueryContextPrivate* con
*isSupported = false;
// Mean and InvStdDev are not supported outputs.
// If only Scale tensor is present then fall back to CPU. This is temporary until
// If only Scale tensor is present then fall back to CPU. This is temporary until
// DML1.9.2 or DML1.10 gets released.
if (context->GetInputCount() < 3 || context->GetOutputCount() > 1)
if (context->GetInputCount() < 3 || context->GetOutputCount() > 1)
{
return;
}

View file

@ -282,6 +282,7 @@ constexpr static std::array<const char*, 1> typeNameListDefault = {"T"};
constexpr static std::array<const char*, 2> typeNameListAttention = {"T", "M"};
constexpr static std::array<const char*, 2> typeNameListTwo = { "T1", "T2" };
constexpr static std::array<const char*, 2> typeNameListLayerNorm = { "T", "U" };
constexpr static std::array<const char*, 2> typeNameListLayerNormContrib = { "T", "V" };
constexpr static std::array<const char*, 3> typeNameListThree = { "T1", "T2", "T3" };
constexpr static std::array<const char*, 4> typeNameListFour = { "T1", "T2", "T3", "T4" };
constexpr static std::array<const char*, 2> typeNameListTopK = { "T", "I" };
@ -339,6 +340,7 @@ constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListIntege
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListArgMinMax = {SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLayerNormalizationContrib = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLayerNormalization = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float32};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListShape = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
@ -423,7 +425,6 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute.
{REG_INFO( 14, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryBatchNormalization)}, // v14 adds training_mode attribute
{REG_INFO( 15, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryBatchNormalization)}, // v15 adds differing types for scale and bias vs input.
{REG_INFO( 7, LayerNormalization, typeNameListLayerNorm, supportedTypeListLayerNormalization, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
{REG_INFO_VER( 17, LayerNormalization, typeNameListLayerNorm, supportedTypeListLayerNormalization, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
{REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
@ -739,6 +740,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
{REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
};
template<typename T>

View file

@ -94,6 +94,57 @@ TEST(LayerNormTest, LayerNorm_Scale) {
test.Run();
}
TEST(LayerNormTest, LayerNorm_Scale_Float16Input) {
// TODO: Unskip when fixed #41968513
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet";
}
OpTester test("LayerNormalization");
test.AddAttribute<float>("epsilon", 1e-05f);
std::vector<int64_t> dims{2, 2, 2};
test.AddInput<MLFloat16>("x", dims, ToFloat16({-10.264f, 8.6453f, 43.1561f, -0.641239f, -8.2164f, 0.11412f, 41.3156f, 3.0458f}));
test.AddInput<float>("gamma", {2}, {-0.6953f, 5.1824f});
test.AddOutput<float>("output", dims, {0.6953f, 5.1824f, -0.6953f, -5.1824f, 0.6953f, 5.1824f, -0.6953f, -5.1824f});
// TRT, DNNL and OpenVINO don't support this combination of datatypes
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) {
// TODO: Unskip when fixed #41968513
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet";
}
OpTester test("LayerNormalization");
test.AddAttribute<float>("epsilon", 1e-05f);
std::vector<int64_t> dims{2, 2, 2};
test.AddInput<float>("x", dims, {-10.264f, 8.6453f, 43.1561f, -0.641239f, -8.2164f, 0.11412f, 41.3156f, 3.0458f});
test.AddInput<MLFloat16>("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}));
test.AddOutput<MLFloat16>("output", dims, ToFloat16({0.6953f, 5.1824f, -0.6953f, -5.1824f, 0.6953f, 5.1824f, -0.6953f, -5.1824f}));
// TRT, DNNL and OpenVINO don't support this combination of datatypes
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput) {
// TODO: Unskip when fixed #41968513
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet";
}
OpTester test("LayerNormalization");
test.AddAttribute<float>("epsilon", 1e-05f);
std::vector<int64_t> dims{2, 2, 2};
test.AddInput<MLFloat16>("x", dims, ToFloat16({-10.264f, 8.6453f, 43.1561f, -0.641239f, -8.2164f, 0.11412f, 41.3156f, 3.0458f}));
test.AddInput<MLFloat16>("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}));
test.AddOutput<MLFloat16>("output", dims, ToFloat16({0.6953f, 5.1824f, -0.6953f, -5.1824f, 0.6953f, 5.1824f, -0.6953f, -5.1824f}));
// TRT, DNNL and OpenVINO don't support this combination of datatypes
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(LayerNormTest, LayerNorm_Scale_Bias) {
OpTester test("LayerNormalization");
test.AddAttribute<float>("epsilon", 1e-05f);
@ -106,6 +157,45 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias) {
test.Run();
}
TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16Input) {
OpTester test("LayerNormalization");
test.AddAttribute<float>("epsilon", 1e-05f);
std::vector<int64_t> dims{1, 3, 2};
test.AddInput<MLFloat16>("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f}));
test.AddInput<float>("gamma", {2}, {-0.6953f, 5.1824f});
test.AddInput<float>("bias", {2}, {0.6435f, -0.3964f});
test.AddOutput<float>("output", dims, {-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f});
// TRT, DNNL and OpenVINO don't support this combination of datatypes
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16ScaleBiasOutput) {
OpTester test("LayerNormalization");
test.AddAttribute<float>("epsilon", 1e-05f);
std::vector<int64_t> dims{1, 3, 2};
test.AddInput<float>("x", dims, {1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f});
test.AddInput<MLFloat16>("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}));
test.AddInput<MLFloat16>("bias", {2}, ToFloat16({0.6435f, -0.3964f}));
test.AddOutput<MLFloat16>("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f}));
// TRT, DNNL and OpenVINO don't support this combination of datatypes
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput) {
OpTester test("LayerNormalization");
test.AddAttribute<float>("epsilon", 1e-05f);
std::vector<int64_t> dims{1, 3, 2};
test.AddInput<MLFloat16>("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f}));
test.AddInput<MLFloat16>("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}));
test.AddInput<MLFloat16>("bias", {2}, ToFloat16({0.6435f, -0.3964f}));
test.AddOutput<MLFloat16>("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f}));
// TRT, DNNL and OpenVINO don't support this combination of datatypes
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider});
}
// LayerNormalization became an ONNX operator in opset 17. It uses the same implementation so this is a sanity check.
TEST(LayerNormTest, LayerNorm17_float) {
OpTester test("LayerNormalization", 17);
@ -151,7 +241,7 @@ TEST(LayerNormTest, LayerNorm_InvalidScaleBias) {
#if defined(USE_DNNL)
TEST(LayerNormTest, LayerNorm17_Scale_Bias_bfloat16) {
#ifdef USE_DNNL
if (!DnnlHasBF16Support()) {
if (!DnnlHasBF16Support()) {
LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16";
return;
}