[DML EP] Add dynamic graph compilation (#17876)

Historically, DML was only able to fuse partitions when all sizes are
known in advance or when we were overriding them at session creation
time. But in practice, it should be possible to compile partitions at
compute time if the caller knows that the dimensions won't be changed
for every inference (e.g. resizing a webcam window, or padding the input
to powers of 2). This graph will be cached and reused until the sizes
change.

This is an opt-in option gated under the `enable_dynamic_graph_fusion`
option, which means that it will only be enabled when the caller
requests it since they have more context on how their model will be
called between inferences.

This PR also adds the option to disable metacommands from the python
API, which is an option for the C API but was lacking for python.
This commit is contained in:
Patrice Vignola 2023-10-25 19:56:16 -07:00 committed by GitHub
parent d30d4d372a
commit 538e97cbda
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 1127 additions and 143 deletions

View file

@ -128,7 +128,7 @@ struct OrtDmlApi {
/**
* SessionOptionsAppendExecutionProvider_DML2
* Creates a DirectML Execution Provider given the supplied device options that contain a performance preference
* (high power, low power, or defult) and a device filter (None, GPU, or NPU).
* (high power, low power, or default) and a device filter (None, GPU, or NPU).
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts);
};

View file

@ -3,6 +3,9 @@
#pragma once
interface IMLOperatorRegistry;
interface IDMLDevice;
interface ID3D12CommandQueue;
interface ID3D12Resource;
#include "core/common/status.h"
#include "core/framework/data_transfer.h"
@ -28,7 +31,8 @@ namespace Dml
std::unique_ptr<onnxruntime::IExecutionProvider> CreateExecutionProvider(
IDMLDevice* dmlDevice,
ID3D12CommandQueue* commandQueue,
bool enableMetacommands = true);
bool enableMetacommands,
bool enableDynamicGraphFusion);
ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr);
void FlushContext(onnxruntime::IExecutionProvider* provider);

View file

@ -7,11 +7,14 @@
#include <functional>
#include <variant>
#include <optional>
#include <wrl/client.h>
#include "core/framework/op_kernel.h"
#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h"
struct AbstractOperatorDesc;
interface IMLOperatorTensor;
interface IDMLOperator;
struct DML_INPUT_GRAPH_EDGE_DESC;
struct DML_OUTPUT_GRAPH_EDGE_DESC;
struct DML_INTERMEDIATE_GRAPH_EDGE_DESC;
@ -92,6 +95,8 @@ namespace Windows::AI::MachineLearning::Adapter
const onnxruntime::Node& node,
MLOperatorTensorGetter& constantInputGetter,
const void* executionHandle,
const EdgeShapes* inputShapesOverrides,
/*out*/ EdgeShapes* outputShapes,
/*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo
)>;

View file

@ -491,6 +491,8 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
const onnxruntime::Node& node,
MLOperatorTensorGetter& constantInputGetter,
const void* executionHandle,
const EdgeShapes* inputShapesOverrides,
/*out*/ EdgeShapes* outputShapes,
/*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo
)
{
@ -498,15 +500,15 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
onnxruntime::OpNodeProtoHelper<onnxruntime::ProtoHelperNodeContext> protoHelper(&nodeContext);
// Use the same list of required constant inputs for the shape inferrer and the kernel.
EdgeShapes outputShapes;
InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, nullptr, outputShapes);
InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, inputShapesOverrides, *outputShapes);
// Create the kernel while allowing input shape and output shape queries according to options
ComPtr<DmlGraphOpKernelInfoWrapper> kernelInfoWrapper = wil::MakeOrThrow<DmlGraphOpKernelInfoWrapper>(
&protoHelper,
executionHandle,
true,
&outputShapes,
inputShapesOverrides,
outputShapes,
&defaultAttributesCapture,
graphNodeCreateInfo,
constantCpuInputCapture,

View file

@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace Windows::AI::MachineLearning::Adapter
{
// edges and unused edges have an empty array of dimensions.
class EdgeShapes
{
public:
EdgeShapes() = default;
EdgeShapes(size_t count) : m_shapes(count) {}
const std::vector<uint32_t>& GetShape(size_t edgeIndex) const
{
return m_shapes[edgeIndex];
}
std::vector<uint32_t>& GetMutableShape(size_t edgeIndex)
{
return m_shapes[edgeIndex];
}
size_t EdgeCount() const { return m_shapes.size(); }
void Reset(size_t edge_count)
{
m_shapes.clear();
m_shapes.resize(edge_count);
}
bool operator!=(const EdgeShapes& other) const noexcept
{
return (m_shapes != other.m_shapes);
}
private:
std::vector<std::vector<uint32_t>> m_shapes;
};
}

View file

@ -1,7 +1,7 @@
#pragma once
#include "DmlGraphFusionHelper.h"
#include "DmlRuntimeFusedGraphKernel.h"
namespace Dml
{
@ -501,5 +501,171 @@ namespace DmlGraphFusionHelper
graph.FinalizeFuseSubGraph(indexedSubGraph, fusedNode);
}
void RegisterDynamicKernel(
onnxruntime::Graph& graph,
onnxruntime::KernelRegistry* registryForPartitionKernels,
const ExecutionProviderImpl* providerImpl,
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties> graphNodePropertyMap,
const std::unordered_set<std::string>& dynamicCpuInputMap,
std::shared_ptr<const onnxruntime::IndexedSubGraph> indexedSubGraph,
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>&& isInitializerTransferable)
{
struct NodeInfo
{
std::string name;
std::string opType;
std::string description;
std::string domain;
onnxruntime::NodeAttributes attributes;
std::vector<onnxruntime::NodeArg*> inputDefPointers;
std::vector<onnxruntime::NodeArg*> outputDefPointers;
};
auto partitionNodePropsMap = DmlGraphFusionHelper::CreatePartitionNodePropsMap(
graph,
*indexedSubGraph,
std::move(graphNodePropertyMap));
auto modelPath = graph.ModelPath();
const gsl::span<const std::string> subGraphInputArgNames = indexedSubGraph->GetMetaDef()->inputs;
const gsl::span<const std::string> subGraphOutputArgNames = indexedSubGraph->GetMetaDef()->outputs;
std::vector<NodeInfo> nodesInfo;
nodesInfo.reserve(indexedSubGraph->nodes.size());
std::vector<const onnxruntime::NodeArg*> subgraphInputs;
subgraphInputs.reserve(subGraphInputArgNames.size());
std::vector<const onnxruntime::NodeArg*> subgraphOutputs;
subgraphOutputs.reserve(subGraphOutputArgNames.size());
std::vector<onnxruntime::NodeAttributes> nodeAttributes;
nodeAttributes.reserve(indexedSubGraph->nodes.size());
std::vector<std::shared_ptr<onnxruntime::NodeArg>> intermediateNodeArgs;
for (size_t sortedNodeIndex : indexedSubGraph->nodes)
{
auto node = graph.GetNode(sortedNodeIndex);
nodeAttributes.push_back(node->GetAttributes());
NodeInfo nodeInfo{};
nodeInfo.name = node->Name();
nodeInfo.opType = node->OpType();
nodeInfo.description = node->Description();
nodeInfo.domain = node->Domain();
nodeInfo.attributes = node->GetAttributes();
nodeInfo.inputDefPointers.reserve(node->InputDefs().size());
nodeInfo.outputDefPointers.reserve(node->OutputDefs().size());
for (const onnxruntime::NodeArg* inputDef : node->InputDefs())
{
intermediateNodeArgs.emplace_back(std::make_shared<onnxruntime::NodeArg>(inputDef->Name(), inputDef->TypeAsProto()));
nodeInfo.inputDefPointers.push_back(intermediateNodeArgs.back().get());
}
for (const onnxruntime::NodeArg* outputDef : node->OutputDefs())
{
intermediateNodeArgs.emplace_back(std::make_shared<onnxruntime::NodeArg>(outputDef->Name(), outputDef->TypeAsProto()));
nodeInfo.outputDefPointers.push_back(intermediateNodeArgs.back().get());
}
nodesInfo.push_back(std::move(nodeInfo));
}
for (const std::string& graphInputName : subGraphInputArgNames)
{
subgraphInputs.push_back(graph.GetNodeArg(graphInputName));
}
for (const std::string& graphOutputName : subGraphOutputArgNames)
{
subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName));
}
// We need to keep the initializers alive since they will be freed once the nodes are removed from the graph
std::vector<ONNX_NAMESPACE::TensorProto> ownedInitializers;
ownedInitializers.reserve(isInitializerTransferable.size());
for (auto& kvp : isInitializerTransferable)
{
ONNX_NAMESPACE::TensorProto tensorProto;
tensorProto.set_data_type(kvp.second.first->data_type());
tensorProto.set_raw_data(kvp.second.first->raw_data());
tensorProto.set_name(kvp.second.first->name());
for (int i = 0; i < kvp.second.first->dims_size(); ++i)
{
tensorProto.add_dims(kvp.second.first->dims(i));
}
ownedInitializers.push_back(std::move(tensorProto));
kvp.second.first = &ownedInitializers.back();
}
// lamda captures for the kernel registration
auto fused_kernel_func = [
indexedSubGraph,
&modelPath,
nodesInfo = std::move(nodesInfo),
intermediateNodeArgs = std::move(intermediateNodeArgs),
subgraphInputs = std::move(subgraphInputs),
subgraphOutputs = std::move(subgraphOutputs),
partitionNodePropsMap = std::move(partitionNodePropsMap),
ownedInitializers = std::move(ownedInitializers)] (onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr<onnxruntime::OpKernel>& out) mutable ->onnxruntime::Status
{
std::vector<std::shared_ptr<onnxruntime::Node>> subgraphNodes;
subgraphNodes.reserve(nodesInfo.size());
for (const NodeInfo& nodeInfo : nodesInfo)
{
subgraphNodes.emplace_back(std::make_shared<onnxruntime::Node>(
nodeInfo.name,
nodeInfo.opType,
nodeInfo.description,
nodeInfo.inputDefPointers,
nodeInfo.outputDefPointers,
&nodeInfo.attributes,
nodeInfo.domain));
}
out.reset(CreateRuntimeFusedGraphKernel(
info,
indexedSubGraph,
modelPath,
std::move(subgraphNodes),
std::move(subgraphInputs),
std::move(subgraphOutputs),
std::move(intermediateNodeArgs),
std::move(partitionNodePropsMap),
std::move(ownedInitializers)));
return Status::OK();
};
// build the kernel definition on the fly, and register it to the fused_kernel_regisitry.
onnxruntime::KernelDefBuilder builder;
builder.SetName(indexedSubGraph->GetMetaDef()->name)
.SetDomain(indexedSubGraph->GetMetaDef()->domain)
.SinceVersion(indexedSubGraph->GetMetaDef()->since_version)
.Provider(onnxruntime::kDmlExecutionProvider);
// Force the CPU inputs to be allocated on the CPU
for (int i = 0; i < subGraphInputArgNames.size(); ++i)
{
if (dynamicCpuInputMap.find(subGraphInputArgNames[i]) != dynamicCpuInputMap.end())
{
builder.InputMemoryType(OrtMemTypeCPUInput, i);
}
}
ORT_THROW_IF_ERROR(registryForPartitionKernels->Register(builder, fused_kernel_func));
auto& fusedNode = graph.BeginFuseSubGraph(*indexedSubGraph, indexedSubGraph->GetMetaDef()->name);
fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider);
graph.FinalizeFuseSubGraph(*indexedSubGraph, fusedNode);
}
}
}

View file

@ -80,5 +80,14 @@ namespace DmlGraphFusionHelper
std::vector<uint8_t>&& isInputsUploadedByDmlEP,
const GraphDescBuilder::GraphDesc& graphDesc,
Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiledExecutionPlanOperator);
void RegisterDynamicKernel(
onnxruntime::Graph& graph,
onnxruntime::KernelRegistry* registryForPartitionKernels,
const ExecutionProviderImpl* providerImpl,
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties> graphNodePropertyMap,
const std::unordered_set<std::string>& dynamicCpuInputMap,
std::shared_ptr<const onnxruntime::IndexedSubGraph> indexedSubGraph,
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>&& isInitializerTransferable);
}
}

View file

@ -15,6 +15,18 @@
namespace Dml
{
namespace
{
struct CompiledPartitionInfo
{
Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiledOperator;
onnxruntime::IndexedSubGraph indexedSubGraph;
std::vector<uint8_t> isInputsUploadedByDmlEP;
GraphDescBuilder::GraphDesc graphDesc;
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>> isInitializerTransferable;
};
}
DmlGraphFusionTransformer::DmlGraphFusionTransformer(
const std::string& name,
const onnxruntime::IExecutionProvider* provider
@ -24,15 +36,6 @@ namespace Dml
{
}
struct CompiledPartitionInfo
{
Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiledOperator;
onnxruntime::IndexedSubGraph indexedSubGraph;
std::vector<uint8_t> isInputsUploadedByDmlEP;
GraphDescBuilder::GraphDesc graphDesc;
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>> isInitializerTransferable;
};
onnxruntime::common::Status DmlGraphFusionTransformer::ApplyImpl(
onnxruntime::Graph& graph,
bool& modified,
@ -87,6 +90,7 @@ namespace Dml
{
// Initializers needed by any graph partition
std::unordered_set<std::string> requiredInitializerMap;
std::unordered_set<std::string> dynamicCpuInputMap;
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties> graphNodePropertyMap;
onnxruntime::GraphViewer graphViewer(graph);
std::vector<std::unique_ptr<GraphPartition>> partitions = BuildPartitions(
@ -96,8 +100,10 @@ namespace Dml
m_providerImpl->GetSupportedDeviceDataTypeMask(),
graphNodePropertyMap,
requiredInitializerMap,
dynamicCpuInputMap,
additionalSplittingNodes,
implicitInputDefs);
implicitInputDefs,
false);
// Reset the splitting nodes for the current iteration
additionalSplittingNodes.clear();
@ -190,17 +196,48 @@ namespace Dml
std::move(graphNodePropertyMap));
// Convert partitionONNXGraph into DML EP GraphDesc
auto modelPath = graph.ModelPath();
const gsl::span<const std::string> subGraphInputArgNames = indexedSubGraph.GetMetaDef()->inputs;
const gsl::span<const std::string> subGraphOutputArgNames = indexedSubGraph.GetMetaDef()->outputs;
std::vector<const onnxruntime::Node*> subgraphNodes;
subgraphNodes.reserve(indexedSubGraph.nodes.size());
std::vector<const onnxruntime::NodeArg*> subgraphInputs;
subgraphInputs.reserve(subGraphInputArgNames.size());
std::vector<const onnxruntime::NodeArg*> subgraphOutputs;
subgraphOutputs.reserve(subGraphOutputArgNames.size());
for (size_t sortedNodeIndex : indexedSubGraph.nodes)
{
subgraphNodes.push_back(graph.GetNode(sortedNodeIndex));
}
for (const std::string& graphInputName : subGraphInputArgNames)
{
subgraphInputs.push_back(graph.GetNodeArg(graphInputName));
}
for (const std::string& graphOutputName : subGraphOutputArgNames)
{
subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName));
}
ComPtr<IDMLDevice> device;
ORT_THROW_IF_FAILED(m_providerImpl->GetDmlDevice(device.GetAddressOf()));
GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc(
isInputsUploadedByDmlEP.data(),
isInputsUploadedByDmlEP.size(),
isInitializerTransferable,
graph,
indexedSubGraph,
partitionNodePropsMap,
device.Get(),
m_providerImpl);
m_providerImpl,
modelPath,
subgraphNodes,
subgraphInputs,
subgraphOutputs);
// Compile the operator
auto compiledPartition = DmlGraphFusionHelper::TryCreateCompiledOperator(

View file

@ -0,0 +1,369 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
#include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h"
#include "core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h"
#include "core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h"
using namespace Windows::AI::MachineLearning::Adapter;
namespace Dml
{
class DmlRuntimeFusedGraphKernel : public onnxruntime::OpKernel
{
public:
DmlRuntimeFusedGraphKernel() = delete;
DmlRuntimeFusedGraphKernel(
const onnxruntime::OpKernelInfo& kernelInfo,
std::shared_ptr<const onnxruntime::IndexedSubGraph> indexedSubGraph,
const onnxruntime::Path& modelPath,
std::vector<std::shared_ptr<onnxruntime::Node>>&& subgraphNodes,
std::vector<const onnxruntime::NodeArg*>&& subgraphInputs,
std::vector<const onnxruntime::NodeArg*>&& subgraphOutputs,
std::vector<std::shared_ptr<onnxruntime::NodeArg>>&& intermediateNodeArgs,
std::unordered_map<std::string, GraphNodeProperties>&& partitionNodePropsMap,
std::vector<ONNX_NAMESPACE::TensorProto>&& ownedInitializers)
: OpKernel(kernelInfo),
m_indexedSubGraph(std::move(indexedSubGraph)),
m_modelPath(modelPath),
m_subgraphNodes(std::move(subgraphNodes)),
m_subgraphInputs(std::move(subgraphInputs)),
m_subgraphOutputs(std::move(subgraphOutputs)),
m_intermediateNodeArgs(std::move(intermediateNodeArgs)),
m_partitionNodePropsMap(std::move(partitionNodePropsMap)),
m_ownedInitializers(std::move(ownedInitializers))
{
for (const auto& initializer : m_ownedInitializers)
{
m_isInitializerTransferable[initializer.name()] = std::make_pair(&initializer, false);
}
// Get the execution provider interfaces
auto executionHandle = kernelInfo.GetExecutionProvider()->GetExecutionHandle();
if (executionHandle)
{
// We assume the execution object inherits IUnknown as its first base
ComPtr<IUnknown> providerExecutionObject = const_cast<IUnknown*>(static_cast<const IUnknown*>(executionHandle));
// Get the WinML-specific execution provider interface from the execution object.
ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_provider));
ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_winmlProvider));
}
m_subgraphNodePointers.reserve(m_subgraphNodes.size());
for (auto& subgraphNode : m_subgraphNodes)
{
m_subgraphNodePointers.push_back(subgraphNode.get());
}
}
void TranslateAndCompileGraph(
const onnxruntime::OpKernelInfo& kernelInfo,
std::vector<Microsoft::WRL::ComPtr<ID3D12Resource>>& initializeResourceRefs,
std::vector<DML_BUFFER_BINDING> initInputBindings) const
{
// Allocate a persistent resource and initialize the operator
UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize;
if (persistentResourceSize > 0)
{
ORT_THROW_IF_FAILED(m_provider->AllocatePooledResource(
static_cast<size_t>(persistentResourceSize),
AllocatorRoundingMode::Disabled,
m_persistentResource.ReleaseAndGetAddressOf(),
m_persistentResourceAllocatorUnk.ReleaseAndGetAddressOf()));
m_persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize };
}
ORT_THROW_IF_FAILED(m_provider->InitializeOperator(
m_compiledExecutionPlanOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
gsl::make_span(initInputBindings)));
std::for_each(
initializeResourceRefs.begin(),
initializeResourceRefs.end(),
[&](ComPtr<ID3D12Resource>& resource){ m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(resource).Get()); }
);
}
onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override
{
ORT_THROW_HR_IF(E_UNEXPECTED, m_subgraphInputs.size() != kernelContext->InputCount());
bool recompileNeeded = m_compiledExecutionPlanOperator == nullptr;
for (int inputIndex = 0; inputIndex < kernelContext->InputCount(); ++inputIndex)
{
const auto& input = kernelContext->RequiredInput<onnxruntime::Tensor>(inputIndex);
const std::string& inputName = m_subgraphInputs[inputIndex]->Name();
auto shapeIter = m_inferredInputShapes.find(inputName);
if (shapeIter == m_inferredInputShapes.end())
{
m_inferredInputShapes[inputName] = input.Shape();
recompileNeeded = true;
}
else if (shapeIter->second != input.Shape())
{
shapeIter->second = input.Shape();
recompileNeeded = true;
}
// If we have CPU inputs that are not initializers (i.e. they were computed at runtime), add them to the initializer list
if (input.Location().device.Type() == OrtDevice::CPU)
{
auto inputProto = onnxruntime::utils::TensorToTensorProto(input, inputName);
// We can only avoid recompiling the graph when all CPU inputs are identical
auto initializerIter = m_isInitializerTransferable.find(inputName);
if (initializerIter != m_isInitializerTransferable.end())
{
if (initializerIter->second.first->raw_data().length() == inputProto.raw_data().length())
{
for (int i = 0; i < inputProto.raw_data().length(); ++i)
{
if (initializerIter->second.first->raw_data()[i] != inputProto.raw_data()[i])
{
recompileNeeded = true;
break;
}
}
}
else
{
recompileNeeded = true;
}
}
else
{
recompileNeeded = true;
}
m_ownedCpuInputs.push_back(std::make_unique<ONNX_NAMESPACE::TensorProto>(std::move(inputProto)));
m_isInitializerTransferable[inputName] = std::make_pair(m_ownedCpuInputs.back().get(), false);
}
}
if (recompileNeeded)
{
// Go through all the node args and replace their shapes with the real ones
for (auto& nodeArg : m_intermediateNodeArgs)
{
auto iter = m_inferredInputShapes.find(nodeArg->Name());
if (iter != m_inferredInputShapes.end())
{
auto tensorShape = *nodeArg->Shape();
ORT_THROW_HR_IF(E_UNEXPECTED, tensorShape.dim_size() != iter->second.NumDimensions());
for (int i = 0; i < tensorShape.dim_size(); ++i)
{
tensorShape.mutable_dim(i)->set_dim_value(iter->second.GetDims()[i]);
}
nodeArg->SetShape(tensorShape);
}
}
// Populate input bindings for operator initialization
const uint32_t fusedNodeInputCount = gsl::narrow_cast<uint32_t>(m_indexedSubGraph->GetMetaDef()->inputs.size());
std::vector<Microsoft::WRL::ComPtr<ID3D12Resource>> initializeResourceRefs; // For lifetime control
std::vector<DML_BUFFER_BINDING> initInputBindings(fusedNodeInputCount);
std::vector<uint8_t> isInputsUploadedByDmlEP(fusedNodeInputCount);
auto providerImpl = static_cast<const ExecutionProvider*>(Info().GetExecutionProvider())->GetImpl();
// Convert partitionONNXGraph into DML EP GraphDesc
ComPtr<IDMLDevice> device;
ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf()));
GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc(
isInputsUploadedByDmlEP.data(),
isInputsUploadedByDmlEP.size(),
m_isInitializerTransferable,
m_partitionNodePropsMap,
device.Get(),
providerImpl,
m_modelPath,
m_subgraphNodePointers,
m_subgraphInputs,
m_subgraphOutputs);
m_outputShapes = graphDesc.outputShapes;
// Walk through each graph edge and mark used inputs
m_inputsUsed.resize(fusedNodeInputCount, false);
for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges)
{
m_inputsUsed[edge.GraphInputIndex] = true;
}
// Compile the operator
m_compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator(
graphDesc,
*m_indexedSubGraph,
providerImpl);
// Queue references to objects which must be kept alive until resulting GPU work completes
m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get());
TranslateAndCompileGraph(
Info(),
initializeResourceRefs,
initInputBindings);
}
// Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator
OpKernelContextWrapper contextWrapper(
kernelContext,
Info().GetExecutionProvider(),
true,
nullptr);
ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier());
// Get input resources for execution, excluding those which were specified as owned by DML and provided
// at initialization instead.
std::vector<ComPtr<IMLOperatorTensor>> inputTensors(kernelContext->InputCount());
std::vector<ID3D12Resource*> inputPtrs(kernelContext->InputCount());
for (int i = 0; i < kernelContext->InputCount(); ++i)
{
if (!m_inputsUsed[i])
{
continue;
}
ORT_THROW_IF_FAILED(contextWrapper.GetInputTensor(i, inputTensors[i].GetAddressOf()));
inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get());
}
auto outputTensors = contextWrapper.GetOutputTensors(m_outputShapes);
ExecuteOperator(
m_compiledExecutionPlanOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
inputPtrs,
outputTensors);
ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier());
return onnxruntime::Status::OK();
}
void ExecuteOperator(
IDMLCompiledOperator* op,
_In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding,
gsl::span<ID3D12Resource*> inputTensors,
gsl::span<IMLOperatorTensor*> outputTensors) const
{
auto FillBindingsFromTensors = [this](auto& bufferBindings, auto& bindingDescs, gsl::span<IMLOperatorTensor*>& tensors)
{
for (IMLOperatorTensor* tensor : tensors)
{
if (tensor)
{
assert(tensor->IsDataInterface());
ID3D12Resource* resource = m_provider->DecodeResource(MLOperatorTensor(tensor).GetDataInterface().Get());
D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc();
bufferBindings.push_back({ resource, 0, resourceDesc.Width });
bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() });
}
else
{
bufferBindings.push_back({ nullptr, 0, 0 });
bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr });
}
}
};
auto FillBindingsFromBuffers = [](auto& bufferBindings, auto& bindingDescs, gsl::span<ID3D12Resource*>& resources)
{
for (ID3D12Resource* resource : resources)
{
if (resource)
{
D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc();
bufferBindings.push_back({ resource, 0, resourceDesc.Width });
bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() });
}
else
{
bufferBindings.push_back({ nullptr, 0, 0 });
bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr });
}
}
};
std::vector<DML_BUFFER_BINDING> inputBufferBindings;
inputBufferBindings.reserve(inputTensors.size());
std::vector<DML_BINDING_DESC> inputBindings;
inputBindings.reserve(inputTensors.size());
FillBindingsFromBuffers(inputBufferBindings, inputBindings, inputTensors);
std::vector<DML_BUFFER_BINDING> outputBufferBindings;
outputBufferBindings.reserve(outputTensors.size());
std::vector<DML_BINDING_DESC> outputBindings;
outputBindings.reserve(outputTensors.size());
FillBindingsFromTensors(outputBufferBindings, outputBindings, outputTensors);
ORT_THROW_IF_FAILED(m_provider->ExecuteOperator(
op,
persistentResourceBinding,
inputBindings,
outputBindings));
}
private:
ComPtr<IWinmlExecutionProvider> m_winmlProvider;
ComPtr<Dml::IExecutionProvider> m_provider;
mutable std::optional<DML_BUFFER_BINDING> m_persistentResourceBinding;
std::shared_ptr<const onnxruntime::IndexedSubGraph> m_indexedSubGraph;
const onnxruntime::Path& m_modelPath;
std::vector<std::shared_ptr<onnxruntime::Node>> m_subgraphNodes;
std::vector<const onnxruntime::NodeArg*> m_subgraphInputs;
std::vector<const onnxruntime::NodeArg*> m_subgraphOutputs;
mutable std::vector<std::shared_ptr<onnxruntime::NodeArg>> m_intermediateNodeArgs;
std::unordered_map<std::string, GraphNodeProperties> m_partitionNodePropsMap;
std::vector<ONNX_NAMESPACE::TensorProto> m_ownedInitializers;
mutable std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>> m_isInitializerTransferable;
std::vector<const onnxruntime::Node*> m_subgraphNodePointers;
// Bindings from previous executions of a re-used command list
mutable std::vector<std::unique_ptr<ONNX_NAMESPACE::TensorProto>> m_ownedCpuInputs;
mutable ComPtr<IDMLCompiledOperator> m_compiledExecutionPlanOperator;
mutable std::vector<bool> m_inputsUsed;
mutable ComPtr<ID3D12Resource> m_persistentResource;
mutable ComPtr<IUnknown> m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator
mutable Windows::AI::MachineLearning::Adapter::EdgeShapes m_outputShapes;
mutable std::unordered_map<std::string, onnxruntime::TensorShape> m_inferredInputShapes;
};
onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel(
const onnxruntime::OpKernelInfo& info,
std::shared_ptr<const onnxruntime::IndexedSubGraph> indexedSubGraph,
const onnxruntime::Path& modelPath,
std::vector<std::shared_ptr<onnxruntime::Node>>&& subgraphNodes,
std::vector<const onnxruntime::NodeArg*>&& subgraphInputs,
std::vector<const onnxruntime::NodeArg*>&& subgraphOutputs,
std::vector<std::shared_ptr<onnxruntime::NodeArg>>&& intermediateNodeArgs,
std::unordered_map<std::string, GraphNodeProperties>&& partitionNodePropsMap,
std::vector<ONNX_NAMESPACE::TensorProto>&& ownedInitializers)
{
return new DmlRuntimeFusedGraphKernel(
info,
std::move(indexedSubGraph),
modelPath,
std::move(subgraphNodes),
std::move(subgraphInputs),
std::move(subgraphOutputs),
std::move(intermediateNodeArgs),
std::move(partitionNodePropsMap),
std::move(ownedInitializers)
);
}
} // namespace Dml

View file

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/op_kernel.h"
#include "GraphDescBuilder.h"
#include "DmlRuntimeGraphFusionTransformer.h"
namespace Dml
{
onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel(
const onnxruntime::OpKernelInfo& info,
std::shared_ptr<const onnxruntime::IndexedSubGraph> indexedSubGraph,
const onnxruntime::Path& modelPath,
std::vector<std::shared_ptr<onnxruntime::Node>>&& subgraphNodes,
std::vector<const onnxruntime::NodeArg*>&& subgraphInputs,
std::vector<const onnxruntime::NodeArg*>&& subgraphOutputs,
std::vector<std::shared_ptr<onnxruntime::NodeArg>>&& intermediateNodeArgs,
std::unordered_map<std::string, GraphNodeProperties>&& partitionNodePropsMap,
std::vector<ONNX_NAMESPACE::TensorProto>&& ownedInitializers
);
} // namespace Dml

View file

@ -0,0 +1,161 @@
#pragma once
#include "precomp.h"
#include "GraphDescBuilder.h"
#include "ExecutionProvider.h"
#include "DmlRuntimeGraphFusionTransformer.h"
#include "GraphPartitioner.h"
#include "core/framework/kernel_type_str_resolver.h"
#include "core/framework/kernel_lookup.h"
#include "core/optimizer/constant_sharing.h"
#include "DmlRuntimeFusedGraphKernel.h"
#include "MLOperatorAuthorImpl.h"
#include "DmlGraphFusionHelper.h"
namespace Dml
{
namespace
{
struct CompiledPartitionInfo
{
std::shared_ptr<onnxruntime::IndexedSubGraph> indexedSubGraph;
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>> isInitializerTransferable;
};
}
DmlRuntimeGraphFusionTransformer::DmlRuntimeGraphFusionTransformer(
const std::string& name,
const onnxruntime::IExecutionProvider* provider
)
:onnxruntime::GraphTransformer(name),
m_providerImpl(static_cast<const ExecutionProvider*>(provider)->GetImpl())
{
}
onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImpl(
onnxruntime::Graph& graph,
bool& modified,
int graphLevel,
const onnxruntime::logging::Logger& logger) const
{
return ApplyImplHelper(graph, modified, graphLevel, logger, {});
}
onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImplHelper(
onnxruntime::Graph& graph,
bool& modified,
int graphLevel,
const onnxruntime::logging::Logger& logger,
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputDefs) const
{
onnxruntime::ProviderType providerType = onnxruntime::kDmlExecutionProvider;
const gsl::not_null<const onnxruntime::KernelRegistry*> registry = m_providerImpl->GetKernelRegistry().get();
const auto kernelTypeStrResolver = onnxruntime::OpSchemaKernelTypeStrResolver{};
const auto kernelLookup = onnxruntime::KernelLookup(
providerType,
gsl::make_span(&registry, 1),
kernelTypeStrResolver);
onnxruntime::GraphViewer graphViewer(graph);
const auto& nodeTopologyList = graphViewer.GetNodesInTopologicalOrder();
for (auto nodeIndex : nodeTopologyList)
{
auto* node = graph.GetNode(nodeIndex);
if (!node)
{
continue; // node was removed
}
std::unordered_map<std::string, const onnxruntime::NodeArg*> subgraphImplicitInputDefs;
for (const onnxruntime::NodeArg* inputDef : node->ImplicitInputDefs())
{
subgraphImplicitInputDefs[inputDef->Name()] = inputDef;
}
for (auto& entry : node->GetAttributeNameToMutableSubgraphMap())
{
auto& subgraph = *entry.second;
ORT_RETURN_IF_ERROR(ApplyImplHelper(subgraph, modified, graphLevel + 1, logger, subgraphImplicitInputDefs));
}
}
// Initializers needed by any graph partition
std::vector<onnxruntime::NodeIndex> additionalSplittingNodes;
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties> graphNodePropertyMap;
std::unordered_set<std::string> requiredInitializerMap;
std::unordered_set<std::string> dynamicCpuInputMap;
std::vector<std::unique_ptr<GraphPartition>> partitions = BuildPartitions(
graphViewer,
*m_providerImpl->GetInternalRegistrationInfoMap(),
kernelLookup,
m_providerImpl->GetSupportedDeviceDataTypeMask(),
graphNodePropertyMap,
requiredInitializerMap,
dynamicCpuInputMap,
additionalSplittingNodes,
implicitInputDefs,
true);
// Reset the splitting nodes for the current iteration
additionalSplittingNodes.clear();
// Reset the compiled operators for the current iteration
std::vector<std::shared_ptr<CompiledPartitionInfo>> compiledPartitionInfos(partitions.size());
// Create a map between each initialized tensor and the partition(s) it is part of.
auto initializerPartitionMap = DmlGraphFusionHelper::GetInitializerToPartitionMap(graphViewer, partitions);
for (uint32_t partitionIndex = 0; partitionIndex < partitions.size(); ++partitionIndex)
{
auto& partition = partitions[partitionIndex];
if (partition->GetRootMergedPartition() != partition.get() ||
!partition->IsDmlPartition())
{
continue;
}
if (partition->IsDmlGraphPartition())
{
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>> isInitializerTransferable;
std::string partitionKernelPrefix = std::to_string(m_providerImpl->GetPartitionKernelPrefixVal()) + "_";
m_providerImpl->IncreasePartitionKernelPrefixVal();
// populate isInitializerTransferable
for (const auto& input : partition->GetInputs())
{
const onnx::TensorProto* tensor = nullptr;
if (graph.GetInitializedTensor(input, tensor) && requiredInitializerMap.find(input) != requiredInitializerMap.end())
{
isInitializerTransferable[input] = {tensor, false};
}
}
compiledPartitionInfos[partitionIndex] = std::make_shared<CompiledPartitionInfo>();
compiledPartitionInfos[partitionIndex]->indexedSubGraph = std::make_shared<onnxruntime::IndexedSubGraph>(
DmlGraphFusionHelper::CreateIndexedSubGraph(partition.get(), partitionIndex, partitionKernelPrefix));
compiledPartitionInfos[partitionIndex]->isInitializerTransferable = std::move(isInitializerTransferable);
}
}
for (auto&& compiledPartitionInfo : compiledPartitionInfos)
{
// Null compiled operators were not DML partitions
if (compiledPartitionInfo)
{
DmlGraphFusionHelper::RegisterDynamicKernel(
graph,
m_providerImpl->GetKernelRegistry().get(),
m_providerImpl,
graphNodePropertyMap,
dynamicCpuInputMap,
std::move(compiledPartitionInfo->indexedSubGraph),
std::move(compiledPartitionInfo->isInitializerTransferable));
}
}
return onnxruntime::common::Status::OK();
}
}

View file

@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include <unordered_map>
#include "core/optimizer/graph_transformer.h"
#include "core/framework/execution_providers.h"
namespace Dml
{
class ExecutionProviderImpl;
class DmlRuntimeGraphFusionTransformer : public onnxruntime::GraphTransformer
{
public:
DmlRuntimeGraphFusionTransformer(
const std::string& name,
const onnxruntime::IExecutionProvider* provider
);
public:
static inline const char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlRuntimeFusedNode_";
static inline const char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlRuntimeFusedNodeDomain";
private:
onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph,
bool& modified,
int graphLevel,
const onnxruntime::logging::Logger& logger) const final;
onnxruntime::common::Status ApplyImplHelper(
onnxruntime::Graph& graph,
bool& modified,
int graphLevel,
const onnxruntime::logging::Logger& logger,
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputDefs) const;
private:
const ExecutionProviderImpl* m_providerImpl = nullptr;
};
}

View file

@ -67,7 +67,8 @@ namespace Dml
ExecutionProvider::ExecutionProvider(
IDMLDevice* dmlDevice,
ID3D12CommandQueue* commandQueue,
bool enableMetacommands) :
bool enableMetacommands,
bool enableDynamicGraphFusion) :
IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0))
{
D3D12_COMMAND_LIST_TYPE queueType = commandQueue->GetDesc().Type;
@ -80,7 +81,7 @@ namespace Dml
ComPtr<ID3D12Device> device;
GRAPHICS_THROW_IF_FAILED(commandQueue->GetDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf())));
m_impl = wil::MakeOrThrow<ExecutionProviderImpl>(dmlDevice, device.Get(), commandQueue, enableMetacommands);
m_impl = wil::MakeOrThrow<ExecutionProviderImpl>(dmlDevice, device.Get(), commandQueue, enableMetacommands, enableDynamicGraphFusion);
}
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
@ -147,12 +148,12 @@ namespace Dml
// Task 24384515: Update ORT AIInfra release agent pool to install 19H1 SDK on VM bootstrap
#define D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE ((D3D_FEATURE_LEVEL)0x1000)
ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands)
ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands, bool enableDynamicGraphFusion)
: m_d3d12Device(d3d12Device),
m_dmlDevice(dmlDevice),
m_areMetacommandsEnabled(enableMetacommands)
m_areMetacommandsEnabled(enableMetacommands),
m_dynamicGraphFusionEnabled(enableDynamicGraphFusion)
{
D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {};
D3D_FEATURE_LEVEL featureLevelsList[] = {
@ -1093,6 +1094,11 @@ namespace Dml
return m_areMetacommandsEnabled;
}
bool ExecutionProviderImpl::DynamicGraphFusionEnabled() const noexcept
{
return m_dynamicGraphFusionEnabled;
}
std::shared_ptr<const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfoMap>
ExecutionProviderImpl::GetInternalRegistrationInfoMap() const
{
@ -1129,9 +1135,10 @@ namespace Dml
std::unique_ptr<onnxruntime::IExecutionProvider> CreateExecutionProvider(
IDMLDevice* dmlDevice,
ID3D12CommandQueue* commandQueue,
bool enableMetacommands)
bool enableMetacommands,
bool enableDynamicGraphFusion)
{
return std::make_unique<Dml::ExecutionProvider>(dmlDevice, commandQueue, enableMetacommands);
return std::make_unique<Dml::ExecutionProvider>(dmlDevice, commandQueue, enableMetacommands, enableDynamicGraphFusion);
}
ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr)

View file

@ -5,6 +5,7 @@
#include "GraphTransformer.h"
#include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h"
#include "core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h"
#include <wrl/client.h>
#include <wrl/implements.h>
@ -34,7 +35,8 @@ namespace Dml
IDMLDevice* dmlDevice,
ID3D12Device* d3d12Device,
ID3D12CommandQueue* queue,
bool enableMetacommands = true);
bool enableMetacommands,
bool enableDynamicGraphFusion);
void ReleaseCompletedReferences();
@ -150,6 +152,7 @@ namespace Dml
STDMETHOD_(bool, IsMcdmDevice)() const noexcept final;
STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final;
bool DynamicGraphFusionEnabled() const noexcept;
std::shared_ptr<onnxruntime::IAllocator> GetGpuAllocator();
std::shared_ptr<onnxruntime::IAllocator> GetCpuInputAllocator();
@ -184,6 +187,7 @@ namespace Dml
ComPtr<IDMLDevice> m_dmlDevice;
bool m_isMcdmDevice = false;
bool m_areMetacommandsEnabled = true;
bool m_dynamicGraphFusionEnabled = false;
bool m_native16BitShaderOpsSupported = false;
std::shared_ptr<ExecutionContext> m_context;
std::unique_ptr<PooledUploadHeap> m_uploadHeap;
@ -236,7 +240,8 @@ namespace Dml
explicit ExecutionProvider(
IDMLDevice* dmlDevice,
ID3D12CommandQueue* commandQueue,
bool enableMetacommands = true
bool enableMetacommands,
bool enableDynamicGraphFusion
);
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const final override
@ -299,9 +304,9 @@ namespace Dml
return m_impl.Get();
}
void MetacommandsEnabled()
bool DynamicGraphFusionEnabled() const
{
m_impl->MetacommandsEnabled();
return m_impl->DynamicGraphFusionEnabled();
}
virtual std::vector<onnxruntime::AllocatorPtr> CreatePreferredAllocators() override

View file

@ -147,14 +147,14 @@ namespace Dml::GraphDescBuilder
const uint8_t* isConstGpuGraphInput,
const size_t isConstGpuGraphInputCount,
const std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>& isInitializerTransferable,
const onnxruntime::Graph& graph,
const onnxruntime::IndexedSubGraph& indexedSubGraph,
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
IDMLDevice* device,
const void* executionHandle)
const void* executionHandle,
const onnxruntime::Path& modelPath,
gsl::span<const onnxruntime::Node* const> subgraphNodes,
gsl::span<const onnxruntime::NodeArg* const> subgraphInputs,
gsl::span<const onnxruntime::NodeArg* const> subgraphOutputs)
{
const gsl::span<const std::string> subGraphInputArgNames = indexedSubGraph.GetMetaDef()->inputs;
const gsl::span<const std::string> subGraphOutputArgNames = indexedSubGraph.GetMetaDef()->outputs;
struct NodeAndIndex
{
uint32_t nodeIndex; // The index of the node itself
@ -164,12 +164,14 @@ namespace Dml::GraphDescBuilder
// Map from Lotus node argument names to the new node and index where it will be produced
std::unordered_map<std::string, NodeAndIndex> nameToNodeAndIndexMap;
std::unordered_map<std::string, EdgeShapes> nodeOutputShapes;
// Map from Lotus node argument names to input indices of the fused kernel node.
std::unordered_map<std::string, uint32_t> nameToDmlFusedNodeInputIndex;
for (size_t inputIndex = 0; inputIndex < subGraphInputArgNames.size(); ++inputIndex)
for (size_t inputIndex = 0; inputIndex < subgraphInputs.size(); ++inputIndex)
{
const onnxruntime::NodeArg* graphInput = graph.GetNodeArg(subGraphInputArgNames[inputIndex]);
const onnxruntime::NodeArg* graphInput = subgraphInputs[inputIndex];
if (!graphInput)
{
@ -196,13 +198,11 @@ namespace Dml::GraphDescBuilder
const uint32_t minNodeCountToReuseCommandList = 5;
bool reuseCommandList = false;
if (indexedSubGraph.nodes.size() >= minNodeCountToReuseCommandList)
if (subgraphNodes.size() >= minNodeCountToReuseCommandList)
{
reuseCommandList = true;
}
auto modelPath = graph.ModelPath();
auto constantCpuGraphInputGetter = [&isInitializerTransferable, &modelPath](const std::string& argName)
{
ComPtr<OnnxTensorWrapper> tensorWrapper;
@ -219,9 +219,11 @@ namespace Dml::GraphDescBuilder
// Iterate through each node and create a corresponding node in the new graph
// We can iterate the nodes in any order because the edge connectivity will take care of the topological order
for (size_t sortedNodeIndex : indexedSubGraph.nodes)
std::unordered_map<std::string, std::vector<uint32_t>> inferredOutputShapes;
for (const onnxruntime::Node* subgraphNode : subgraphNodes)
{
const onnxruntime::Node& node = *graph.GetNode(sortedNodeIndex);
const onnxruntime::Node& node = *subgraphNode;
const GraphNodeProperties& graphNodeProps = graphNodePropertyMap.find(GetUniqueNodeName(node))->second;
const auto& requiredConstantCpuInputs = graphNodeProps.internalRegInfo->requiredConstantCpuInputs;
@ -244,14 +246,45 @@ namespace Dml::GraphDescBuilder
return tensor;
};
EdgeShapes inputShapesOverrides(node.InputDefs().size());
// Override the input shapes with shapes that were previously inferred
for (int inputIndex = 0; inputIndex < node.InputDefs().size(); ++inputIndex)
{
auto inputDef = node.InputDefs()[inputIndex];
auto outputShapesIter = inferredOutputShapes.find(inputDef->Name());
if (outputShapesIter != inferredOutputShapes.end())
{
inputShapesOverrides.GetMutableShape(inputIndex) = outputShapesIter->second;
}
else if (inputDef->HasTensorOrScalarShape())
{
for (int i = 0; i < inputDef->Shape()->dim_size(); ++i)
{
ORT_THROW_HR_IF(E_INVALIDARG, !inputDef->Shape()->dim(i).has_dim_value());
inputShapesOverrides.GetMutableShape(inputIndex).push_back(gsl::narrow_cast<uint32_t>(inputDef->Shape()->dim(i).dim_value()));
}
}
}
EdgeShapes outputShapes;
DmlGraphNodeCreateInfo graphNodeCreateInfo;
graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory(
node,
constantCpuNodeInputGetter,
executionHandle,
&inputShapesOverrides,
/*out*/ &outputShapes,
/*out*/ &graphNodeCreateInfo
);
ORT_THROW_HR_IF(E_UNEXPECTED, outputShapes.EdgeCount() != node.OutputDefs().size());
for (int i = 0; i < node.OutputDefs().size(); ++i)
{
inferredOutputShapes[node.OutputDefs()[i]->Name()] = outputShapes.GetShape(i);
}
// Create a map between operatorGraphNodeIndex to mainGraphNodeIndex.
std::unordered_map<uint32_t, uint32_t> operatorGraphNodeIndexToMainGraphNodeIndexMap;
uint32_t graphNodeCount = gsl::narrow_cast<uint32_t>(graphNodes.size());
@ -347,6 +380,8 @@ namespace Dml::GraphDescBuilder
operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphOutputEdge.FromNodeIndex],
operatorGraphOutputEdge.FromNodeOutputIndex
};
nodeOutputShapes[arg->Name()] = outputShapes;
}
}
@ -367,10 +402,12 @@ namespace Dml::GraphDescBuilder
}
}
EdgeShapes graphOutputShapes(subgraphOutputs.size());
// Add graph output nodes, which might be in a different order from the encapsulating node
for (size_t outputIndex = 0; outputIndex < subGraphOutputArgNames.size(); ++outputIndex)
for (size_t outputIndex = 0; outputIndex < subgraphOutputs.size(); ++outputIndex)
{
const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(subGraphOutputArgNames[outputIndex]);
const onnxruntime::NodeArg* graphOutput = subgraphOutputs[outputIndex];
ORT_THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg");
const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name());
@ -380,6 +417,7 @@ namespace Dml::GraphDescBuilder
edge.FromNodeOutputIndex = outputNodeAndIndex.targetIndex;
edge.GraphOutputIndex = gsl::narrow_cast<uint32_t>(outputIndex);
graphOutputEdges.push_back(edge);
graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex);
}
RemoveUnconnectedNodes(graphNodes, graphInputEdges, graphIntermediateEdges, graphOutputEdges);
@ -390,6 +428,7 @@ namespace Dml::GraphDescBuilder
graphDesc.outputEdges = std::move(graphOutputEdges);
graphDesc.intermediateEdges = std::move(graphIntermediateEdges);
graphDesc.reuseCommandList = reuseCommandList;
graphDesc.outputShapes = std::move(graphOutputShapes);
return graphDesc;
}
}

