mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Merged PR 5866812: Decompose unsupported QLinearSigmoid operation in DML EP
Related work items: #32220862
This commit is contained in:
parent
56d2c4baa2
commit
057de97d92
2 changed files with 118 additions and 8 deletions
|
|
@ -27,7 +27,8 @@ namespace Dml
|
|||
onnxruntime::common::Status GraphTransformer::ApplyImpl(
|
||||
onnxruntime::Graph& graph,
|
||||
bool& modified,
|
||||
int graph_level, const onnxruntime::logging::Logger&) const {
|
||||
int graph_level, const onnxruntime::logging::Logger&) const
|
||||
{
|
||||
modified = false;
|
||||
|
||||
// Perform fusion
|
||||
|
|
@ -36,8 +37,13 @@ namespace Dml
|
|||
PerformOperatorFusion(&graph, &transformModifiedGraph);
|
||||
modified |= transformModifiedGraph;
|
||||
|
||||
if (modified) {
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
transformModifiedGraph = false;
|
||||
PerformQuantizedOperatorDecomposition(&graph, &transformModifiedGraph);
|
||||
modified |= transformModifiedGraph;
|
||||
|
||||
if (modified)
|
||||
{
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -110,9 +116,10 @@ namespace Dml
|
|||
|
||||
// We need to predict whether the nodes will be assigned to the DML transformer by Lotus,
|
||||
// which occurs in IExecutionProvider::GetCapability.
|
||||
if (!onnxruntime::KernelRegistry::HasImplementationOf(*registry, outputNode, onnxruntime::kDmlExecutionProvider)) {
|
||||
// Can't fuse nodes that don't belong to this execution provider
|
||||
continue;
|
||||
if (!onnxruntime::KernelRegistry::HasImplementationOf(*registry, outputNode, onnxruntime::kDmlExecutionProvider))
|
||||
{
|
||||
// Can't fuse nodes that don't belong to this execution provider
|
||||
continue;
|
||||
}
|
||||
|
||||
if (outputNode.InputDefs().size() != 1)
|
||||
|
|
@ -160,12 +167,14 @@ namespace Dml
|
|||
fusedNode.activationAttributes = activationNode.GetAttributes();
|
||||
|
||||
// Inputs to the fused node are the inputs to the fuseable node
|
||||
for (const auto *input : fuseableNode.InputDefs()) {
|
||||
for (const auto *input : fuseableNode.InputDefs())
|
||||
{
|
||||
fusedNode.inputs.push_back(graph->GetNodeArg(input->Name()));
|
||||
}
|
||||
|
||||
// Outputs from the fused node are the outputs to the activation node
|
||||
for (const auto *output : activationNode.OutputDefs()){
|
||||
for (const auto *output : activationNode.OutputDefs())
|
||||
{
|
||||
fusedNode.outputs.push_back(graph->GetNodeArg(output->Name()));
|
||||
}
|
||||
|
||||
|
|
@ -210,4 +219,103 @@ namespace Dml
|
|||
}
|
||||
}
|
||||
|
||||
// Converts certain QLinear operations unsupported by the DML API into a sequence of DeQuantizeLinear, 32-bit operator, QuantizeLinear
|
||||
void GraphTransformer::PerformQuantizedOperatorDecomposition(onnxruntime::Graph* graph, bool* modified) const
|
||||
{
|
||||
struct NodeToAdd
|
||||
{
|
||||
std::string name;
|
||||
std::string description;
|
||||
std::string opType;
|
||||
std::string domain;
|
||||
onnxruntime::NodeAttributes attributes;
|
||||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
std::vector<onnxruntime::NodeArg*> outputs;
|
||||
};
|
||||
|
||||
// Defer adding and removing nodes in the graph until after we're done iterating over it, because we can't mutate the
|
||||
// graph while iterating over it
|
||||
std::vector<NodeToAdd> nodesToAdd;
|
||||
std::vector<onnxruntime::NodeIndex> nodesToRemove;
|
||||
|
||||
for (auto& node : graph->Nodes())
|
||||
{
|
||||
// For now, only QLinearSigmoid is handled
|
||||
if (node.Domain() == onnxruntime::kMSDomain &&
|
||||
node.OpType() == "QLinearSigmoid")
|
||||
{
|
||||
// Intermediate node arg type proto with floating point format
|
||||
onnx::TypeProto floatTensorProto;
|
||||
floatTensorProto.mutable_tensor_type()->set_elem_type(onnx::TensorProto_DataType_FLOAT);
|
||||
|
||||
// Add intermediate graph edges for the input and output of the FP32 sigmoid operator
|
||||
auto* sigmoidInputArg = &graph->GetOrCreateNodeArg("decomposed_QLinearSigmoid_input_" + GetUniqueNodeName(&node), &floatTensorProto);
|
||||
auto* sigmoidOutputArg = &graph->GetOrCreateNodeArg("decomposed_QLinearSigmoid_output_" + GetUniqueNodeName(&node), &floatTensorProto);
|
||||
|
||||
{
|
||||
NodeToAdd dequantizeNode;
|
||||
dequantizeNode.name = "decomposed_QLinearSigmoid_DequantizeLinear_" + GetUniqueNodeName(&node);
|
||||
dequantizeNode.description = "";
|
||||
dequantizeNode.opType = "DequantizeLinear";
|
||||
dequantizeNode.domain = "";
|
||||
|
||||
dequantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[0]->Name()));
|
||||
dequantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[1]->Name()));
|
||||
dequantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[2]->Name()));
|
||||
dequantizeNode.outputs.push_back(sigmoidInputArg);
|
||||
|
||||
nodesToAdd.push_back(std::move(dequantizeNode));
|
||||
}
|
||||
|
||||
{
|
||||
NodeToAdd sigmoidNode;
|
||||
sigmoidNode.name = "decomposed_QLinearSigmoid_Sigmoid_" + GetUniqueNodeName(&node);
|
||||
sigmoidNode.description = "";
|
||||
sigmoidNode.opType = "Sigmoid";
|
||||
sigmoidNode.domain = "";
|
||||
sigmoidNode.inputs.push_back(sigmoidInputArg);
|
||||
sigmoidNode.outputs.push_back(sigmoidOutputArg);
|
||||
nodesToAdd.push_back(std::move(sigmoidNode));
|
||||
}
|
||||
|
||||
{
|
||||
NodeToAdd quantizeNode;
|
||||
quantizeNode.name = "decomposed_QLinearSigmoid_QuantizeLinear_" + GetUniqueNodeName(&node);
|
||||
quantizeNode.description = "";
|
||||
quantizeNode.opType = "QuantizeLinear";
|
||||
quantizeNode.domain = "";
|
||||
|
||||
quantizeNode.inputs.push_back(sigmoidOutputArg);
|
||||
quantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[3]->Name()));
|
||||
quantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[4]->Name()));
|
||||
quantizeNode.outputs.push_back(graph->GetNodeArg(node.OutputDefs()[0]->Name()));
|
||||
|
||||
nodesToAdd.push_back(std::move(quantizeNode));
|
||||
}
|
||||
|
||||
nodesToRemove.push_back(node.Index());
|
||||
*modified = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& nodeToAdd : nodesToAdd)
|
||||
{
|
||||
auto& node = graph->AddNode(
|
||||
nodeToAdd.name,
|
||||
nodeToAdd.opType,
|
||||
nodeToAdd.description,
|
||||
nodeToAdd.inputs,
|
||||
nodeToAdd.outputs,
|
||||
&nodeToAdd.attributes,
|
||||
nodeToAdd.domain);
|
||||
}
|
||||
|
||||
for (const auto& nodeIndex : nodesToRemove)
|
||||
{
|
||||
onnxruntime::Node* node = graph->GetNode(nodeIndex);
|
||||
onnxruntime::graph_utils::RemoveNodeOutputEdges(*graph, *node);
|
||||
graph->RemoveNode(node->Index());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ namespace Dml
|
|||
|
||||
private:
|
||||
void PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const;
|
||||
void PerformQuantizedOperatorDecomposition(onnxruntime::Graph* graph, bool* modified) const;
|
||||
|
||||
std::shared_ptr<onnxruntime::KernelRegistry> m_registry;
|
||||
uint32_t m_supportedDataTypeMask = 0;
|
||||
const ExecutionProviderImpl* m_providerImpl = nullptr;
|
||||
|
|
|
|||
Loading…
Reference in a new issue