From dbcc60bed5b094e0e769f6b4ef957918541000bb Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 5 Sep 2023 15:25:12 -0700 Subject: [PATCH] Introduce output type/shape validation (#17301) ### Description Validate outputs type and shapes. Make sure sparse initializers are taken into account. ### Motivation and Context ORT currently does not validate output types or shapes. Further, neither inputs or outputs take into account sparse initializers that are converted from dense. It is currently possible to pre-allocate a wrong type/shape buffer for output. Cc: @Craigacp --- onnxruntime/core/session/inference_session.cc | 257 ++++++++++-------- onnxruntime/core/session/inference_session.h | 41 +-- .../test/framework/execution_frame_test.cc | 28 +- .../test/framework/inference_session_test.cc | 4 +- onnxruntime/test/shared_lib/test_inference.cc | 4 +- 5 files changed, 187 insertions(+), 147 deletions(-) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 6a70176ebc..5a2a6efb6d 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1829,83 +1829,102 @@ const DataTransferManager& InferenceSession::GetDataTransferManager() const { return data_transfer_mgr_; } -common::Status InferenceSession::CheckShapes(const std::string& input_name, const TensorShape& input_shape, - const TensorShape& expected_shape) const { - auto input_shape_sz = input_shape.NumDimensions(); - auto expected_shape_sz = expected_shape.NumDimensions(); - if (input_shape_sz != expected_shape_sz) { - std::ostringstream ostr; - ostr << "Invalid rank for input: " << input_name << " Got: " << input_shape_sz << " Expected: " << expected_shape_sz - << " Please fix either the inputs or the model."; - return Status(ONNXRUNTIME, INVALID_ARGUMENT, ostr.str()); +common::Status InferenceSession::CheckShapes(const std::string& input_output_name, const TensorShape& input_output_shape, + const TensorShape& expected_shape, const char* input_output_moniker) const { + const auto shape_size = input_output_shape.NumDimensions(); + const auto expected_shape_size = expected_shape.NumDimensions(); + if (shape_size != expected_shape_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid rank for ", input_output_moniker, ": ", + input_output_name, " Got: ", shape_size, " Expected: ", expected_shape_size, + " Please fix either the inputs/outputs or the model."); } - std::vector invalid_dim_indices; - for (size_t i = 0; i < input_shape_sz; ++i) { + InlinedVector invalid_dim_indices; + for (size_t i = 0; i < shape_size; ++i) { if (expected_shape[i] < 0) { continue; // this represents a symbolic shape dimension } - if (input_shape[i] != expected_shape[i]) { + if (input_output_shape[i] != expected_shape[i]) { invalid_dim_indices.push_back(i); } } if (!invalid_dim_indices.empty()) { std::ostringstream ostr; - ostr << "Got invalid dimensions for input: " << input_name << " for the following indices\n"; + ostr << "Got invalid dimensions for " << input_output_moniker << ": " << input_output_name << " for the following indices\n"; for (size_t i = 0, end = invalid_dim_indices.size(); i < end; ++i) { size_t idx = invalid_dim_indices[i]; - ostr << " index: " << idx << " Got: " << input_shape[idx] << " Expected: " << expected_shape[idx] << "\n"; + ostr << " index: " << idx << " Got: " << input_output_shape[idx] << " Expected: " << expected_shape[idx] << "\n"; } - ostr << " Please fix either the inputs or the model."; + ostr << " Please fix either the inputs/outputs or the model."; return Status(ONNXRUNTIME, INVALID_ARGUMENT, ostr.str()); } return Status::OK(); } -static common::Status CheckTypes(MLDataType actual, MLDataType expected, const std::string& base_type) { +static common::Status CheckTypes(MLDataType actual, MLDataType expected, const std::string& base_type, + const char* input_output_moniker) { if (actual == expected) { return Status::OK(); } - std::ostringstream ostr; - ostr << "Unexpected input data type. Actual: ("; - ostr << base_type; - ostr << "("; - ostr << DataTypeImpl::ToString(actual); - ostr << ")) , expected: ("; - ostr << base_type; - ostr << "("; - ostr << DataTypeImpl::ToString(expected); - ostr << "))"; - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str()); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unexpected ", input_output_moniker, " data type. Actual: (", + base_type, "(", + DataTypeImpl::ToString(actual), ")) , expected: (", base_type, "(", + DataTypeImpl::ToString(expected), "))"); } -common::Status InferenceSession::ValidateInputs(gsl::span feed_names, - gsl::span feeds) const { - if (feed_names.size() != feeds.size()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Size mismatch: feed_names has ", feed_names.size(), - "elements, but feeds has ", feeds.size(), " elements."); +common::Status InferenceSession::ValidateInputsOutputs(gsl::span names, + gsl::span feeds_fetches, + const InputOutputDefMetaMap& input_output_meta_map, + ArgType arg_type) const { + ORT_ENFORCE(arg_type == ArgType::kInput || arg_type == ArgType::kOutput, "Valid values kInput, kOutput"); + + const bool is_inputs = arg_type == ArgType::kInput; + + const char* const input_output_moniker = is_inputs ? "input" : "output"; + const char* const feed_fetches_moniker = is_inputs ? "feed" : "fetch"; + +#if !defined(DISABLE_SPARSE_TENSORS) + auto is_sparse_initializer = [this](const std::string& name) -> bool { + int idx = -1; + if (session_state_->GetOrtValueNameIdxMap().GetIdx(name, idx).IsOK()) { + return session_state_->IsSparseInitializer(idx); + } + return false; + }; +#endif + + if (names.size() != feeds_fetches.size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, feed_fetches_moniker, " names has ", names.size(), + " elements, but ", feed_fetches_moniker, " has ", feeds_fetches.size(), " elements."); } - for (size_t i = 0; i < feeds.size(); ++i) { - const auto& feed_name = feed_names[i]; + for (size_t i = 0; i < feeds_fetches.size(); ++i) { + const auto& name = names[i]; - auto iter = input_def_map_.find(feed_name); - if (input_def_map_.end() == iter) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid Feed Input Name:", feed_name); + auto iter = input_output_meta_map.find(name); + if (input_output_meta_map.end() == iter) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid ", input_output_moniker, " name: ", name); + } + + const auto& input_output_ml_value = feeds_fetches[i]; + + // For outputs the user may supply an unallocated placeholder. + if (!is_inputs && !input_output_ml_value.IsAllocated()) { + continue; } auto expected_type = iter->second.ml_data_type; - auto& input_ml_value = feeds[i]; - if (input_ml_value.IsTensor()) { + + if (input_output_ml_value.IsTensor()) { if (!expected_type->IsTensorType() #if !defined(DISABLE_OPTIONAL_TYPE) && !utils::IsOptionalTensor(expected_type) #endif ) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name, - " is not expected to be of type tensor."); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, input_output_moniker, " with name: '", name, + "' expected to be of type: ", static_cast(expected_type->type_), " but received a tensor"); } // check for type @@ -1919,44 +1938,56 @@ common::Status InferenceSession::ValidateInputs(gsl::span fee auto expected_element_type = expected_type->AsTensorType()->GetElementType(); #endif - auto input_element_type = input_ml_value.Get().DataType(); - ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_element_type, expected_element_type, "tensor")); + const auto& input_output_tensor = input_output_ml_value.Get(); + ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_output_tensor.DataType(), + expected_element_type, "tensor", input_output_moniker)); // check for shape - const auto& expected_shape = iter->second.tensor_shape; - if (expected_shape.NumDimensions() > 0) { - const auto& input_shape = input_ml_value.Get().Shape(); - ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(feed_name, input_shape, expected_shape)); + if (iter->second.tensor_shape.has_value()) { + ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, input_output_tensor.Shape(), + *iter->second.tensor_shape, input_output_moniker)); } - } else if (input_ml_value.IsSparseTensor()) { + } else if (input_output_ml_value.IsSparseTensor()) { #if !defined(DISABLE_SPARSE_TENSORS) - if (!expected_type->IsSparseTensorType()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name, - " is not expected to be of type sparse tensor."); - } - auto expected_element_type = expected_type->AsSparseTensorType()->GetElementType(); - const SparseTensor& sparse_tensor = input_ml_value.Get(); - auto input_element_type = sparse_tensor.DataType(); - ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_element_type, expected_element_type, "sparse_tensor")); - // Check shape - const auto& expected_shape = iter->second.tensor_shape; - if (expected_shape.NumDimensions() > 0) { - const auto& input_shape = sparse_tensor.DenseShape(); - ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(feed_name, input_shape, expected_shape)); + + const SparseTensor& sparse_tensor = input_output_ml_value.Get(); + if (expected_type->IsSparseTensorType()) { + auto expected_element_type = expected_type->AsSparseTensorType()->GetElementType(); + ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(sparse_tensor.DataType(), expected_element_type, + "sparse_tensor", input_output_moniker)); + // Check shape + if (iter->second.tensor_shape.has_value()) { + ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, sparse_tensor.DenseShape(), + *iter->second.tensor_shape, input_output_moniker)); + } + } else if (is_sparse_initializer(name) && + expected_type->IsTensorType()) { + // If this metadata came from a sparse initializer converted to dense, then still validate it. + auto expected_element_type = expected_type->AsTensorType()->GetElementType(); + ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(sparse_tensor.DataType(), expected_element_type, + "sparse_tensor", input_output_moniker)); + // Check shape + if (iter->second.tensor_shape.has_value()) { + ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, sparse_tensor.DenseShape(), + *iter->second.tensor_shape, input_output_moniker)); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, input_output_moniker, " with name: '", name, + "' expected to be of type: ", static_cast(expected_type->type_), " but received a sparse tensor"); } #else - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name ", feed_name, + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, input_output_moniker, " with name ", name, " is a sparse tensor, which is not supported in this build."); #endif - } else if (input_ml_value.IsTensorSequence()) { + } else if (input_output_ml_value.IsTensorSequence()) { if (!expected_type->IsTensorSequenceType() #if !defined(DISABLE_OPTIONAL_TYPE) && !utils::IsOptionalSeqTensor(expected_type) #endif ) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name, - " is not expected to be of type tensor sequence."); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, input_output_moniker, " with name: '", name, + "' expected to be of type: ", static_cast(expected_type->type_), " but received a tensor sequence"); } #if !defined(DISABLE_OPTIONAL_TYPE) @@ -1969,43 +2000,40 @@ common::Status InferenceSession::ValidateInputs(gsl::span fee auto expected_element_type = expected_type->AsSequenceTensorType()->GetElementType(); #endif - auto input_element_type = input_ml_value.Get().DataType(); - ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_element_type, expected_element_type, "seq")); + auto input_output_element_type = input_output_ml_value.Get().DataType(); + ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_output_element_type, expected_element_type, "seq", input_output_moniker)); } else { - auto input_type = input_ml_value.Type(); - ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_type, expected_type, "")); + auto input_output_type = input_output_ml_value.Type(); + ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_output_type, expected_type, "", input_output_moniker)); } } return Status::OK(); } +common::Status InferenceSession::ValidateInputs(gsl::span feed_names, + gsl::span feeds) const { + return ValidateInputsOutputs(feed_names, feeds, input_def_map_, ArgType::kInput); +} + common::Status InferenceSession::ValidateOutputs(gsl::span output_names, const std::vector* p_fetches) const { - if (p_fetches == nullptr) { - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Output vector pointer is NULL"); - } - if (output_names.empty()) { return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "At least one output should be requested."); } - if (!p_fetches->empty() && (output_names.size() != p_fetches->size())) { - std::ostringstream ostr; - ostr << "Output vector incorrectly sized: output_names.size(): " << output_names.size() - << "p_fetches->size(): " << p_fetches->size(); - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str()); - } + const auto fetches = (p_fetches == nullptr) ? EmptySpan() : gsl::make_span(*p_fetches); - for (const auto& name : output_names) { - if (model_output_names_.find(name) == model_output_names_.end()) { - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid Output Name:" + name); + if (fetches.empty()) { + for (const auto& name : output_names) { + if (output_def_map_.count(name) == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid output name:", name); + } } + return Status::OK(); } - // TODO add more validation here like checking shape of the allocated buffers - - return common::Status::OK(); + return ValidateInputsOutputs(output_names, fetches, output_def_map_, ArgType::kOutput); } #ifdef ENABLE_TRAINING @@ -2483,7 +2511,7 @@ std::pair InferenceSession::GetModelOutput } } - return std::make_pair(common::Status::OK(), &output_def_list_); + return std::make_pair(common::Status::OK(), &model_->MainGraph().GetOutputs()); } common::Status InferenceSession::NewIOBinding(std::unique_ptr* io_binding) { @@ -2697,43 +2725,40 @@ common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& mod model_metadata_.custom_metadata_map = model.MetaData(); model_metadata_.graph_name = graph.Name(); - required_inputs_.clear(); - for (auto input : graph.GetInputs()) { - required_inputs_.insert(input->Name()); - } - - auto add_inputs = [this](const InputDefList& inputs) { - input_def_map_.clear(); - input_def_map_.reserve(inputs.size()); - for (auto elem : inputs) { + auto add_inputs_outputs = [](const InputDefList& inputs_outputs, InputOutputDefMetaMap& map) { + map.reserve(inputs_outputs.size()); + for (auto elem : inputs_outputs) { auto elem_type = utils::GetMLDataType(*elem); - auto elem_shape_proto = elem->Shape(); - input_def_map_.insert( - {elem->Name(), - InputDefMetaData( - elem, elem_type, - elem_shape_proto ? utils::GetTensorShapeFromTensorShapeProto(*elem_shape_proto) : TensorShape())}); + const auto* elem_shape_proto = elem->Shape(); + if (elem_shape_proto != nullptr) { + map.emplace(elem->Name(), InputOutputDefMetaData( + elem, elem_type, + utils::GetTensorShapeFromTensorShapeProto(*elem_shape_proto))); + } else { + map.emplace(elem->Name(), InputOutputDefMetaData(elem, elem_type)); + } } }; - if (graph.CanOverrideInitializer()) { - // for IR 4 or higher it is optional to have a matching graph input for an initializer, and if one exists the - // initializer is explicitly overridable. - add_inputs(graph.GetInputsIncludingInitializers()); - } else { - // for IR < 4 we don't allow overriding initializers so that they can be treated as constant. exclude them from - // the list of valid inputs by just using the GetInputs() list. - add_inputs(graph.GetInputs()); + { + InputOutputDefMetaMap input_defs; + if (graph.CanOverrideInitializer()) { + // for IR 4 or higher it is optional to have a matching graph input for an initializer, and if one exists the + // initializer is explicitly overridable. + add_inputs_outputs(graph.GetInputsIncludingInitializers(), input_defs); + } else { + // for IR < 4 we don't allow overriding initializers so that they can be treated as constant. exclude them from + // the list of valid inputs by just using the GetInputs() list. + add_inputs_outputs(graph.GetInputs(), input_defs); + } + input_def_map_.swap(input_defs); } - // save outputs const auto& outputs = graph.GetOutputs(); - output_def_list_ = outputs; // A direct copy of outputs - - model_output_names_.clear(); - model_output_names_.reserve(outputs.size()); - for (const auto& elem : outputs) { - model_output_names_.insert(elem->Name()); + { + InputOutputDefMetaMap output_defs; + add_inputs_outputs(outputs, output_defs); + output_def_map_.swap(output_defs); } VLOGS(*session_logger_, 1) << "Done saving model metadata"; diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index e4127085b3..9259e014b9 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include @@ -103,6 +104,22 @@ struct ModelMetadata { */ class InferenceSession { + struct InputOutputDefMetaData { + InputOutputDefMetaData(const NodeArg* node_arg0, MLDataType ml_data_type0, TensorShape&& tensor_shape0) + : node_arg(node_arg0), ml_data_type(ml_data_type0), tensor_shape(std::move(tensor_shape0)) { + } + + InputOutputDefMetaData(const NodeArg* node_arg0, MLDataType ml_data_type0) + : node_arg(node_arg0), ml_data_type(ml_data_type0) { + } + + gsl::not_null node_arg; + MLDataType ml_data_type; + std::optional tensor_shape; // not applicable if the input is non-tensor type + }; + + using InputOutputDefMetaMap = InlinedHashMap; + public: #if !defined(ORT_MINIMAL_BUILD) @@ -570,9 +587,6 @@ class InferenceSession { // if they need. std::shared_ptr model_; - // names of model outputs used for quick validation. - std::unordered_set model_output_names_; - // The file path of where the model was loaded. e.g. /tmp/test_squeezenet/model.onnx PathString model_location_; @@ -628,7 +642,7 @@ class InferenceSession { void InitLogger(logging::LoggingManager* logging_manager); [[nodiscard]] common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape, - const TensorShape& expected_shape) const; + const TensorShape& expected_shape, const char* input_output_moniker) const; [[nodiscard]] common::Status ValidateInputs(gsl::span feed_names, gsl::span feeds) const; @@ -636,6 +650,11 @@ class InferenceSession { [[nodiscard]] common::Status ValidateOutputs(gsl::span output_names, const std::vector* p_fetches) const; + [[nodiscard]] common::Status ValidateInputsOutputs(gsl::span feed_fetches_names, + gsl::span feeds_fetches, + const InputOutputDefMetaMap& input_output_meta_map, + ArgType arg_type) const; + [[nodiscard]] common::Status WaitForNotification(Notification* p_executor_done, int64_t timeout_in_ms); template @@ -737,19 +756,9 @@ class InferenceSession { #endif ModelMetadata model_metadata_; - std::unordered_set required_inputs_; - struct InputDefMetaData { - InputDefMetaData(const NodeArg* node_arg0, MLDataType ml_data_type0, TensorShape&& tensor_shape0) - : node_arg(node_arg0), ml_data_type(ml_data_type0), tensor_shape(std::move(tensor_shape0)) { - } - const NodeArg* node_arg; - MLDataType ml_data_type; - TensorShape tensor_shape; // not applicable if the input is non-tensor type - }; - - std::unordered_map input_def_map_; - OutputDefList output_def_list_; + InputOutputDefMetaMap input_def_map_; + InputOutputDefMetaMap output_def_map_; // Data transfer manager. DataTransferManager data_transfer_mgr_; diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index 4da0d9b488..ec572ce9de 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -496,14 +496,16 @@ TEST(ExecutionFrameTestInit, InitializerAsOutput) { #if !defined(DISABLE_SPARSE_TENSORS) TEST(ExecutionFrameTestInit, SparseInitializerAsOutput) { - const std::vector dense_shape{3, 3}; - std::vector dense_data = { - 0, 0, 1.764052391052246f, - 0.40015721321105957f, 0, 0.978738009929657f, - 0, 0, 0}; + constexpr std::array dense_shape{3, 3}; - const std::vector expected_values = {1.764052391052246f, 0.40015721321105957f, 0.978738009929657f}; - const std::vector expected_linear_indices = {2, 3, 5}; + // Tensor data in a dense form, useful for debugging and reference. + // constexpr std::array dense_data = { + // 0, 0, 1.764052391052246f, + // 0.40015721321105957f, 0, 0.978738009929657f, + // 0, 0, 0}; + + constexpr std::array expected_values = {1.764052391052246f, 0.40015721321105957f, 0.978738009929657f}; + constexpr std::array expected_linear_indices = {2, 3, 5}; // sparse_initializer_as_output.onnx SessionOptions so; @@ -515,14 +517,18 @@ TEST(ExecutionFrameTestInit, SparseInitializerAsOutput) { ASSERT_STATUS_OK(session.Initialize()); auto allocator = test::AllocatorManager::Instance().GetAllocator(CPU); - auto p_tensor = std::make_unique(); std::vector results; results.resize(1); - auto ml_type = DataTypeImpl::GetType(); - results[0].Init(p_tensor.release(), ml_type, ml_type->GetDeleteFunc()); + + // Initialize the output value as a SparseTensor with pre-allocated memory + // this is done here to test output types. + auto element_type = DataTypeImpl::GetSparseTensorType()->AsSparseTensorType()->GetElementType(); + SparseTensor::InitOrtValue(element_type, TensorShape(dense_shape), allocator, results[0]); + RunOptions ro; - ASSERT_STATUS_OK(session.Run(ro, EmptySpan(), EmptySpan(), AsSpan({"values"}), &results, nullptr)); + ASSERT_STATUS_OK(session.Run(ro, EmptySpan(), EmptySpan(), + AsSpan({"values"}), &results, nullptr)); ASSERT_TRUE(results[0].IsAllocated()); ASSERT_TRUE(results[0].IsSparseTensor()); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index fa3d61a28b..077c6ff58e 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -1218,13 +1218,13 @@ TEST(InferenceSessionTests, TestOptionalInputs) { // required, optional and invalid input status = RunOptionalInputTest(true, true, true, version, sess_env); ASSERT_FALSE(status.IsOK()); - EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Name")); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid input name")); // missing required status = RunOptionalInputTest(false, true, false, version, sess_env); ASSERT_FALSE(status.IsOK()); if (version == 3) { - EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Name")); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid input name")); } else { EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Missing Input:")); } diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index f3a0058c6f..8357ce22fb 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -159,8 +159,8 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod expected_values_y, nullptr); // with preallocated output tensor - Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), - expected_dims_y.data(), expected_dims_y.size()); + Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), + expected_dims_y.data(), expected_dims_y.size()); // test it twice for (int i = 0; i != 2; ++i)