View file

@ -9,10 +9,10 @@ namespace Dml
{
struct GraphNodeProperties
{
std::shared_ptr<const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfo>
std::shared_ptr<const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfo>
internalRegInfo;
// These are currently passed from the partitioning step since the only DML operators current
// These are currently passed from the partitioning step since the only DML operators current
// supporting graph nodes don't customize the order of edges or shapes, other than coercing
// dimension count. This will change as the supported set of operators as graph nodes increases.
Windows::AI::MachineLearning::Adapter::EdgeShapes inputShapes;
@ -38,16 +38,19 @@ namespace Dml
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
bool reuseCommandList;
Windows::AI::MachineLearning::Adapter::EdgeShapes outputShapes;
};
GraphDesc BuildGraphDesc(
const uint8_t* isConstGpuGraphInput,
const size_t isConstGpuGraphInputCount,
const std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>& isInitializerTransferable,
const onnxruntime::Graph& graph,
const onnxruntime::IndexedSubGraph& indexedSubGraph,
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
IDMLDevice* device,
const void* executionHandle);
const void* executionHandle,
const onnxruntime::Path& modelPath,
gsl::span<const onnxruntime::Node* const> subgraphNodes,
gsl::span<const onnxruntime::NodeArg* const> subgraphInputs,
gsl::span<const onnxruntime::NodeArg* const> subgraphOutputs);
}
}
}

