Merged PR 5866812: Decompose unsupported QLinearSigmoid operation in DML EP

Related work items: #32220862
This commit is contained in:
Jeff Bloomfield 2021-04-01 00:24:38 +00:00
parent 56d2c4baa2
commit 057de97d92
2 changed files with 118 additions and 8 deletions

View file

@ -27,7 +27,8 @@ namespace Dml
onnxruntime::common::Status GraphTransformer::ApplyImpl(
onnxruntime::Graph& graph,
bool& modified,
int graph_level, const onnxruntime::logging::Logger&) const {
int graph_level, const onnxruntime::logging::Logger&) const
{
modified = false;
// Perform fusion
@ -36,8 +37,13 @@ namespace Dml
PerformOperatorFusion(&graph, &transformModifiedGraph);
modified |= transformModifiedGraph;
if (modified) {
ORT_RETURN_IF_ERROR(graph.Resolve());
transformModifiedGraph = false;
PerformQuantizedOperatorDecomposition(&graph, &transformModifiedGraph);
modified |= transformModifiedGraph;
if (modified)
{
ORT_RETURN_IF_ERROR(graph.Resolve());
}
}
@ -110,9 +116,10 @@ namespace Dml
// We need to predict whether the nodes will be assigned to the DML transformer by Lotus,
// which occurs in IExecutionProvider::GetCapability.
if (!onnxruntime::KernelRegistry::HasImplementationOf(*registry, outputNode, onnxruntime::kDmlExecutionProvider)) {
// Can't fuse nodes that don't belong to this execution provider
continue;
if (!onnxruntime::KernelRegistry::HasImplementationOf(*registry, outputNode, onnxruntime::kDmlExecutionProvider))
{
// Can't fuse nodes that don't belong to this execution provider
continue;
}
if (outputNode.InputDefs().size() != 1)
@ -160,12 +167,14 @@ namespace Dml
fusedNode.activationAttributes = activationNode.GetAttributes();
// Inputs to the fused node are the inputs to the fuseable node
for (const auto *input : fuseableNode.InputDefs()) {
for (const auto *input : fuseableNode.InputDefs())
{
fusedNode.inputs.push_back(graph->GetNodeArg(input->Name()));
}
// Outputs from the fused node are the outputs to the activation node
for (const auto *output : activationNode.OutputDefs()){
for (const auto *output : activationNode.OutputDefs())
{
fusedNode.outputs.push_back(graph->GetNodeArg(output->Name()));
}
@ -210,4 +219,103 @@ namespace Dml
}
}
// Converts certain QLinear operations unsupported by the DML API into a sequence of DeQuantizeLinear, 32-bit operator, QuantizeLinear
void GraphTransformer::PerformQuantizedOperatorDecomposition(onnxruntime::Graph* graph, bool* modified) const
{
struct NodeToAdd
{
std::string name;
std::string description;
std::string opType;
std::string domain;
onnxruntime::NodeAttributes attributes;
std::vector<onnxruntime::NodeArg*> inputs;
std::vector<onnxruntime::NodeArg*> outputs;
};
// Defer adding and removing nodes in the graph until after we're done iterating over it, because we can't mutate the
// graph while iterating over it
std::vector<NodeToAdd> nodesToAdd;
std::vector<onnxruntime::NodeIndex> nodesToRemove;
for (auto& node : graph->Nodes())
{
// For now, only QLinearSigmoid is handled
if (node.Domain() == onnxruntime::kMSDomain &&
node.OpType() == "QLinearSigmoid")
{
// Intermediate node arg type proto with floating point format
onnx::TypeProto floatTensorProto;
floatTensorProto.mutable_tensor_type()->set_elem_type(onnx::TensorProto_DataType_FLOAT);
// Add intermediate graph edges for the input and output of the FP32 sigmoid operator
auto* sigmoidInputArg = &graph->GetOrCreateNodeArg("decomposed_QLinearSigmoid_input_" + GetUniqueNodeName(&node), &floatTensorProto);
auto* sigmoidOutputArg = &graph->GetOrCreateNodeArg("decomposed_QLinearSigmoid_output_" + GetUniqueNodeName(&node), &floatTensorProto);
{
NodeToAdd dequantizeNode;
dequantizeNode.name = "decomposed_QLinearSigmoid_DequantizeLinear_" + GetUniqueNodeName(&node);
dequantizeNode.description = "";
dequantizeNode.opType = "DequantizeLinear";
dequantizeNode.domain = "";
dequantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[0]->Name()));
dequantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[1]->Name()));
dequantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[2]->Name()));
dequantizeNode.outputs.push_back(sigmoidInputArg);
nodesToAdd.push_back(std::move(dequantizeNode));
}
{
NodeToAdd sigmoidNode;
sigmoidNode.name = "decomposed_QLinearSigmoid_Sigmoid_" + GetUniqueNodeName(&node);
sigmoidNode.description = "";
sigmoidNode.opType = "Sigmoid";
sigmoidNode.domain = "";
sigmoidNode.inputs.push_back(sigmoidInputArg);
sigmoidNode.outputs.push_back(sigmoidOutputArg);
nodesToAdd.push_back(std::move(sigmoidNode));
}
{
NodeToAdd quantizeNode;
quantizeNode.name = "decomposed_QLinearSigmoid_QuantizeLinear_" + GetUniqueNodeName(&node);
quantizeNode.description = "";
quantizeNode.opType = "QuantizeLinear";
quantizeNode.domain = "";
quantizeNode.inputs.push_back(sigmoidOutputArg);
quantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[3]->Name()));
quantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[4]->Name()));
quantizeNode.outputs.push_back(graph->GetNodeArg(node.OutputDefs()[0]->Name()));
nodesToAdd.push_back(std::move(quantizeNode));
}
nodesToRemove.push_back(node.Index());
*modified = true;
}
}
for (auto& nodeToAdd : nodesToAdd)
{
auto& node = graph->AddNode(
nodeToAdd.name,
nodeToAdd.opType,
nodeToAdd.description,
nodeToAdd.inputs,
nodeToAdd.outputs,
&nodeToAdd.attributes,
nodeToAdd.domain);
}
for (const auto& nodeIndex : nodesToRemove)
{
onnxruntime::Node* node = graph->GetNode(nodeIndex);
onnxruntime::graph_utils::RemoveNodeOutputEdges(*graph, *node);
graph->RemoveNode(node->Index());
}
}
} // namespace Dml

View file

@ -28,6 +28,8 @@ namespace Dml
private:
void PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const;
void PerformQuantizedOperatorDecomposition(onnxruntime::Graph* graph, bool* modified) const;
std::shared_ptr<onnxruntime::KernelRegistry> m_registry;
uint32_t m_supportedDataTypeMask = 0;
const ExecutionProviderImpl* m_providerImpl = nullptr;