Update session state initializer to support overriding initializers.

Update test
This commit is contained in:
Scott McKay 2018-11-29 14:31:36 +10:00
parent bfaade660b
commit 989b00321e
3 changed files with 10 additions and 8 deletions

View file

@ -486,8 +486,7 @@ static bool IsArgNameInInputsOutputs(const std::string& name,
common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph,
const KernelRegistryManager& custom_registry_manager,
SessionState& session_state) {
auto& weights_map = graph.GetAllInitializedTensors();
auto& graph_inputs = graph.GetInputs();
auto& graph_inputs = graph.GetInputsIncludingInitializers();
auto& graph_outputs = graph.GetOutputs();
for (auto& node : graph.Nodes()) {
@ -495,7 +494,7 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph
onnxruntime::Node::ForEachWithIndex(
node.InputDefs(),
[&](const onnxruntime::NodeArg& arg, size_t index) {
if (arg.Name().empty() || weights_map.count(arg.Name())) {
if (arg.Name().empty()) {
return Status::OK();
}

View file

@ -454,9 +454,9 @@ class InferenceSession::Impl {
[&ostr](const std::string& elem) {
ostr << elem << " ";
});
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Invalid Feed Input Names:" + invalid_names.str() +
" Valid input names are: " + ostr.str());
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid Feed Input Names:", invalid_names.str(),
". Valid input names are: ", ostr.str());
}
return Status::OK();

View file

@ -924,11 +924,14 @@ TEST(InferenceSessionTests, TestOptionalInputs) {
// required, optional and invalid input
status = RunOptionalInputTest(true, true, true);
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
ASSERT_FALSE(status.IsOK());
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Names: unknown_input"));
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Valid input names are: required_input optional_input"));
// missing required
status = RunOptionalInputTest(false, true, false);
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
ASSERT_FALSE(status.IsOK());
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Missing required inputs: required_input"));
}
TEST(ExecutionProviderTest, FunctionTest) {