diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index b3621cb1e1..d126fe5e59 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -115,7 +115,7 @@ void SessionState::AddInputNameToNodeInfoMapping(const std::string& input_name, common::Status SessionState::GetInputNodeInfo(const std::string& input_name, std::vector& node_info_vec) const { if (!input_names_to_nodeinfo_mapping_.count(input_name)) { - return Status(ONNXRUNTIME, FAIL, "Failed to find input name in the mapping"); + return Status(ONNXRUNTIME, FAIL, "Failed to find input name in the mapping: " + input_name); } node_info_vec = input_names_to_nodeinfo_mapping_.at(input_name); return Status::OK(); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e8b580ae31..f4cd485c5e 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -421,10 +421,21 @@ class InferenceSession::Impl { } common::Status ValidateInputNames(const NameMLValMap& feeds) { - if (model_input_names_.size() != feeds.size()) { + std::string missing_required_inputs; + + std::for_each(required_model_input_names_.cbegin(), required_model_input_names_.cend(), + [&](const std::string& required_input) { + if (feeds.find(required_input) == feeds.cend()) { + if (!missing_required_inputs.empty()) + missing_required_inputs += ","; + + missing_required_inputs += required_input; + } + }); + + if (!missing_required_inputs.empty()) { return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "The number of feeds is not same as the number of the model input, expect ", - model_input_names_.size(), " got ", feeds.size()); + "Missing required inputs: ", missing_required_inputs); } bool valid = true; @@ -804,7 +815,7 @@ class InferenceSession::Impl { } } - return std::make_pair(common::Status::OK(), &input_def_list_); + return std::make_pair(common::Status::OK(), &required_input_def_list_); } std::pair GetModelOutputs() const { @@ -896,28 +907,33 @@ class InferenceSession::Impl { model_metadata_.custom_metadata_map = model.MetaData(); model_metadata_.graph_name = graph.Name(); - // save inputs - auto& inputs = graph.GetInputs(); // inputs excluding initializers - input_def_list_.reserve(inputs.size()); - for (const auto& elem : inputs) { - if (!elem) { - return common::Status(common::ONNXRUNTIME, common::FAIL, "Got null input nodearg ptr"); - } + // save required inputs + const auto& required_inputs = graph.GetInputs(); // inputs excluding initializers + required_input_def_list_.reserve(required_inputs.size()); + required_model_input_names_.reserve(required_inputs.size()); + for (const auto& elem : required_inputs) { + required_input_def_list_.push_back(elem); + required_model_input_names_.insert(elem->Name()); + } + // save all valid inputs + const auto& all_inputs = graph.GetInputsIncludingInitializers(); + input_def_list_.reserve(all_inputs.size()); + model_input_names_.reserve(all_inputs.size()); + for (const auto& elem : all_inputs) { input_def_list_.push_back(elem); model_input_names_.insert(elem->Name()); } // save outputs - auto& outputs = graph.GetOutputs(); + const auto& outputs = graph.GetOutputs(); output_def_list_.reserve(outputs.size()); + model_output_names_.reserve(outputs.size()); for (const auto& elem : outputs) { - if (!elem) { - return common::Status(common::ONNXRUNTIME, common::FAIL, "Got null output nodearg ptr"); - } output_def_list_.push_back(elem); model_output_names_.insert(elem->Name()); } + VLOGS(*session_logger_, 1) << "Done saving model metadata"; return common::Status::OK(); } @@ -1030,10 +1046,12 @@ class InferenceSession::Impl { SessionState session_state_; ModelMetadata model_metadata_; + InputDefList required_input_def_list_; InputDefList input_def_list_; OutputDefList output_def_list_; // names of model inputs and outputs used for quick validation. + std::unordered_set required_model_input_names_; std::unordered_set model_input_names_; std::unordered_set model_output_names_; diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index f6209e0a62..f8317bcbaa 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -25,6 +25,7 @@ #include "core/session/IOBinding.h" #include "test/capturing_sink.h" #include "test/test_environment.h" +#include "test/providers/provider_test_utils.h" #include "test_utils.h" #include "gtest/gtest.h" @@ -808,6 +809,128 @@ TEST(InferenceSessionTests, ModelWithoutOpset) { } } +static ONNX_NAMESPACE::ModelProto CreateModelWithOptionalInputs() { + Model model("ModelWithOptionalInputs"); + auto& graph = model.MainGraph(); + + // create an initializer, which is an optional input that can be overridden + onnx::TensorProto tensor_proto; + tensor_proto.add_dims(1); + tensor_proto.set_data_type(TensorProto_DataType_FLOAT); + tensor_proto.add_float_data(1.f); + tensor_proto.set_name("optional_input"); + + graph.AddInitializedTensor(tensor_proto); + + TypeProto single_float; + single_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + single_float.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + auto& required_input = graph.GetOrCreateNodeArg("required_input", &single_float); + auto& optional_input = graph.GetOrCreateNodeArg("optional_input", nullptr); + auto& add_output = graph.GetOrCreateNodeArg("add_output", &single_float); + + EXPECT_TRUE(optional_input.Shape() != nullptr) << "AddInitializedTensor should have created the NodeArg with shape."; + + graph.AddNode("add", "Add", "Add required and optional inputs", {&required_input, &optional_input}, {&add_output}); + + auto status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + auto model_proto = model.ToProto(); + + return model_proto; +} + +static common::Status RunOptionalInputTest(bool add_required_input, + bool add_optional_input, + bool add_invalid_input) { + auto model_proto = CreateModelWithOptionalInputs(); + + SessionOptions so; + so.session_logid = "InferenceSessionTests.TestOptionalInputs"; + + InferenceSession session_object{so, &DefaultLoggingManager()}; + + std::stringstream s1; + model_proto.SerializeToOstream(&s1); + auto status = session_object.Load(s1); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + status = session_object.Initialize(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + RunOptions run_options; + run_options.run_tag = so.session_logid; + + // prepare inputs + std::vector dims = {1}; + std::vector required_input_val = {1.f}; + std::vector optional_input_val = {10.f}; // override initializer value of 1 + std::vector unknown_input_val = {20.f}; + + MLValue required_input_mlvalue; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, ONNXRuntimeMemTypeDefault), + dims, required_input_val, &required_input_mlvalue); + + MLValue optional_input_mlvalue; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, ONNXRuntimeMemTypeDefault), + dims, optional_input_val, &optional_input_mlvalue); + + MLValue unknown_input_mlvalue; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, ONNXRuntimeMemTypeDefault), + dims, unknown_input_val, &unknown_input_mlvalue); + + NameMLValMap feeds; + + if (add_required_input) + feeds.insert(std::make_pair("required_input", required_input_mlvalue)); + + if (add_optional_input) + feeds.insert(std::make_pair("optional_input", optional_input_mlvalue)); + + if (add_invalid_input) + feeds.insert(std::make_pair("unknown_input", unknown_input_mlvalue)); + + // prepare outputs + std::vector output_names; + output_names.push_back("add_output"); + std::vector fetches; + + float expected_value = required_input_val[0]; + expected_value += add_optional_input ? optional_input_val[0] : 1.f; + + status = session_object.Run(run_options, feeds, output_names, &fetches); + + if (status.IsOK()) { + MLValue& output = fetches.front(); + const auto& tensor = output.Get(); + float output_value = *tensor.Data(); + if (output_value != expected_value) { + status = ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output of ", output_value, " != ", expected_value); + } + } + + return status; +} + +TEST(InferenceSessionTests, TestOptionalInputs) { + // required input only + auto status = RunOptionalInputTest(true, false, false); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // required and optional input + status = RunOptionalInputTest(true, true, false); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // required, optional and invalid input + status = RunOptionalInputTest(true, true, true); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // missing required + status = RunOptionalInputTest(false, true, false); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); +} + TEST(ExecutionProviderTest, FunctionTest) { onnxruntime::Model model("graph_1"); auto& graph = model.MainGraph();