Enable metadata and signature changes in graph transformers (#4783)

After applying all the graph transformations the metadata and signature could have changes
(e.g.: new outputs got added, or the outputs/inputs got renamed). Therefore the local
copies of metadata and signature, that InferenceSession administrated for faster lookup, has to be updated.
For this the `SaveModelMetadata`, that now has to be idempotent, should be called after resolving the transformed graph
This commit is contained in:
Dudeldu 2020-09-03 00:46:36 +02:00 committed by GitHub
parent 4fd4b74149
commit 4a0f6595eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1038,6 +1038,9 @@ common::Status InferenceSession::Initialize() {
// now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve());
// Update temporary copies of metadata, input- and output definitions to the same state as the resolved graph
ORT_RETURN_IF_ERROR_SESSIONID_(SaveModelMetadata(*model_));
#endif // !defined(ORT_MINIMAL_BUILD)
// need to keep the initializers if we're going to save the optimized model
@ -1568,11 +1571,13 @@ common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& mod
model_metadata_.custom_metadata_map = model.MetaData();
model_metadata_.graph_name = graph.Name();
required_inputs_.clear();
for (auto input : graph.GetInputs()) {
required_inputs_.insert(input->Name());
}
auto add_inputs = [this](const InputDefList& inputs) {
input_def_map_.clear();
input_def_map_.reserve(inputs.size());
for (auto elem : inputs) {
auto elem_type = utils::GetMLDataType(*elem);
@ -1598,6 +1603,8 @@ common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& mod
// save outputs
const auto& outputs = graph.GetOutputs();
output_def_list_ = outputs; // A direct copy of outputs
model_output_names_.clear();
model_output_names_.reserve(outputs.size());
for (const auto& elem : outputs) {
model_output_names_.insert(elem->Name());