From 8f900666394d7b73a4a0ed8b399ccdbfa714d4b1 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Thu, 23 Feb 2023 18:13:53 -0800 Subject: [PATCH] [DML EP] Fix GetInputTensor crash when accessing null tensor (#14811) --- .../DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index e811505ddb..dec109c2e9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1827,9 +1827,8 @@ namespace Windows::AI::MachineLearning::Adapter ML_CHECK_BOOL(inputIndex < m_inputTensors.size()); auto opKernelContextWrapper = const_cast(this); - if (m_inputTensors[inputIndex][0]->GetInterface() == nullptr) + if (m_inputTensors[inputIndex][0] == nullptr) { - assert(m_impl->InputType(gsl::narrow_cast(inputIndex))->IsTensorType()); auto inputTensor = m_impl->Input(gsl::narrow_cast(inputIndex)); if (inputTensor != nullptr) { @@ -1867,9 +1866,8 @@ namespace Windows::AI::MachineLearning::Adapter opKernelContextWrapper->m_inputTensors[inputIndex].resize(sequenceIndex+1); } - if (m_inputTensors[inputIndex][sequenceIndex]->GetInterface() == nullptr) + if (m_inputTensors[inputIndex][sequenceIndex] == nullptr) { - assert(m_impl->InputType(gsl::narrow_cast(inputIndex))->IsTensorSequenceType()); auto inputTensorSeq = m_impl->Input(gsl::narrow_cast(inputIndex)); ML_CHECK_BOOL(inputTensorSeq != nullptr); @@ -1918,7 +1916,7 @@ namespace Windows::AI::MachineLearning::Adapter } // Verify that the provided shape matches the shape determined using the kernel's shape inference function. - if (m_outputTensors[outputIndex][sequenceIndex]->GetInterface() == nullptr) + if (m_outputTensors[outputIndex][sequenceIndex] == nullptr) { auto outputTensorSeq = m_impl->Output(gsl::narrow_cast(outputIndex)); ML_CHECK_BOOL(outputTensorSeq != nullptr); @@ -2026,7 +2024,7 @@ namespace Windows::AI::MachineLearning::Adapter ML_CHECK_BOOL(outputIndex < m_outputTensors.size()); // Verify that the provided shape matches the shape determined using the kernel's shape inference function. - if (m_outputTensors[outputIndex][0]->GetInterface() == nullptr) + if (m_outputTensors[outputIndex][0] == nullptr) { if (m_outputShapes) {