Add backwards compatibility for all versions of ORT format model in full build. (#13242)

### Description
<!-- Describe your changes. -->
Add ability to upgrade an ORT format model when loaded in a full build
by inserting the kernel constraint info and ignoring the kernel hashes.

This also allows upgrading the model to the latest format by saving the
model after loading.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
Provide official path to upgrading an ORT format model directly (vs.
reconverting).

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
This commit is contained in:
Scott McKay 2022-10-12 17:45:52 +10:00 committed by GitHub
parent 67bde18d0d
commit cbe4eb65b3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 138 additions and 45 deletions

View file

@ -20,10 +20,10 @@ namespace onnxruntime {
// Version 3 - add `graph_doc_string` to Model
// Version 4 - update kernel def hashing to not depend on ordering of type constraint types (NOT BACKWARDS COMPATIBLE)
// Version 5 - deprecate kernel def hashes and add KernelTypeStrResolver info to replace them (NOT BACKWARDS COMPATIBLE)
constexpr const char* kOrtModelVersion = "5";
constexpr const int kOrtModelVersion = 5;
// Check if the given ort model version is supported in this build
inline bool IsOrtModelVersionSupported(std::string_view ort_model_version) {
inline bool IsOrtModelVersionSupported(const int ort_model_version) {
// The ort model versions we will support in this build
// This may contain more versions than the kOrtModelVersion, based on the compatibilities
constexpr std::array kSupportedOrtModelVersions{

View file

@ -79,13 +79,12 @@ Model::Model(const std::string& graph_name,
p_domain_to_version = &domain_to_version_static;
}
for (const auto& domain : *p_domain_to_version) {
model_load_utils::ValidateOpsetForDomain(
domain_to_version_static, logger, allow_released_opsets_only_final,
domain.first, domain.second);
for (const auto& [domain, version] : *p_domain_to_version) {
model_load_utils::ValidateOpsetForDomain(domain_to_version_static, logger, allow_released_opsets_only_final,
domain, version);
const gsl::not_null<OperatorSetIdProto*> opset_id_proto{model_proto_.add_opset_import()};
opset_id_proto->set_domain(domain.first);
opset_id_proto->set_version(domain.second);
opset_id_proto->set_domain(domain);
opset_id_proto->set_version(version);
}
for (auto& func : model_local_functions) {
@ -202,12 +201,12 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path,
auto domain_map = allow_official_onnx_release_only_final
? schema_registry->GetLastReleasedOpsetVersions(false)
: schema_registry->GetLatestOpsetVersions(false);
for (const auto& domain : domain_map) {
if (domain_to_version.find(domain.first) == domain_to_version.end()) {
domain_to_version[domain.first] = domain.second;
for (const auto& [domain, version] : domain_map) {
if (domain_to_version.find(domain) == domain_to_version.end()) {
domain_to_version[domain] = version;
const gsl::not_null<OperatorSetIdProto*> opset_id_proto{model_proto_.add_opset_import()};
opset_id_proto->set_domain(domain.first);
opset_id_proto->set_version(domain.second);
opset_id_proto->set_domain(domain);
opset_id_proto->set_version(version);
}
}
@ -230,12 +229,14 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path,
func_template_ptr->op_schema_ = std::move(func_schema_ptr);
func_template_ptr->onnx_func_proto_ = &func;
model_local_function_templates_.push_back(std::move(func_template_ptr));
model_local_function_templates_maps_[function_utils::GetFunctionIdentifier(func.domain(), func.name())] = model_local_function_templates_.back().get();
model_local_function_templates_maps_[function_utils::GetFunctionIdentifier(func.domain(), func.name())] =
model_local_function_templates_.back().get();
}
// create instance. need to call private ctor so can't use make_unique
GSL_SUPPRESS(r .11)
graph_.reset(new Graph(*this, model_proto_.mutable_graph(), domain_to_version, IrVersion(), schema_registry, logger, options.strict_shape_type_inference));
graph_.reset(new Graph(*this, model_proto_.mutable_graph(), domain_to_version, IrVersion(), schema_registry,
logger, options.strict_shape_type_inference));
}
const InlinedHashMap<std::string, FunctionTemplate*>& Model::GetModelLocalFunctionTemplates() const {
@ -828,6 +829,14 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,
ORT_RETURN_IF(nullptr == fbs_graph, "Graph is null. Invalid ORT format model.");
#if !defined(ORT_MINIMAL_BUILD)
// add the opset imports to the model_proto in case we're updating an ORT format model and need those to be
// included when SaveToOrtFormat is called later
for (const auto& [domain, version] : domain_to_version) {
const gsl::not_null<OperatorSetIdProto*> opset_id_proto{model->model_proto_.add_opset_import()};
opset_id_proto->set_domain(domain);
opset_id_proto->set_version(version);
}
ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version, schema_registry,
can_use_flatbuffer_for_initializers, logger, model->graph_));
#else

View file

@ -604,7 +604,7 @@ common::Status InferenceSession::SaveToOrtFormat(const PathString& filepath) con
fbs_buffer_size = ((fbs_buffer_size + m_bytes - 1) / m_bytes) * m_bytes;
flatbuffers::FlatBufferBuilder builder(fbs_buffer_size);
auto ort_model_version = builder.CreateString(kOrtModelVersion);
auto ort_model_version = builder.CreateString(std::to_string(kOrtModelVersion));
flatbuffers::Offset<fbs::Model> fbs_model;
ORT_RETURN_IF_ERROR(
model_->SaveToOrtFormat(builder, fbs_model));
@ -1024,16 +1024,30 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function<Status()> load_ort
const auto* fbs_ort_model_version = fbs_session->ort_version();
ORT_RETURN_IF(fbs_ort_model_version == nullptr, "Serialized version info is null. Invalid ORT format model.");
auto model_version = std::stoi(fbs_ort_model_version->str());
bool is_supported = IsOrtModelVersionSupported(model_version);
#if defined(ORT_MINIMAL_BUILD)
// Note about the ORT format version 5 breaking change.
// TODO This change was introduced in 1.13. Remove this note a few releases later, e.g., 1.15.
constexpr auto* kOrtFormatVersion5BreakingChangeNote =
"This build doesn't support ORT format models older than version 5. "
"See: https://github.com/microsoft/onnxruntime/blob/rel-1.13.0/docs/ORT_Format_Update_in_1.13.md";
ORT_RETURN_IF_NOT(IsOrtModelVersionSupported(fbs_ort_model_version->string_view()),
ORT_RETURN_IF(!is_supported,
"The ORT format model version [", fbs_ort_model_version->string_view(),
"] is not supported in this build ", ORT_VERSION, ". ",
kOrtFormatVersion5BreakingChangeNote);
#else
// models prior to v5 can be handled by inserting the kernel constraints in a full build
bool is_supported_with_update = !is_supported && model_version < 5;
// currently this means the model is using a future version
// i.e. attempted load of model created with future version of ORT
ORT_RETURN_IF_NOT(is_supported || is_supported_with_update,
"The ORT format model version [", fbs_ort_model_version->string_view(),
"] is not supported in this build ", ORT_VERSION, ". ",
kOrtFormatVersion5BreakingChangeNote);
"] is not supported in this build ", ORT_VERSION, ". ");
#endif
const auto* fbs_model = fbs_session->model();
ORT_RETURN_IF(nullptr == fbs_model, "Missing Model. Invalid ORT format model.");
@ -1066,7 +1080,15 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function<Status()> load_ort
if (const auto* fbs_kernel_type_str_resolver = fbs_session->kernel_type_str_resolver();
fbs_kernel_type_str_resolver != nullptr) {
ORT_RETURN_IF_ERROR(kernel_type_str_resolver.LoadFromOrtFormat(*fbs_kernel_type_str_resolver));
} else {
#if !defined(ORT_MINIMAL_BUILD)
// insert the kernel type constraints if we're updating an old model that had kernel hashes.
if (is_supported_with_update) {
ORT_RETURN_IF_ERROR(kernel_type_str_resolver.RegisterGraphNodeOpSchemas(model_->MainGraph()));
}
#endif
}
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
ORT_RETURN_IF_ERROR(
kernel_type_str_resolver_utils::AddLayoutTransformationRequiredOpsToKernelTypeStrResolver(

View file

@ -1,24 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/flatbuffers/schema/ort.fbs.h"
#include "core/framework/data_types.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/model.h"
#include "core/graph/onnx_protobuf.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/inference_session.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/graph/model.h"
#include "test/test_environment.h"
#include "test_utils.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/test_environment.h"
#include "test/util/include/asserts.h"
#include "test/util/include/inference_session_wrapper.h"
#include "core/flatbuffers/schema/ort.fbs.h"
#include "flatbuffers/idl.h"
#include "flatbuffers/util.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "gtest/gtest.h"
#include "test/util/include/asserts.h"
using namespace std;
using namespace ONNX_NAMESPACE;
@ -37,6 +37,7 @@ struct OrtModelTestInfo {
bool run_use_buffer{false};
bool disable_copy_ort_buffer{false};
bool use_buffer_for_initializers{false};
TransformerLevel optimization_level = TransformerLevel::Level3;
};
static void RunOrtModel(const OrtModelTestInfo& test_info) {
@ -54,6 +55,8 @@ static void RunOrtModel(const OrtModelTestInfo& test_info) {
}
}
so.graph_optimization_level = test_info.optimization_level;
std::vector<char> model_data;
InferenceSessionWrapper session_object{so, GetEnvironment()};
if (test_info.run_use_buffer) {
@ -94,6 +97,7 @@ static void CompareTensors(const OrtValue& left_value, const OrtValue& right_val
EXPECT_EQ(left_strings[i], right_strings[i]) << "Mismatch index:" << i;
}
} else {
ASSERT_EQ(memcmp(left.DataRaw(), right.DataRaw(), left.SizeInBytes()), 0);
}
}
@ -227,16 +231,20 @@ static void CompareSessionMetadata(const InferenceSessionWrapper& session_object
ASSERT_EQ(model_1.ProducerVersion(), model_2.ProducerVersion());
}
static void SaveAndCompareModels(const std::string& onnx_file, const std::basic_string<ORTCHAR_T>& ort_file) {
static void SaveAndCompareModels(const std::basic_string<ORTCHAR_T>& orig_file,
const std::basic_string<ORTCHAR_T>& ort_file,
TransformerLevel optimization_level = TransformerLevel::Level3) {
SessionOptions so;
so.session_logid = "SerializeToOrtFormat";
so.optimized_model_filepath = ort_file;
so.graph_optimization_level = optimization_level;
// not strictly necessary - type should be inferred from the filename
ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsConfigSaveModelFormat, "ORT"));
InferenceSessionWrapper session_object{so, GetEnvironment()};
// create .ort file during Initialize due to values in SessionOptions
ASSERT_STATUS_OK(session_object.Load(onnx_file));
ASSERT_STATUS_OK(session_object.Load(orig_file));
ASSERT_STATUS_OK(session_object.Initialize());
SessionOptions so2;
@ -282,8 +290,8 @@ We could take steps to handle this scenario in a full build, but for consistency
on any ORT format model.
*/
TEST(OrtModelOnlyTests, ValidateOrtFormatModelDoesNotRunOptimizersInFullBuild) {
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("testdata/mnist.onnx.test_output.ort");
SaveAndCompareModels("testdata/mnist.onnx", ort_file);
const auto ort_file = ORT_TSTR("testdata/mnist.onnx.test_output.ort");
SaveAndCompareModels(ORT_TSTR("testdata/mnist.onnx"), ort_file);
// DumpOrtModelAsJson(ToUTF8String(ort_file));
@ -309,8 +317,8 @@ TEST(OrtModelOnlyTests, ValidateOrtFormatModelDoesNotRunOptimizersInFullBuild) {
}
TEST(OrtModelOnlyTests, SerializeToOrtFormat) {
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("testdata/ort_github_issue_4031.onnx.test_output.ort");
SaveAndCompareModels("testdata/ort_github_issue_4031.onnx", ort_file);
const auto ort_file = ORT_TSTR("testdata/ort_github_issue_4031.onnx.test_output.ort");
SaveAndCompareModels(ORT_TSTR("testdata/ort_github_issue_4031.onnx"), ort_file);
// DumpOrtModelAsJson(ToUTF8String(ort_file));
@ -336,9 +344,8 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormat) {
}
TEST(OrtModelOnlyTests, SparseInitializerHandling) {
const std::basic_string<ORTCHAR_T> ort_file =
ORT_TSTR("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx.test_output.ort");
SaveAndCompareModels("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx", ort_file);
const auto ort_file = ORT_TSTR("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx.test_output.ort");
SaveAndCompareModels(ORT_TSTR("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx"), ort_file);
SessionOptions so;
so.session_logid = "SparseInitializerHandling";
@ -355,22 +362,75 @@ TEST(OrtModelOnlyTests, SparseInitializerHandling) {
// regression test to make sure the model path is correctly passed through when serializing a tensor attribute
TEST(OrtModelOnlyTests, TensorAttributeSerialization) {
const std::basic_string<ORTCHAR_T> ort_file =
ORT_TSTR("testdata/ort_minimal_test_models/tensor_attribute.onnx.test_output.ort");
SaveAndCompareModels("testdata/ort_minimal_test_models/tensor_attribute.onnx", ort_file);
const auto ort_file = ORT_TSTR("testdata/ort_minimal_test_models/tensor_attribute.onnx.test_output.ort");
SaveAndCompareModels(ORT_TSTR("testdata/ort_minimal_test_models/tensor_attribute.onnx"), ort_file);
}
TEST(OrtModelOnlyTests, MetadataSerialization) {
const std::basic_string<ORTCHAR_T> ort_file =
ORT_TSTR("testdata/model_with_metadata.onnx.test_output.ort");
SaveAndCompareModels("testdata/model_with_metadata.onnx", ort_file);
const auto ort_file = ORT_TSTR("testdata/model_with_metadata.onnx.test_output.ort");
SaveAndCompareModels(ORT_TSTR("testdata/model_with_metadata.onnx"), ort_file);
}
// test we can load an old ORT format model and run it in a full build.
// we changed from using kernel hashes to kernel type constraints in v5, so any old model should be able to be loaded
// in a full build if we add the kernel type constraints during loading. this also means we can save the updated
// ORT format model to effectively upgrade it to v5.
TEST(OrtModelOnlyTests, UpdateOrtModelVersion) {
// input is ORT format model using v4 where we used kernel hashes instead of constraints
const auto onnx_file = ORT_TSTR("testdata/mnist.onnx");
const auto ort_file_v4 = ORT_TSTR("testdata/mnist.basic.v4.ort");
const auto ort_file_v5 = ORT_TSTR("testdata/mnist.basic.v5.test_output.ort");
// update v4 model and save as v5. do not run optimizations in order to preserve the model as-is.
SaveAndCompareModels(ort_file_v4, ort_file_v5, TransformerLevel::Default);
// run the original, v4 and v5 models and check the output is the same
RandomValueGenerator random{};
std::vector<int64_t> input_dims{1, 1, 28, 28};
std::vector<float> input_data = random.Gaussian<float>(input_dims, 0.0f, 0.9f);
OrtValue ml_value;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault),
input_dims, input_data, &ml_value);
OrtModelTestInfo test_info;
// keep the onnx and ort models to the same optimization level
test_info.optimization_level = TransformerLevel::Level1;
test_info.inputs.insert(std::make_pair("Input3", ml_value));
test_info.output_names = {"Plus214_Output_0"};
OrtValue orig_out, v4_out, v5_out;
test_info.model_filename = onnx_file;
test_info.output_verifier = [&orig_out](const std::vector<OrtValue>& fetches) {
orig_out = fetches[0];
};
RunOrtModel(test_info);
// run with v4 as input. this should also update to v5 prior to execution.
test_info.model_filename = ort_file_v4;
test_info.output_verifier = [&v4_out](const std::vector<OrtValue>& fetches) {
v4_out = fetches[0];
};
RunOrtModel(test_info);
// validate the model saved as v5 also works
test_info.model_filename = ort_file_v5;
test_info.output_verifier = [&v5_out](const std::vector<OrtValue>& fetches) {
v5_out = fetches[0];
};
RunOrtModel(test_info);
CompareTensors(orig_out, v4_out);
CompareTensors(v4_out, v5_out);
}
#if !defined(DISABLE_ML_OPS)
TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) {
const std::basic_string<ORTCHAR_T> ort_file =
ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft.onnx.test_output.ort");
SaveAndCompareModels("testdata/sklearn_bin_voting_classifier_soft.onnx", ort_file);
const auto ort_file = ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft.onnx.test_output.ort");
SaveAndCompareModels(ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft.onnx"), ort_file);
OrtModelTestInfo test_info;
test_info.model_filename = ort_file;
@ -414,7 +474,7 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) {
// test loading ORT format model with sparse initializers
TEST(OrtModelOnlyTests, LoadSparseInitializersOrtFormat) {
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx.ort");
const auto ort_file = ORT_TSTR("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx.ort");
SessionOptions so;
so.session_logid = "LoadOrtFormat";
ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsConfigLoadModelFormat, "ORT"));
@ -537,5 +597,6 @@ TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOpsFromBufferNoCopy) {
}
#endif // !defined(DISABLE_ML_OPS)
} // namespace test
} // namespace onnxruntime

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include <algorithm>
#include <string>
#include "gtest/gtest.h"
@ -128,7 +129,7 @@ TEST(GraphRuntimeOptimizationTest, SaveRuntimeOptimizationToOrtFormat) {
flatbuffers::Offset<fbs::InferenceSession> fbs_session_offset =
fbs::CreateInferenceSessionDirect(builder,
kOrtModelVersion,
std::to_string(kOrtModelVersion).c_str(),
fbs_model_offset,
0);

Binary file not shown.