[DML EP] Fix GetInputTensor crash when accessing null tensor (#14809)

This commit is contained in:
Patrice Vignola 2023-02-23 18:13:40 -08:00 committed by GitHub
parent 29428cd9dc
commit abb51ec975
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1909,9 +1909,8 @@ namespace Windows::AI::MachineLearning::Adapter
ML_CHECK_BOOL(inputIndex < m_inputTensors.size());
auto opKernelContextWrapper = const_cast<OpKernelContextWrapper*>(this);
if (m_inputTensors[inputIndex][0]->GetInterface() == nullptr)
if (m_inputTensors[inputIndex][0] == nullptr)
{
assert(m_impl->InputType(gsl::narrow_cast<int>(inputIndex))->IsTensorType());
auto inputTensor = m_impl->Input<onnxruntime::Tensor>(gsl::narrow_cast<int>(inputIndex));
if (inputTensor != nullptr)
{
@ -1949,9 +1948,8 @@ namespace Windows::AI::MachineLearning::Adapter
opKernelContextWrapper->m_inputTensors[inputIndex].resize(sequenceIndex+1);
}
if (m_inputTensors[inputIndex][sequenceIndex]->GetInterface() == nullptr)
if (m_inputTensors[inputIndex][sequenceIndex] == nullptr)
{
assert(m_impl->InputType(gsl::narrow_cast<int>(inputIndex))->IsTensorSequenceType());
auto inputTensorSeq = m_impl->Input<onnxruntime::TensorSeq>(gsl::narrow_cast<int>(inputIndex));
ML_CHECK_BOOL(inputTensorSeq != nullptr);
@ -2000,7 +1998,7 @@ namespace Windows::AI::MachineLearning::Adapter
}
// Verify that the provided shape matches the shape determined using the kernel's shape inference function.
if (m_outputTensors[outputIndex][sequenceIndex]->GetInterface() == nullptr)
if (m_outputTensors[outputIndex][sequenceIndex] == nullptr)
{
auto outputTensorSeq = m_impl->Output<onnxruntime::TensorSeq>(gsl::narrow_cast<int>(outputIndex));
ML_CHECK_BOOL(outputTensorSeq != nullptr);
@ -2108,7 +2106,7 @@ namespace Windows::AI::MachineLearning::Adapter
ML_CHECK_BOOL(outputIndex < m_outputTensors.size());
// Verify that the provided shape matches the shape determined using the kernel's shape inference function.
if (m_outputTensors[outputIndex][0]->GetInterface() == nullptr)
if (m_outputTensors[outputIndex][0] == nullptr)
{
if (m_outputShapes)
{