mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
[DML EP] Add subgraph fusion support (#17504)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
4e37c5d1f0
commit
7edff1c2bf
4 changed files with 76 additions and 33 deletions
|
|
@ -38,6 +38,16 @@ namespace Dml
|
|||
bool& modified,
|
||||
int graph_level,
|
||||
const onnxruntime::logging::Logger& logger) const
|
||||
{
|
||||
return ApplyImplHelper(graph, modified, graph_level, logger, {});
|
||||
}
|
||||
|
||||
onnxruntime::common::Status DmlGraphFusionTransformer::ApplyImplHelper(
|
||||
onnxruntime::Graph& graph,
|
||||
bool& modified,
|
||||
int graph_level,
|
||||
const onnxruntime::logging::Logger& logger,
|
||||
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputDefs) const
|
||||
{
|
||||
onnxruntime::ProviderType provider_type = onnxruntime::kDmlExecutionProvider;
|
||||
const gsl::not_null<const onnxruntime::KernelRegistry*> registry = m_providerImpl->GetKernelRegistry().get();
|
||||
|
|
@ -49,6 +59,30 @@ namespace Dml
|
|||
std::vector<std::shared_ptr<CompiledPartitionInfo>> compiledPartitionInfos;
|
||||
std::vector<onnxruntime::NodeIndex> additionalSplittingNodes;
|
||||
|
||||
onnxruntime::GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
for (auto node_index : node_topology_list)
|
||||
{
|
||||
auto* node = graph.GetNode(node_index);
|
||||
if (!node)
|
||||
{
|
||||
continue; // node was removed
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, const onnxruntime::NodeArg*> subgraphImplicitInputDefs;
|
||||
for (const onnxruntime::NodeArg* inputDef : node->ImplicitInputDefs())
|
||||
{
|
||||
subgraphImplicitInputDefs[inputDef->Name()] = inputDef;
|
||||
}
|
||||
|
||||
for (auto& entry : node->GetAttributeNameToMutableSubgraphMap())
|
||||
{
|
||||
auto& subgraph = *entry.second;
|
||||
ORT_RETURN_IF_ERROR(ApplyImplHelper(subgraph, modified, graph_level + 1, logger, subgraphImplicitInputDefs));
|
||||
}
|
||||
}
|
||||
|
||||
do
|
||||
{
|
||||
// Initializers needed by any graph partition
|
||||
|
|
@ -62,7 +96,8 @@ namespace Dml
|
|||
m_providerImpl->GetSupportedDeviceDataTypeMask(),
|
||||
graphNodePropertyMap,
|
||||
requiredInitializerMap,
|
||||
additionalSplittingNodes);
|
||||
additionalSplittingNodes,
|
||||
implicitInputDefs);
|
||||
|
||||
// Reset the splitting nodes for the current iteration
|
||||
additionalSplittingNodes.clear();
|
||||
|
|
|
|||
|
|
@ -2,32 +2,41 @@
|
|||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/framework/execution_providers.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
class ExecutionProviderImpl;
|
||||
class ExecutionProviderImpl;
|
||||
|
||||
class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer
|
||||
{
|
||||
public:
|
||||
DmlGraphFusionTransformer(
|
||||
const std::string& name,
|
||||
const onnxruntime::IExecutionProvider* provider
|
||||
);
|
||||
class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer
|
||||
{
|
||||
public:
|
||||
DmlGraphFusionTransformer(
|
||||
const std::string& name,
|
||||
const onnxruntime::IExecutionProvider* provider
|
||||
);
|
||||
|
||||
public:
|
||||
inline const static char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlFusedNode_";
|
||||
inline const static char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlFusedNodeDomain";
|
||||
public:
|
||||
static inline const char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlFusedNode_";
|
||||
static inline const char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlFusedNodeDomain";
|
||||
|
||||
private:
|
||||
onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph,
|
||||
bool& modified,
|
||||
int graph_level,
|
||||
const onnxruntime::logging::Logger& logger) const final;
|
||||
private:
|
||||
const ExecutionProviderImpl* m_providerImpl = nullptr;
|
||||
};
|
||||
private:
|
||||
onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph,
|
||||
bool& modified,
|
||||
int graph_level,
|
||||
const onnxruntime::logging::Logger& logger) const final;
|
||||
|
||||
onnxruntime::common::Status ApplyImplHelper(
|
||||
onnxruntime::Graph& graph,
|
||||
bool& modified,
|
||||
int graph_level,
|
||||
const onnxruntime::logging::Logger& logger,
|
||||
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputDefs) const;
|
||||
|
||||
private:
|
||||
const ExecutionProviderImpl* m_providerImpl = nullptr;
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -345,13 +345,8 @@ namespace Dml
|
|||
// Whether any operator in the model contains a subgraph. This is true
|
||||
// if the graph being partitioned is itself within a subgraph, or contains
|
||||
// an operator with a subgraph.
|
||||
bool ModelUsesSubgraph(const onnxruntime::GraphViewer& graph)
|
||||
bool ContainsSubgraph(const onnxruntime::GraphViewer& graph)
|
||||
{
|
||||
if (graph.IsSubgraph())
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::vector<onnxruntime::NodeIndex>& toplogicalOrder = graph.GetNodesInTopologicalOrder();
|
||||
|
||||
for (size_t nodeIndex : toplogicalOrder)
|
||||
|
|
@ -384,7 +379,8 @@ namespace Dml
|
|||
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
|
||||
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties>& graphNodePropertyMap,
|
||||
std::unordered_set<std::string>& requiredInitializerMap,
|
||||
gsl::span<const onnxruntime::NodeIndex> additionalSplittingNodes)
|
||||
gsl::span<const onnxruntime::NodeIndex> additionalSplittingNodes,
|
||||
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputs)
|
||||
{
|
||||
// Nodes are uniquely identified by the name of their first output argument
|
||||
std::vector<std::unique_ptr<GraphPartition>> partitions;
|
||||
|
|
@ -419,7 +415,7 @@ namespace Dml
|
|||
}
|
||||
|
||||
// Check whether this graph is a subgraph, or contains any node with a subgraph.
|
||||
bool modelUsesSubgraph = ModelUsesSubgraph(graph);
|
||||
bool containsSubgraph = ContainsSubgraph(graph);
|
||||
|
||||
uint32_t splittingNodeIndex = 0;
|
||||
|
||||
|
|
@ -454,10 +450,10 @@ namespace Dml
|
|||
// Add a unique partition if graph node usage is not supported.
|
||||
//
|
||||
// Partitioning is disabled in models with subgraphs to work around issues with implicit inputs.
|
||||
// The partitioning algorithm does not currently consider such inputs. Transfering shared initializers
|
||||
// The partitioning algorithm does not currently consider such inputs. Transferring shared initializers
|
||||
// for partitions could also cause problems. Note, operators with subgraphs are currently not efficient
|
||||
// anyhow due to CPU/GPU copies.
|
||||
if (modelUsesSubgraph || !isDmlGraphNode)
|
||||
if (containsSubgraph || !isDmlGraphNode)
|
||||
{
|
||||
partitions.push_back(CreatePartitionAndFinalizeInputs(node, isDmlNode, false, nodeNameToPartitionMap));
|
||||
continue;
|
||||
|
|
@ -505,7 +501,7 @@ namespace Dml
|
|||
firstNonFinalInputPartition->AddInput(arg->Name());
|
||||
}
|
||||
|
||||
if (graphInputs.find(arg->Name()) != graphInputs.end())
|
||||
if (graphInputs.find(arg->Name()) != graphInputs.end() || implicitInputs.find(arg->Name()) != implicitInputs.end())
|
||||
{
|
||||
firstNonFinalInputPartition->AddInput(arg->Name());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h"
|
||||
|
||||
namespace Dml
|
||||
|
|
@ -48,5 +50,6 @@ namespace Dml
|
|||
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
|
||||
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties>& graphNodePropertyMap,
|
||||
std::unordered_set<std::string>& requiredInitializerMap,
|
||||
gsl::span<const onnxruntime::NodeIndex> additionalSplittingNodes);
|
||||
gsl::span<const onnxruntime::NodeIndex> additionalSplittingNodes,
|
||||
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputs);
|
||||
} // namespace Dml
|
||||
|
|
|
|||
Loading…
Reference in a new issue