mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
[DML EP] Add support for LayerNorm (scale == nullptr) != (bias == nullptr) (#13818)
### Description Add support for LayerNorm scale == nullptr != bias == nullptr
This commit is contained in:
parent
a0b470bc35
commit
c2d08fd73a
2 changed files with 20 additions and 47 deletions
|
|
@ -175,24 +175,27 @@ public:
|
|||
++currentNodeIndex;
|
||||
}
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC biasInputEdge = {};
|
||||
biasInputEdge.GraphInputIndex = 2;
|
||||
biasInputEdge.ToNodeIndex = biasCastOpDesc.Desc ? currentNodeIndex : 0;
|
||||
biasInputEdge.ToNodeInputIndex = biasCastOpDesc.Desc ? 0 : 2;
|
||||
inputEdges.push_back(std::move(biasInputEdge));
|
||||
|
||||
if (biasCastOpDesc.Desc)
|
||||
if (biasDesc.Desc)
|
||||
{
|
||||
opDescs.push_back(&biasCastOpDesc);
|
||||
DML_INPUT_GRAPH_EDGE_DESC biasInputEdge = {};
|
||||
biasInputEdge.GraphInputIndex = 2;
|
||||
biasInputEdge.ToNodeIndex = biasCastOpDesc.Desc ? currentNodeIndex : 0;
|
||||
biasInputEdge.ToNodeInputIndex = biasCastOpDesc.Desc ? 0 : 2;
|
||||
inputEdges.push_back(std::move(biasInputEdge));
|
||||
|
||||
// Link the cast op to the MVN op
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
|
||||
intermediateEdge.FromNodeIndex = currentNodeIndex;
|
||||
intermediateEdge.FromNodeOutputIndex = 0;
|
||||
intermediateEdge.ToNodeIndex = 0;
|
||||
intermediateEdge.ToNodeInputIndex = 2;
|
||||
intermediateEdges.push_back(std::move(intermediateEdge));
|
||||
++currentNodeIndex;
|
||||
if (biasCastOpDesc.Desc)
|
||||
{
|
||||
opDescs.push_back(&biasCastOpDesc);
|
||||
|
||||
// Link the cast op to the MVN op
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
|
||||
intermediateEdge.FromNodeIndex = currentNodeIndex;
|
||||
intermediateEdge.FromNodeOutputIndex = 0;
|
||||
intermediateEdge.ToNodeIndex = 0;
|
||||
intermediateEdge.ToNodeInputIndex = 2;
|
||||
intermediateEdges.push_back(std::move(intermediateEdge));
|
||||
++currentNodeIndex;
|
||||
}
|
||||
}
|
||||
|
||||
DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
|
||||
|
|
@ -231,17 +234,7 @@ public:
|
|||
|
||||
void CALLBACK QueryLayerNormalization(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported)
|
||||
{
|
||||
*isSupported = false;
|
||||
|
||||
// Mean and InvStdDev are not supported outputs.
|
||||
// If only Scale tensor is present then fall back to CPU. This is temporary until
|
||||
// DML1.9.2 or DML1.10 gets released.
|
||||
if (context->GetInputCount() < 3 || context->GetOutputCount() > 1)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
*isSupported = true;
|
||||
*isSupported = context->GetOutputCount() == 1;
|
||||
}
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(LayerNormalization, DmlOperatorLayerNormalization);
|
||||
|
|
|
|||
|
|
@ -45,11 +45,6 @@ TEST(LayerNormTest, BERTLayerNorm) {
|
|||
}
|
||||
|
||||
TEST(LayerNormTest, BERTLayerNorm_NoBias) {
|
||||
// TODO: Unskip when fixed #41968513
|
||||
if (DefaultDmlExecutionProvider().get() != nullptr) {
|
||||
GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect";
|
||||
}
|
||||
|
||||
OpTester tester("LayerNormalization", 1 /*opset_version*/);
|
||||
tester.AddAttribute<int64_t>("axis", -1);
|
||||
tester.AddAttribute<float>("epsilon", 1e-12f);
|
||||
|
|
@ -95,11 +90,6 @@ TEST(LayerNormTest, LayerNorm_Scale) {
|
|||
}
|
||||
|
||||
TEST(LayerNormTest, LayerNorm_Scale_Float16Input) {
|
||||
// TODO: Unskip when fixed #41968513
|
||||
if (DefaultDmlExecutionProvider().get() != nullptr) {
|
||||
GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet";
|
||||
}
|
||||
|
||||
OpTester test("LayerNormalization");
|
||||
test.AddAttribute<float>("epsilon", 1e-05f);
|
||||
|
||||
|
|
@ -112,11 +102,6 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16Input) {
|
|||
}
|
||||
|
||||
TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) {
|
||||
// TODO: Unskip when fixed #41968513
|
||||
if (DefaultDmlExecutionProvider().get() != nullptr) {
|
||||
GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet";
|
||||
}
|
||||
|
||||
OpTester test("LayerNormalization");
|
||||
test.AddAttribute<float>("epsilon", 1e-05f);
|
||||
|
||||
|
|
@ -129,11 +114,6 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) {
|
|||
}
|
||||
|
||||
TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput) {
|
||||
// TODO: Unskip when fixed #41968513
|
||||
if (DefaultDmlExecutionProvider().get() != nullptr) {
|
||||
GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet";
|
||||
}
|
||||
|
||||
OpTester test("LayerNormalization");
|
||||
test.AddAttribute<float>("epsilon", 1e-05f);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue