From 7201dbebe53cde956a9294a47b74d61196065d0b Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Wed, 9 Aug 2023 19:53:15 -0700 Subject: [PATCH] [DML EP] Split fused kernels when the persistent resource is too big (#16780) The approach is the following: 1. Build partitions 2. Try compiling each partition into a `IDMLCompiledOperator` 3. If the compiled operator's persistent resource is bigger than 4GB, tell the partitioner to split the partition in the middle and try again. 4. Once all partitions have been successfully compiled into an `IDMLCompiledOperator`, fuse the partitions into an ORT operator and register them all. This change is relatively simple (basically a basic retry mechanism), but it required a lot of refactoring just to make sure that we don't modify the graph until **all** partitions have been compiled successfully. This is because partly modifying the graph before making sure that all partitions can be compiled will break future retries. This path is not expected to be used a lot, and even then the loop is not expected to loop more than twice very often. This is a very specific edge case for large models that were able to merge a large number of nodes into a single partition. --- .../src/DmlGraphFusionHelper.cpp | 164 ++++++------ .../src/DmlGraphFusionHelper.h | 36 +-- .../src/DmlGraphFusionTransformer.cpp | 246 ++++++++++++------ .../src/FusedGraphKernel.cpp | 16 +- .../src/FusedGraphKernel.h | 4 +- .../src/GraphPartitioner.cpp | 21 +- .../src/GraphPartitioner.h | 2 +- 7 files changed, 293 insertions(+), 196 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 17aa197396..51b93efb3a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -106,7 +106,7 @@ namespace DmlGraphFusionHelper void ProcessInputData( const ExecutionProviderImpl* providerImpl, const std::vector& isInputsUploadedByDmlEP, - std::vector& inputEdges, + const std::vector& inputEdges, const gsl::span subGraphInputArgNames, const std::unordered_map>& initializerNameToInitializerMap, onnxruntime::Graph& graph, @@ -325,37 +325,60 @@ namespace DmlGraphFusionHelper dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data(); } - void CreateIDmlCompiledOperatorAndRegisterKernel( - onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, - const onnxruntime::Node& fusedNode, - const std::unordered_map& partitionNodePropsMap, - const std::unordered_map>& initializerNameToInitializerMap, - const ExecutionProviderImpl* providerImpl, - onnxruntime::KernelRegistry* registryForPartitionKernels) + onnxruntime::IndexedSubGraph CreateIndexedSubGraph( + GraphPartition* partition, + uint32_t partitionIndex, + const std::string& partitionKernelPrefix) { - // convert partitionONNXGraph into DML EP GraphDesc - const uint32_t fusedNodeInputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->inputs.size()); - const uint32_t fusedNodeOutputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->outputs.size()); + assert(partition->IsDmlGraphPartition()); - std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); - for (uint32_t index = 0; index < fusedNodeInputCount; ++index) + onnxruntime::IndexedSubGraph indexedSubGraph; + // Create a definition for the node. The name must be unique. + auto def = std::make_unique(); + def->name = DmlGraphFusionTransformer::DML_GRAPH_FUSION_NODE_NAME_PREFIX + partitionKernelPrefix + std::to_string(partitionIndex); + def->domain = DmlGraphFusionTransformer::DML_GRAPH_FUSION_NODE_DOMAIN; + def->since_version = 1; + def->inputs.insert(def->inputs.begin(), partition->GetInputs().begin(), partition->GetInputs().end()); + def->outputs.insert(def->outputs.begin(), partition->GetOutputs().begin(), partition->GetOutputs().end()); + + indexedSubGraph.SetMetaDef(std::move(def)); + indexedSubGraph.nodes = std::move(partition->GetNodeIndices()); + + return indexedSubGraph; + } + + std::unordered_map CreatePartitionNodePropsMap( + const onnxruntime::Graph& graph, + const onnxruntime::IndexedSubGraph& indexedSubGraph, + std::unordered_map&& graphNodePropertyMap) + { + // Populate properties which will be passed to OpKernel for this graph via the function below + std::unordered_map partitionNodePropsMap; + for (auto nodeIndex : indexedSubGraph.nodes) { - auto iter = initializerNameToInitializerMap.find(indexedSubGraph.GetMetaDef()->inputs[index]); - isInputsUploadedByDmlEP[index] = iter != initializerNameToInitializerMap.end() ? true : false; + const onnxruntime::Node* node = graph.GetNode(nodeIndex); + +#ifdef PRINT_PARTITON_INFO + printf("Partition %u\t%s\n", partitionIndex, GraphDescBuilder::GetUniqueNodeName(*node).c_str()); +#endif + partitionNodePropsMap.insert(std::make_pair( + GraphDescBuilder::GetUniqueNodeName(*node), std::move(graphNodePropertyMap[node]))); } - ComPtr device; - ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); - GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( - isInputsUploadedByDmlEP.data(), - isInputsUploadedByDmlEP.size(), - initializerNameToInitializerMap, - graph, - indexedSubGraph, - partitionNodePropsMap, - device.Get(), - providerImpl); +#ifdef PRINT_PARTITON_INFO + printf("\n"); +#endif + + return partitionNodePropsMap; + } + + Microsoft::WRL::ComPtr TryCreateCompiledOperator( + const GraphDescBuilder::GraphDesc& graphDesc, + const onnxruntime::IndexedSubGraph& indexedSubGraph, + const ExecutionProviderImpl* providerImpl) + { + const uint32_t fusedNodeInputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->inputs.size()); + const uint32_t fusedNodeOutputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->outputs.size()); // convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator DML_GRAPH_DESC dmlGraphDesc = {}; @@ -387,14 +410,42 @@ namespace DmlGraphFusionHelper executionFlags |= DML_EXECUTION_FLAG_DISABLE_META_COMMANDS; } + ComPtr device; + ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); + ComPtr device1; ORT_THROW_IF_FAILED(device.As(&device1)); + ComPtr compiledExecutionPlanOperator; ORT_THROW_IF_FAILED(device1->CompileGraph( &dmlGraphDesc, executionFlags, IID_PPV_ARGS(&compiledExecutionPlanOperator))); + // UINT32_MAX is currently the maximum number of bytes allowed by D3D12 for the offset of a view over a resource + if (compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize > UINT32_MAX) + { + return nullptr; + } + + return compiledExecutionPlanOperator; + } + + void FusePartitionAndRegisterKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const std::unordered_map>& initializerNameToInitializerMap, + const ExecutionProviderImpl* providerImpl, + const onnxruntime::IndexedSubGraph& indexedSubGraph, + std::vector&& isInputsUploadedByDmlEP, + const GraphDescBuilder::GraphDesc& graphDesc, + Microsoft::WRL::ComPtr compiledExecutionPlanOperator) + { + auto& fusedNode = graph.BeginFuseSubGraph(indexedSubGraph, indexedSubGraph.GetMetaDef()->name); + fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); + + const uint32_t fusedNodeInputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->inputs.size()); + // Populate input bindings for operator initialization std::vector> initializeResourceRefs; // For lifetime control std::vector initInputBindings(fusedNodeInputCount); @@ -424,8 +475,8 @@ namespace DmlGraphFusionHelper nonOwnedGraphInputsFromInitializers, initializeResourceRefs, initInputBindings, - isInputsUploadedByDmlEP, - inputsUsed] + isInputsUploadedByDmlEP = std::move(isInputsUploadedByDmlEP), + inputsUsed = std::move(inputsUsed)] (onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr& out) mutable ->onnxruntime::Status { out.reset(CreateFusedGraphKernel(info, @@ -435,8 +486,8 @@ namespace DmlGraphFusionHelper nonOwnedGraphInputsFromInitializers, initializeResourceRefs, initInputBindings, - isInputsUploadedByDmlEP, - inputsUsed)); + std::move(isInputsUploadedByDmlEP), + std::move(inputsUsed))); return Status::OK(); }; @@ -447,58 +498,7 @@ namespace DmlGraphFusionHelper .SinceVersion(indexedSubGraph.GetMetaDef()->since_version) .Provider(onnxruntime::kDmlExecutionProvider); ORT_THROW_IF_ERROR(registryForPartitionKernels->Register(builder, fused_kernel_func)); - } - void FusePartitionAndRegisterKernel( - GraphPartition* partition, - uint32_t partitionIndex, - onnxruntime::Graph& graph, - std::unordered_map& graphNodePropertyMap, - onnxruntime::KernelRegistry* registryForPartitionKernels, - const std::string& partitionKernelPrefix, - const std::unordered_map>& initializerNameToInitializerMap, - const ExecutionProviderImpl* providerImpl) - { - assert(partition->IsDmlGraphPartition()); - - onnxruntime::IndexedSubGraph indexedSubGraph; - // Create a definition for the node. The name must be unique. - auto def = std::make_unique(); - def->name = DmlGraphFusionTransformer::DML_GRAPH_FUSION_NODE_NAME_PREFIX + partitionKernelPrefix + std::to_string(partitionIndex); - def->domain = DmlGraphFusionTransformer::DML_GRAPH_FUSION_NODE_DOMAIN; - def->since_version = 1; - def->inputs.insert(def->inputs.begin(), partition->GetInputs().begin(), partition->GetInputs().end()); - def->outputs.insert(def->outputs.begin(), partition->GetOutputs().begin(), partition->GetOutputs().end()); - - indexedSubGraph.SetMetaDef(std::move(def)); - indexedSubGraph.nodes = std::move(partition->GetNodeIndices()); - auto& fusedNode = graph.BeginFuseSubGraph(indexedSubGraph, indexedSubGraph.GetMetaDef()->name); - fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); - - // Populate properties which will be passed to OpKernel for this graph via the function below - std::unordered_map partitionNodePropsMap; - for (auto nodeIndex : indexedSubGraph.nodes) - { - const onnxruntime::Node* node = graph.GetNode(nodeIndex); - -#ifdef PRINT_PARTITON_INFO - printf("Partition %u\t%s\n", partitionIndex, GraphDescBuilder::GetUniqueNodeName(*node).c_str()); -#endif - partitionNodePropsMap.insert(std::make_pair( - GraphDescBuilder::GetUniqueNodeName(*node), std::move(graphNodePropertyMap[node]))); - } - -#ifdef PRINT_PARTITON_INFO - printf("\n"); -#endif - CreateIDmlCompiledOperatorAndRegisterKernel( - graph, - indexedSubGraph, - fusedNode, - partitionNodePropsMap, - initializerNameToInitializerMap, - providerImpl, - registryForPartitionKernels); graph.FinalizeFuseSubGraph(indexedSubGraph, fusedNode); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h index f2533bb37b..030cffc2a8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h @@ -56,23 +56,29 @@ namespace DmlGraphFusionHelper _Inout_ std::vector& dmlOutputEdges, _Inout_ std::vector& dmlIntermediateEdges); - void CreateIDmlCompiledOperatorAndRegisterKernel( - onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, - const onnxruntime::Node& fusedNode, - const std::unordered_map& partitionNodePropsMap, - const std::unordered_map>& isInitializerTransferable, - const ExecutionProviderImpl* providerImpl, - onnxruntime::KernelRegistry* registryForPartitionKernels); - - void FusePartitionAndRegisterKernel( + onnxruntime::IndexedSubGraph CreateIndexedSubGraph( GraphPartition* partition, uint32_t partitionIndex, - onnxruntime::Graph& graph, - std::unordered_map& graphNodePropertyMap, - onnxruntime::KernelRegistry* registryForPartitionKernels, - const std::string& partitionKernelPrefix, - const std::unordered_map>& isInitializerTransferable, + const std::string& partitionKernelPrefix); + + std::unordered_map CreatePartitionNodePropsMap( + const onnxruntime::Graph& graph, + const onnxruntime::IndexedSubGraph& indexedSubGraph, + std::unordered_map&& graphNodePropertyMap); + + Microsoft::WRL::ComPtr TryCreateCompiledOperator( + const GraphDescBuilder::GraphDesc& graphDesc, + const onnxruntime::IndexedSubGraph& indexedSubGraph, const ExecutionProviderImpl* providerImpl); + + void FusePartitionAndRegisterKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const std::unordered_map>& initializerNameToInitializerMap, + const ExecutionProviderImpl* providerImpl, + const onnxruntime::IndexedSubGraph& indexedSubGraph, + std::vector&& isInputsUploadedByDmlEP, + const GraphDescBuilder::GraphDesc& graphDesc, + Microsoft::WRL::ComPtr compiledExecutionPlanOperator); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index c8c5fddc78..a9d19a022d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -23,7 +23,16 @@ namespace Dml m_providerImpl(static_cast(provider)->GetImpl()) { } - + + struct CompiledPartitionInfo + { + Microsoft::WRL::ComPtr compiledOperator; + onnxruntime::IndexedSubGraph indexedSubGraph; + std::vector isInputsUploadedByDmlEP; + GraphDescBuilder::GraphDesc graphDesc; + std::unordered_map> isInitializerTransferable; + }; + onnxruntime::common::Status DmlGraphFusionTransformer::ApplyImpl( onnxruntime::Graph& graph, bool& modified, @@ -37,96 +46,173 @@ namespace Dml gsl::make_span(®istry, 1), kernel_type_str_resolver}; - // Initializers needed by any graph partition - std::unordered_set requiredInitializerMap; - std::unordered_map graphNodePropertyMap; - onnxruntime::GraphViewer graphViewer(graph); - std::vector> partitions = BuildPartitions( - graphViewer, - *m_providerImpl->GetInternalRegistrationInfoMap(), - kernel_lookup, - m_providerImpl->GetSupportedDeviceDataTypeMask(), - graphNodePropertyMap, - requiredInitializerMap); + std::vector> compiledPartitionInfos; + std::vector additionalSplittingNodes; - // Create a map between each initialized tensor and the partition(s) it is part of. - auto initializerPartitionMap = DmlGraphFusionHelper::GetInitializerToPartitionMap(graphViewer, partitions); - - for (uint32_t partitionIndex = 0; partitionIndex < partitions.size(); ++partitionIndex) + do { - auto& partition = partitions[partitionIndex]; + // Initializers needed by any graph partition + std::unordered_set requiredInitializerMap; + std::unordered_map graphNodePropertyMap; + onnxruntime::GraphViewer graphViewer(graph); + std::vector> partitions = BuildPartitions( + graphViewer, + *m_providerImpl->GetInternalRegistrationInfoMap(), + kernel_lookup, + m_providerImpl->GetSupportedDeviceDataTypeMask(), + graphNodePropertyMap, + requiredInitializerMap, + additionalSplittingNodes); - if (partition->GetRootMergedPartition() != partition.get() || - !partition->IsDmlPartition()) + // Reset the splitting nodes for the current iteration + additionalSplittingNodes.clear(); + + // Reset the compiled operators for the current iteration + compiledPartitionInfos.clear(); + compiledPartitionInfos.resize(partitions.size()); + + // Create a map between each initialized tensor and the partition(s) it is part of. + auto initializerPartitionMap = DmlGraphFusionHelper::GetInitializerToPartitionMap(graphViewer, partitions); + + for (uint32_t partitionIndex = 0; partitionIndex < partitions.size(); ++partitionIndex) { - continue; - } + auto& partition = partitions[partitionIndex]; - // This map will tell which initializer can be removed from onnxruntime::Graph (and from it's field - // onnx::GraphProto) while we upload the initializer to GPU. - // Why we want to remove the initializer from ORT? - // 1. To keep the peak memory usage as low as possible. That's why we are doing incremental upload to GPU. - // What is initializer? - // An initializer is a input tensor to an operator or the graph itself, which is contant and will never change. - // Why are we uploading the initialzer now? - // This prevents OnnxRuntime from allocating GPU resources and uploading those initializers, - // so the partiton's kernel can do so. In the process, it will pre-process weights while consuming a CPU - // backed resource, avoiding an extra set of GPU resources in memory. - std::unordered_map> isInitializerTransferable; - - - if (partition->IsDmlGraphPartition()) - { - // populate transferredInitializerMap - for (const auto& input : partition->GetInputs()) + if (partition->GetRootMergedPartition() != partition.get() || + !partition->IsDmlPartition()) { - const onnx::TensorProto* tensor = nullptr; - if (graph.GetInitializedTensor(input, tensor)) - { - // It's only safe to transfer tensors which are used by this partition alone. - auto iter = initializerPartitionMap.find(tensor); - assert(iter != initializerPartitionMap.end()); - if (iter->second.size() > 1) - { - // By including non-transferrable tensors in isInitializerTransferable, it causes DML to upload and preprocess them - // to duplicate locations rather than treating them as being non-constant, which is helpful for optimization. - // The size threshold for this should be no smaller than that used to combine initializers in the constant - // sharing transform to prevent that transform from hurting performance. - // If the kernel relies on this input to be initialized, it should also be small enough to copy cheaply. - const uint64_t maximumElementsForDuplicationTensor = 64; - static_assert(maximumElementsForDuplicationTensor >= onnxruntime::ConstantSharing::TENSOR_ELEM_COUNT_THRESHOLD); - - uint64_t totalElementCount = 1; - for (int i = 0; i < tensor->dims().size(); ++i) - { - totalElementCount *= tensor->dims()[i]; - } - - if (totalElementCount <= maximumElementsForDuplicationTensor || - requiredInitializerMap.find(input) != requiredInitializerMap.end()) - { - isInitializerTransferable[input] = {tensor, false}; - } - - continue; - } - isInitializerTransferable[input] = {tensor, true}; - } + continue; } - std::string partitionKernelPrefix = std::to_string(m_providerImpl->GetPartitionKernelPrefixVal()) + "_"; - m_providerImpl->IncreasePartitionKernelPrefixVal(); + // This map will tell which initializer can be removed from onnxruntime::Graph (and from it's field + // onnx::GraphProto) while we upload the initializer to GPU. + // Why we want to remove the initializer from ORT? + // 1. To keep the peak memory usage as low as possible. That's why we are doing incremental upload to GPU. + // What is initializer? + // An initializer is a input tensor to an operator or the graph itself, which is contant and will never change. + // Why are we uploading the initialzer now? + // This prevents OnnxRuntime from allocating GPU resources and uploading those initializers, + // so the partiton's kernel can do so. In the process, it will pre-process weights while consuming a CPU + // backed resource, avoiding an extra set of GPU resources in memory. + std::unordered_map> isInitializerTransferable; + if (partition->IsDmlGraphPartition()) + { + // populate isInitializerTransferable + for (const auto& input : partition->GetInputs()) + { + const onnx::TensorProto* tensor = nullptr; + if (graph.GetInitializedTensor(input, tensor)) + { + // It's only safe to transfer tensors which are used by this partition alone. + auto iter = initializerPartitionMap.find(tensor); + assert(iter != initializerPartitionMap.end()); + if (iter->second.size() > 1) + { + // By including non-transferrable tensors in isInitializerTransferable, it causes DML to upload and preprocess them + // to duplicate locations rather than treating them as being non-constant, which is helpful for optimization. + // The size threshold for this should be no smaller than that used to combine initializers in the constant + // sharing transform to prevent that transform from hurting performance. + // If the kernel relies on this input to be initialized, it should also be small enough to copy cheaply. + constexpr uint64_t maximumElementsForDuplicationTensor = 64; + static_assert(maximumElementsForDuplicationTensor >= onnxruntime::ConstantSharing::TENSOR_ELEM_COUNT_THRESHOLD); + + uint64_t totalElementCount = 1; + for (int i = 0; i < tensor->dims().size(); ++i) + { + totalElementCount *= tensor->dims()[i]; + } + + if (totalElementCount <= maximumElementsForDuplicationTensor || + requiredInitializerMap.find(input) != requiredInitializerMap.end()) + { + isInitializerTransferable[input] = {tensor, false}; + } + + continue; + } + isInitializerTransferable[input] = {tensor, true}; + } + } + + std::string partitionKernelPrefix = std::to_string(m_providerImpl->GetPartitionKernelPrefixVal()) + "_"; + m_providerImpl->IncreasePartitionKernelPrefixVal(); + + auto indexedSubGraph = DmlGraphFusionHelper::CreateIndexedSubGraph(partition.get(), partitionIndex, partitionKernelPrefix); + + // Create a map of which inputs are uploaded by the DML EP + const uint32_t fusedNodeInputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->inputs.size()); + std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); + for (uint32_t index = 0; index < fusedNodeInputCount; ++index) + { + auto iter = isInitializerTransferable.find(indexedSubGraph.GetMetaDef()->inputs[index]); + isInputsUploadedByDmlEP[index] = iter != isInitializerTransferable.end() ? true : false; + } + + auto partitionNodePropsMap = DmlGraphFusionHelper::CreatePartitionNodePropsMap( + graph, + indexedSubGraph, + std::move(graphNodePropertyMap)); + + // Convert partitionONNXGraph into DML EP GraphDesc + ComPtr device; + ORT_THROW_IF_FAILED(m_providerImpl->GetDmlDevice(device.GetAddressOf())); + GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( + isInputsUploadedByDmlEP.data(), + isInputsUploadedByDmlEP.size(), + isInitializerTransferable, + graph, + indexedSubGraph, + partitionNodePropsMap, + device.Get(), + m_providerImpl); + + // Compile the operator + auto compiledPartition = DmlGraphFusionHelper::TryCreateCompiledOperator( + graphDesc, + indexedSubGraph, + m_providerImpl); + + if (!compiledPartition) + { + // Fail early if even a single operator is too big to compile. This is highly unlikely. + ORT_THROW_HR_IF(E_INVALIDARG, indexedSubGraph.nodes.size() < 2); + + // Tell the partitioner to split the current partition in half, in the middle + additionalSplittingNodes.push_back(indexedSubGraph.nodes[indexedSubGraph.nodes.size() / 2]); + + // Exit early since we need to repartition + break; + } + else + { + auto compiledPartitionInfo = std::make_shared(); + compiledPartitionInfo->compiledOperator = std::move(compiledPartition); + compiledPartitionInfo->indexedSubGraph = std::move(indexedSubGraph); + compiledPartitionInfo->isInputsUploadedByDmlEP = std::move(isInputsUploadedByDmlEP); + compiledPartitionInfo->graphDesc = std::move(graphDesc); + compiledPartitionInfo->isInitializerTransferable = std::move(isInitializerTransferable); + compiledPartitionInfos[partitionIndex] = std::move(compiledPartitionInfo); + } + } + } + } + while (!additionalSplittingNodes.empty()); + + for (auto&& compiledPartitionInfo : compiledPartitionInfos) + { + // Null compiled operators were not DML partitions + if (compiledPartitionInfo) + { DmlGraphFusionHelper::FusePartitionAndRegisterKernel( - partition.get(), - partitionIndex, - graph, - graphNodePropertyMap, + graph, m_providerImpl->GetKernelRegistry().get(), - partitionKernelPrefix, - isInitializerTransferable, - m_providerImpl - ); + compiledPartitionInfo->isInitializerTransferable, + m_providerImpl, + compiledPartitionInfo->indexedSubGraph, + std::move(compiledPartitionInfo->isInputsUploadedByDmlEP), + compiledPartitionInfo->graphDesc, + compiledPartitionInfo->compiledOperator); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp index b7f24d49d1..67c3f110e5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp @@ -24,13 +24,13 @@ namespace Dml std::vector>& nonOwnedGraphInputsFromInitializers, std::vector>& initializeResourceRefs, std::vector initInputBindings, - std::vector& isInputsUploadedByDmlEP, - std::vector& inputsUsed) : + std::vector&& isInputsUploadedByDmlEP, + std::vector&& inputsUsed) : OpKernel(kernelInfo), m_compiledExecutionPlanOperator(compiledExecutionPlanOperator), - m_inputsUsed(inputsUsed), + m_inputsUsed(std::move(inputsUsed)), m_outputShapes(outputShapes), - m_isInputsUploadedByDmlEP(isInputsUploadedByDmlEP), + m_isInputsUploadedByDmlEP(std::move(isInputsUploadedByDmlEP)), m_nonOwnedGraphInputsFromInitializers(nonOwnedGraphInputsFromInitializers) { // Get the execution provider interfaces @@ -443,8 +443,8 @@ namespace Dml std::vector>& nonOwnedGraphInputsFromInitializers, std::vector>& initializeResourceRefs, std::vector initInputBindings, - std::vector& isInputsUploadedByDmlEP, - std::vector& inputsUsed + std::vector&& isInputsUploadedByDmlEP, + std::vector&& inputsUsed ) { return new FusedGraphKernel( @@ -455,8 +455,8 @@ namespace Dml nonOwnedGraphInputsFromInitializers, initializeResourceRefs, initInputBindings, - isInputsUploadedByDmlEP, - inputsUsed + std::move(isInputsUploadedByDmlEP), + std::move(inputsUsed) ); } } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.h index 00a858d54e..ced6160a99 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.h @@ -15,7 +15,7 @@ namespace Dml std::vector>& nonOwnedGraphInputsFromInitializers, std::vector>& initializeResourceRefs, std::vector initInputBindings, - std::vector& isInputsUploadedByDmlEP, - std::vector& inputsUsed + std::vector&& isInputsUploadedByDmlEP, + std::vector&& inputsUsed ); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index b9c5f8849a..2c8d4e4459 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -209,14 +209,15 @@ namespace Dml // Creates a partition for a node which is not a DML graph node, and finalizes partitions // which are inputs of the new partition. - std::unique_ptr CreateNonGraphNodePartitionAndFinalizeInputs( + std::unique_ptr CreatePartitionAndFinalizeInputs( const onnxruntime::Node& node, bool isDmlNode, + bool isDmlGraphPartitionNode, std::unordered_map& nodeNameToPartitionMap ) { std::unique_ptr partition = std::make_unique(); - partition->SetIsDmlGraphPartition(false); + partition->SetIsDmlGraphPartition(isDmlGraphPartitionNode); partition->SetIsDmlPartition(isDmlNode); partition->AddNodeIndex(node.Index()); @@ -383,7 +384,7 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, - std::function onNodeUnsupportedInGraph) + gsl::span additionalSplittingNodes) { // Nodes are uniquely identified by the name of their first output argument std::vector> partitions; @@ -420,6 +421,8 @@ namespace Dml // Check whether this graph is a subgraph, or contains any node with a subgraph. bool modelUsesSubgraph = ModelUsesSubgraph(graph); + uint32_t splittingNodeIndex = 0; + // Build up partitions while traversing the graph. for (size_t nodeIndex : toplogicalOrder) { @@ -456,12 +459,14 @@ namespace Dml // anyhow due to CPU/GPU copies. if (modelUsesSubgraph || !isDmlGraphNode) { - if (onNodeUnsupportedInGraph) - { - onNodeUnsupportedInGraph(node); - } + partitions.push_back(CreatePartitionAndFinalizeInputs(node, isDmlNode, false, nodeNameToPartitionMap)); + continue; + } - partitions.push_back(CreateNonGraphNodePartitionAndFinalizeInputs(node, isDmlNode, nodeNameToPartitionMap)); + if (splittingNodeIndex < additionalSplittingNodes.size() && additionalSplittingNodes[splittingNodeIndex] == nodeIndex) + { + partitions.push_back(CreatePartitionAndFinalizeInputs(node, isDmlNode, isDmlGraphNode, nodeNameToPartitionMap)); + ++splittingNodeIndex; continue; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h index c5b15c7b1c..990ba00fc4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h @@ -48,5 +48,5 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, - std::function onNodeUnsupportedInGraph = nullptr); + gsl::span additionalSplittingNodes); } // namespace Dml