View file

@ -151,6 +151,8 @@ namespace Dml
_In_opt_ const std::unordered_map<std::string, GraphPartition*>* nodeNameToPartitionMap,
_Inout_ std::unordered_map<const onnxruntime::Node*, GraphNodeProperties>& dmlNodePropertyMap,
_Inout_ std::unordered_set<std::string>& requiredInitializerMap,
_Inout_ std::unordered_set<std::string>& dynamicCpuInputMap,
bool allowDmlGraphDynamicShapes,
_Out_ bool* isDmlGraphNode
)
{
@ -172,36 +174,68 @@ namespace Dml
if (internalRegInfo && internalRegInfo->graphNodeFactoryRegistration)
{
bool requiredCpuInputsConstant = true;
for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs)
if (allowDmlGraphDynamicShapes)
{
if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists())
for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs)
{
continue;
if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists())
{
continue;
}
const onnx::TensorProto* tensor = nullptr;
const std::string& inputName = node.InputDefs()[inputIndex]->Name();
if (graph.GetInitializedTensor(inputName, tensor))
{
requiredInitializerMap.insert(inputName);
}
else
{
dynamicCpuInputMap.insert(inputName);
}
}
const onnx::TensorProto* tensor = nullptr;
const std::string& inputName = node.InputDefs()[inputIndex]->Name();
if (!graph.GetInitializedTensor(inputName, tensor))
std::optional<uint32_t> requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount;
if (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())
{
requiredCpuInputsConstant = false;
break;
*isDmlGraphNode = true;
graphNodeProperty.first->second.internalRegInfo = internalRegInfo;
}
requiredInitializerMap.insert(inputName);
}
std::optional<uint32_t> requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount;
if (requiredCpuInputsConstant &&
TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) &&
!ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) &&
TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) &&
!ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) &&
(requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size()))
else
{
*isDmlGraphNode = true;
graphNodeProperty.first->second.internalRegInfo = internalRegInfo;
bool requiredCpuInputsConstant = true;
for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs)
{
if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists())
{
continue;
}
const onnx::TensorProto* tensor = nullptr;
const std::string& inputName = node.InputDefs()[inputIndex]->Name();
if (!graph.GetInitializedTensor(inputName, tensor))
{
requiredCpuInputsConstant = false;
break;
}
requiredInitializerMap.insert(inputName);
}
std::optional<uint32_t> requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount;
if (requiredCpuInputsConstant &&
TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) &&
!ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) &&
TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) &&
!ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) &&
(requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size()))
{
*isDmlGraphNode = true;
graphNodeProperty.first->second.internalRegInfo = internalRegInfo;
}
}
}
}
@ -379,8 +413,10 @@ namespace Dml
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties>& graphNodePropertyMap,
std::unordered_set<std::string>& requiredInitializerMap,
std::unordered_set<std::string>& dynamicCpuInputMap,
gsl::span<const onnxruntime::NodeIndex> additionalSplittingNodes,
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputs)
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputs,
bool allowDmlGraphDynamicShapes)
{
// Nodes are uniquely identified by the name of their first output argument
std::vector<std::unique_ptr<GraphPartition>> partitions;
@ -443,6 +479,8 @@ namespace Dml
&nodeNameToPartitionMap,
graphNodePropertyMap,
requiredInitializerMap,
dynamicCpuInputMap,
allowDmlGraphDynamicShapes,
/*out*/ &isDmlGraphNode
);
}

