From 989b00321e9a2bc2043764622dc84b6ed1aeee03 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 29 Nov 2018 14:31:36 +1000 Subject: [PATCH] Update session state initializer to support overriding initializers. Update test --- onnxruntime/core/framework/session_state_initializer.cc | 5 ++--- onnxruntime/core/session/inference_session.cc | 6 +++--- onnxruntime/test/framework/inference_session_test.cc | 7 +++++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/framework/session_state_initializer.cc b/onnxruntime/core/framework/session_state_initializer.cc index 5cef936a71..edceb317f7 100644 --- a/onnxruntime/core/framework/session_state_initializer.cc +++ b/onnxruntime/core/framework/session_state_initializer.cc @@ -486,8 +486,7 @@ static bool IsArgNameInInputsOutputs(const std::string& name, common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph, const KernelRegistryManager& custom_registry_manager, SessionState& session_state) { - auto& weights_map = graph.GetAllInitializedTensors(); - auto& graph_inputs = graph.GetInputs(); + auto& graph_inputs = graph.GetInputsIncludingInitializers(); auto& graph_outputs = graph.GetOutputs(); for (auto& node : graph.Nodes()) { @@ -495,7 +494,7 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph onnxruntime::Node::ForEachWithIndex( node.InputDefs(), [&](const onnxruntime::NodeArg& arg, size_t index) { - if (arg.Name().empty() || weights_map.count(arg.Name())) { + if (arg.Name().empty()) { return Status::OK(); } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index f4cd485c5e..6178e7810c 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -454,9 +454,9 @@ class InferenceSession::Impl { [&ostr](const std::string& elem) { ostr << elem << " "; }); - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Invalid Feed Input Names:" + invalid_names.str() + - " Valid input names are: " + ostr.str()); + return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid Feed Input Names:", invalid_names.str(), + ". Valid input names are: ", ostr.str()); } return Status::OK(); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index f8317bcbaa..c2010b3af2 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -924,11 +924,14 @@ TEST(InferenceSessionTests, TestOptionalInputs) { // required, optional and invalid input status = RunOptionalInputTest(true, true, true); - ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Names: unknown_input")); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Valid input names are: required_input optional_input")); // missing required status = RunOptionalInputTest(false, true, false); - ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Missing required inputs: required_input")); } TEST(ExecutionProviderTest, FunctionTest) {