diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 17aa197396..51b93efb3a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -106,7 +106,7 @@ namespace DmlGraphFusionHelper void ProcessInputData( const ExecutionProviderImpl* providerImpl, const std::vector& isInputsUploadedByDmlEP, - std::vector& inputEdges, + const std::vector& inputEdges, const gsl::span subGraphInputArgNames, const std::unordered_map>& initializerNameToInitializerMap, onnxruntime::Graph& graph, @@ -325,37 +325,60 @@ namespace DmlGraphFusionHelper dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data(); } - void CreateIDmlCompiledOperatorAndRegisterKernel( - onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, - const onnxruntime::Node& fusedNode, - const std::unordered_map& partitionNodePropsMap, - const std::unordered_map>& initializerNameToInitializerMap, - const ExecutionProviderImpl* providerImpl, - onnxruntime::KernelRegistry* registryForPartitionKernels) + onnxruntime::IndexedSubGraph CreateIndexedSubGraph( + GraphPartition* partition, + uint32_t partitionIndex, + const std::string& partitionKernelPrefix) { - // convert partitionONNXGraph into DML EP GraphDesc - const uint32_t fusedNodeInputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->inputs.size()); - const uint32_t fusedNodeOutputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->outputs.size()); + assert(partition->IsDmlGraphPartition()); - std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); - for (uint32_t index = 0; index < fusedNodeInputCount; ++index) + onnxruntime::IndexedSubGraph indexedSubGraph; + // Create a definition for the node. The name must be unique. + auto def = std::make_unique(); + def->name = DmlGraphFusionTransformer::DML_GRAPH_FUSION_NODE_NAME_PREFIX + partitionKernelPrefix + std::to_string(partitionIndex); + def->domain = DmlGraphFusionTransformer::DML_GRAPH_FUSION_NODE_DOMAIN; + def->since_version = 1; + def->inputs.insert(def->inputs.begin(), partition->GetInputs().begin(), partition->GetInputs().end()); + def->outputs.insert(def->outputs.begin(), partition->GetOutputs().begin(), partition->GetOutputs().end()); + + indexedSubGraph.SetMetaDef(std::move(def)); + indexedSubGraph.nodes = std::move(partition->GetNodeIndices()); + + return indexedSubGraph; + } + + std::unordered_map CreatePartitionNodePropsMap( + const onnxruntime::Graph& graph, + const onnxruntime::IndexedSubGraph& indexedSubGraph, + std::unordered_map&& graphNodePropertyMap) + { + // Populate properties which will be passed to OpKernel for this graph via the function below + std::unordered_map partitionNodePropsMap; + for (auto nodeIndex : indexedSubGraph.nodes) { - auto iter = initializerNameToInitializerMap.find(indexedSubGraph.GetMetaDef()->inputs[index]); - isInputsUploadedByDmlEP[index] = iter != initializerNameToInitializerMap.end() ? true : false; + const onnxruntime::Node* node = graph.GetNode(nodeIndex); + +#ifdef PRINT_PARTITON_INFO + printf("Partition %u\t%s\n", partitionIndex, GraphDescBuilder::GetUniqueNodeName(*node).c_str()); +#endif + partitionNodePropsMap.insert(std::make_pair( + GraphDescBuilder::GetUniqueNodeName(*node), std::move(graphNodePropertyMap[node]))); } - ComPtr device; - ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); - GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( - isInputsUploadedByDmlEP.data(), - isInputsUploadedByDmlEP.size(), - initializerNameToInitializerMap, - graph, - indexedSubGraph, - partitionNodePropsMap, - device.Get(), - providerImpl); +#ifdef PRINT_PARTITON_INFO + printf("\n"); +#endif + + return partitionNodePropsMap; + } + + Microsoft::WRL::ComPtr TryCreateCompiledOperator( + const GraphDescBuilder::GraphDesc& graphDesc, + const onnxruntime::IndexedSubGraph& indexedSubGraph, + const ExecutionProviderImpl* providerImpl) + { + const uint32_t fusedNodeInputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->inputs.size()); + const uint32_t fusedNodeOutputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->outputs.size()); // convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator DML_GRAPH_DESC dmlGraphDesc = {}; @@ -387,14 +410,42 @@ namespace DmlGraphFusionHelper executionFlags |= DML_EXECUTION_FLAG_DISABLE_META_COMMANDS; } + ComPtr device; + ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); + ComPtr device1; ORT_THROW_IF_FAILED(device.As(&device1)); + ComPtr compiledExecutionPlanOperator; ORT_THROW_IF_FAILED(device1->CompileGraph( &dmlGraphDesc, executionFlags, IID_PPV_ARGS(&compiledExecutionPlanOperator))); + // UINT32_MAX is currently the maximum number of bytes allowed by D3D12 for the offset of a view over a resource + if (compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize > UINT32_MAX) + { + return nullptr; + } + + return compiledExecutionPlanOperator; + } + + void FusePartitionAndRegisterKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const std::unordered_map>& initializerNameToInitializerMap, + const ExecutionProviderImpl* providerImpl, + const onnxruntime::IndexedSubGraph& indexedSubGraph, + std::vector&& isInputsUploadedByDmlEP, + const GraphDescBuilder::GraphDesc& graphDesc, + Microsoft::WRL::ComPtr compiledExecutionPlanOperator) + { + auto& fusedNode = graph.BeginFuseSubGraph(indexedSubGraph, indexedSubGraph.GetMetaDef()->name); + fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); + + const uint32_t fusedNodeInputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->inputs.size()); + // Populate input bindings for operator initialization std::vector> initializeResourceRefs; // For lifetime control std::vector initInputBindings(fusedNodeInputCount); @@ -424,8 +475,8 @@ namespace DmlGraphFusionHelper nonOwnedGraphInputsFromInitializers, initializeResourceRefs, initInputBindings, - isInputsUploadedByDmlEP, - inputsUsed] + isInputsUploadedByDmlEP = std::move(isInputsUploadedByDmlEP), + inputsUsed = std::move(inputsUsed)] (onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr& out) mutable ->onnxruntime::Status { out.reset(CreateFusedGraphKernel(info, @@ -435,8 +486,8 @@ namespace DmlGraphFusionHelper nonOwnedGraphInputsFromInitializers, initializeResourceRefs, initInputBindings, - isInputsUploadedByDmlEP, - inputsUsed)); + std::move(isInputsUploadedByDmlEP), + std::move(inputsUsed))); return Status::OK(); }; @@ -447,58 +498,7 @@ namespace DmlGraphFusionHelper .SinceVersion(indexedSubGraph.GetMetaDef()->since_version) .Provider(onnxruntime::kDmlExecutionProvider); ORT_THROW_IF_ERROR(registryForPartitionKernels->Register(builder, fused_kernel_func)); - } - void FusePartitionAndRegisterKernel( - GraphPartition* partition, - uint32_t partitionIndex, - onnxruntime::Graph& graph, - std::unordered_map& graphNodePropertyMap, - onnxruntime::KernelRegistry* registryForPartitionKernels, - const std::string& partitionKernelPrefix, - const std::unordered_map>& initializerNameToInitializerMap, - const ExecutionProviderImpl* providerImpl) - { - assert(partition->IsDmlGraphPartition()); - - onnxruntime::IndexedSubGraph indexedSubGraph; - // Create a definition for the node. The name must be unique. - auto def = std::make_unique(); - def->name = DmlGraphFusionTransformer::DML_GRAPH_FUSION_NODE_NAME_PREFIX + partitionKernelPrefix + std::to_string(partitionIndex); - def->domain = DmlGraphFusionTransformer::DML_GRAPH_FUSION_NODE_DOMAIN; - def->since_version = 1; - def->inputs.insert(def->inputs.begin(), partition->GetInputs().begin(), partition->GetInputs().end()); - def->outputs.insert(def->outputs.begin(), partition->GetOutputs().begin(), partition->GetOutputs().end()); - - indexedSubGraph.SetMetaDef(std::move(def)); - indexedSubGraph.nodes = std::move(partition->GetNodeIndices()); - auto& fusedNode = graph.BeginFuseSubGraph(indexedSubGraph, indexedSubGraph.GetMetaDef()->name); - fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); - - // Populate properties which will be passed to OpKernel for this graph via the function below - std::unordered_map partitionNodePropsMap; - for (auto nodeIndex : indexedSubGraph.nodes) - { - const onnxruntime::Node* node = graph.GetNode(nodeIndex); - -#ifdef PRINT_PARTITON_INFO - printf("Partition %u\t%s\n", partitionIndex, GraphDescBuilder::GetUniqueNodeName(*node).c_str()); -#endif - partitionNodePropsMap.insert(std::make_pair( - GraphDescBuilder::GetUniqueNodeName(*node), std::move(graphNodePropertyMap[node]))); - } - -#ifdef PRINT_PARTITON_INFO - printf("\n"); -#endif - CreateIDmlCompiledOperatorAndRegisterKernel( - graph, - indexedSubGraph, - fusedNode, - partitionNodePropsMap, - initializerNameToInitializerMap, - providerImpl, - registryForPartitionKernels); graph.FinalizeFuseSubGraph(indexedSubGraph, fusedNode); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h index f2533bb37b..030cffc2a8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h @@ -56,23 +56,29 @@ namespace DmlGraphFusionHelper _Inout_ std::vector& dmlOutputEdges, _Inout_ std::vector& dmlIntermediateEdges); - void CreateIDmlCompiledOperatorAndRegisterKernel( - onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, - const onnxruntime::Node& fusedNode, - const std::unordered_map& partitionNodePropsMap, - const std::unordered_map>& isInitializerTransferable, - const ExecutionProviderImpl* providerImpl, - onnxruntime::KernelRegistry* registryForPartitionKernels); - - void FusePartitionAndRegisterKernel( + onnxruntime::IndexedSubGraph CreateIndexedSubGraph( GraphPartition* partition, uint32_t partitionIndex, - onnxruntime::Graph& graph, - std::unordered_map& graphNodePropertyMap, - onnxruntime::KernelRegistry* registryForPartitionKernels, - const std::string& partitionKernelPrefix, - const std::unordered_map>& isInitializerTransferable, + const std::string& partitionKernelPrefix); + + std::unordered_map CreatePartitionNodePropsMap( + const onnxruntime::Graph& graph, + const onnxruntime::IndexedSubGraph& indexedSubGraph, + std::unordered_map&& graphNodePropertyMap); + + Microsoft::WRL::ComPtr TryCreateCompiledOperator( + const GraphDescBuilder::GraphDesc& graphDesc, + const onnxruntime::IndexedSubGraph& indexedSubGraph, const ExecutionProviderImpl* providerImpl); + + void FusePartitionAndRegisterKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const std::unordered_map>& initializerNameToInitializerMap, + const ExecutionProviderImpl* providerImpl, + const onnxruntime::IndexedSubGraph& indexedSubGraph, + std::vector&& isInputsUploadedByDmlEP, + const GraphDescBuilder::GraphDesc& graphDesc, + Microsoft::WRL::ComPtr compiledExecutionPlanOperator); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index c8c5fddc78..a9d19a022d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -23,7 +23,16 @@ namespace Dml m_providerImpl(static_cast(provider)->GetImpl()) { } - + + struct CompiledPartitionInfo + { + Microsoft::WRL::ComPtr compiledOperator; + onnxruntime::IndexedSubGraph indexedSubGraph; + std::vector isInputsUploadedByDmlEP; + GraphDescBuilder::GraphDesc graphDesc; + std::unordered_map> isInitializerTransferable; + }; + onnxruntime::common::Status DmlGraphFusionTransformer::ApplyImpl( onnxruntime::Graph& graph, bool& modified, @@ -37,96 +46,173 @@ namespace Dml gsl::make_span(®istry, 1), kernel_type_str_resolver}; - // Initializers needed by any graph partition - std::unordered_set requiredInitializerMap; - std::unordered_map graphNodePropertyMap; - onnxruntime::GraphViewer graphViewer(graph); - std::vector> partitions = BuildPartitions( - graphViewer, - *m_providerImpl->GetInternalRegistrationInfoMap(), - kernel_lookup, - m_providerImpl->GetSupportedDeviceDataTypeMask(), - graphNodePropertyMap, - requiredInitializerMap); + std::vector> compiledPartitionInfos; + std::vector additionalSplittingNodes; - // Create a map between each initialized tensor and the partition(s) it is part of. - auto initializerPartitionMap = DmlGraphFusionHelper::GetInitializerToPartitionMap(graphViewer, partitions); - - for (uint32_t partitionIndex = 0; partitionIndex < partitions.size(); ++partitionIndex) + do { - auto& partition = partitions[partitionIndex]; + // Initializers needed by any graph partition + std::unordered_set requiredInitializerMap; + std::unordered_map graphNodePropertyMap; + onnxruntime::GraphViewer graphViewer(graph); + std::vector> partitions = BuildPartitions( + graphViewer, + *m_providerImpl->GetInternalRegistrationInfoMap(), + kernel_lookup, + m_providerImpl->GetSupportedDeviceDataTypeMask(), + graphNodePropertyMap, + requiredInitializerMap, + additionalSplittingNodes); - if (partition->GetRootMergedPartition() != partition.get() || - !partition->IsDmlPartition()) + // Reset the splitting nodes for the current iteration + additionalSplittingNodes.clear(); + + // Reset the compiled operators for the current iteration + compiledPartitionInfos.clear(); + compiledPartitionInfos.resize(partitions.size()); + + // Create a map between each initialized tensor and the partition(s) it is part of. + auto initializerPartitionMap = DmlGraphFusionHelper::GetInitializerToPartitionMap(graphViewer, partitions); + + for (uint32_t partitionIndex = 0; partitionIndex < partitions.size(); ++partitionIndex) { - continue; - } + auto& partition = partitions[partitionIndex]; - // This map will tell which initializer can be removed from onnxruntime::Graph (and from it's field - // onnx::GraphProto) while we upload the initializer to GPU. - // Why we want to remove the initializer from ORT? - // 1. To keep the peak memory usage as low as possible. That's why we are doing incremental upload to GPU. - // What is initializer? - // An initializer is a input tensor to an operator or the graph itself, which is contant and will never change. - // Why are we uploading the initialzer now? - // This prevents OnnxRuntime from allocating GPU resources and uploading those initializers, - // so the partiton's kernel can do so. In the process, it will pre-process weights while consuming a CPU - // backed resource, avoiding an extra set of GPU resources in memory. - std::unordered_map> isInitializerTransferable; - - - if (partition->IsDmlGraphPartition()) - { - // populate transferredInitializerMap - for (const auto& input : partition->GetInputs()) + if (partition->GetRootMergedPartition() != partition.get() || + !partition->IsDmlPartition()) { - const onnx::TensorProto* tensor = nullptr; - if (graph.GetInitializedTensor(input, tensor)) - { - // It's only safe to transfer tensors which are used by this partition alone. - auto iter = initializerPartitionMap.find(tensor); - assert(iter != initializerPartitionMap.end()); - if (iter->second.size() > 1) - { - // By including non-transferrable tensors in isInitializerTransferable, it causes DML to upload and preprocess them - // to duplicate locations rather than treating them as being non-constant, which is helpful for optimization. - // The size threshold for this should be no smaller than that used to combine initializers in the constant - // sharing transform to prevent that transform from hurting performance. - // If the kernel relies on this input to be initialized, it should also be small enough to copy cheaply. - const uint64_t maximumElementsForDuplicationTensor = 64; - static_assert(maximumElementsForDuplicationTensor >= onnxruntime::ConstantSharing::TENSOR_ELEM_COUNT_THRESHOLD); - - uint64_t totalElementCount = 1; - for (int i = 0; i < tensor->dims().size(); ++i) - { - totalElementCount *= tensor->dims()[i]; - } - - if (totalElementCount <= maximumElementsForDuplicationTensor || - requiredInitializerMap.find(input) != requiredInitializerMap.end()) - { - isInitializerTransferable[input] = {tensor, false}; - } - - continue; - } - isInitializerTransferable[input] = {tensor, true}; - } + continue; } - std::string partitionKernelPrefix = std::to_string(m_providerImpl->GetPartitionKernelPrefixVal()) + "_"; - m_providerImpl->IncreasePartitionKernelPrefixVal(); + // This map will tell which initializer can be removed from onnxruntime::Graph (and from it's field + // onnx::GraphProto) while we upload the initializer to GPU. + // Why we want to remove the initializer from ORT? + // 1. To keep the peak memory usage as low as possible. That's why we are doing incremental upload to GPU. + // What is initializer? + // An initializer is a input tensor to an operator or the graph itself, which is contant and will never change. + // Why are we uploading the initialzer now? + // This prevents OnnxRuntime from allocating GPU resources and uploading those initializers, + // so the partiton's kernel can do so. In the process, it will pre-process weights while consuming a CPU + // backed resource, avoiding an extra set of GPU resources in memory. + std::unordered_map> isInitializerTransferable; + if (partition->IsDmlGraphPartition()) + { + // populate isInitializerTransferable + for (const auto& input : partition->GetInputs()) + { + const onnx::TensorProto* tensor = nullptr; + if (graph.GetInitializedTensor(input, tensor)) + { + // It's only safe to transfer tensors which are used by this partition alone. + auto iter = initializerPartitionMap.find(tensor); + assert(iter != initializerPartitionMap.end()); + if (iter->second.size() > 1) + { + // By including non-transferrable tensors in isInitializerTransferable, it causes DML to upload and preprocess them + // to duplicate locations rather than treating them as being non-constant, which is helpful for optimization. + // The size threshold for this should be no smaller than that used to combine initializers in the constant + // sharing transform to prevent that transform from hurting performance. + // If the kernel relies on this input to be initialized, it should also be small enough to copy cheaply. + constexpr uint64_t maximumElementsForDuplicationTensor = 64; + static_assert(maximumElementsForDuplicationTensor >= onnxruntime::ConstantSharing::TENSOR_ELEM_COUNT_THRESHOLD); + + uint64_t totalElementCount = 1; + for (int i = 0; i < tensor->dims().size(); ++i) + { + totalElementCount *= tensor->dims()[i]; + } + + if (totalElementCount <= maximumElementsForDuplicationTensor || + requiredInitializerMap.find(input) != requiredInitializerMap.end()) + { + isInitializerTransferable[input] = {tensor, false}; + } + + continue; + } + isInitializerTransferable[input] = {tensor, true}; + } + } + + std::string partitionKernelPrefix = std::to_string(m_providerImpl->GetPartitionKernelPrefixVal()) + "_"; + m_providerImpl->IncreasePartitionKernelPrefixVal(); + + auto indexedSubGraph = DmlGraphFusionHelper::CreateIndexedSubGraph(partition.get(), partitionIndex, partitionKernelPrefix); + + // Create a map of which inputs are uploaded by the DML EP + const uint32_t fusedNodeInputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->inputs.size()); + std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); + for (uint32_t index = 0; index < fusedNodeInputCount; ++index) + { + auto iter = isInitializerTransferable.find(indexedSubGraph.GetMetaDef()->inputs[index]); + isInputsUploadedByDmlEP[index] = iter != isInitializerTransferable.end() ? true : false; + } + + auto partitionNodePropsMap = DmlGraphFusionHelper::CreatePartitionNodePropsMap( + graph, + indexedSubGraph, + std::move(graphNodePropertyMap)); + + // Convert partitionONNXGraph into DML EP GraphDesc + ComPtr device; + ORT_THROW_IF_FAILED(m_providerImpl->GetDmlDevice(device.GetAddressOf())); + GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( + isInputsUploadedByDmlEP.data(), + isInputsUploadedByDmlEP.size(), + isInitializerTransferable, + graph, + indexedSubGraph, + partitionNodePropsMap, + device.Get(), + m_providerImpl); + + // Compile the operator + auto compiledPartition = DmlGraphFusionHelper::TryCreateCompiledOperator( + graphDesc, + indexedSubGraph, + m_providerImpl); + + if (!compiledPartition) + { + // Fail early if even a single operator is too big to compile. This is highly unlikely. + ORT_THROW_HR_IF(E_INVALIDARG, indexedSubGraph.nodes.size() < 2); + + // Tell the partitioner to split the current partition in half, in the middle + additionalSplittingNodes.push_back(indexedSubGraph.nodes[indexedSubGraph.nodes.size() / 2]); + + // Exit early since we need to repartition + break; + } + else + { + auto compiledPartitionInfo = std::make_shared(); + compiledPartitionInfo->compiledOperator = std::move(compiledPartition); + compiledPartitionInfo->indexedSubGraph = std::move(indexedSubGraph); + compiledPartitionInfo->isInputsUploadedByDmlEP = std::move(isInputsUploadedByDmlEP); + compiledPartitionInfo->graphDesc = std::move(graphDesc); + compiledPartitionInfo->isInitializerTransferable = std::move(isInitializerTransferable); + compiledPartitionInfos[partitionIndex] = std::move(compiledPartitionInfo); + } + } + } + } + while (!additionalSplittingNodes.empty()); + + for (auto&& compiledPartitionInfo : compiledPartitionInfos) + { + // Null compiled operators were not DML partitions + if (compiledPartitionInfo) + { DmlGraphFusionHelper::FusePartitionAndRegisterKernel( - partition.get(), - partitionIndex, - graph, - graphNodePropertyMap, + graph, m_providerImpl->GetKernelRegistry().get(), - partitionKernelPrefix, - isInitializerTransferable, - m_providerImpl - ); + compiledPartitionInfo->isInitializerTransferable, + m_providerImpl, + compiledPartitionInfo->indexedSubGraph, + std::move(compiledPartitionInfo->isInputsUploadedByDmlEP), + compiledPartitionInfo->graphDesc, + compiledPartitionInfo->compiledOperator); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp index b7f24d49d1..67c3f110e5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp @@ -24,13 +24,13 @@ namespace Dml std::vector>& nonOwnedGraphInputsFromInitializers, std::vector>& initializeResourceRefs, std::vector initInputBindings, - std::vector& isInputsUploadedByDmlEP, - std::vector& inputsUsed) : + std::vector&& isInputsUploadedByDmlEP, + std::vector&& inputsUsed) : OpKernel(kernelInfo), m_compiledExecutionPlanOperator(compiledExecutionPlanOperator), - m_inputsUsed(inputsUsed), + m_inputsUsed(std::move(inputsUsed)), m_outputShapes(outputShapes), - m_isInputsUploadedByDmlEP(isInputsUploadedByDmlEP), + m_isInputsUploadedByDmlEP(std::move(isInputsUploadedByDmlEP)), m_nonOwnedGraphInputsFromInitializers(nonOwnedGraphInputsFromInitializers) { // Get the execution provider interfaces @@ -443,8 +443,8 @@ namespace Dml std::vector>& nonOwnedGraphInputsFromInitializers, std::vector>& initializeResourceRefs, std::vector initInputBindings, - std::vector& isInputsUploadedByDmlEP, - std::vector& inputsUsed + std::vector&& isInputsUploadedByDmlEP, + std::vector&& inputsUsed ) { return new FusedGraphKernel( @@ -455,8 +455,8 @@ namespace Dml nonOwnedGraphInputsFromInitializers, initializeResourceRefs, initInputBindings, - isInputsUploadedByDmlEP, - inputsUsed + std::move(isInputsUploadedByDmlEP), + std::move(inputsUsed) ); } } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.h index 00a858d54e..ced6160a99 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.h @@ -15,7 +15,7 @@ namespace Dml std::vector>& nonOwnedGraphInputsFromInitializers, std::vector>& initializeResourceRefs, std::vector initInputBindings, - std::vector& isInputsUploadedByDmlEP, - std::vector& inputsUsed + std::vector&& isInputsUploadedByDmlEP, + std::vector&& inputsUsed ); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index b9c5f8849a..2c8d4e4459 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -209,14 +209,15 @@ namespace Dml // Creates a partition for a node which is not a DML graph node, and finalizes partitions // which are inputs of the new partition. - std::unique_ptr CreateNonGraphNodePartitionAndFinalizeInputs( + std::unique_ptr CreatePartitionAndFinalizeInputs( const onnxruntime::Node& node, bool isDmlNode, + bool isDmlGraphPartitionNode, std::unordered_map& nodeNameToPartitionMap ) { std::unique_ptr partition = std::make_unique(); - partition->SetIsDmlGraphPartition(false); + partition->SetIsDmlGraphPartition(isDmlGraphPartitionNode); partition->SetIsDmlPartition(isDmlNode); partition->AddNodeIndex(node.Index()); @@ -383,7 +384,7 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, - std::function onNodeUnsupportedInGraph) + gsl::span additionalSplittingNodes) { // Nodes are uniquely identified by the name of their first output argument std::vector> partitions; @@ -420,6 +421,8 @@ namespace Dml // Check whether this graph is a subgraph, or contains any node with a subgraph. bool modelUsesSubgraph = ModelUsesSubgraph(graph); + uint32_t splittingNodeIndex = 0; + // Build up partitions while traversing the graph. for (size_t nodeIndex : toplogicalOrder) { @@ -456,12 +459,14 @@ namespace Dml // anyhow due to CPU/GPU copies. if (modelUsesSubgraph || !isDmlGraphNode) { - if (onNodeUnsupportedInGraph) - { - onNodeUnsupportedInGraph(node); - } + partitions.push_back(CreatePartitionAndFinalizeInputs(node, isDmlNode, false, nodeNameToPartitionMap)); + continue; + } - partitions.push_back(CreateNonGraphNodePartitionAndFinalizeInputs(node, isDmlNode, nodeNameToPartitionMap)); + if (splittingNodeIndex < additionalSplittingNodes.size() && additionalSplittingNodes[splittingNodeIndex] == nodeIndex) + { + partitions.push_back(CreatePartitionAndFinalizeInputs(node, isDmlNode, isDmlGraphNode, nodeNameToPartitionMap)); + ++splittingNodeIndex; continue; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h index c5b15c7b1c..990ba00fc4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h @@ -48,5 +48,5 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, - std::function onNodeUnsupportedInGraph = nullptr); + gsl::span additionalSplittingNodes); } // namespace Dml