View file

@ -50,6 +50,8 @@ namespace Dml
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties>& graphNodePropertyMap,
std::unordered_set<std::string>& requiredInitializerMap,
std::unordered_set<std::string>& dynamicCpuInputMap,
gsl::span<const onnxruntime::NodeIndex> additionalSplittingNodes,
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputs);
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputs,
bool allowDmlGraphDynamicShapes);
} // namespace Dml

View file

@ -2,8 +2,15 @@
// Licensed under the MIT License.
#pragma once
#include <d3d12.h>
#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h"
interface IDMLCompiledOperator;
struct DML_BUFFER_BINDING;
struct DML_BINDING_DESC;
namespace Dml
{
struct Binding

View file

@ -1356,13 +1356,14 @@ namespace Windows::AI::MachineLearning::Adapter
const onnxruntime::OpNodeProtoHelper<onnxruntime::ProtoHelperNodeContext>* protoHelper,
const void* executionHandle,
bool isInternalOperator,
const EdgeShapes* inputShapesOverrides,
const EdgeShapes* inferredOutputShapes,
const AttributeMap* defaultAttributes,
DmlGraphNodeCreateInfo* graphNodeCreateInfo,
gsl::span<const uint32_t> requiredConstantCpuInputs,
MLOperatorTensorGetter& constantInputGetter
)
: OpNodeInfoWrapper(protoHelper, nullptr, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, nullptr),
: OpNodeInfoWrapper(protoHelper, inputShapesOverrides, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, nullptr),
m_inferredOutputShapes(inferredOutputShapes),
m_internalOperator(isInternalOperator),
m_graphNodeCreateInfo(graphNodeCreateInfo)

