From 057de97d92a146f5e1d676d51110dee6652cd246 Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield Date: Thu, 1 Apr 2021 00:24:38 +0000 Subject: [PATCH] Merged PR 5866812: Decompose unsupported QLinearSigmoid operation in DML EP Related work items: #32220862 --- .../src/GraphTransformer.cpp | 124 ++++++++++++++++-- .../src/GraphTransformer.h | 2 + 2 files changed, 118 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp index 112e7279d0..220de8fb60 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp @@ -27,7 +27,8 @@ namespace Dml onnxruntime::common::Status GraphTransformer::ApplyImpl( onnxruntime::Graph& graph, bool& modified, - int graph_level, const onnxruntime::logging::Logger&) const { + int graph_level, const onnxruntime::logging::Logger&) const + { modified = false; // Perform fusion @@ -36,8 +37,13 @@ namespace Dml PerformOperatorFusion(&graph, &transformModifiedGraph); modified |= transformModifiedGraph; - if (modified) { - ORT_RETURN_IF_ERROR(graph.Resolve()); + transformModifiedGraph = false; + PerformQuantizedOperatorDecomposition(&graph, &transformModifiedGraph); + modified |= transformModifiedGraph; + + if (modified) + { + ORT_RETURN_IF_ERROR(graph.Resolve()); } } @@ -110,9 +116,10 @@ namespace Dml // We need to predict whether the nodes will be assigned to the DML transformer by Lotus, // which occurs in IExecutionProvider::GetCapability. - if (!onnxruntime::KernelRegistry::HasImplementationOf(*registry, outputNode, onnxruntime::kDmlExecutionProvider)) { - // Can't fuse nodes that don't belong to this execution provider - continue; + if (!onnxruntime::KernelRegistry::HasImplementationOf(*registry, outputNode, onnxruntime::kDmlExecutionProvider)) + { + // Can't fuse nodes that don't belong to this execution provider + continue; } if (outputNode.InputDefs().size() != 1) @@ -160,12 +167,14 @@ namespace Dml fusedNode.activationAttributes = activationNode.GetAttributes(); // Inputs to the fused node are the inputs to the fuseable node - for (const auto *input : fuseableNode.InputDefs()) { + for (const auto *input : fuseableNode.InputDefs()) + { fusedNode.inputs.push_back(graph->GetNodeArg(input->Name())); } // Outputs from the fused node are the outputs to the activation node - for (const auto *output : activationNode.OutputDefs()){ + for (const auto *output : activationNode.OutputDefs()) + { fusedNode.outputs.push_back(graph->GetNodeArg(output->Name())); } @@ -210,4 +219,103 @@ namespace Dml } } + // Converts certain QLinear operations unsupported by the DML API into a sequence of DeQuantizeLinear, 32-bit operator, QuantizeLinear + void GraphTransformer::PerformQuantizedOperatorDecomposition(onnxruntime::Graph* graph, bool* modified) const + { + struct NodeToAdd + { + std::string name; + std::string description; + std::string opType; + std::string domain; + onnxruntime::NodeAttributes attributes; + std::vector inputs; + std::vector outputs; + }; + + // Defer adding and removing nodes in the graph until after we're done iterating over it, because we can't mutate the + // graph while iterating over it + std::vector nodesToAdd; + std::vector nodesToRemove; + + for (auto& node : graph->Nodes()) + { + // For now, only QLinearSigmoid is handled + if (node.Domain() == onnxruntime::kMSDomain && + node.OpType() == "QLinearSigmoid") + { + // Intermediate node arg type proto with floating point format + onnx::TypeProto floatTensorProto; + floatTensorProto.mutable_tensor_type()->set_elem_type(onnx::TensorProto_DataType_FLOAT); + + // Add intermediate graph edges for the input and output of the FP32 sigmoid operator + auto* sigmoidInputArg = &graph->GetOrCreateNodeArg("decomposed_QLinearSigmoid_input_" + GetUniqueNodeName(&node), &floatTensorProto); + auto* sigmoidOutputArg = &graph->GetOrCreateNodeArg("decomposed_QLinearSigmoid_output_" + GetUniqueNodeName(&node), &floatTensorProto); + + { + NodeToAdd dequantizeNode; + dequantizeNode.name = "decomposed_QLinearSigmoid_DequantizeLinear_" + GetUniqueNodeName(&node); + dequantizeNode.description = ""; + dequantizeNode.opType = "DequantizeLinear"; + dequantizeNode.domain = ""; + + dequantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[0]->Name())); + dequantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[1]->Name())); + dequantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[2]->Name())); + dequantizeNode.outputs.push_back(sigmoidInputArg); + + nodesToAdd.push_back(std::move(dequantizeNode)); + } + + { + NodeToAdd sigmoidNode; + sigmoidNode.name = "decomposed_QLinearSigmoid_Sigmoid_" + GetUniqueNodeName(&node); + sigmoidNode.description = ""; + sigmoidNode.opType = "Sigmoid"; + sigmoidNode.domain = ""; + sigmoidNode.inputs.push_back(sigmoidInputArg); + sigmoidNode.outputs.push_back(sigmoidOutputArg); + nodesToAdd.push_back(std::move(sigmoidNode)); + } + + { + NodeToAdd quantizeNode; + quantizeNode.name = "decomposed_QLinearSigmoid_QuantizeLinear_" + GetUniqueNodeName(&node); + quantizeNode.description = ""; + quantizeNode.opType = "QuantizeLinear"; + quantizeNode.domain = ""; + + quantizeNode.inputs.push_back(sigmoidOutputArg); + quantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[3]->Name())); + quantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[4]->Name())); + quantizeNode.outputs.push_back(graph->GetNodeArg(node.OutputDefs()[0]->Name())); + + nodesToAdd.push_back(std::move(quantizeNode)); + } + + nodesToRemove.push_back(node.Index()); + *modified = true; + } + } + + for (auto& nodeToAdd : nodesToAdd) + { + auto& node = graph->AddNode( + nodeToAdd.name, + nodeToAdd.opType, + nodeToAdd.description, + nodeToAdd.inputs, + nodeToAdd.outputs, + &nodeToAdd.attributes, + nodeToAdd.domain); + } + + for (const auto& nodeIndex : nodesToRemove) + { + onnxruntime::Node* node = graph->GetNode(nodeIndex); + onnxruntime::graph_utils::RemoveNodeOutputEdges(*graph, *node); + graph->RemoveNode(node->Index()); + } + } + } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h index af2cc456d5..a87c9b2314 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h @@ -28,6 +28,8 @@ namespace Dml private: void PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const; + void PerformQuantizedOperatorDecomposition(onnxruntime::Graph* graph, bool* modified) const; + std::shared_ptr m_registry; uint32_t m_supportedDataTypeMask = 0; const ExecutionProviderImpl* m_providerImpl = nullptr;