mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
Initial changes.
Optional inputs aren't being handled properly in SaveInputOutputNamesToNodeMapping
This commit is contained in:
parent
2f234e4e78
commit
bfaade660b
3 changed files with 157 additions and 16 deletions
|
|
@ -115,7 +115,7 @@ void SessionState::AddInputNameToNodeInfoMapping(const std::string& input_name,
|
|||
|
||||
common::Status SessionState::GetInputNodeInfo(const std::string& input_name, std::vector<NodeInfo>& node_info_vec) const {
|
||||
if (!input_names_to_nodeinfo_mapping_.count(input_name)) {
|
||||
return Status(ONNXRUNTIME, FAIL, "Failed to find input name in the mapping");
|
||||
return Status(ONNXRUNTIME, FAIL, "Failed to find input name in the mapping: " + input_name);
|
||||
}
|
||||
node_info_vec = input_names_to_nodeinfo_mapping_.at(input_name);
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -421,10 +421,21 @@ class InferenceSession::Impl {
|
|||
}
|
||||
|
||||
common::Status ValidateInputNames(const NameMLValMap& feeds) {
|
||||
if (model_input_names_.size() != feeds.size()) {
|
||||
std::string missing_required_inputs;
|
||||
|
||||
std::for_each(required_model_input_names_.cbegin(), required_model_input_names_.cend(),
|
||||
[&](const std::string& required_input) {
|
||||
if (feeds.find(required_input) == feeds.cend()) {
|
||||
if (!missing_required_inputs.empty())
|
||||
missing_required_inputs += ",";
|
||||
|
||||
missing_required_inputs += required_input;
|
||||
}
|
||||
});
|
||||
|
||||
if (!missing_required_inputs.empty()) {
|
||||
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"The number of feeds is not same as the number of the model input, expect ",
|
||||
model_input_names_.size(), " got ", feeds.size());
|
||||
"Missing required inputs: ", missing_required_inputs);
|
||||
}
|
||||
|
||||
bool valid = true;
|
||||
|
|
@ -804,7 +815,7 @@ class InferenceSession::Impl {
|
|||
}
|
||||
}
|
||||
|
||||
return std::make_pair(common::Status::OK(), &input_def_list_);
|
||||
return std::make_pair(common::Status::OK(), &required_input_def_list_);
|
||||
}
|
||||
|
||||
std::pair<common::Status, const OutputDefList*> GetModelOutputs() const {
|
||||
|
|
@ -896,28 +907,33 @@ class InferenceSession::Impl {
|
|||
model_metadata_.custom_metadata_map = model.MetaData();
|
||||
model_metadata_.graph_name = graph.Name();
|
||||
|
||||
// save inputs
|
||||
auto& inputs = graph.GetInputs(); // inputs excluding initializers
|
||||
input_def_list_.reserve(inputs.size());
|
||||
for (const auto& elem : inputs) {
|
||||
if (!elem) {
|
||||
return common::Status(common::ONNXRUNTIME, common::FAIL, "Got null input nodearg ptr");
|
||||
}
|
||||
// save required inputs
|
||||
const auto& required_inputs = graph.GetInputs(); // inputs excluding initializers
|
||||
required_input_def_list_.reserve(required_inputs.size());
|
||||
required_model_input_names_.reserve(required_inputs.size());
|
||||
for (const auto& elem : required_inputs) {
|
||||
required_input_def_list_.push_back(elem);
|
||||
required_model_input_names_.insert(elem->Name());
|
||||
}
|
||||
|
||||
// save all valid inputs
|
||||
const auto& all_inputs = graph.GetInputsIncludingInitializers();
|
||||
input_def_list_.reserve(all_inputs.size());
|
||||
model_input_names_.reserve(all_inputs.size());
|
||||
for (const auto& elem : all_inputs) {
|
||||
input_def_list_.push_back(elem);
|
||||
model_input_names_.insert(elem->Name());
|
||||
}
|
||||
|
||||
// save outputs
|
||||
auto& outputs = graph.GetOutputs();
|
||||
const auto& outputs = graph.GetOutputs();
|
||||
output_def_list_.reserve(outputs.size());
|
||||
model_output_names_.reserve(outputs.size());
|
||||
for (const auto& elem : outputs) {
|
||||
if (!elem) {
|
||||
return common::Status(common::ONNXRUNTIME, common::FAIL, "Got null output nodearg ptr");
|
||||
}
|
||||
output_def_list_.push_back(elem);
|
||||
model_output_names_.insert(elem->Name());
|
||||
}
|
||||
|
||||
VLOGS(*session_logger_, 1) << "Done saving model metadata";
|
||||
return common::Status::OK();
|
||||
}
|
||||
|
|
@ -1030,10 +1046,12 @@ class InferenceSession::Impl {
|
|||
SessionState session_state_;
|
||||
|
||||
ModelMetadata model_metadata_;
|
||||
InputDefList required_input_def_list_;
|
||||
InputDefList input_def_list_;
|
||||
OutputDefList output_def_list_;
|
||||
|
||||
// names of model inputs and outputs used for quick validation.
|
||||
std::unordered_set<std::string> required_model_input_names_;
|
||||
std::unordered_set<std::string> model_input_names_;
|
||||
std::unordered_set<std::string> model_output_names_;
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@
|
|||
#include "core/session/IOBinding.h"
|
||||
#include "test/capturing_sink.h"
|
||||
#include "test/test_environment.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test_utils.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
|
|
@ -808,6 +809,128 @@ TEST(InferenceSessionTests, ModelWithoutOpset) {
|
|||
}
|
||||
}
|
||||
|
||||
static ONNX_NAMESPACE::ModelProto CreateModelWithOptionalInputs() {
|
||||
Model model("ModelWithOptionalInputs");
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
// create an initializer, which is an optional input that can be overridden
|
||||
onnx::TensorProto tensor_proto;
|
||||
tensor_proto.add_dims(1);
|
||||
tensor_proto.set_data_type(TensorProto_DataType_FLOAT);
|
||||
tensor_proto.add_float_data(1.f);
|
||||
tensor_proto.set_name("optional_input");
|
||||
|
||||
graph.AddInitializedTensor(tensor_proto);
|
||||
|
||||
TypeProto single_float;
|
||||
single_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
single_float.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
auto& required_input = graph.GetOrCreateNodeArg("required_input", &single_float);
|
||||
auto& optional_input = graph.GetOrCreateNodeArg("optional_input", nullptr);
|
||||
auto& add_output = graph.GetOrCreateNodeArg("add_output", &single_float);
|
||||
|
||||
EXPECT_TRUE(optional_input.Shape() != nullptr) << "AddInitializedTensor should have created the NodeArg with shape.";
|
||||
|
||||
graph.AddNode("add", "Add", "Add required and optional inputs", {&required_input, &optional_input}, {&add_output});
|
||||
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
auto model_proto = model.ToProto();
|
||||
|
||||
return model_proto;
|
||||
}
|
||||
|
||||
static common::Status RunOptionalInputTest(bool add_required_input,
|
||||
bool add_optional_input,
|
||||
bool add_invalid_input) {
|
||||
auto model_proto = CreateModelWithOptionalInputs();
|
||||
|
||||
SessionOptions so;
|
||||
so.session_logid = "InferenceSessionTests.TestOptionalInputs";
|
||||
|
||||
InferenceSession session_object{so, &DefaultLoggingManager()};
|
||||
|
||||
std::stringstream s1;
|
||||
model_proto.SerializeToOstream(&s1);
|
||||
auto status = session_object.Load(s1);
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
status = session_object.Initialize();
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
RunOptions run_options;
|
||||
run_options.run_tag = so.session_logid;
|
||||
|
||||
// prepare inputs
|
||||
std::vector<int64_t> dims = {1};
|
||||
std::vector<float> required_input_val = {1.f};
|
||||
std::vector<float> optional_input_val = {10.f}; // override initializer value of 1
|
||||
std::vector<float> unknown_input_val = {20.f};
|
||||
|
||||
MLValue required_input_mlvalue;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, ONNXRuntimeMemTypeDefault),
|
||||
dims, required_input_val, &required_input_mlvalue);
|
||||
|
||||
MLValue optional_input_mlvalue;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, ONNXRuntimeMemTypeDefault),
|
||||
dims, optional_input_val, &optional_input_mlvalue);
|
||||
|
||||
MLValue unknown_input_mlvalue;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, ONNXRuntimeMemTypeDefault),
|
||||
dims, unknown_input_val, &unknown_input_mlvalue);
|
||||
|
||||
NameMLValMap feeds;
|
||||
|
||||
if (add_required_input)
|
||||
feeds.insert(std::make_pair("required_input", required_input_mlvalue));
|
||||
|
||||
if (add_optional_input)
|
||||
feeds.insert(std::make_pair("optional_input", optional_input_mlvalue));
|
||||
|
||||
if (add_invalid_input)
|
||||
feeds.insert(std::make_pair("unknown_input", unknown_input_mlvalue));
|
||||
|
||||
// prepare outputs
|
||||
std::vector<std::string> output_names;
|
||||
output_names.push_back("add_output");
|
||||
std::vector<MLValue> fetches;
|
||||
|
||||
float expected_value = required_input_val[0];
|
||||
expected_value += add_optional_input ? optional_input_val[0] : 1.f;
|
||||
|
||||
status = session_object.Run(run_options, feeds, output_names, &fetches);
|
||||
|
||||
if (status.IsOK()) {
|
||||
MLValue& output = fetches.front();
|
||||
const auto& tensor = output.Get<Tensor>();
|
||||
float output_value = *tensor.Data<float>();
|
||||
if (output_value != expected_value) {
|
||||
status = ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output of ", output_value, " != ", expected_value);
|
||||
}
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
TEST(InferenceSessionTests, TestOptionalInputs) {
|
||||
// required input only
|
||||
auto status = RunOptionalInputTest(true, false, false);
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
// required and optional input
|
||||
status = RunOptionalInputTest(true, true, false);
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
// required, optional and invalid input
|
||||
status = RunOptionalInputTest(true, true, true);
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
// missing required
|
||||
status = RunOptionalInputTest(false, true, false);
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
}
|
||||
|
||||
TEST(ExecutionProviderTest, FunctionTest) {
|
||||
onnxruntime::Model model("graph_1");
|
||||
auto& graph = model.MainGraph();
|
||||
|
|
|
|||
Loading…
Reference in a new issue