More code cleanup (#1405)

* More code cleanup

* More cleanup
This commit is contained in:
Yuan Yu 2019-07-17 14:45:50 -07:00 committed by Ke Zhang
parent a7b1a8969c
commit 93fb62bb3e
2 changed files with 9 additions and 54 deletions

View file

@ -572,34 +572,13 @@ common::Status InferenceSession::ValidateInputs(const std::vector<std::string>&
feeds.size(), " elements.");
}
std::unordered_set<std::string> seen_names;
seen_names.reserve(feeds.size());
size_t seen_required_inputs = 0;
const Graph& graph = model_->MainGraph();
for (size_t i = 0; i < feeds.size(); ++i) {
const auto& feed_name = feed_names[i];
if (seen_names.insert(feed_name).second == false) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Duplicate name in feeds: ", feed_name);
}
auto iter = input_def_map_.find(feed_name);
if (input_def_map_.end() == iter) {
// if IR < 4 all initializers are required to have a matching graph input with the same name,
// however we disallow using that graph input to override the initializer, and treat the initializers as constant.
// check for this and output a nicer error message if that is the case.
// As we've already moved all initializers to SessionState we need to check if it's in the constant initializers there
int idx;
bool is_constant_initializer = session_state_.GetOrtValueNameIdxMap().GetIdx(feed_name, idx).IsOK() &&
session_state_.GetConstantInitializedTensors().find(idx) !=
session_state_.GetConstantInitializedTensors().cend();
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid Feed Input Name:", feed_name,
is_constant_initializer ? ". Initializers may not be overridden by feeds"
" if model IR version is less than 4."
: ".");
"Invalid Feed Input Name:", feed_name);
}
auto expected_type = utils::GetMLDataType(*iter->second);
@ -612,31 +591,6 @@ common::Status InferenceSession::ValidateInputs(const std::vector<std::string>&
auto input_type = input_ml_value.Type();
ORT_RETURN_IF_ERROR(CheckTypes(input_type, expected_type));
}
if (!graph.CanOverrideInitializer() || // all entries in input_def_map_ are required.
required_inputs_.find(feed_name) != required_inputs_.cend()) {
++seen_required_inputs;
}
}
if (seen_required_inputs < required_inputs_.size()) {
std::ostringstream req_input_str;
auto cur = required_inputs_.cbegin(), end = required_inputs_.cend();
req_input_str << "Required inputs: ";
req_input_str << *(cur++);
while (cur != end) {
req_input_str << ", " << *(cur++);
}
req_input_str << " . Got: ";
for (size_t i = 0; i < feed_names.size(); ++i) {
if (i > 0)
req_input_str << ", ";
req_input_str << feed_names[i];
}
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "One or more missing required inputs. ",
req_input_str.str());
}
return Status::OK();

View file

@ -914,23 +914,24 @@ TEST(InferenceSessionTests, TestOptionalInputs) {
// required and optional input
status = RunOptionalInputTest(true, true, false, version);
if (version < 4) {
ASSERT_FALSE(status.IsOK());
EXPECT_THAT(status.ErrorMessage(),
testing::HasSubstr("Initializers may not be overridden by feeds if model IR version is less than 4"));
if (version == 3) {
ASSERT_FALSE(status.IsOK()) << status.ErrorMessage();
} else {
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
}
// required, optional and invalid input
status = RunOptionalInputTest(true, true, true, version);
ASSERT_FALSE(status.IsOK());
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Name"));
// missing required
status = RunOptionalInputTest(false, false, false, version);
status = RunOptionalInputTest(false, true, false, version);
ASSERT_FALSE(status.IsOK());
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("One or more missing required inputs"));
if (version == 3) {
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Name"));
} else {
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Missing Input:"));
}
}
}