[DML EP] Add support for LayerNorm (scale == nullptr) != (bias == nullptr) (#13818)

### Description
Add support for LayerNorm scale == nullptr != bias == nullptr
This commit is contained in:
Patrice Vignola 2022-12-02 13:19:53 -08:00 committed by GitHub
parent a0b470bc35
commit c2d08fd73a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 47 deletions

View file

@ -175,24 +175,27 @@ public:
++currentNodeIndex;
}
DML_INPUT_GRAPH_EDGE_DESC biasInputEdge = {};
biasInputEdge.GraphInputIndex = 2;
biasInputEdge.ToNodeIndex = biasCastOpDesc.Desc ? currentNodeIndex : 0;
biasInputEdge.ToNodeInputIndex = biasCastOpDesc.Desc ? 0 : 2;
inputEdges.push_back(std::move(biasInputEdge));
if (biasCastOpDesc.Desc)
if (biasDesc.Desc)
{
opDescs.push_back(&biasCastOpDesc);
DML_INPUT_GRAPH_EDGE_DESC biasInputEdge = {};
biasInputEdge.GraphInputIndex = 2;
biasInputEdge.ToNodeIndex = biasCastOpDesc.Desc ? currentNodeIndex : 0;
biasInputEdge.ToNodeInputIndex = biasCastOpDesc.Desc ? 0 : 2;
inputEdges.push_back(std::move(biasInputEdge));
// Link the cast op to the MVN op
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
intermediateEdge.FromNodeIndex = currentNodeIndex;
intermediateEdge.FromNodeOutputIndex = 0;
intermediateEdge.ToNodeIndex = 0;
intermediateEdge.ToNodeInputIndex = 2;
intermediateEdges.push_back(std::move(intermediateEdge));
++currentNodeIndex;
if (biasCastOpDesc.Desc)
{
opDescs.push_back(&biasCastOpDesc);
// Link the cast op to the MVN op
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
intermediateEdge.FromNodeIndex = currentNodeIndex;
intermediateEdge.FromNodeOutputIndex = 0;
intermediateEdge.ToNodeIndex = 0;
intermediateEdge.ToNodeInputIndex = 2;
intermediateEdges.push_back(std::move(intermediateEdge));
++currentNodeIndex;
}
}
DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
@ -231,17 +234,7 @@ public:
void CALLBACK QueryLayerNormalization(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported)
{
*isSupported = false;
// Mean and InvStdDev are not supported outputs.
// If only Scale tensor is present then fall back to CPU. This is temporary until
// DML1.9.2 or DML1.10 gets released.
if (context->GetInputCount() < 3 || context->GetOutputCount() > 1)
{
return;
}
*isSupported = true;
*isSupported = context->GetOutputCount() == 1;
}
DML_OP_DEFINE_CREATION_FUNCTION(LayerNormalization, DmlOperatorLayerNormalization);

View file

@ -45,11 +45,6 @@ TEST(LayerNormTest, BERTLayerNorm) {
}
TEST(LayerNormTest, BERTLayerNorm_NoBias) {
// TODO: Unskip when fixed #41968513
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect";
}
OpTester tester("LayerNormalization", 1 /*opset_version*/);
tester.AddAttribute<int64_t>("axis", -1);
tester.AddAttribute<float>("epsilon", 1e-12f);
@ -95,11 +90,6 @@ TEST(LayerNormTest, LayerNorm_Scale) {
}
TEST(LayerNormTest, LayerNorm_Scale_Float16Input) {
// TODO: Unskip when fixed #41968513
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet";
}
OpTester test("LayerNormalization");
test.AddAttribute<float>("epsilon", 1e-05f);
@ -112,11 +102,6 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16Input) {
}
TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) {
// TODO: Unskip when fixed #41968513
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet";
}
OpTester test("LayerNormalization");
test.AddAttribute<float>("epsilon", 1e-05f);
@ -129,11 +114,6 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) {
}
TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput) {
// TODO: Unskip when fixed #41968513
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet";
}
OpTester test("LayerNormalization");
test.AddAttribute<float>("epsilon", 1e-05f);