From 4a0f6595eb38f4f9fca7c14548d999876fbc1282 Mon Sep 17 00:00:00 2001 From: Dudeldu Date: Thu, 3 Sep 2020 00:46:36 +0200 Subject: [PATCH] Enable metadata and signature changes in graph transformers (#4783) After applying all the graph transformations the metadata and signature could have changes (e.g.: new outputs got added, or the outputs/inputs got renamed). Therefore the local copies of metadata and signature, that InferenceSession administrated for faster lookup, has to be updated. For this the `SaveModelMetadata`, that now has to be idempotent, should be called after resolving the transformed graph --- onnxruntime/core/session/inference_session.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 6456d27dcf..0773a83230 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1038,6 +1038,9 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); + + // Update temporary copies of metadata, input- and output definitions to the same state as the resolved graph + ORT_RETURN_IF_ERROR_SESSIONID_(SaveModelMetadata(*model_)); #endif // !defined(ORT_MINIMAL_BUILD) // need to keep the initializers if we're going to save the optimized model @@ -1568,11 +1571,13 @@ common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& mod model_metadata_.custom_metadata_map = model.MetaData(); model_metadata_.graph_name = graph.Name(); + required_inputs_.clear(); for (auto input : graph.GetInputs()) { required_inputs_.insert(input->Name()); } auto add_inputs = [this](const InputDefList& inputs) { + input_def_map_.clear(); input_def_map_.reserve(inputs.size()); for (auto elem : inputs) { auto elem_type = utils::GetMLDataType(*elem); @@ -1598,6 +1603,8 @@ common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& mod // save outputs const auto& outputs = graph.GetOutputs(); output_def_list_ = outputs; // A direct copy of outputs + + model_output_names_.clear(); model_output_names_.reserve(outputs.size()); for (const auto& elem : outputs) { model_output_names_.insert(elem->Name());