View file

@ -4,6 +4,7 @@
#pragma once
#include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h"
#include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h"
#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h"
#include "core/framework/op_kernel.h"
#include "core/framework/customregistry.h"
#include "core/framework/tensorprotoutils.h"
@ -93,42 +94,6 @@ public:
using AttributeMap = std::map<std::string, AttributeValue>;
// Encapsulation of shapes across different edges of an operator. Non-tensor
// edges and unused edges have an empty array of dimensions.
class EdgeShapes
{
public:
EdgeShapes() = default;
EdgeShapes(size_t count) : m_shapes(count) {}
const std::vector<uint32_t>& GetShape(size_t edgeIndex) const
{
return m_shapes[edgeIndex];
}
std::vector<uint32_t>& GetMutableShape(size_t edgeIndex)
{
return m_shapes[edgeIndex];
}
size_t EdgeCount() const { return m_shapes.size(); }
void Reset(size_t edge_count)
{
m_shapes.clear();
m_shapes.resize(edge_count);
}
bool operator!=(const EdgeShapes& other) const noexcept
{
return (m_shapes != other.m_shapes);
}
private:
std::vector<std::vector<uint32_t>> m_shapes;
};
// Base class for ABI objects which may be "Closed", at which point calls will predictably
// fail or return a dummy value. This is used for transient ABI context objects which
// are passed to methods on kernel or inferencers, and which wrap Lotus objects whose lifetimes
@ -434,6 +399,7 @@ class DmlGraphOpKernelInfoWrapper : public OpNodeInfoWrapper<
const onnxruntime::OpNodeProtoHelper<onnxruntime::ProtoHelperNodeContext> * protoHelper,
const void* executionHandle,
bool isInternalOperator,
const EdgeShapes* inputShapesOverrides,
const EdgeShapes* inferredOutputShapes,
const AttributeMap* defaultAttributes,
DmlGraphNodeCreateInfo* graphNodeCreateInfo,

