mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
[DML EP] Fix GetInputTensor crash when accessing null tensor (#14809)
This commit is contained in:
parent
29428cd9dc
commit
abb51ec975
1 changed files with 4 additions and 6 deletions
|
|
@ -1909,9 +1909,8 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
ML_CHECK_BOOL(inputIndex < m_inputTensors.size());
|
||||
|
||||
auto opKernelContextWrapper = const_cast<OpKernelContextWrapper*>(this);
|
||||
if (m_inputTensors[inputIndex][0]->GetInterface() == nullptr)
|
||||
if (m_inputTensors[inputIndex][0] == nullptr)
|
||||
{
|
||||
assert(m_impl->InputType(gsl::narrow_cast<int>(inputIndex))->IsTensorType());
|
||||
auto inputTensor = m_impl->Input<onnxruntime::Tensor>(gsl::narrow_cast<int>(inputIndex));
|
||||
if (inputTensor != nullptr)
|
||||
{
|
||||
|
|
@ -1949,9 +1948,8 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
opKernelContextWrapper->m_inputTensors[inputIndex].resize(sequenceIndex+1);
|
||||
}
|
||||
|
||||
if (m_inputTensors[inputIndex][sequenceIndex]->GetInterface() == nullptr)
|
||||
if (m_inputTensors[inputIndex][sequenceIndex] == nullptr)
|
||||
{
|
||||
assert(m_impl->InputType(gsl::narrow_cast<int>(inputIndex))->IsTensorSequenceType());
|
||||
auto inputTensorSeq = m_impl->Input<onnxruntime::TensorSeq>(gsl::narrow_cast<int>(inputIndex));
|
||||
ML_CHECK_BOOL(inputTensorSeq != nullptr);
|
||||
|
||||
|
|
@ -2000,7 +1998,7 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
}
|
||||
|
||||
// Verify that the provided shape matches the shape determined using the kernel's shape inference function.
|
||||
if (m_outputTensors[outputIndex][sequenceIndex]->GetInterface() == nullptr)
|
||||
if (m_outputTensors[outputIndex][sequenceIndex] == nullptr)
|
||||
{
|
||||
auto outputTensorSeq = m_impl->Output<onnxruntime::TensorSeq>(gsl::narrow_cast<int>(outputIndex));
|
||||
ML_CHECK_BOOL(outputTensorSeq != nullptr);
|
||||
|
|
@ -2108,7 +2106,7 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
ML_CHECK_BOOL(outputIndex < m_outputTensors.size());
|
||||
|
||||
// Verify that the provided shape matches the shape determined using the kernel's shape inference function.
|
||||
if (m_outputTensors[outputIndex][0]->GetInterface() == nullptr)
|
||||
if (m_outputTensors[outputIndex][0] == nullptr)
|
||||
{
|
||||
if (m_outputShapes)
|
||||
{
|
||||
|
|
|
|||
Loading…
Reference in a new issue