diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index c6a15e76f4..2456b396de 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -344,20 +344,25 @@ namespace Dml::GraphDescBuilder dmlFusedNodeInputIndex < isConstGpuGraphInputCount && isConstGpuGraphInput[dmlFusedNodeInputIndex]) { - // This is a highly inefficient approach to generating constant nodes. It duplicates constant data - // across the graph input as well as every consumer's unique constant node. However it is currently + // This is a highly inefficient approach to generating constant nodes. It duplicates constant data + // across the graph input as well as every consumer's unique constant node. However it is currently // only used for small inputs. uint32_t c_maxConstNodeDataSize = 8; - ComPtr constantInput = constantCpuGraphInputGetter(arg->Name()); auto& operatorGraphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; std::vector toNodeInputTensorDescs = operatorGraphInputNode->GetInputTensors(); DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; + ComPtr constantInput; - if (constantInput && tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) + if (tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) { - // The tensor description's size should be no larger than the constant input unless it was rounded to + constantInput = constantCpuGraphInputGetter(arg->Name()); + } + + if (constantInput) + { + // The tensor description's size should be no larger than the constant input unless it was rounded to // the required alignment. assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes); size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast(tensorDesc->totalTensorSizeInBytes)); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index dbd06abf82..d524780de7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1123,7 +1123,7 @@ namespace Windows::AI::MachineLearning::Adapter } ORT_CATCH_RETURN } - + template HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept { @@ -1168,7 +1168,7 @@ namespace Windows::AI::MachineLearning::Adapter m_requiredConstantCpuInputs.begin(), m_requiredConstantCpuInputs.end(), inputIndex) != m_requiredConstantCpuInputs.end(); - + // This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present. ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant); } @@ -1562,7 +1562,13 @@ namespace Windows::AI::MachineLearning::Adapter OnnxTensorWrapper::OnnxTensorWrapper(onnx::TensorProto* impl, const onnxruntime::Path& modelPath) : m_impl(impl) { // The tensor may be stored as raw data or in typed fields. - if (impl->has_raw_data()) + if (impl->data_location() == onnx::TensorProto_DataLocation_EXTERNAL) + { + THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*impl, modelPath, m_unpackedExternalTensor)); + m_dataPtr = reinterpret_cast(m_unpackedExternalTensor.data()); + m_tensorByteSize = m_unpackedExternalTensor.size(); + } + else if (impl->has_raw_data()) { m_dataPtr = reinterpret_cast(impl->mutable_raw_data()->data()); m_tensorByteSize = impl->raw_data().size(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index 6530d89d89..59e253e884 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -309,6 +309,7 @@ class OnnxTensorWrapper : public WRL::Base, public Closable private: size_t m_tensorByteSize = 0; std::unique_ptr m_unpackedTensor; + std::vector m_unpackedExternalTensor; std::byte* m_dataPtr = nullptr; // Lifetime is managed by the caller and guaranteed to outlive this class