mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
parent
a7b1a8969c
commit
93fb62bb3e
2 changed files with 9 additions and 54 deletions
|
|
@ -572,34 +572,13 @@ common::Status InferenceSession::ValidateInputs(const std::vector<std::string>&
|
|||
feeds.size(), " elements.");
|
||||
}
|
||||
|
||||
std::unordered_set<std::string> seen_names;
|
||||
seen_names.reserve(feeds.size());
|
||||
size_t seen_required_inputs = 0;
|
||||
const Graph& graph = model_->MainGraph();
|
||||
|
||||
for (size_t i = 0; i < feeds.size(); ++i) {
|
||||
const auto& feed_name = feed_names[i];
|
||||
|
||||
if (seen_names.insert(feed_name).second == false) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Duplicate name in feeds: ", feed_name);
|
||||
}
|
||||
|
||||
auto iter = input_def_map_.find(feed_name);
|
||||
if (input_def_map_.end() == iter) {
|
||||
// if IR < 4 all initializers are required to have a matching graph input with the same name,
|
||||
// however we disallow using that graph input to override the initializer, and treat the initializers as constant.
|
||||
// check for this and output a nicer error message if that is the case.
|
||||
// As we've already moved all initializers to SessionState we need to check if it's in the constant initializers there
|
||||
int idx;
|
||||
bool is_constant_initializer = session_state_.GetOrtValueNameIdxMap().GetIdx(feed_name, idx).IsOK() &&
|
||||
session_state_.GetConstantInitializedTensors().find(idx) !=
|
||||
session_state_.GetConstantInitializedTensors().cend();
|
||||
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Invalid Feed Input Name:", feed_name,
|
||||
is_constant_initializer ? ". Initializers may not be overridden by feeds"
|
||||
" if model IR version is less than 4."
|
||||
: ".");
|
||||
"Invalid Feed Input Name:", feed_name);
|
||||
}
|
||||
|
||||
auto expected_type = utils::GetMLDataType(*iter->second);
|
||||
|
|
@ -612,31 +591,6 @@ common::Status InferenceSession::ValidateInputs(const std::vector<std::string>&
|
|||
auto input_type = input_ml_value.Type();
|
||||
ORT_RETURN_IF_ERROR(CheckTypes(input_type, expected_type));
|
||||
}
|
||||
|
||||
if (!graph.CanOverrideInitializer() || // all entries in input_def_map_ are required.
|
||||
required_inputs_.find(feed_name) != required_inputs_.cend()) {
|
||||
++seen_required_inputs;
|
||||
}
|
||||
}
|
||||
|
||||
if (seen_required_inputs < required_inputs_.size()) {
|
||||
std::ostringstream req_input_str;
|
||||
auto cur = required_inputs_.cbegin(), end = required_inputs_.cend();
|
||||
req_input_str << "Required inputs: ";
|
||||
req_input_str << *(cur++);
|
||||
while (cur != end) {
|
||||
req_input_str << ", " << *(cur++);
|
||||
}
|
||||
|
||||
req_input_str << " . Got: ";
|
||||
for (size_t i = 0; i < feed_names.size(); ++i) {
|
||||
if (i > 0)
|
||||
req_input_str << ", ";
|
||||
req_input_str << feed_names[i];
|
||||
}
|
||||
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "One or more missing required inputs. ",
|
||||
req_input_str.str());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -914,23 +914,24 @@ TEST(InferenceSessionTests, TestOptionalInputs) {
|
|||
|
||||
// required and optional input
|
||||
status = RunOptionalInputTest(true, true, false, version);
|
||||
if (version < 4) {
|
||||
ASSERT_FALSE(status.IsOK());
|
||||
EXPECT_THAT(status.ErrorMessage(),
|
||||
testing::HasSubstr("Initializers may not be overridden by feeds if model IR version is less than 4"));
|
||||
if (version == 3) {
|
||||
ASSERT_FALSE(status.IsOK()) << status.ErrorMessage();
|
||||
} else {
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
}
|
||||
|
||||
// required, optional and invalid input
|
||||
status = RunOptionalInputTest(true, true, true, version);
|
||||
ASSERT_FALSE(status.IsOK());
|
||||
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Name"));
|
||||
|
||||
// missing required
|
||||
status = RunOptionalInputTest(false, false, false, version);
|
||||
status = RunOptionalInputTest(false, true, false, version);
|
||||
ASSERT_FALSE(status.IsOK());
|
||||
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("One or more missing required inputs"));
|
||||
if (version == 3) {
|
||||
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Name"));
|
||||
} else {
|
||||
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Missing Input:"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue