mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
[DML EP] Add dynamic graph compilation (#17876)
Historically, DML was only able to fuse partitions when all sizes are known in advance or when we were overriding them at session creation time. But in practice, it should be possible to compile partitions at compute time if the caller knows that the dimensions won't be changed for every inference (e.g. resizing a webcam window, or padding the input to powers of 2). This graph will be cached and reused until the sizes change. This is an opt-in option gated under the `enable_dynamic_graph_fusion` option, which means that it will only be enabled when the caller requests it since they have more context on how their model will be called between inferences. This PR also adds the option to disable metacommands from the python API, which is an option for the C API but was lacking for python.
This commit is contained in:
parent
d30d4d372a
commit
538e97cbda
26 changed files with 1127 additions and 143 deletions
|
|
@ -128,7 +128,7 @@ struct OrtDmlApi {
|
|||
/**
|
||||
* SessionOptionsAppendExecutionProvider_DML2
|
||||
* Creates a DirectML Execution Provider given the supplied device options that contain a performance preference
|
||||
* (high power, low power, or defult) and a device filter (None, GPU, or NPU).
|
||||
* (high power, low power, or default) and a device filter (None, GPU, or NPU).
|
||||
*/
|
||||
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@
|
|||
|
||||
#pragma once
|
||||
interface IMLOperatorRegistry;
|
||||
interface IDMLDevice;
|
||||
interface ID3D12CommandQueue;
|
||||
interface ID3D12Resource;
|
||||
|
||||
#include "core/common/status.h"
|
||||
#include "core/framework/data_transfer.h"
|
||||
|
|
@ -28,7 +31,8 @@ namespace Dml
|
|||
std::unique_ptr<onnxruntime::IExecutionProvider> CreateExecutionProvider(
|
||||
IDMLDevice* dmlDevice,
|
||||
ID3D12CommandQueue* commandQueue,
|
||||
bool enableMetacommands = true);
|
||||
bool enableMetacommands,
|
||||
bool enableDynamicGraphFusion);
|
||||
|
||||
ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr);
|
||||
void FlushContext(onnxruntime::IExecutionProvider* provider);
|
||||
|
|
|
|||
|
|
@ -7,11 +7,14 @@
|
|||
#include <functional>
|
||||
#include <variant>
|
||||
#include <optional>
|
||||
#include <wrl/client.h>
|
||||
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h"
|
||||
|
||||
struct AbstractOperatorDesc;
|
||||
interface IMLOperatorTensor;
|
||||
interface IDMLOperator;
|
||||
struct DML_INPUT_GRAPH_EDGE_DESC;
|
||||
struct DML_OUTPUT_GRAPH_EDGE_DESC;
|
||||
struct DML_INTERMEDIATE_GRAPH_EDGE_DESC;
|
||||
|
|
@ -92,6 +95,8 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
const onnxruntime::Node& node,
|
||||
MLOperatorTensorGetter& constantInputGetter,
|
||||
const void* executionHandle,
|
||||
const EdgeShapes* inputShapesOverrides,
|
||||
/*out*/ EdgeShapes* outputShapes,
|
||||
/*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo
|
||||
)>;
|
||||
|
||||
|
|
|
|||
|
|
@ -491,6 +491,8 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
|
|||
const onnxruntime::Node& node,
|
||||
MLOperatorTensorGetter& constantInputGetter,
|
||||
const void* executionHandle,
|
||||
const EdgeShapes* inputShapesOverrides,
|
||||
/*out*/ EdgeShapes* outputShapes,
|
||||
/*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo
|
||||
)
|
||||
{
|
||||
|
|
@ -498,15 +500,15 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
|
|||
onnxruntime::OpNodeProtoHelper<onnxruntime::ProtoHelperNodeContext> protoHelper(&nodeContext);
|
||||
|
||||
// Use the same list of required constant inputs for the shape inferrer and the kernel.
|
||||
EdgeShapes outputShapes;
|
||||
InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, nullptr, outputShapes);
|
||||
InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, inputShapesOverrides, *outputShapes);
|
||||
|
||||
// Create the kernel while allowing input shape and output shape queries according to options
|
||||
ComPtr<DmlGraphOpKernelInfoWrapper> kernelInfoWrapper = wil::MakeOrThrow<DmlGraphOpKernelInfoWrapper>(
|
||||
&protoHelper,
|
||||
executionHandle,
|
||||
true,
|
||||
&outputShapes,
|
||||
inputShapesOverrides,
|
||||
outputShapes,
|
||||
&defaultAttributesCapture,
|
||||
graphNodeCreateInfo,
|
||||
constantCpuInputCapture,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace Windows::AI::MachineLearning::Adapter
|
||||
{
|
||||
// edges and unused edges have an empty array of dimensions.
|
||||
class EdgeShapes
|
||||
{
|
||||
public:
|
||||
EdgeShapes() = default;
|
||||
|
||||
EdgeShapes(size_t count) : m_shapes(count) {}
|
||||
|
||||
const std::vector<uint32_t>& GetShape(size_t edgeIndex) const
|
||||
{
|
||||
return m_shapes[edgeIndex];
|
||||
}
|
||||
|
||||
std::vector<uint32_t>& GetMutableShape(size_t edgeIndex)
|
||||
{
|
||||
return m_shapes[edgeIndex];
|
||||
}
|
||||
|
||||
size_t EdgeCount() const { return m_shapes.size(); }
|
||||
|
||||
void Reset(size_t edge_count)
|
||||
{
|
||||
m_shapes.clear();
|
||||
m_shapes.resize(edge_count);
|
||||
}
|
||||
|
||||
bool operator!=(const EdgeShapes& other) const noexcept
|
||||
{
|
||||
return (m_shapes != other.m_shapes);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::vector<uint32_t>> m_shapes;
|
||||
};
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "DmlGraphFusionHelper.h"
|
||||
|
||||
#include "DmlRuntimeFusedGraphKernel.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
|
|
@ -501,5 +501,171 @@ namespace DmlGraphFusionHelper
|
|||
|
||||
graph.FinalizeFuseSubGraph(indexedSubGraph, fusedNode);
|
||||
}
|
||||
|
||||
void RegisterDynamicKernel(
|
||||
onnxruntime::Graph& graph,
|
||||
onnxruntime::KernelRegistry* registryForPartitionKernels,
|
||||
const ExecutionProviderImpl* providerImpl,
|
||||
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties> graphNodePropertyMap,
|
||||
const std::unordered_set<std::string>& dynamicCpuInputMap,
|
||||
std::shared_ptr<const onnxruntime::IndexedSubGraph> indexedSubGraph,
|
||||
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>&& isInitializerTransferable)
|
||||
{
|
||||
struct NodeInfo
|
||||
{
|
||||
std::string name;
|
||||
std::string opType;
|
||||
std::string description;
|
||||
std::string domain;
|
||||
onnxruntime::NodeAttributes attributes;
|
||||
std::vector<onnxruntime::NodeArg*> inputDefPointers;
|
||||
std::vector<onnxruntime::NodeArg*> outputDefPointers;
|
||||
};
|
||||
|
||||
auto partitionNodePropsMap = DmlGraphFusionHelper::CreatePartitionNodePropsMap(
|
||||
graph,
|
||||
*indexedSubGraph,
|
||||
std::move(graphNodePropertyMap));
|
||||
|
||||
auto modelPath = graph.ModelPath();
|
||||
|
||||
const gsl::span<const std::string> subGraphInputArgNames = indexedSubGraph->GetMetaDef()->inputs;
|
||||
const gsl::span<const std::string> subGraphOutputArgNames = indexedSubGraph->GetMetaDef()->outputs;
|
||||
|
||||
std::vector<NodeInfo> nodesInfo;
|
||||
nodesInfo.reserve(indexedSubGraph->nodes.size());
|
||||
|
||||
std::vector<const onnxruntime::NodeArg*> subgraphInputs;
|
||||
subgraphInputs.reserve(subGraphInputArgNames.size());
|
||||
|
||||
std::vector<const onnxruntime::NodeArg*> subgraphOutputs;
|
||||
subgraphOutputs.reserve(subGraphOutputArgNames.size());
|
||||
|
||||
std::vector<onnxruntime::NodeAttributes> nodeAttributes;
|
||||
nodeAttributes.reserve(indexedSubGraph->nodes.size());
|
||||
|
||||
std::vector<std::shared_ptr<onnxruntime::NodeArg>> intermediateNodeArgs;
|
||||
|
||||
for (size_t sortedNodeIndex : indexedSubGraph->nodes)
|
||||
{
|
||||
auto node = graph.GetNode(sortedNodeIndex);
|
||||
|
||||
nodeAttributes.push_back(node->GetAttributes());
|
||||
|
||||
NodeInfo nodeInfo{};
|
||||
nodeInfo.name = node->Name();
|
||||
nodeInfo.opType = node->OpType();
|
||||
nodeInfo.description = node->Description();
|
||||
nodeInfo.domain = node->Domain();
|
||||
nodeInfo.attributes = node->GetAttributes();
|
||||
nodeInfo.inputDefPointers.reserve(node->InputDefs().size());
|
||||
nodeInfo.outputDefPointers.reserve(node->OutputDefs().size());
|
||||
|
||||
for (const onnxruntime::NodeArg* inputDef : node->InputDefs())
|
||||
{
|
||||
intermediateNodeArgs.emplace_back(std::make_shared<onnxruntime::NodeArg>(inputDef->Name(), inputDef->TypeAsProto()));
|
||||
nodeInfo.inputDefPointers.push_back(intermediateNodeArgs.back().get());
|
||||
}
|
||||
|
||||
for (const onnxruntime::NodeArg* outputDef : node->OutputDefs())
|
||||
{
|
||||
intermediateNodeArgs.emplace_back(std::make_shared<onnxruntime::NodeArg>(outputDef->Name(), outputDef->TypeAsProto()));
|
||||
nodeInfo.outputDefPointers.push_back(intermediateNodeArgs.back().get());
|
||||
}
|
||||
|
||||
nodesInfo.push_back(std::move(nodeInfo));
|
||||
}
|
||||
|
||||
for (const std::string& graphInputName : subGraphInputArgNames)
|
||||
{
|
||||
subgraphInputs.push_back(graph.GetNodeArg(graphInputName));
|
||||
}
|
||||
|
||||
for (const std::string& graphOutputName : subGraphOutputArgNames)
|
||||
{
|
||||
subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName));
|
||||
}
|
||||
|
||||
// We need to keep the initializers alive since they will be freed once the nodes are removed from the graph
|
||||
std::vector<ONNX_NAMESPACE::TensorProto> ownedInitializers;
|
||||
ownedInitializers.reserve(isInitializerTransferable.size());
|
||||
|
||||
for (auto& kvp : isInitializerTransferable)
|
||||
{
|
||||
ONNX_NAMESPACE::TensorProto tensorProto;
|
||||
tensorProto.set_data_type(kvp.second.first->data_type());
|
||||
tensorProto.set_raw_data(kvp.second.first->raw_data());
|
||||
tensorProto.set_name(kvp.second.first->name());
|
||||
|
||||
for (int i = 0; i < kvp.second.first->dims_size(); ++i)
|
||||
{
|
||||
tensorProto.add_dims(kvp.second.first->dims(i));
|
||||
}
|
||||
ownedInitializers.push_back(std::move(tensorProto));
|
||||
kvp.second.first = &ownedInitializers.back();
|
||||
}
|
||||
|
||||
// lamda captures for the kernel registration
|
||||
auto fused_kernel_func = [
|
||||
indexedSubGraph,
|
||||
&modelPath,
|
||||
nodesInfo = std::move(nodesInfo),
|
||||
intermediateNodeArgs = std::move(intermediateNodeArgs),
|
||||
subgraphInputs = std::move(subgraphInputs),
|
||||
subgraphOutputs = std::move(subgraphOutputs),
|
||||
partitionNodePropsMap = std::move(partitionNodePropsMap),
|
||||
ownedInitializers = std::move(ownedInitializers)] (onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr<onnxruntime::OpKernel>& out) mutable ->onnxruntime::Status
|
||||
{
|
||||
std::vector<std::shared_ptr<onnxruntime::Node>> subgraphNodes;
|
||||
subgraphNodes.reserve(nodesInfo.size());
|
||||
|
||||
for (const NodeInfo& nodeInfo : nodesInfo)
|
||||
{
|
||||
subgraphNodes.emplace_back(std::make_shared<onnxruntime::Node>(
|
||||
nodeInfo.name,
|
||||
nodeInfo.opType,
|
||||
nodeInfo.description,
|
||||
nodeInfo.inputDefPointers,
|
||||
nodeInfo.outputDefPointers,
|
||||
&nodeInfo.attributes,
|
||||
nodeInfo.domain));
|
||||
}
|
||||
|
||||
out.reset(CreateRuntimeFusedGraphKernel(
|
||||
info,
|
||||
indexedSubGraph,
|
||||
modelPath,
|
||||
std::move(subgraphNodes),
|
||||
std::move(subgraphInputs),
|
||||
std::move(subgraphOutputs),
|
||||
std::move(intermediateNodeArgs),
|
||||
std::move(partitionNodePropsMap),
|
||||
std::move(ownedInitializers)));
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
// build the kernel definition on the fly, and register it to the fused_kernel_regisitry.
|
||||
onnxruntime::KernelDefBuilder builder;
|
||||
builder.SetName(indexedSubGraph->GetMetaDef()->name)
|
||||
.SetDomain(indexedSubGraph->GetMetaDef()->domain)
|
||||
.SinceVersion(indexedSubGraph->GetMetaDef()->since_version)
|
||||
.Provider(onnxruntime::kDmlExecutionProvider);
|
||||
|
||||
// Force the CPU inputs to be allocated on the CPU
|
||||
for (int i = 0; i < subGraphInputArgNames.size(); ++i)
|
||||
{
|
||||
if (dynamicCpuInputMap.find(subGraphInputArgNames[i]) != dynamicCpuInputMap.end())
|
||||
{
|
||||
builder.InputMemoryType(OrtMemTypeCPUInput, i);
|
||||
}
|
||||
}
|
||||
|
||||
ORT_THROW_IF_ERROR(registryForPartitionKernels->Register(builder, fused_kernel_func));
|
||||
|
||||
auto& fusedNode = graph.BeginFuseSubGraph(*indexedSubGraph, indexedSubGraph->GetMetaDef()->name);
|
||||
fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider);
|
||||
|
||||
graph.FinalizeFuseSubGraph(*indexedSubGraph, fusedNode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -80,5 +80,14 @@ namespace DmlGraphFusionHelper
|
|||
std::vector<uint8_t>&& isInputsUploadedByDmlEP,
|
||||
const GraphDescBuilder::GraphDesc& graphDesc,
|
||||
Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiledExecutionPlanOperator);
|
||||
|
||||
void RegisterDynamicKernel(
|
||||
onnxruntime::Graph& graph,
|
||||
onnxruntime::KernelRegistry* registryForPartitionKernels,
|
||||
const ExecutionProviderImpl* providerImpl,
|
||||
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties> graphNodePropertyMap,
|
||||
const std::unordered_set<std::string>& dynamicCpuInputMap,
|
||||
std::shared_ptr<const onnxruntime::IndexedSubGraph> indexedSubGraph,
|
||||
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>&& isInitializerTransferable);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,18 @@
|
|||
|
||||
namespace Dml
|
||||
{
|
||||
namespace
|
||||
{
|
||||
struct CompiledPartitionInfo
|
||||
{
|
||||
Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiledOperator;
|
||||
onnxruntime::IndexedSubGraph indexedSubGraph;
|
||||
std::vector<uint8_t> isInputsUploadedByDmlEP;
|
||||
GraphDescBuilder::GraphDesc graphDesc;
|
||||
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>> isInitializerTransferable;
|
||||
};
|
||||
}
|
||||
|
||||
DmlGraphFusionTransformer::DmlGraphFusionTransformer(
|
||||
const std::string& name,
|
||||
const onnxruntime::IExecutionProvider* provider
|
||||
|
|
@ -24,15 +36,6 @@ namespace Dml
|
|||
{
|
||||
}
|
||||
|
||||
struct CompiledPartitionInfo
|
||||
{
|
||||
Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiledOperator;
|
||||
onnxruntime::IndexedSubGraph indexedSubGraph;
|
||||
std::vector<uint8_t> isInputsUploadedByDmlEP;
|
||||
GraphDescBuilder::GraphDesc graphDesc;
|
||||
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>> isInitializerTransferable;
|
||||
};
|
||||
|
||||
onnxruntime::common::Status DmlGraphFusionTransformer::ApplyImpl(
|
||||
onnxruntime::Graph& graph,
|
||||
bool& modified,
|
||||
|
|
@ -87,6 +90,7 @@ namespace Dml
|
|||
{
|
||||
// Initializers needed by any graph partition
|
||||
std::unordered_set<std::string> requiredInitializerMap;
|
||||
std::unordered_set<std::string> dynamicCpuInputMap;
|
||||
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties> graphNodePropertyMap;
|
||||
onnxruntime::GraphViewer graphViewer(graph);
|
||||
std::vector<std::unique_ptr<GraphPartition>> partitions = BuildPartitions(
|
||||
|
|
@ -96,8 +100,10 @@ namespace Dml
|
|||
m_providerImpl->GetSupportedDeviceDataTypeMask(),
|
||||
graphNodePropertyMap,
|
||||
requiredInitializerMap,
|
||||
dynamicCpuInputMap,
|
||||
additionalSplittingNodes,
|
||||
implicitInputDefs);
|
||||
implicitInputDefs,
|
||||
false);
|
||||
|
||||
// Reset the splitting nodes for the current iteration
|
||||
additionalSplittingNodes.clear();
|
||||
|
|
@ -190,17 +196,48 @@ namespace Dml
|
|||
std::move(graphNodePropertyMap));
|
||||
|
||||
// Convert partitionONNXGraph into DML EP GraphDesc
|
||||
auto modelPath = graph.ModelPath();
|
||||
|
||||
const gsl::span<const std::string> subGraphInputArgNames = indexedSubGraph.GetMetaDef()->inputs;
|
||||
const gsl::span<const std::string> subGraphOutputArgNames = indexedSubGraph.GetMetaDef()->outputs;
|
||||
|
||||
std::vector<const onnxruntime::Node*> subgraphNodes;
|
||||
subgraphNodes.reserve(indexedSubGraph.nodes.size());
|
||||
|
||||
std::vector<const onnxruntime::NodeArg*> subgraphInputs;
|
||||
subgraphInputs.reserve(subGraphInputArgNames.size());
|
||||
|
||||
std::vector<const onnxruntime::NodeArg*> subgraphOutputs;
|
||||
subgraphOutputs.reserve(subGraphOutputArgNames.size());
|
||||
|
||||
for (size_t sortedNodeIndex : indexedSubGraph.nodes)
|
||||
{
|
||||
subgraphNodes.push_back(graph.GetNode(sortedNodeIndex));
|
||||
}
|
||||
|
||||
for (const std::string& graphInputName : subGraphInputArgNames)
|
||||
{
|
||||
subgraphInputs.push_back(graph.GetNodeArg(graphInputName));
|
||||
}
|
||||
|
||||
for (const std::string& graphOutputName : subGraphOutputArgNames)
|
||||
{
|
||||
subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName));
|
||||
}
|
||||
|
||||
ComPtr<IDMLDevice> device;
|
||||
ORT_THROW_IF_FAILED(m_providerImpl->GetDmlDevice(device.GetAddressOf()));
|
||||
GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc(
|
||||
isInputsUploadedByDmlEP.data(),
|
||||
isInputsUploadedByDmlEP.size(),
|
||||
isInitializerTransferable,
|
||||
graph,
|
||||
indexedSubGraph,
|
||||
partitionNodePropsMap,
|
||||
device.Get(),
|
||||
m_providerImpl);
|
||||
m_providerImpl,
|
||||
modelPath,
|
||||
subgraphNodes,
|
||||
subgraphInputs,
|
||||
subgraphOutputs);
|
||||
|
||||
// Compile the operator
|
||||
auto compiledPartition = DmlGraphFusionHelper::TryCreateCompiledOperator(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,369 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "precomp.h"
|
||||
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h"
|
||||
|
||||
using namespace Windows::AI::MachineLearning::Adapter;
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
class DmlRuntimeFusedGraphKernel : public onnxruntime::OpKernel
|
||||
{
|
||||
public:
|
||||
DmlRuntimeFusedGraphKernel() = delete;
|
||||
|
||||
DmlRuntimeFusedGraphKernel(
|
||||
const onnxruntime::OpKernelInfo& kernelInfo,
|
||||
std::shared_ptr<const onnxruntime::IndexedSubGraph> indexedSubGraph,
|
||||
const onnxruntime::Path& modelPath,
|
||||
std::vector<std::shared_ptr<onnxruntime::Node>>&& subgraphNodes,
|
||||
std::vector<const onnxruntime::NodeArg*>&& subgraphInputs,
|
||||
std::vector<const onnxruntime::NodeArg*>&& subgraphOutputs,
|
||||
std::vector<std::shared_ptr<onnxruntime::NodeArg>>&& intermediateNodeArgs,
|
||||
std::unordered_map<std::string, GraphNodeProperties>&& partitionNodePropsMap,
|
||||
std::vector<ONNX_NAMESPACE::TensorProto>&& ownedInitializers)
|
||||
: OpKernel(kernelInfo),
|
||||
m_indexedSubGraph(std::move(indexedSubGraph)),
|
||||
m_modelPath(modelPath),
|
||||
m_subgraphNodes(std::move(subgraphNodes)),
|
||||
m_subgraphInputs(std::move(subgraphInputs)),
|
||||
m_subgraphOutputs(std::move(subgraphOutputs)),
|
||||
m_intermediateNodeArgs(std::move(intermediateNodeArgs)),
|
||||
m_partitionNodePropsMap(std::move(partitionNodePropsMap)),
|
||||
m_ownedInitializers(std::move(ownedInitializers))
|
||||
{
|
||||
for (const auto& initializer : m_ownedInitializers)
|
||||
{
|
||||
m_isInitializerTransferable[initializer.name()] = std::make_pair(&initializer, false);
|
||||
}
|
||||
|
||||
// Get the execution provider interfaces
|
||||
auto executionHandle = kernelInfo.GetExecutionProvider()->GetExecutionHandle();
|
||||
if (executionHandle)
|
||||
{
|
||||
// We assume the execution object inherits IUnknown as its first base
|
||||
ComPtr<IUnknown> providerExecutionObject = const_cast<IUnknown*>(static_cast<const IUnknown*>(executionHandle));
|
||||
|
||||
// Get the WinML-specific execution provider interface from the execution object.
|
||||
ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_provider));
|
||||
ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_winmlProvider));
|
||||
}
|
||||
|
||||
m_subgraphNodePointers.reserve(m_subgraphNodes.size());
|
||||
|
||||
for (auto& subgraphNode : m_subgraphNodes)
|
||||
{
|
||||
m_subgraphNodePointers.push_back(subgraphNode.get());
|
||||
}
|
||||
}
|
||||
|
||||
void TranslateAndCompileGraph(
|
||||
const onnxruntime::OpKernelInfo& kernelInfo,
|
||||
std::vector<Microsoft::WRL::ComPtr<ID3D12Resource>>& initializeResourceRefs,
|
||||
std::vector<DML_BUFFER_BINDING> initInputBindings) const
|
||||
{
|
||||
// Allocate a persistent resource and initialize the operator
|
||||
UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize;
|
||||
if (persistentResourceSize > 0)
|
||||
{
|
||||
ORT_THROW_IF_FAILED(m_provider->AllocatePooledResource(
|
||||
static_cast<size_t>(persistentResourceSize),
|
||||
AllocatorRoundingMode::Disabled,
|
||||
m_persistentResource.ReleaseAndGetAddressOf(),
|
||||
m_persistentResourceAllocatorUnk.ReleaseAndGetAddressOf()));
|
||||
|
||||
m_persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize };
|
||||
}
|
||||
|
||||
ORT_THROW_IF_FAILED(m_provider->InitializeOperator(
|
||||
m_compiledExecutionPlanOperator.Get(),
|
||||
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
|
||||
gsl::make_span(initInputBindings)));
|
||||
|
||||
std::for_each(
|
||||
initializeResourceRefs.begin(),
|
||||
initializeResourceRefs.end(),
|
||||
[&](ComPtr<ID3D12Resource>& resource){ m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(resource).Get()); }
|
||||
);
|
||||
}
|
||||
|
||||
onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override
|
||||
{
|
||||
ORT_THROW_HR_IF(E_UNEXPECTED, m_subgraphInputs.size() != kernelContext->InputCount());
|
||||
|
||||
bool recompileNeeded = m_compiledExecutionPlanOperator == nullptr;
|
||||
|
||||
for (int inputIndex = 0; inputIndex < kernelContext->InputCount(); ++inputIndex)
|
||||
{
|
||||
const auto& input = kernelContext->RequiredInput<onnxruntime::Tensor>(inputIndex);
|
||||
const std::string& inputName = m_subgraphInputs[inputIndex]->Name();
|
||||
auto shapeIter = m_inferredInputShapes.find(inputName);
|
||||
|
||||
if (shapeIter == m_inferredInputShapes.end())
|
||||
{
|
||||
m_inferredInputShapes[inputName] = input.Shape();
|
||||
recompileNeeded = true;
|
||||
}
|
||||
else if (shapeIter->second != input.Shape())
|
||||
{
|
||||
shapeIter->second = input.Shape();
|
||||
recompileNeeded = true;
|
||||
}
|
||||
|
||||
// If we have CPU inputs that are not initializers (i.e. they were computed at runtime), add them to the initializer list
|
||||
if (input.Location().device.Type() == OrtDevice::CPU)
|
||||
{
|
||||
auto inputProto = onnxruntime::utils::TensorToTensorProto(input, inputName);
|
||||
|
||||
// We can only avoid recompiling the graph when all CPU inputs are identical
|
||||
auto initializerIter = m_isInitializerTransferable.find(inputName);
|
||||
|
||||
if (initializerIter != m_isInitializerTransferable.end())
|
||||
{
|
||||
if (initializerIter->second.first->raw_data().length() == inputProto.raw_data().length())
|
||||
{
|
||||
for (int i = 0; i < inputProto.raw_data().length(); ++i)
|
||||
{
|
||||
if (initializerIter->second.first->raw_data()[i] != inputProto.raw_data()[i])
|
||||
{
|
||||
recompileNeeded = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
recompileNeeded = true;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
recompileNeeded = true;
|
||||
}
|
||||
|
||||
m_ownedCpuInputs.push_back(std::make_unique<ONNX_NAMESPACE::TensorProto>(std::move(inputProto)));
|
||||
m_isInitializerTransferable[inputName] = std::make_pair(m_ownedCpuInputs.back().get(), false);
|
||||
}
|
||||
}
|
||||
|
||||
if (recompileNeeded)
|
||||
{
|
||||
// Go through all the node args and replace their shapes with the real ones
|
||||
for (auto& nodeArg : m_intermediateNodeArgs)
|
||||
{
|
||||
auto iter = m_inferredInputShapes.find(nodeArg->Name());
|
||||
if (iter != m_inferredInputShapes.end())
|
||||
{
|
||||
auto tensorShape = *nodeArg->Shape();
|
||||
ORT_THROW_HR_IF(E_UNEXPECTED, tensorShape.dim_size() != iter->second.NumDimensions());
|
||||
|
||||
for (int i = 0; i < tensorShape.dim_size(); ++i)
|
||||
{
|
||||
tensorShape.mutable_dim(i)->set_dim_value(iter->second.GetDims()[i]);
|
||||
}
|
||||
|
||||
nodeArg->SetShape(tensorShape);
|
||||
}
|
||||
}
|
||||
|
||||
// Populate input bindings for operator initialization
|
||||
const uint32_t fusedNodeInputCount = gsl::narrow_cast<uint32_t>(m_indexedSubGraph->GetMetaDef()->inputs.size());
|
||||
std::vector<Microsoft::WRL::ComPtr<ID3D12Resource>> initializeResourceRefs; // For lifetime control
|
||||
std::vector<DML_BUFFER_BINDING> initInputBindings(fusedNodeInputCount);
|
||||
std::vector<uint8_t> isInputsUploadedByDmlEP(fusedNodeInputCount);
|
||||
auto providerImpl = static_cast<const ExecutionProvider*>(Info().GetExecutionProvider())->GetImpl();
|
||||
|
||||
// Convert partitionONNXGraph into DML EP GraphDesc
|
||||
ComPtr<IDMLDevice> device;
|
||||
ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf()));
|
||||
GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc(
|
||||
isInputsUploadedByDmlEP.data(),
|
||||
isInputsUploadedByDmlEP.size(),
|
||||
m_isInitializerTransferable,
|
||||
m_partitionNodePropsMap,
|
||||
device.Get(),
|
||||
providerImpl,
|
||||
m_modelPath,
|
||||
m_subgraphNodePointers,
|
||||
m_subgraphInputs,
|
||||
m_subgraphOutputs);
|
||||
|
||||
m_outputShapes = graphDesc.outputShapes;
|
||||
|
||||
// Walk through each graph edge and mark used inputs
|
||||
m_inputsUsed.resize(fusedNodeInputCount, false);
|
||||
for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges)
|
||||
{
|
||||
m_inputsUsed[edge.GraphInputIndex] = true;
|
||||
}
|
||||
|
||||
// Compile the operator
|
||||
m_compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator(
|
||||
graphDesc,
|
||||
*m_indexedSubGraph,
|
||||
providerImpl);
|
||||
|
||||
// Queue references to objects which must be kept alive until resulting GPU work completes
|
||||
m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get());
|
||||
|
||||
TranslateAndCompileGraph(
|
||||
Info(),
|
||||
initializeResourceRefs,
|
||||
initInputBindings);
|
||||
}
|
||||
|
||||
// Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator
|
||||
OpKernelContextWrapper contextWrapper(
|
||||
kernelContext,
|
||||
Info().GetExecutionProvider(),
|
||||
true,
|
||||
nullptr);
|
||||
|
||||
ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier());
|
||||
|
||||
// Get input resources for execution, excluding those which were specified as owned by DML and provided
|
||||
// at initialization instead.
|
||||
std::vector<ComPtr<IMLOperatorTensor>> inputTensors(kernelContext->InputCount());
|
||||
std::vector<ID3D12Resource*> inputPtrs(kernelContext->InputCount());
|
||||
|
||||
for (int i = 0; i < kernelContext->InputCount(); ++i)
|
||||
{
|
||||
if (!m_inputsUsed[i])
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
ORT_THROW_IF_FAILED(contextWrapper.GetInputTensor(i, inputTensors[i].GetAddressOf()));
|
||||
inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get());
|
||||
}
|
||||
|
||||
auto outputTensors = contextWrapper.GetOutputTensors(m_outputShapes);
|
||||
ExecuteOperator(
|
||||
m_compiledExecutionPlanOperator.Get(),
|
||||
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
|
||||
inputPtrs,
|
||||
outputTensors);
|
||||
|
||||
ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier());
|
||||
|
||||
return onnxruntime::Status::OK();
|
||||
}
|
||||
|
||||
void ExecuteOperator(
|
||||
IDMLCompiledOperator* op,
|
||||
_In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding,
|
||||
gsl::span<ID3D12Resource*> inputTensors,
|
||||
gsl::span<IMLOperatorTensor*> outputTensors) const
|
||||
{
|
||||
auto FillBindingsFromTensors = [this](auto& bufferBindings, auto& bindingDescs, gsl::span<IMLOperatorTensor*>& tensors)
|
||||
{
|
||||
for (IMLOperatorTensor* tensor : tensors)
|
||||
{
|
||||
if (tensor)
|
||||
{
|
||||
assert(tensor->IsDataInterface());
|
||||
ID3D12Resource* resource = m_provider->DecodeResource(MLOperatorTensor(tensor).GetDataInterface().Get());
|
||||
D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc();
|
||||
bufferBindings.push_back({ resource, 0, resourceDesc.Width });
|
||||
bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() });
|
||||
}
|
||||
else
|
||||
{
|
||||
bufferBindings.push_back({ nullptr, 0, 0 });
|
||||
bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr });
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto FillBindingsFromBuffers = [](auto& bufferBindings, auto& bindingDescs, gsl::span<ID3D12Resource*>& resources)
|
||||
{
|
||||
for (ID3D12Resource* resource : resources)
|
||||
{
|
||||
if (resource)
|
||||
{
|
||||
D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc();
|
||||
bufferBindings.push_back({ resource, 0, resourceDesc.Width });
|
||||
bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() });
|
||||
}
|
||||
else
|
||||
{
|
||||
bufferBindings.push_back({ nullptr, 0, 0 });
|
||||
bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr });
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<DML_BUFFER_BINDING> inputBufferBindings;
|
||||
inputBufferBindings.reserve(inputTensors.size());
|
||||
std::vector<DML_BINDING_DESC> inputBindings;
|
||||
inputBindings.reserve(inputTensors.size());
|
||||
FillBindingsFromBuffers(inputBufferBindings, inputBindings, inputTensors);
|
||||
|
||||
std::vector<DML_BUFFER_BINDING> outputBufferBindings;
|
||||
outputBufferBindings.reserve(outputTensors.size());
|
||||
std::vector<DML_BINDING_DESC> outputBindings;
|
||||
outputBindings.reserve(outputTensors.size());
|
||||
FillBindingsFromTensors(outputBufferBindings, outputBindings, outputTensors);
|
||||
|
||||
ORT_THROW_IF_FAILED(m_provider->ExecuteOperator(
|
||||
op,
|
||||
persistentResourceBinding,
|
||||
inputBindings,
|
||||
outputBindings));
|
||||
}
|
||||
|
||||
private:
|
||||
ComPtr<IWinmlExecutionProvider> m_winmlProvider;
|
||||
ComPtr<Dml::IExecutionProvider> m_provider;
|
||||
|
||||
mutable std::optional<DML_BUFFER_BINDING> m_persistentResourceBinding;
|
||||
std::shared_ptr<const onnxruntime::IndexedSubGraph> m_indexedSubGraph;
|
||||
const onnxruntime::Path& m_modelPath;
|
||||
|
||||
std::vector<std::shared_ptr<onnxruntime::Node>> m_subgraphNodes;
|
||||
std::vector<const onnxruntime::NodeArg*> m_subgraphInputs;
|
||||
std::vector<const onnxruntime::NodeArg*> m_subgraphOutputs;
|
||||
mutable std::vector<std::shared_ptr<onnxruntime::NodeArg>> m_intermediateNodeArgs;
|
||||
std::unordered_map<std::string, GraphNodeProperties> m_partitionNodePropsMap;
|
||||
std::vector<ONNX_NAMESPACE::TensorProto> m_ownedInitializers;
|
||||
mutable std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>> m_isInitializerTransferable;
|
||||
std::vector<const onnxruntime::Node*> m_subgraphNodePointers;
|
||||
|
||||
// Bindings from previous executions of a re-used command list
|
||||
mutable std::vector<std::unique_ptr<ONNX_NAMESPACE::TensorProto>> m_ownedCpuInputs;
|
||||
mutable ComPtr<IDMLCompiledOperator> m_compiledExecutionPlanOperator;
|
||||
mutable std::vector<bool> m_inputsUsed;
|
||||
mutable ComPtr<ID3D12Resource> m_persistentResource;
|
||||
mutable ComPtr<IUnknown> m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator
|
||||
mutable Windows::AI::MachineLearning::Adapter::EdgeShapes m_outputShapes;
|
||||
mutable std::unordered_map<std::string, onnxruntime::TensorShape> m_inferredInputShapes;
|
||||
};
|
||||
|
||||
onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel(
|
||||
const onnxruntime::OpKernelInfo& info,
|
||||
std::shared_ptr<const onnxruntime::IndexedSubGraph> indexedSubGraph,
|
||||
const onnxruntime::Path& modelPath,
|
||||
std::vector<std::shared_ptr<onnxruntime::Node>>&& subgraphNodes,
|
||||
std::vector<const onnxruntime::NodeArg*>&& subgraphInputs,
|
||||
std::vector<const onnxruntime::NodeArg*>&& subgraphOutputs,
|
||||
std::vector<std::shared_ptr<onnxruntime::NodeArg>>&& intermediateNodeArgs,
|
||||
std::unordered_map<std::string, GraphNodeProperties>&& partitionNodePropsMap,
|
||||
std::vector<ONNX_NAMESPACE::TensorProto>&& ownedInitializers)
|
||||
{
|
||||
return new DmlRuntimeFusedGraphKernel(
|
||||
info,
|
||||
std::move(indexedSubGraph),
|
||||
modelPath,
|
||||
std::move(subgraphNodes),
|
||||
std::move(subgraphInputs),
|
||||
std::move(subgraphOutputs),
|
||||
std::move(intermediateNodeArgs),
|
||||
std::move(partitionNodePropsMap),
|
||||
std::move(ownedInitializers)
|
||||
);
|
||||
}
|
||||
} // namespace Dml
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "GraphDescBuilder.h"
|
||||
#include "DmlRuntimeGraphFusionTransformer.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel(
|
||||
const onnxruntime::OpKernelInfo& info,
|
||||
std::shared_ptr<const onnxruntime::IndexedSubGraph> indexedSubGraph,
|
||||
const onnxruntime::Path& modelPath,
|
||||
std::vector<std::shared_ptr<onnxruntime::Node>>&& subgraphNodes,
|
||||
std::vector<const onnxruntime::NodeArg*>&& subgraphInputs,
|
||||
std::vector<const onnxruntime::NodeArg*>&& subgraphOutputs,
|
||||
std::vector<std::shared_ptr<onnxruntime::NodeArg>>&& intermediateNodeArgs,
|
||||
std::unordered_map<std::string, GraphNodeProperties>&& partitionNodePropsMap,
|
||||
std::vector<ONNX_NAMESPACE::TensorProto>&& ownedInitializers
|
||||
);
|
||||
} // namespace Dml
|
||||
|
|
@ -0,0 +1,161 @@
|
|||
#pragma once
|
||||
|
||||
#include "precomp.h"
|
||||
#include "GraphDescBuilder.h"
|
||||
#include "ExecutionProvider.h"
|
||||
#include "DmlRuntimeGraphFusionTransformer.h"
|
||||
#include "GraphPartitioner.h"
|
||||
#include "core/framework/kernel_type_str_resolver.h"
|
||||
#include "core/framework/kernel_lookup.h"
|
||||
#include "core/optimizer/constant_sharing.h"
|
||||
#include "DmlRuntimeFusedGraphKernel.h"
|
||||
#include "MLOperatorAuthorImpl.h"
|
||||
#include "DmlGraphFusionHelper.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
namespace
|
||||
{
|
||||
struct CompiledPartitionInfo
|
||||
{
|
||||
std::shared_ptr<onnxruntime::IndexedSubGraph> indexedSubGraph;
|
||||
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>> isInitializerTransferable;
|
||||
};
|
||||
}
|
||||
|
||||
DmlRuntimeGraphFusionTransformer::DmlRuntimeGraphFusionTransformer(
|
||||
const std::string& name,
|
||||
const onnxruntime::IExecutionProvider* provider
|
||||
)
|
||||
:onnxruntime::GraphTransformer(name),
|
||||
m_providerImpl(static_cast<const ExecutionProvider*>(provider)->GetImpl())
|
||||
{
|
||||
}
|
||||
|
||||
onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImpl(
|
||||
onnxruntime::Graph& graph,
|
||||
bool& modified,
|
||||
int graphLevel,
|
||||
const onnxruntime::logging::Logger& logger) const
|
||||
{
|
||||
return ApplyImplHelper(graph, modified, graphLevel, logger, {});
|
||||
}
|
||||
|
||||
onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImplHelper(
|
||||
onnxruntime::Graph& graph,
|
||||
bool& modified,
|
||||
int graphLevel,
|
||||
const onnxruntime::logging::Logger& logger,
|
||||
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputDefs) const
|
||||
{
|
||||
onnxruntime::ProviderType providerType = onnxruntime::kDmlExecutionProvider;
|
||||
const gsl::not_null<const onnxruntime::KernelRegistry*> registry = m_providerImpl->GetKernelRegistry().get();
|
||||
const auto kernelTypeStrResolver = onnxruntime::OpSchemaKernelTypeStrResolver{};
|
||||
const auto kernelLookup = onnxruntime::KernelLookup(
|
||||
providerType,
|
||||
gsl::make_span(®istry, 1),
|
||||
kernelTypeStrResolver);
|
||||
|
||||
onnxruntime::GraphViewer graphViewer(graph);
|
||||
const auto& nodeTopologyList = graphViewer.GetNodesInTopologicalOrder();
|
||||
|
||||
for (auto nodeIndex : nodeTopologyList)
|
||||
{
|
||||
auto* node = graph.GetNode(nodeIndex);
|
||||
if (!node)
|
||||
{
|
||||
continue; // node was removed
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, const onnxruntime::NodeArg*> subgraphImplicitInputDefs;
|
||||
for (const onnxruntime::NodeArg* inputDef : node->ImplicitInputDefs())
|
||||
{
|
||||
subgraphImplicitInputDefs[inputDef->Name()] = inputDef;
|
||||
}
|
||||
|
||||
for (auto& entry : node->GetAttributeNameToMutableSubgraphMap())
|
||||
{
|
||||
auto& subgraph = *entry.second;
|
||||
ORT_RETURN_IF_ERROR(ApplyImplHelper(subgraph, modified, graphLevel + 1, logger, subgraphImplicitInputDefs));
|
||||
}
|
||||
}
|
||||
|
||||
// Initializers needed by any graph partition
|
||||
std::vector<onnxruntime::NodeIndex> additionalSplittingNodes;
|
||||
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties> graphNodePropertyMap;
|
||||
std::unordered_set<std::string> requiredInitializerMap;
|
||||
std::unordered_set<std::string> dynamicCpuInputMap;
|
||||
std::vector<std::unique_ptr<GraphPartition>> partitions = BuildPartitions(
|
||||
graphViewer,
|
||||
*m_providerImpl->GetInternalRegistrationInfoMap(),
|
||||
kernelLookup,
|
||||
m_providerImpl->GetSupportedDeviceDataTypeMask(),
|
||||
graphNodePropertyMap,
|
||||
requiredInitializerMap,
|
||||
dynamicCpuInputMap,
|
||||
additionalSplittingNodes,
|
||||
implicitInputDefs,
|
||||
true);
|
||||
|
||||
// Reset the splitting nodes for the current iteration
|
||||
additionalSplittingNodes.clear();
|
||||
|
||||
// Reset the compiled operators for the current iteration
|
||||
std::vector<std::shared_ptr<CompiledPartitionInfo>> compiledPartitionInfos(partitions.size());
|
||||
|
||||
// Create a map between each initialized tensor and the partition(s) it is part of.
|
||||
auto initializerPartitionMap = DmlGraphFusionHelper::GetInitializerToPartitionMap(graphViewer, partitions);
|
||||
|
||||
for (uint32_t partitionIndex = 0; partitionIndex < partitions.size(); ++partitionIndex)
|
||||
{
|
||||
auto& partition = partitions[partitionIndex];
|
||||
|
||||
if (partition->GetRootMergedPartition() != partition.get() ||
|
||||
!partition->IsDmlPartition())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (partition->IsDmlGraphPartition())
|
||||
{
|
||||
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>> isInitializerTransferable;
|
||||
|
||||
std::string partitionKernelPrefix = std::to_string(m_providerImpl->GetPartitionKernelPrefixVal()) + "_";
|
||||
m_providerImpl->IncreasePartitionKernelPrefixVal();
|
||||
|
||||
// populate isInitializerTransferable
|
||||
for (const auto& input : partition->GetInputs())
|
||||
{
|
||||
const onnx::TensorProto* tensor = nullptr;
|
||||
if (graph.GetInitializedTensor(input, tensor) && requiredInitializerMap.find(input) != requiredInitializerMap.end())
|
||||
{
|
||||
isInitializerTransferable[input] = {tensor, false};
|
||||
}
|
||||
}
|
||||
|
||||
compiledPartitionInfos[partitionIndex] = std::make_shared<CompiledPartitionInfo>();
|
||||
compiledPartitionInfos[partitionIndex]->indexedSubGraph = std::make_shared<onnxruntime::IndexedSubGraph>(
|
||||
DmlGraphFusionHelper::CreateIndexedSubGraph(partition.get(), partitionIndex, partitionKernelPrefix));
|
||||
compiledPartitionInfos[partitionIndex]->isInitializerTransferable = std::move(isInitializerTransferable);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto&& compiledPartitionInfo : compiledPartitionInfos)
|
||||
{
|
||||
// Null compiled operators were not DML partitions
|
||||
if (compiledPartitionInfo)
|
||||
{
|
||||
DmlGraphFusionHelper::RegisterDynamicKernel(
|
||||
graph,
|
||||
m_providerImpl->GetKernelRegistry().get(),
|
||||
m_providerImpl,
|
||||
graphNodePropertyMap,
|
||||
dynamicCpuInputMap,
|
||||
std::move(compiledPartitionInfo->indexedSubGraph),
|
||||
std::move(compiledPartitionInfo->isInitializerTransferable));
|
||||
}
|
||||
}
|
||||
|
||||
return onnxruntime::common::Status::OK();
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/framework/execution_providers.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
class ExecutionProviderImpl;
|
||||
|
||||
class DmlRuntimeGraphFusionTransformer : public onnxruntime::GraphTransformer
|
||||
{
|
||||
public:
|
||||
DmlRuntimeGraphFusionTransformer(
|
||||
const std::string& name,
|
||||
const onnxruntime::IExecutionProvider* provider
|
||||
);
|
||||
|
||||
public:
|
||||
static inline const char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlRuntimeFusedNode_";
|
||||
static inline const char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlRuntimeFusedNodeDomain";
|
||||
|
||||
private:
|
||||
onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph,
|
||||
bool& modified,
|
||||
int graphLevel,
|
||||
const onnxruntime::logging::Logger& logger) const final;
|
||||
|
||||
onnxruntime::common::Status ApplyImplHelper(
|
||||
onnxruntime::Graph& graph,
|
||||
bool& modified,
|
||||
int graphLevel,
|
||||
const onnxruntime::logging::Logger& logger,
|
||||
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputDefs) const;
|
||||
|
||||
private:
|
||||
const ExecutionProviderImpl* m_providerImpl = nullptr;
|
||||
};
|
||||
}
|
||||
|
|
@ -67,7 +67,8 @@ namespace Dml
|
|||
ExecutionProvider::ExecutionProvider(
|
||||
IDMLDevice* dmlDevice,
|
||||
ID3D12CommandQueue* commandQueue,
|
||||
bool enableMetacommands) :
|
||||
bool enableMetacommands,
|
||||
bool enableDynamicGraphFusion) :
|
||||
IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0))
|
||||
{
|
||||
D3D12_COMMAND_LIST_TYPE queueType = commandQueue->GetDesc().Type;
|
||||
|
|
@ -80,7 +81,7 @@ namespace Dml
|
|||
ComPtr<ID3D12Device> device;
|
||||
GRAPHICS_THROW_IF_FAILED(commandQueue->GetDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf())));
|
||||
|
||||
m_impl = wil::MakeOrThrow<ExecutionProviderImpl>(dmlDevice, device.Get(), commandQueue, enableMetacommands);
|
||||
m_impl = wil::MakeOrThrow<ExecutionProviderImpl>(dmlDevice, device.Get(), commandQueue, enableMetacommands, enableDynamicGraphFusion);
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
|
||||
|
|
@ -147,12 +148,12 @@ namespace Dml
|
|||
// Task 24384515: Update ORT AIInfra release agent pool to install 19H1 SDK on VM bootstrap
|
||||
#define D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE ((D3D_FEATURE_LEVEL)0x1000)
|
||||
|
||||
ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands)
|
||||
ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands, bool enableDynamicGraphFusion)
|
||||
: m_d3d12Device(d3d12Device),
|
||||
m_dmlDevice(dmlDevice),
|
||||
m_areMetacommandsEnabled(enableMetacommands)
|
||||
m_areMetacommandsEnabled(enableMetacommands),
|
||||
m_dynamicGraphFusionEnabled(enableDynamicGraphFusion)
|
||||
{
|
||||
|
||||
D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {};
|
||||
|
||||
D3D_FEATURE_LEVEL featureLevelsList[] = {
|
||||
|
|
@ -1093,6 +1094,11 @@ namespace Dml
|
|||
return m_areMetacommandsEnabled;
|
||||
}
|
||||
|
||||
bool ExecutionProviderImpl::DynamicGraphFusionEnabled() const noexcept
|
||||
{
|
||||
return m_dynamicGraphFusionEnabled;
|
||||
}
|
||||
|
||||
std::shared_ptr<const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfoMap>
|
||||
ExecutionProviderImpl::GetInternalRegistrationInfoMap() const
|
||||
{
|
||||
|
|
@ -1129,9 +1135,10 @@ namespace Dml
|
|||
std::unique_ptr<onnxruntime::IExecutionProvider> CreateExecutionProvider(
|
||||
IDMLDevice* dmlDevice,
|
||||
ID3D12CommandQueue* commandQueue,
|
||||
bool enableMetacommands)
|
||||
bool enableMetacommands,
|
||||
bool enableDynamicGraphFusion)
|
||||
{
|
||||
return std::make_unique<Dml::ExecutionProvider>(dmlDevice, commandQueue, enableMetacommands);
|
||||
return std::make_unique<Dml::ExecutionProvider>(dmlDevice, commandQueue, enableMetacommands, enableDynamicGraphFusion);
|
||||
}
|
||||
|
||||
ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
#include "GraphTransformer.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h"
|
||||
|
||||
#include <wrl/client.h>
|
||||
#include <wrl/implements.h>
|
||||
|
|
@ -34,7 +35,8 @@ namespace Dml
|
|||
IDMLDevice* dmlDevice,
|
||||
ID3D12Device* d3d12Device,
|
||||
ID3D12CommandQueue* queue,
|
||||
bool enableMetacommands = true);
|
||||
bool enableMetacommands,
|
||||
bool enableDynamicGraphFusion);
|
||||
|
||||
void ReleaseCompletedReferences();
|
||||
|
||||
|
|
@ -150,6 +152,7 @@ namespace Dml
|
|||
STDMETHOD_(bool, IsMcdmDevice)() const noexcept final;
|
||||
|
||||
STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final;
|
||||
bool DynamicGraphFusionEnabled() const noexcept;
|
||||
std::shared_ptr<onnxruntime::IAllocator> GetGpuAllocator();
|
||||
std::shared_ptr<onnxruntime::IAllocator> GetCpuInputAllocator();
|
||||
|
||||
|
|
@ -184,6 +187,7 @@ namespace Dml
|
|||
ComPtr<IDMLDevice> m_dmlDevice;
|
||||
bool m_isMcdmDevice = false;
|
||||
bool m_areMetacommandsEnabled = true;
|
||||
bool m_dynamicGraphFusionEnabled = false;
|
||||
bool m_native16BitShaderOpsSupported = false;
|
||||
std::shared_ptr<ExecutionContext> m_context;
|
||||
std::unique_ptr<PooledUploadHeap> m_uploadHeap;
|
||||
|
|
@ -236,7 +240,8 @@ namespace Dml
|
|||
explicit ExecutionProvider(
|
||||
IDMLDevice* dmlDevice,
|
||||
ID3D12CommandQueue* commandQueue,
|
||||
bool enableMetacommands = true
|
||||
bool enableMetacommands,
|
||||
bool enableDynamicGraphFusion
|
||||
);
|
||||
|
||||
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const final override
|
||||
|
|
@ -299,9 +304,9 @@ namespace Dml
|
|||
return m_impl.Get();
|
||||
}
|
||||
|
||||
void MetacommandsEnabled()
|
||||
bool DynamicGraphFusionEnabled() const
|
||||
{
|
||||
m_impl->MetacommandsEnabled();
|
||||
return m_impl->DynamicGraphFusionEnabled();
|
||||
}
|
||||
|
||||
virtual std::vector<onnxruntime::AllocatorPtr> CreatePreferredAllocators() override
|
||||
|
|
|
|||
|
|
@ -147,14 +147,14 @@ namespace Dml::GraphDescBuilder
|
|||
const uint8_t* isConstGpuGraphInput,
|
||||
const size_t isConstGpuGraphInputCount,
|
||||
const std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>& isInitializerTransferable,
|
||||
const onnxruntime::Graph& graph,
|
||||
const onnxruntime::IndexedSubGraph& indexedSubGraph,
|
||||
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
|
||||
IDMLDevice* device,
|
||||
const void* executionHandle)
|
||||
const void* executionHandle,
|
||||
const onnxruntime::Path& modelPath,
|
||||
gsl::span<const onnxruntime::Node* const> subgraphNodes,
|
||||
gsl::span<const onnxruntime::NodeArg* const> subgraphInputs,
|
||||
gsl::span<const onnxruntime::NodeArg* const> subgraphOutputs)
|
||||
{
|
||||
const gsl::span<const std::string> subGraphInputArgNames = indexedSubGraph.GetMetaDef()->inputs;
|
||||
const gsl::span<const std::string> subGraphOutputArgNames = indexedSubGraph.GetMetaDef()->outputs;
|
||||
struct NodeAndIndex
|
||||
{
|
||||
uint32_t nodeIndex; // The index of the node itself
|
||||
|
|
@ -164,12 +164,14 @@ namespace Dml::GraphDescBuilder
|
|||
// Map from Lotus node argument names to the new node and index where it will be produced
|
||||
std::unordered_map<std::string, NodeAndIndex> nameToNodeAndIndexMap;
|
||||
|
||||
std::unordered_map<std::string, EdgeShapes> nodeOutputShapes;
|
||||
|
||||
// Map from Lotus node argument names to input indices of the fused kernel node.
|
||||
std::unordered_map<std::string, uint32_t> nameToDmlFusedNodeInputIndex;
|
||||
|
||||
for (size_t inputIndex = 0; inputIndex < subGraphInputArgNames.size(); ++inputIndex)
|
||||
for (size_t inputIndex = 0; inputIndex < subgraphInputs.size(); ++inputIndex)
|
||||
{
|
||||
const onnxruntime::NodeArg* graphInput = graph.GetNodeArg(subGraphInputArgNames[inputIndex]);
|
||||
const onnxruntime::NodeArg* graphInput = subgraphInputs[inputIndex];
|
||||
|
||||
if (!graphInput)
|
||||
{
|
||||
|
|
@ -196,13 +198,11 @@ namespace Dml::GraphDescBuilder
|
|||
const uint32_t minNodeCountToReuseCommandList = 5;
|
||||
bool reuseCommandList = false;
|
||||
|
||||
if (indexedSubGraph.nodes.size() >= minNodeCountToReuseCommandList)
|
||||
if (subgraphNodes.size() >= minNodeCountToReuseCommandList)
|
||||
{
|
||||
reuseCommandList = true;
|
||||
}
|
||||
|
||||
auto modelPath = graph.ModelPath();
|
||||
|
||||
auto constantCpuGraphInputGetter = [&isInitializerTransferable, &modelPath](const std::string& argName)
|
||||
{
|
||||
ComPtr<OnnxTensorWrapper> tensorWrapper;
|
||||
|
|
@ -219,9 +219,11 @@ namespace Dml::GraphDescBuilder
|
|||
|
||||
// Iterate through each node and create a corresponding node in the new graph
|
||||
// We can iterate the nodes in any order because the edge connectivity will take care of the topological order
|
||||
for (size_t sortedNodeIndex : indexedSubGraph.nodes)
|
||||
std::unordered_map<std::string, std::vector<uint32_t>> inferredOutputShapes;
|
||||
|
||||
for (const onnxruntime::Node* subgraphNode : subgraphNodes)
|
||||
{
|
||||
const onnxruntime::Node& node = *graph.GetNode(sortedNodeIndex);
|
||||
const onnxruntime::Node& node = *subgraphNode;
|
||||
|
||||
const GraphNodeProperties& graphNodeProps = graphNodePropertyMap.find(GetUniqueNodeName(node))->second;
|
||||
const auto& requiredConstantCpuInputs = graphNodeProps.internalRegInfo->requiredConstantCpuInputs;
|
||||
|
|
@ -244,14 +246,45 @@ namespace Dml::GraphDescBuilder
|
|||
return tensor;
|
||||
};
|
||||
|
||||
EdgeShapes inputShapesOverrides(node.InputDefs().size());
|
||||
|
||||
// Override the input shapes with shapes that were previously inferred
|
||||
for (int inputIndex = 0; inputIndex < node.InputDefs().size(); ++inputIndex)
|
||||
{
|
||||
auto inputDef = node.InputDefs()[inputIndex];
|
||||
|
||||
auto outputShapesIter = inferredOutputShapes.find(inputDef->Name());
|
||||
if (outputShapesIter != inferredOutputShapes.end())
|
||||
{
|
||||
inputShapesOverrides.GetMutableShape(inputIndex) = outputShapesIter->second;
|
||||
}
|
||||
else if (inputDef->HasTensorOrScalarShape())
|
||||
{
|
||||
for (int i = 0; i < inputDef->Shape()->dim_size(); ++i)
|
||||
{
|
||||
ORT_THROW_HR_IF(E_INVALIDARG, !inputDef->Shape()->dim(i).has_dim_value());
|
||||
inputShapesOverrides.GetMutableShape(inputIndex).push_back(gsl::narrow_cast<uint32_t>(inputDef->Shape()->dim(i).dim_value()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EdgeShapes outputShapes;
|
||||
DmlGraphNodeCreateInfo graphNodeCreateInfo;
|
||||
graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory(
|
||||
node,
|
||||
constantCpuNodeInputGetter,
|
||||
executionHandle,
|
||||
&inputShapesOverrides,
|
||||
/*out*/ &outputShapes,
|
||||
/*out*/ &graphNodeCreateInfo
|
||||
);
|
||||
|
||||
ORT_THROW_HR_IF(E_UNEXPECTED, outputShapes.EdgeCount() != node.OutputDefs().size());
|
||||
for (int i = 0; i < node.OutputDefs().size(); ++i)
|
||||
{
|
||||
inferredOutputShapes[node.OutputDefs()[i]->Name()] = outputShapes.GetShape(i);
|
||||
}
|
||||
|
||||
// Create a map between operatorGraphNodeIndex to mainGraphNodeIndex.
|
||||
std::unordered_map<uint32_t, uint32_t> operatorGraphNodeIndexToMainGraphNodeIndexMap;
|
||||
uint32_t graphNodeCount = gsl::narrow_cast<uint32_t>(graphNodes.size());
|
||||
|
|
@ -347,6 +380,8 @@ namespace Dml::GraphDescBuilder
|
|||
operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphOutputEdge.FromNodeIndex],
|
||||
operatorGraphOutputEdge.FromNodeOutputIndex
|
||||
};
|
||||
|
||||
nodeOutputShapes[arg->Name()] = outputShapes;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -367,10 +402,12 @@ namespace Dml::GraphDescBuilder
|
|||
}
|
||||
}
|
||||
|
||||
EdgeShapes graphOutputShapes(subgraphOutputs.size());
|
||||
|
||||
// Add graph output nodes, which might be in a different order from the encapsulating node
|
||||
for (size_t outputIndex = 0; outputIndex < subGraphOutputArgNames.size(); ++outputIndex)
|
||||
for (size_t outputIndex = 0; outputIndex < subgraphOutputs.size(); ++outputIndex)
|
||||
{
|
||||
const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(subGraphOutputArgNames[outputIndex]);
|
||||
const onnxruntime::NodeArg* graphOutput = subgraphOutputs[outputIndex];
|
||||
|
||||
ORT_THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg");
|
||||
const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name());
|
||||
|
|
@ -380,6 +417,7 @@ namespace Dml::GraphDescBuilder
|
|||
edge.FromNodeOutputIndex = outputNodeAndIndex.targetIndex;
|
||||
edge.GraphOutputIndex = gsl::narrow_cast<uint32_t>(outputIndex);
|
||||
graphOutputEdges.push_back(edge);
|
||||
graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex);
|
||||
}
|
||||
|
||||
RemoveUnconnectedNodes(graphNodes, graphInputEdges, graphIntermediateEdges, graphOutputEdges);
|
||||
|
|
@ -390,6 +428,7 @@ namespace Dml::GraphDescBuilder
|
|||
graphDesc.outputEdges = std::move(graphOutputEdges);
|
||||
graphDesc.intermediateEdges = std::move(graphIntermediateEdges);
|
||||
graphDesc.reuseCommandList = reuseCommandList;
|
||||
graphDesc.outputShapes = std::move(graphOutputShapes);
|
||||
return graphDesc;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ namespace Dml
|
|||
{
|
||||
struct GraphNodeProperties
|
||||
{
|
||||
std::shared_ptr<const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfo>
|
||||
std::shared_ptr<const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfo>
|
||||
internalRegInfo;
|
||||
|
||||
// These are currently passed from the partitioning step since the only DML operators current
|
||||
// These are currently passed from the partitioning step since the only DML operators current
|
||||
// supporting graph nodes don't customize the order of edges or shapes, other than coercing
|
||||
// dimension count. This will change as the supported set of operators as graph nodes increases.
|
||||
Windows::AI::MachineLearning::Adapter::EdgeShapes inputShapes;
|
||||
|
|
@ -38,16 +38,19 @@ namespace Dml
|
|||
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
|
||||
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
|
||||
bool reuseCommandList;
|
||||
Windows::AI::MachineLearning::Adapter::EdgeShapes outputShapes;
|
||||
};
|
||||
|
||||
GraphDesc BuildGraphDesc(
|
||||
const uint8_t* isConstGpuGraphInput,
|
||||
const size_t isConstGpuGraphInputCount,
|
||||
const std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>& isInitializerTransferable,
|
||||
const onnxruntime::Graph& graph,
|
||||
const onnxruntime::IndexedSubGraph& indexedSubGraph,
|
||||
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
|
||||
IDMLDevice* device,
|
||||
const void* executionHandle);
|
||||
const void* executionHandle,
|
||||
const onnxruntime::Path& modelPath,
|
||||
gsl::span<const onnxruntime::Node* const> subgraphNodes,
|
||||
gsl::span<const onnxruntime::NodeArg* const> subgraphInputs,
|
||||
gsl::span<const onnxruntime::NodeArg* const> subgraphOutputs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -151,6 +151,8 @@ namespace Dml
|
|||
_In_opt_ const std::unordered_map<std::string, GraphPartition*>* nodeNameToPartitionMap,
|
||||
_Inout_ std::unordered_map<const onnxruntime::Node*, GraphNodeProperties>& dmlNodePropertyMap,
|
||||
_Inout_ std::unordered_set<std::string>& requiredInitializerMap,
|
||||
_Inout_ std::unordered_set<std::string>& dynamicCpuInputMap,
|
||||
bool allowDmlGraphDynamicShapes,
|
||||
_Out_ bool* isDmlGraphNode
|
||||
)
|
||||
{
|
||||
|
|
@ -172,36 +174,68 @@ namespace Dml
|
|||
|
||||
if (internalRegInfo && internalRegInfo->graphNodeFactoryRegistration)
|
||||
{
|
||||
bool requiredCpuInputsConstant = true;
|
||||
for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs)
|
||||
if (allowDmlGraphDynamicShapes)
|
||||
{
|
||||
if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists())
|
||||
for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs)
|
||||
{
|
||||
continue;
|
||||
if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
const onnx::TensorProto* tensor = nullptr;
|
||||
const std::string& inputName = node.InputDefs()[inputIndex]->Name();
|
||||
|
||||
if (graph.GetInitializedTensor(inputName, tensor))
|
||||
{
|
||||
requiredInitializerMap.insert(inputName);
|
||||
}
|
||||
else
|
||||
{
|
||||
dynamicCpuInputMap.insert(inputName);
|
||||
}
|
||||
}
|
||||
|
||||
const onnx::TensorProto* tensor = nullptr;
|
||||
const std::string& inputName = node.InputDefs()[inputIndex]->Name();
|
||||
|
||||
if (!graph.GetInitializedTensor(inputName, tensor))
|
||||
std::optional<uint32_t> requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount;
|
||||
if (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())
|
||||
{
|
||||
requiredCpuInputsConstant = false;
|
||||
break;
|
||||
*isDmlGraphNode = true;
|
||||
graphNodeProperty.first->second.internalRegInfo = internalRegInfo;
|
||||
}
|
||||
|
||||
requiredInitializerMap.insert(inputName);
|
||||
}
|
||||
|
||||
std::optional<uint32_t> requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount;
|
||||
if (requiredCpuInputsConstant &&
|
||||
TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) &&
|
||||
!ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) &&
|
||||
TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) &&
|
||||
!ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) &&
|
||||
(requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size()))
|
||||
else
|
||||
{
|
||||
*isDmlGraphNode = true;
|
||||
graphNodeProperty.first->second.internalRegInfo = internalRegInfo;
|
||||
bool requiredCpuInputsConstant = true;
|
||||
for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs)
|
||||
{
|
||||
if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
const onnx::TensorProto* tensor = nullptr;
|
||||
const std::string& inputName = node.InputDefs()[inputIndex]->Name();
|
||||
|
||||
if (!graph.GetInitializedTensor(inputName, tensor))
|
||||
{
|
||||
requiredCpuInputsConstant = false;
|
||||
break;
|
||||
}
|
||||
|
||||
requiredInitializerMap.insert(inputName);
|
||||
}
|
||||
|
||||
std::optional<uint32_t> requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount;
|
||||
if (requiredCpuInputsConstant &&
|
||||
TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) &&
|
||||
!ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) &&
|
||||
TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) &&
|
||||
!ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) &&
|
||||
(requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size()))
|
||||
{
|
||||
*isDmlGraphNode = true;
|
||||
graphNodeProperty.first->second.internalRegInfo = internalRegInfo;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -379,8 +413,10 @@ namespace Dml
|
|||
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
|
||||
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties>& graphNodePropertyMap,
|
||||
std::unordered_set<std::string>& requiredInitializerMap,
|
||||
std::unordered_set<std::string>& dynamicCpuInputMap,
|
||||
gsl::span<const onnxruntime::NodeIndex> additionalSplittingNodes,
|
||||
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputs)
|
||||
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputs,
|
||||
bool allowDmlGraphDynamicShapes)
|
||||
{
|
||||
// Nodes are uniquely identified by the name of their first output argument
|
||||
std::vector<std::unique_ptr<GraphPartition>> partitions;
|
||||
|
|
@ -443,6 +479,8 @@ namespace Dml
|
|||
&nodeNameToPartitionMap,
|
||||
graphNodePropertyMap,
|
||||
requiredInitializerMap,
|
||||
dynamicCpuInputMap,
|
||||
allowDmlGraphDynamicShapes,
|
||||
/*out*/ &isDmlGraphNode
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -50,6 +50,8 @@ namespace Dml
|
|||
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
|
||||
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties>& graphNodePropertyMap,
|
||||
std::unordered_set<std::string>& requiredInitializerMap,
|
||||
std::unordered_set<std::string>& dynamicCpuInputMap,
|
||||
gsl::span<const onnxruntime::NodeIndex> additionalSplittingNodes,
|
||||
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputs);
|
||||
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputs,
|
||||
bool allowDmlGraphDynamicShapes);
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -2,8 +2,15 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <d3d12.h>
|
||||
|
||||
#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h"
|
||||
|
||||
interface IDMLCompiledOperator;
|
||||
struct DML_BUFFER_BINDING;
|
||||
struct DML_BINDING_DESC;
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
struct Binding
|
||||
|
|
|
|||
|
|
@ -1356,13 +1356,14 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
const onnxruntime::OpNodeProtoHelper<onnxruntime::ProtoHelperNodeContext>* protoHelper,
|
||||
const void* executionHandle,
|
||||
bool isInternalOperator,
|
||||
const EdgeShapes* inputShapesOverrides,
|
||||
const EdgeShapes* inferredOutputShapes,
|
||||
const AttributeMap* defaultAttributes,
|
||||
DmlGraphNodeCreateInfo* graphNodeCreateInfo,
|
||||
gsl::span<const uint32_t> requiredConstantCpuInputs,
|
||||
MLOperatorTensorGetter& constantInputGetter
|
||||
)
|
||||
: OpNodeInfoWrapper(protoHelper, nullptr, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, nullptr),
|
||||
: OpNodeInfoWrapper(protoHelper, inputShapesOverrides, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, nullptr),
|
||||
m_inferredOutputShapes(inferredOutputShapes),
|
||||
m_internalOperator(isInternalOperator),
|
||||
m_graphNodeCreateInfo(graphNodeCreateInfo)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#pragma once
|
||||
#include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h"
|
||||
#include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/customregistry.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
|
|
@ -93,42 +94,6 @@ public:
|
|||
|
||||
using AttributeMap = std::map<std::string, AttributeValue>;
|
||||
|
||||
// Encapsulation of shapes across different edges of an operator. Non-tensor
|
||||
// edges and unused edges have an empty array of dimensions.
|
||||
class EdgeShapes
|
||||
{
|
||||
public:
|
||||
EdgeShapes() = default;
|
||||
|
||||
EdgeShapes(size_t count) : m_shapes(count) {}
|
||||
|
||||
const std::vector<uint32_t>& GetShape(size_t edgeIndex) const
|
||||
{
|
||||
return m_shapes[edgeIndex];
|
||||
}
|
||||
|
||||
std::vector<uint32_t>& GetMutableShape(size_t edgeIndex)
|
||||
{
|
||||
return m_shapes[edgeIndex];
|
||||
}
|
||||
|
||||
size_t EdgeCount() const { return m_shapes.size(); }
|
||||
|
||||
void Reset(size_t edge_count)
|
||||
{
|
||||
m_shapes.clear();
|
||||
m_shapes.resize(edge_count);
|
||||
}
|
||||
|
||||
bool operator!=(const EdgeShapes& other) const noexcept
|
||||
{
|
||||
return (m_shapes != other.m_shapes);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::vector<uint32_t>> m_shapes;
|
||||
};
|
||||
|
||||
// Base class for ABI objects which may be "Closed", at which point calls will predictably
|
||||
// fail or return a dummy value. This is used for transient ABI context objects which
|
||||
// are passed to methods on kernel or inferencers, and which wrap Lotus objects whose lifetimes
|
||||
|
|
@ -434,6 +399,7 @@ class DmlGraphOpKernelInfoWrapper : public OpNodeInfoWrapper<
|
|||
const onnxruntime::OpNodeProtoHelper<onnxruntime::ProtoHelperNodeContext> * protoHelper,
|
||||
const void* executionHandle,
|
||||
bool isInternalOperator,
|
||||
const EdgeShapes* inputShapesOverrides,
|
||||
const EdgeShapes* inferredOutputShapes,
|
||||
const AttributeMap* defaultAttributes,
|
||||
DmlGraphNodeCreateInfo* graphNodeCreateInfo,
|
||||
|
|
|
|||
|
|
@ -30,8 +30,12 @@ namespace onnxruntime {
|
|||
|
||||
struct DMLProviderFactory : IExecutionProviderFactory {
|
||||
DMLProviderFactory(IDMLDevice* dml_device,
|
||||
ID3D12CommandQueue* cmd_queue) : dml_device_(dml_device),
|
||||
cmd_queue_(cmd_queue) {}
|
||||
ID3D12CommandQueue* cmd_queue,
|
||||
bool disable_metacommands,
|
||||
bool enable_dynamic_graph_fusion) : dml_device_(dml_device),
|
||||
cmd_queue_(cmd_queue),
|
||||
metacommands_enabled_(!disable_metacommands),
|
||||
dynamic_graph_fusion_enabled_(enable_dynamic_graph_fusion) {}
|
||||
~DMLProviderFactory() override {}
|
||||
|
||||
std::unique_ptr<IExecutionProvider> CreateProvider() override;
|
||||
|
|
@ -42,10 +46,11 @@ struct DMLProviderFactory : IExecutionProviderFactory {
|
|||
ComPtr<IDMLDevice> dml_device_{};
|
||||
ComPtr<ID3D12CommandQueue> cmd_queue_{};
|
||||
bool metacommands_enabled_ = true;
|
||||
bool dynamic_graph_fusion_enabled_ = false;
|
||||
};
|
||||
|
||||
std::unique_ptr<IExecutionProvider> DMLProviderFactory::CreateProvider() {
|
||||
auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get(), metacommands_enabled_);
|
||||
auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get(), metacommands_enabled_, dynamic_graph_fusion_enabled_);
|
||||
return provider;
|
||||
}
|
||||
|
||||
|
|
@ -54,7 +59,9 @@ void DMLProviderFactory::SetMetacommandsEnabled(bool metacommands_enabled) {
|
|||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(IDMLDevice* dml_device,
|
||||
ID3D12CommandQueue* cmd_queue) {
|
||||
ID3D12CommandQueue* cmd_queue,
|
||||
bool disable_metacommands,
|
||||
bool enable_dynamic_graph_fusion) {
|
||||
#ifndef _GAMING_XBOX
|
||||
// Validate that the D3D12 devices match between DML and the command queue. This specifically asks for IUnknown in
|
||||
// order to be able to compare the pointers for COM object identity.
|
||||
|
|
@ -73,7 +80,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(ID
|
|||
const Env& env = Env::Default();
|
||||
auto luid = d3d12_device->GetAdapterLuid();
|
||||
env.GetTelemetryProvider().LogExecutionProviderEvent(&luid);
|
||||
return std::make_shared<onnxruntime::DMLProviderFactory>(dml_device, cmd_queue);
|
||||
return std::make_shared<onnxruntime::DMLProviderFactory>(dml_device, cmd_queue, disable_metacommands, enable_dynamic_graph_fusion);
|
||||
}
|
||||
|
||||
void DmlConfigureProviderFactoryMetacommandsEnabled(IExecutionProviderFactory* factory, bool metacommandsEnabled) {
|
||||
|
|
@ -234,12 +241,10 @@ static void SortHeterogenousDXCoreAdapterList(
|
|||
std::sort(adapter_infos.begin(), adapter_infos.end(), policy);
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::Create(int device_id) {
|
||||
return Create(device_id, /*skip_software_device_check*/ false);
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFromOptions(
|
||||
OrtDmlDeviceOptions* device_options) {
|
||||
OrtDmlDeviceOptions* device_options,
|
||||
bool disable_metacommands,
|
||||
bool enable_dynamic_graph_fusion) {
|
||||
auto default_device_options = OrtDmlDeviceOptions { Default, Gpu };
|
||||
if (device_options == nullptr) {
|
||||
device_options = &default_device_options;
|
||||
|
|
@ -286,7 +291,7 @@ std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFrom
|
|||
adapters.begin(),
|
||||
[](auto& a){ return a.Adapter; });
|
||||
|
||||
return onnxruntime::DMLProviderFactoryCreator::CreateFromAdapterList(std::move(adapters));
|
||||
return onnxruntime::DMLProviderFactoryCreator::CreateFromAdapterList(std::move(adapters), disable_metacommands, enable_dynamic_graph_fusion);
|
||||
}
|
||||
|
||||
static std::optional<OrtDmlPerformancePreference> ParsePerformancePreference(const ProviderOptions& provider_options) {
|
||||
|
|
@ -354,12 +359,32 @@ static std::optional<int> ParseDeviceId(const ProviderOptions& provider_options)
|
|||
return {};
|
||||
}
|
||||
|
||||
static bool ParseBoolean(const ProviderOptions& provider_options, const std::string& key) {
|
||||
auto preference_it = provider_options.find(key);
|
||||
if (preference_it != provider_options.end() && !preference_it->second.empty()) {
|
||||
if (preference_it->second == "True" || preference_it->second == "true") {
|
||||
return true;
|
||||
} else if (preference_it->second == "False" || preference_it->second == "false") {
|
||||
return false;
|
||||
} else {
|
||||
ORT_THROW("[ERROR] [DirectML] The value for the key '" + key + "' should be 'True' or 'False'. Default value is 'False'.\n");
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFromProviderOptions(
|
||||
const ProviderOptions& provider_options) {
|
||||
const ProviderOptions& provider_options) {
|
||||
|
||||
bool disable_metacommands = ParseBoolean(provider_options, "disable_metacommands");
|
||||
bool enable_dynamic_graph_fusion = ParseBoolean(provider_options, "enable_dynamic_graph_fusion");
|
||||
bool skip_software_device_check = false;
|
||||
auto device_id = ParseDeviceId(provider_options);
|
||||
|
||||
if (device_id.has_value())
|
||||
{
|
||||
return onnxruntime::DMLProviderFactoryCreator::Create(device_id.value());
|
||||
return onnxruntime::DMLProviderFactoryCreator::Create(device_id.value(), skip_software_device_check, disable_metacommands, enable_dynamic_graph_fusion);
|
||||
}
|
||||
|
||||
auto preference = ParsePerformancePreference(provider_options);
|
||||
|
|
@ -367,7 +392,7 @@ std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFrom
|
|||
|
||||
// If no preference/filters are specified then create with default preference/filters.
|
||||
if (!preference.has_value() && !filter.has_value()) {
|
||||
return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(nullptr);
|
||||
return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(nullptr, disable_metacommands, enable_dynamic_graph_fusion);
|
||||
}
|
||||
|
||||
if (!preference.has_value()) {
|
||||
|
|
@ -381,7 +406,7 @@ std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFrom
|
|||
OrtDmlDeviceOptions device_options;
|
||||
device_options.Preference = preference.value();
|
||||
device_options.Filter = filter.value();
|
||||
return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(&device_options);
|
||||
return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(&device_options, disable_metacommands, enable_dynamic_graph_fusion);
|
||||
}
|
||||
|
||||
Microsoft::WRL::ComPtr<ID3D12Device> DMLProviderFactoryCreator::CreateD3D12Device(
|
||||
|
|
@ -441,7 +466,10 @@ Microsoft::WRL::ComPtr<IDMLDevice> DMLProviderFactoryCreator::CreateDMLDevice(ID
|
|||
return dml_device;
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateDMLDeviceAndProviderFactory(ID3D12Device* d3d12_device) {
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateDMLDeviceAndProviderFactory(
|
||||
ID3D12Device* d3d12_device,
|
||||
bool disable_metacommands,
|
||||
bool enable_dynamic_graph_fusion) {
|
||||
D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {};
|
||||
cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
|
||||
cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT;
|
||||
|
|
@ -450,16 +478,22 @@ std::shared_ptr<IExecutionProviderFactory> CreateDMLDeviceAndProviderFactory(ID3
|
|||
ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf())));
|
||||
|
||||
auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device);
|
||||
return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get());
|
||||
return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get(), disable_metacommands, enable_dynamic_graph_fusion);
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) {
|
||||
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::Create(
|
||||
int device_id,
|
||||
bool skip_software_device_check,
|
||||
bool disable_metacommands,
|
||||
bool enable_dynamic_graph_fusion) {
|
||||
ComPtr<ID3D12Device> d3d12_device = CreateD3D12Device(device_id, skip_software_device_check);
|
||||
return CreateDMLDeviceAndProviderFactory(d3d12_device.Get());
|
||||
return CreateDMLDeviceAndProviderFactory(d3d12_device.Get(), disable_metacommands, enable_dynamic_graph_fusion);
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFromAdapterList(
|
||||
std::vector<ComPtr<IDXCoreAdapter>>&& dxcore_devices) {
|
||||
std::vector<ComPtr<IDXCoreAdapter>>&& dxcore_devices,
|
||||
bool disable_metacommands,
|
||||
bool enable_dynamic_graph_fusion) {
|
||||
// Choose the first device from the list since it's the highest priority
|
||||
auto dxcore_device = dxcore_devices[0];
|
||||
|
||||
|
|
@ -467,7 +501,7 @@ std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFrom
|
|||
ComPtr<ID3D12Device> d3d12_device;
|
||||
ORT_THROW_IF_FAILED(D3D12CreateDevice(dxcore_device.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf())));
|
||||
|
||||
return CreateDMLDeviceAndProviderFactory(d3d12_device.Get());
|
||||
return CreateDMLDeviceAndProviderFactory(d3d12_device.Get(), disable_metacommands, enable_dynamic_graph_fusion);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -477,7 +511,7 @@ std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFrom
|
|||
// The OrtSessionOptionsAppendExecutionProvider_DML export on the OrtDmlApi should be used instead.
|
||||
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id) {
|
||||
API_IMPL_BEGIN
|
||||
options->provider_factories.push_back(onnxruntime::DMLProviderFactoryCreator::Create(device_id));
|
||||
options->provider_factories.push_back(onnxruntime::DMLProviderFactoryCreator::Create(device_id, false, false, false));
|
||||
API_IMPL_END
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -489,7 +523,9 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSess
|
|||
_In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue) {
|
||||
API_IMPL_BEGIN
|
||||
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_DML(dml_device,
|
||||
cmd_queue));
|
||||
cmd_queue,
|
||||
false,
|
||||
false));
|
||||
API_IMPL_END
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -517,7 +553,7 @@ ORT_API_STATUS_IMPL(FreeGPUAllocation, _In_ void* ptr) {
|
|||
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_options) {
|
||||
API_IMPL_BEGIN
|
||||
#ifdef USE_DML
|
||||
auto factory = onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(device_options);
|
||||
auto factory = onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(device_options, false, false);
|
||||
// return the create function for a dxcore device
|
||||
options->provider_factories.push_back(factory);
|
||||
#endif // USE_DML
|
||||
|
|
|
|||
|
|
@ -17,15 +17,24 @@
|
|||
namespace onnxruntime {
|
||||
|
||||
struct DMLProviderFactoryCreator {
|
||||
static std::shared_ptr<IExecutionProviderFactory> Create(int device_id);
|
||||
static std::shared_ptr<IExecutionProviderFactory> Create(int device_id, bool skip_software_device_check);
|
||||
static std::shared_ptr<IExecutionProviderFactory> Create(
|
||||
int device_id,
|
||||
bool skip_software_device_check,
|
||||
bool disable_metacommands,
|
||||
bool enable_dynamic_graph_fusion);
|
||||
|
||||
static std::shared_ptr<IExecutionProviderFactory> CreateFromProviderOptions(
|
||||
const ProviderOptions& provider_options_map);
|
||||
static std::shared_ptr<IExecutionProviderFactory> CreateFromOptions(OrtDmlDeviceOptions* device_options);
|
||||
|
||||
static std::shared_ptr<IExecutionProviderFactory> CreateFromOptions(
|
||||
OrtDmlDeviceOptions* device_options,
|
||||
bool disable_metacommands,
|
||||
bool enable_dynamic_graph_fusion);
|
||||
|
||||
static std::shared_ptr<IExecutionProviderFactory> CreateFromAdapterList(
|
||||
std::vector<Microsoft::WRL::ComPtr<IDXCoreAdapter>>&& dxcore_devices);
|
||||
std::vector<Microsoft::WRL::ComPtr<IDXCoreAdapter>>&& dxcore_devices,
|
||||
bool disable_metacommands,
|
||||
bool enable_dynamic_graph_fusion);
|
||||
|
||||
static Microsoft::WRL::ComPtr<ID3D12Device> CreateD3D12Device(int device_id, bool skip_software_device_check);
|
||||
static Microsoft::WRL::ComPtr<IDMLDevice> CreateDMLDevice(ID3D12Device* d3d12_device);
|
||||
|
|
|
|||
|
|
@ -52,8 +52,10 @@
|
|||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#ifdef USE_DML // TODO: This is necessary for the workaround in TransformGraph
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h"
|
||||
#include "core/providers/dml/dml_session_options_config_keys.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h"
|
||||
#endif
|
||||
#include "core/session/environment.h"
|
||||
#include "core/session/user_logging_sink.h"
|
||||
|
|
@ -1598,7 +1600,9 @@ common::Status InferenceSession::Initialize() {
|
|||
record_runtime_optimization_produced_op_schema));
|
||||
|
||||
#ifdef USE_DML
|
||||
if (execution_providers_.Get(kDmlExecutionProvider)) {
|
||||
const IExecutionProvider* dmlExecutionProvider = execution_providers_.Get(kDmlExecutionProvider);
|
||||
|
||||
if (dmlExecutionProvider) {
|
||||
// DML graph fusion is an important runtime optimization that cannot be done ahead of time; it must be disabled
|
||||
// when running in "offline mode" and saving an optimized model to disk. To support users that want to optimize
|
||||
// models offline, and then disable graph optimizations when running "online", this transformer ignores the ORT
|
||||
|
|
@ -1608,11 +1612,20 @@ common::Status InferenceSession::Initialize() {
|
|||
|
||||
if (dml_graph_fusion_enabled) {
|
||||
std::unique_ptr<onnxruntime::GraphTransformer> dmlGraphFusionTransformer = std::make_unique<Dml::DmlGraphFusionTransformer>("DmlGraphFusionTransformer",
|
||||
execution_providers_.Get(kDmlExecutionProvider));
|
||||
dmlExecutionProvider);
|
||||
if (dmlGraphFusionTransformer == nullptr) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr");
|
||||
}
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.Register(std::move(dmlGraphFusionTransformer), onnxruntime::TransformerLevel::Level3));
|
||||
|
||||
if (static_cast<const Dml::ExecutionProvider*>(dmlExecutionProvider)->DynamicGraphFusionEnabled()) {
|
||||
std::unique_ptr<onnxruntime::GraphTransformer> dmlRuntimeGraphFusionTransformer = std::make_unique<Dml::DmlRuntimeGraphFusionTransformer>("DmlRuntimeGraphFusionTransformer",
|
||||
dmlExecutionProvider);
|
||||
if (dmlRuntimeGraphFusionTransformer == nullptr) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "DmlRuntimeGraphFusionTransformer is nullptr");
|
||||
}
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.Register(std::move(dmlRuntimeGraphFusionTransformer), onnxruntime::TransformerLevel::Level3));
|
||||
}
|
||||
}
|
||||
|
||||
// This transformer applies DML-specific fusions that go beyond what ORT offers by default
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ void addGlobalSchemaFunctions(pybind11::module& m) {
|
|||
onnxruntime::ArmNNProviderFactoryCreator::Create(0),
|
||||
#endif
|
||||
#ifdef USE_DML
|
||||
onnxruntime::DMLProviderFactoryCreator::Create(0, /*skip_software_device_check*/ true),
|
||||
onnxruntime::DMLProviderFactoryCreator::Create(0, false, false, false),
|
||||
#endif
|
||||
#ifdef USE_NNAPI
|
||||
onnxruntime::NnapiProviderFactoryCreator::Create(0, std::optional<std::string>()),
|
||||
|
|
|
|||
|
|
@ -268,7 +268,7 @@ std::unique_ptr<IExecutionProvider> DefaultCannExecutionProvider() {
|
|||
|
||||
std::unique_ptr<IExecutionProvider> DefaultDmlExecutionProvider() {
|
||||
#ifdef USE_DML
|
||||
if (auto factory = DMLProviderFactoryCreator::Create(0))
|
||||
if (auto factory = DMLProviderFactoryCreator::Create(0, false, false, false))
|
||||
return factory->CreateProvider();
|
||||
#endif
|
||||
return nullptr;
|
||||
|
|
|
|||
Loading…
Reference in a new issue