From 93fb62bb3e891c8f631e10daa32eba979f87ded2 Mon Sep 17 00:00:00 2001 From: Yuan Yu Date: Wed, 17 Jul 2019 14:45:50 -0700 Subject: [PATCH] More code cleanup (#1405) * More code cleanup * More cleanup --- onnxruntime/core/session/inference_session.cc | 48 +------------------ .../test/framework/inference_session_test.cc | 15 +++--- 2 files changed, 9 insertions(+), 54 deletions(-) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 26a630efaa..248e21d58a 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -572,34 +572,13 @@ common::Status InferenceSession::ValidateInputs(const std::vector& feeds.size(), " elements."); } - std::unordered_set seen_names; - seen_names.reserve(feeds.size()); - size_t seen_required_inputs = 0; - const Graph& graph = model_->MainGraph(); - for (size_t i = 0; i < feeds.size(); ++i) { const auto& feed_name = feed_names[i]; - if (seen_names.insert(feed_name).second == false) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Duplicate name in feeds: ", feed_name); - } - auto iter = input_def_map_.find(feed_name); if (input_def_map_.end() == iter) { - // if IR < 4 all initializers are required to have a matching graph input with the same name, - // however we disallow using that graph input to override the initializer, and treat the initializers as constant. - // check for this and output a nicer error message if that is the case. - // As we've already moved all initializers to SessionState we need to check if it's in the constant initializers there - int idx; - bool is_constant_initializer = session_state_.GetOrtValueNameIdxMap().GetIdx(feed_name, idx).IsOK() && - session_state_.GetConstantInitializedTensors().find(idx) != - session_state_.GetConstantInitializedTensors().cend(); - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Invalid Feed Input Name:", feed_name, - is_constant_initializer ? ". Initializers may not be overridden by feeds" - " if model IR version is less than 4." - : "."); + "Invalid Feed Input Name:", feed_name); } auto expected_type = utils::GetMLDataType(*iter->second); @@ -612,31 +591,6 @@ common::Status InferenceSession::ValidateInputs(const std::vector& auto input_type = input_ml_value.Type(); ORT_RETURN_IF_ERROR(CheckTypes(input_type, expected_type)); } - - if (!graph.CanOverrideInitializer() || // all entries in input_def_map_ are required. - required_inputs_.find(feed_name) != required_inputs_.cend()) { - ++seen_required_inputs; - } - } - - if (seen_required_inputs < required_inputs_.size()) { - std::ostringstream req_input_str; - auto cur = required_inputs_.cbegin(), end = required_inputs_.cend(); - req_input_str << "Required inputs: "; - req_input_str << *(cur++); - while (cur != end) { - req_input_str << ", " << *(cur++); - } - - req_input_str << " . Got: "; - for (size_t i = 0; i < feed_names.size(); ++i) { - if (i > 0) - req_input_str << ", "; - req_input_str << feed_names[i]; - } - - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "One or more missing required inputs. ", - req_input_str.str()); } return Status::OK(); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index cfbd405a86..1c4cdad4f8 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -914,23 +914,24 @@ TEST(InferenceSessionTests, TestOptionalInputs) { // required and optional input status = RunOptionalInputTest(true, true, false, version); - if (version < 4) { - ASSERT_FALSE(status.IsOK()); - EXPECT_THAT(status.ErrorMessage(), - testing::HasSubstr("Initializers may not be overridden by feeds if model IR version is less than 4")); + if (version == 3) { + ASSERT_FALSE(status.IsOK()) << status.ErrorMessage(); } else { ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); } - // required, optional and invalid input status = RunOptionalInputTest(true, true, true, version); ASSERT_FALSE(status.IsOK()); EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Name")); // missing required - status = RunOptionalInputTest(false, false, false, version); + status = RunOptionalInputTest(false, true, false, version); ASSERT_FALSE(status.IsOK()); - EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("One or more missing required inputs")); + if (version == 3) { + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Name")); + } else { + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Missing Input:")); + } } }