mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Update session state initializer to support overriding initializers.
Update test
This commit is contained in:
parent
bfaade660b
commit
989b00321e
3 changed files with 10 additions and 8 deletions
|
|
@ -486,8 +486,7 @@ static bool IsArgNameInInputsOutputs(const std::string& name,
|
|||
common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph,
|
||||
const KernelRegistryManager& custom_registry_manager,
|
||||
SessionState& session_state) {
|
||||
auto& weights_map = graph.GetAllInitializedTensors();
|
||||
auto& graph_inputs = graph.GetInputs();
|
||||
auto& graph_inputs = graph.GetInputsIncludingInitializers();
|
||||
auto& graph_outputs = graph.GetOutputs();
|
||||
|
||||
for (auto& node : graph.Nodes()) {
|
||||
|
|
@ -495,7 +494,7 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph
|
|||
onnxruntime::Node::ForEachWithIndex(
|
||||
node.InputDefs(),
|
||||
[&](const onnxruntime::NodeArg& arg, size_t index) {
|
||||
if (arg.Name().empty() || weights_map.count(arg.Name())) {
|
||||
if (arg.Name().empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -454,9 +454,9 @@ class InferenceSession::Impl {
|
|||
[&ostr](const std::string& elem) {
|
||||
ostr << elem << " ";
|
||||
});
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"Invalid Feed Input Names:" + invalid_names.str() +
|
||||
" Valid input names are: " + ostr.str());
|
||||
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Invalid Feed Input Names:", invalid_names.str(),
|
||||
". Valid input names are: ", ostr.str());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -924,11 +924,14 @@ TEST(InferenceSessionTests, TestOptionalInputs) {
|
|||
|
||||
// required, optional and invalid input
|
||||
status = RunOptionalInputTest(true, true, true);
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
ASSERT_FALSE(status.IsOK());
|
||||
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Names: unknown_input"));
|
||||
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Valid input names are: required_input optional_input"));
|
||||
|
||||
// missing required
|
||||
status = RunOptionalInputTest(false, true, false);
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
ASSERT_FALSE(status.IsOK());
|
||||
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Missing required inputs: required_input"));
|
||||
}
|
||||
|
||||
TEST(ExecutionProviderTest, FunctionTest) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue