mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
[DML EP] Fix external data unpacking (#19415)
### Description
This change
55a669409a
didn't take into account external data when unpacking initializer, and
therefore crashes when trying to unpack them.
This commit is contained in:
parent
91b2e660fe
commit
302d4be7d9
3 changed files with 20 additions and 8 deletions
|
|
@ -344,20 +344,25 @@ namespace Dml::GraphDescBuilder
|
|||
dmlFusedNodeInputIndex < isConstGpuGraphInputCount &&
|
||||
isConstGpuGraphInput[dmlFusedNodeInputIndex])
|
||||
{
|
||||
// This is a highly inefficient approach to generating constant nodes. It duplicates constant data
|
||||
// across the graph input as well as every consumer's unique constant node. However it is currently
|
||||
// This is a highly inefficient approach to generating constant nodes. It duplicates constant data
|
||||
// across the graph input as well as every consumer's unique constant node. However it is currently
|
||||
// only used for small inputs.
|
||||
uint32_t c_maxConstNodeDataSize = 8;
|
||||
|
||||
ComPtr<OnnxTensorWrapper> constantInput = constantCpuGraphInputGetter(arg->Name());
|
||||
|
||||
auto& operatorGraphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex];
|
||||
std::vector<DmlBufferTensorDesc*> toNodeInputTensorDescs = operatorGraphInputNode->GetInputTensors();
|
||||
DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex];
|
||||
ComPtr<OnnxTensorWrapper> constantInput;
|
||||
|
||||
if (constantInput && tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize)
|
||||
if (tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize)
|
||||
{
|
||||
// The tensor description's size should be no larger than the constant input unless it was rounded to
|
||||
constantInput = constantCpuGraphInputGetter(arg->Name());
|
||||
}
|
||||
|
||||
if (constantInput)
|
||||
{
|
||||
// The tensor description's size should be no larger than the constant input unless it was rounded to
|
||||
// the required alignment.
|
||||
assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes);
|
||||
size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast<size_t>(tensorDesc->totalTensorSizeInBytes));
|
||||
|
|
|
|||
|
|
@ -1123,7 +1123,7 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
}
|
||||
ORT_CATCH_RETURN
|
||||
}
|
||||
|
||||
|
||||
template <class NodeInfoImpl_t, class Base1_t, class Base2_t>
|
||||
HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper<NodeInfoImpl_t, Base1_t, Base2_t>::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept
|
||||
{
|
||||
|
|
@ -1168,7 +1168,7 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
m_requiredConstantCpuInputs.begin(),
|
||||
m_requiredConstantCpuInputs.end(),
|
||||
inputIndex) != m_requiredConstantCpuInputs.end();
|
||||
|
||||
|
||||
// This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present.
|
||||
ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant);
|
||||
}
|
||||
|
|
@ -1562,7 +1562,13 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
OnnxTensorWrapper::OnnxTensorWrapper(onnx::TensorProto* impl, const onnxruntime::Path& modelPath) : m_impl(impl)
|
||||
{
|
||||
// The tensor may be stored as raw data or in typed fields.
|
||||
if (impl->has_raw_data())
|
||||
if (impl->data_location() == onnx::TensorProto_DataLocation_EXTERNAL)
|
||||
{
|
||||
THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*impl, modelPath, m_unpackedExternalTensor));
|
||||
m_dataPtr = reinterpret_cast<std::byte*>(m_unpackedExternalTensor.data());
|
||||
m_tensorByteSize = m_unpackedExternalTensor.size();
|
||||
}
|
||||
else if (impl->has_raw_data())
|
||||
{
|
||||
m_dataPtr = reinterpret_cast<std::byte*>(impl->mutable_raw_data()->data());
|
||||
m_tensorByteSize = impl->raw_data().size();
|
||||
|
|
|
|||
|
|
@ -309,6 +309,7 @@ class OnnxTensorWrapper : public WRL::Base<IMLOperatorTensor>, public Closable
|
|||
private:
|
||||
size_t m_tensorByteSize = 0;
|
||||
std::unique_ptr<std::byte[]> m_unpackedTensor;
|
||||
std::vector<uint8_t> m_unpackedExternalTensor;
|
||||
std::byte* m_dataPtr = nullptr;
|
||||
|
||||
// Lifetime is managed by the caller and guaranteed to outlive this class
|
||||
|
|
|
|||
Loading…
Reference in a new issue