diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp index 6477e0c36b..30047e750c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp @@ -175,24 +175,27 @@ public: ++currentNodeIndex; } - DML_INPUT_GRAPH_EDGE_DESC biasInputEdge = {}; - biasInputEdge.GraphInputIndex = 2; - biasInputEdge.ToNodeIndex = biasCastOpDesc.Desc ? currentNodeIndex : 0; - biasInputEdge.ToNodeInputIndex = biasCastOpDesc.Desc ? 0 : 2; - inputEdges.push_back(std::move(biasInputEdge)); - - if (biasCastOpDesc.Desc) + if (biasDesc.Desc) { - opDescs.push_back(&biasCastOpDesc); + DML_INPUT_GRAPH_EDGE_DESC biasInputEdge = {}; + biasInputEdge.GraphInputIndex = 2; + biasInputEdge.ToNodeIndex = biasCastOpDesc.Desc ? currentNodeIndex : 0; + biasInputEdge.ToNodeInputIndex = biasCastOpDesc.Desc ? 0 : 2; + inputEdges.push_back(std::move(biasInputEdge)); - // Link the cast op to the MVN op - DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {}; - intermediateEdge.FromNodeIndex = currentNodeIndex; - intermediateEdge.FromNodeOutputIndex = 0; - intermediateEdge.ToNodeIndex = 0; - intermediateEdge.ToNodeInputIndex = 2; - intermediateEdges.push_back(std::move(intermediateEdge)); - ++currentNodeIndex; + if (biasCastOpDesc.Desc) + { + opDescs.push_back(&biasCastOpDesc); + + // Link the cast op to the MVN op + DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {}; + intermediateEdge.FromNodeIndex = currentNodeIndex; + intermediateEdge.FromNodeOutputIndex = 0; + intermediateEdge.ToNodeIndex = 0; + intermediateEdge.ToNodeInputIndex = 2; + intermediateEdges.push_back(std::move(intermediateEdge)); + ++currentNodeIndex; + } } DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {}; @@ -231,17 +234,7 @@ public: void CALLBACK QueryLayerNormalization(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported) { - *isSupported = false; - - // Mean and InvStdDev are not supported outputs. - // If only Scale tensor is present then fall back to CPU. This is temporary until - // DML1.9.2 or DML1.10 gets released. - if (context->GetInputCount() < 3 || context->GetOutputCount() > 1) - { - return; - } - - *isSupported = true; + *isSupported = context->GetOutputCount() == 1; } DML_OP_DEFINE_CREATION_FUNCTION(LayerNormalization, DmlOperatorLayerNormalization); diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 5caad57e9c..9bfbe781a8 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -45,11 +45,6 @@ TEST(LayerNormTest, BERTLayerNorm) { } TEST(LayerNormTest, BERTLayerNorm_NoBias) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect"; - } - OpTester tester("LayerNormalization", 1 /*opset_version*/); tester.AddAttribute("axis", -1); tester.AddAttribute("epsilon", 1e-12f); @@ -95,11 +90,6 @@ TEST(LayerNormTest, LayerNorm_Scale) { } TEST(LayerNormTest, LayerNorm_Scale_Float16Input) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet"; - } - OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); @@ -112,11 +102,6 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16Input) { } TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet"; - } - OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); @@ -129,11 +114,6 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) { } TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because DML's LayerNorm doesn't support less than 3 inputs yet"; - } - OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f);