View file

@ -30,8 +30,12 @@ namespace onnxruntime {
struct DMLProviderFactory : IExecutionProviderFactory {
DMLProviderFactory(IDMLDevice* dml_device,
ID3D12CommandQueue* cmd_queue) : dml_device_(dml_device),
cmd_queue_(cmd_queue) {}
ID3D12CommandQueue* cmd_queue,
bool disable_metacommands,
bool enable_dynamic_graph_fusion) : dml_device_(dml_device),
cmd_queue_(cmd_queue),
metacommands_enabled_(!disable_metacommands),
dynamic_graph_fusion_enabled_(enable_dynamic_graph_fusion) {}
~DMLProviderFactory() override {}
std::unique_ptr<IExecutionProvider> CreateProvider() override;
@ -42,10 +46,11 @@ struct DMLProviderFactory : IExecutionProviderFactory {
ComPtr<IDMLDevice> dml_device_{};
ComPtr<ID3D12CommandQueue> cmd_queue_{};
bool metacommands_enabled_ = true;
bool dynamic_graph_fusion_enabled_ = false;
};
std::unique_ptr<IExecutionProvider> DMLProviderFactory::CreateProvider() {
auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get(), metacommands_enabled_);
auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get(), metacommands_enabled_, dynamic_graph_fusion_enabled_);
return provider;
}
@ -54,7 +59,9 @@ void DMLProviderFactory::SetMetacommandsEnabled(bool metacommands_enabled) {
}
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(IDMLDevice* dml_device,
ID3D12CommandQueue* cmd_queue) {
ID3D12CommandQueue* cmd_queue,
bool disable_metacommands,
bool enable_dynamic_graph_fusion) {
#ifndef _GAMING_XBOX
// Validate that the D3D12 devices match between DML and the command queue. This specifically asks for IUnknown in
// order to be able to compare the pointers for COM object identity.
@ -73,7 +80,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(ID
const Env& env = Env::Default();
auto luid = d3d12_device->GetAdapterLuid();
env.GetTelemetryProvider().LogExecutionProviderEvent(&luid);
return std::make_shared<onnxruntime::DMLProviderFactory>(dml_device, cmd_queue);
return std::make_shared<onnxruntime::DMLProviderFactory>(dml_device, cmd_queue, disable_metacommands, enable_dynamic_graph_fusion);
}
void DmlConfigureProviderFactoryMetacommandsEnabled(IExecutionProviderFactory* factory, bool metacommandsEnabled) {
@ -234,12 +241,10 @@ static void SortHeterogenousDXCoreAdapterList(
std::sort(adapter_infos.begin(), adapter_infos.end(), policy);
}
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::Create(int device_id) {
return Create(device_id, /*skip_software_device_check*/ false);
}
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFromOptions(
OrtDmlDeviceOptions* device_options) {
OrtDmlDeviceOptions* device_options,
bool disable_metacommands,
bool enable_dynamic_graph_fusion) {
auto default_device_options = OrtDmlDeviceOptions { Default, Gpu };
if (device_options == nullptr) {
device_options = &default_device_options;
@ -286,7 +291,7 @@ std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFrom
adapters.begin(),
[](auto& a){ return a.Adapter; });
return onnxruntime::DMLProviderFactoryCreator::CreateFromAdapterList(std::move(adapters));
return onnxruntime::DMLProviderFactoryCreator::CreateFromAdapterList(std::move(adapters), disable_metacommands, enable_dynamic_graph_fusion);
}
static std::optional<OrtDmlPerformancePreference> ParsePerformancePreference(const ProviderOptions& provider_options) {
@ -354,12 +359,32 @@ static std::optional<int> ParseDeviceId(const ProviderOptions& provider_options)
return {};
}
static bool ParseBoolean(const ProviderOptions& provider_options, const std::string& key) {
auto preference_it = provider_options.find(key);
if (preference_it != provider_options.end() && !preference_it->second.empty()) {
if (preference_it->second == "True" || preference_it->second == "true") {
return true;
} else if (preference_it->second == "False" || preference_it->second == "false") {
return false;
} else {
ORT_THROW("[ERROR] [DirectML] The value for the key '" + key + "' should be 'True' or 'False'. Default value is 'False'.\n");
}
}
return false;
}
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFromProviderOptions(
const ProviderOptions& provider_options) {
const ProviderOptions& provider_options) {
bool disable_metacommands = ParseBoolean(provider_options, "disable_metacommands");
bool enable_dynamic_graph_fusion = ParseBoolean(provider_options, "enable_dynamic_graph_fusion");
bool skip_software_device_check = false;
auto device_id = ParseDeviceId(provider_options);
if (device_id.has_value())
{
return onnxruntime::DMLProviderFactoryCreator::Create(device_id.value());
return onnxruntime::DMLProviderFactoryCreator::Create(device_id.value(), skip_software_device_check, disable_metacommands, enable_dynamic_graph_fusion);
}
auto preference = ParsePerformancePreference(provider_options);
@ -367,7 +392,7 @@ std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFrom
// If no preference/filters are specified then create with default preference/filters.
if (!preference.has_value() && !filter.has_value()) {
return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(nullptr);
return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(nullptr, disable_metacommands, enable_dynamic_graph_fusion);
}
if (!preference.has_value()) {
@ -381,7 +406,7 @@ std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFrom
OrtDmlDeviceOptions device_options;
device_options.Preference = preference.value();
device_options.Filter = filter.value();
return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(&device_options);
return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(&device_options, disable_metacommands, enable_dynamic_graph_fusion);
}
Microsoft::WRL::ComPtr<ID3D12Device> DMLProviderFactoryCreator::CreateD3D12Device(
@ -441,7 +466,10 @@ Microsoft::WRL::ComPtr<IDMLDevice> DMLProviderFactoryCreator::CreateDMLDevice(ID
return dml_device;
}
std::shared_ptr<IExecutionProviderFactory> CreateDMLDeviceAndProviderFactory(ID3D12Device* d3d12_device) {
std::shared_ptr<IExecutionProviderFactory> CreateDMLDeviceAndProviderFactory(
ID3D12Device* d3d12_device,
bool disable_metacommands,
bool enable_dynamic_graph_fusion) {
D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {};
cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT;
@ -450,16 +478,22 @@ std::shared_ptr<IExecutionProviderFactory> CreateDMLDeviceAndProviderFactory(ID3
ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf())));
auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device);
return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get());
return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get(), disable_metacommands, enable_dynamic_graph_fusion);
}
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) {
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::Create(
int device_id,
bool skip_software_device_check,
bool disable_metacommands,
bool enable_dynamic_graph_fusion) {
ComPtr<ID3D12Device> d3d12_device = CreateD3D12Device(device_id, skip_software_device_check);
return CreateDMLDeviceAndProviderFactory(d3d12_device.Get());
return CreateDMLDeviceAndProviderFactory(d3d12_device.Get(), disable_metacommands, enable_dynamic_graph_fusion);
}
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFromAdapterList(
std::vector<ComPtr<IDXCoreAdapter>>&& dxcore_devices) {
std::vector<ComPtr<IDXCoreAdapter>>&& dxcore_devices,
bool disable_metacommands,
bool enable_dynamic_graph_fusion) {
// Choose the first device from the list since it's the highest priority
auto dxcore_device = dxcore_devices[0];
@ -467,7 +501,7 @@ std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFrom
ComPtr<ID3D12Device> d3d12_device;
ORT_THROW_IF_FAILED(D3D12CreateDevice(dxcore_device.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf())));
return CreateDMLDeviceAndProviderFactory(d3d12_device.Get());
return CreateDMLDeviceAndProviderFactory(d3d12_device.Get(), disable_metacommands, enable_dynamic_graph_fusion);
}
} // namespace onnxruntime
@ -477,7 +511,7 @@ std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFrom
// The OrtSessionOptionsAppendExecutionProvider_DML export on the OrtDmlApi should be used instead.
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id) {
API_IMPL_BEGIN
options->provider_factories.push_back(onnxruntime::DMLProviderFactoryCreator::Create(device_id));
options->provider_factories.push_back(onnxruntime::DMLProviderFactoryCreator::Create(device_id, false, false, false));
API_IMPL_END
return nullptr;
}
@ -489,7 +523,9 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSess
_In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue) {
API_IMPL_BEGIN
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_DML(dml_device,
cmd_queue));
cmd_queue,
false,
false));
API_IMPL_END
return nullptr;
}
@ -517,7 +553,7 @@ ORT_API_STATUS_IMPL(FreeGPUAllocation, _In_ void* ptr) {
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_options) {
API_IMPL_BEGIN
#ifdef USE_DML
auto factory = onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(device_options);
auto factory = onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(device_options, false, false);
// return the create function for a dxcore device
options->provider_factories.push_back(factory);
#endif // USE_DML

View file

@ -17,15 +17,24 @@
namespace onnxruntime {
struct DMLProviderFactoryCreator {
static std::shared_ptr<IExecutionProviderFactory> Create(int device_id);
static std::shared_ptr<IExecutionProviderFactory> Create(int device_id, bool skip_software_device_check);
static std::shared_ptr<IExecutionProviderFactory> Create(
int device_id,
bool skip_software_device_check,
bool disable_metacommands,
bool enable_dynamic_graph_fusion);
static std::shared_ptr<IExecutionProviderFactory> CreateFromProviderOptions(
const ProviderOptions& provider_options_map);
static std::shared_ptr<IExecutionProviderFactory> CreateFromOptions(OrtDmlDeviceOptions* device_options);
static std::shared_ptr<IExecutionProviderFactory> CreateFromOptions(
OrtDmlDeviceOptions* device_options,
bool disable_metacommands,
bool enable_dynamic_graph_fusion);
static std::shared_ptr<IExecutionProviderFactory> CreateFromAdapterList(
std::vector<Microsoft::WRL::ComPtr<IDXCoreAdapter>>&& dxcore_devices);
std::vector<Microsoft::WRL::ComPtr<IDXCoreAdapter>>&& dxcore_devices,
bool disable_metacommands,
bool enable_dynamic_graph_fusion);
static Microsoft::WRL::ComPtr<ID3D12Device> CreateD3D12Device(int device_id, bool skip_software_device_check);
static Microsoft::WRL::ComPtr<IDMLDevice> CreateDMLDevice(ID3D12Device* d3d12_device);

View file

@ -52,8 +52,10 @@
#include "core/providers/cpu/cpu_execution_provider.h"
#ifdef USE_DML // TODO: This is necessary for the workaround in TransformGraph
#include "core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h"
#include "core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h"
#include "core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h"
#include "core/providers/dml/dml_session_options_config_keys.h"
#include "core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h"
#endif
#include "core/session/environment.h"
#include "core/session/user_logging_sink.h"
@ -1598,7 +1600,9 @@ common::Status InferenceSession::Initialize() {
record_runtime_optimization_produced_op_schema));
#ifdef USE_DML
if (execution_providers_.Get(kDmlExecutionProvider)) {
const IExecutionProvider* dmlExecutionProvider = execution_providers_.Get(kDmlExecutionProvider);
if (dmlExecutionProvider) {
// DML graph fusion is an important runtime optimization that cannot be done ahead of time; it must be disabled
// when running in "offline mode" and saving an optimized model to disk. To support users that want to optimize
// models offline, and then disable graph optimizations when running "online", this transformer ignores the ORT
@ -1608,11 +1612,20 @@ common::Status InferenceSession::Initialize() {
if (dml_graph_fusion_enabled) {
std::unique_ptr<onnxruntime::GraphTransformer> dmlGraphFusionTransformer = std::make_unique<Dml::DmlGraphFusionTransformer>("DmlGraphFusionTransformer",
execution_providers_.Get(kDmlExecutionProvider));
dmlExecutionProvider);
if (dmlGraphFusionTransformer == nullptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr");
}
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.Register(std::move(dmlGraphFusionTransformer), onnxruntime::TransformerLevel::Level3));
if (static_cast<const Dml::ExecutionProvider*>(dmlExecutionProvider)->DynamicGraphFusionEnabled()) {
std::unique_ptr<onnxruntime::GraphTransformer> dmlRuntimeGraphFusionTransformer = std::make_unique<Dml::DmlRuntimeGraphFusionTransformer>("DmlRuntimeGraphFusionTransformer",
dmlExecutionProvider);
if (dmlRuntimeGraphFusionTransformer == nullptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "DmlRuntimeGraphFusionTransformer is nullptr");
}
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.Register(std::move(dmlRuntimeGraphFusionTransformer), onnxruntime::TransformerLevel::Level3));
}
}
// This transformer applies DML-specific fusions that go beyond what ORT offers by default

View file

@ -59,7 +59,7 @@ void addGlobalSchemaFunctions(pybind11::module& m) {
onnxruntime::ArmNNProviderFactoryCreator::Create(0),
#endif
#ifdef USE_DML
onnxruntime::DMLProviderFactoryCreator::Create(0, /*skip_software_device_check*/ true),
onnxruntime::DMLProviderFactoryCreator::Create(0, false, false, false),
#endif
#ifdef USE_NNAPI
onnxruntime::NnapiProviderFactoryCreator::Create(0, std::optional<std::string>()),

View file

@ -268,7 +268,7 @@ std::unique_ptr<IExecutionProvider> DefaultCannExecutionProvider() {
std::unique_ptr<IExecutionProvider> DefaultDmlExecutionProvider() {
#ifdef USE_DML
if (auto factory = DMLProviderFactoryCreator::Create(0))
if (auto factory = DMLProviderFactoryCreator::Create(0, false, false, false))
return factory->CreateProvider();
#endif
return nullptr;