diff --git a/onnxruntime/core/flatbuffers/ort_format_version.h b/onnxruntime/core/flatbuffers/ort_format_version.h index c6c0ad7c20..61e0afb982 100644 --- a/onnxruntime/core/flatbuffers/ort_format_version.h +++ b/onnxruntime/core/flatbuffers/ort_format_version.h @@ -20,10 +20,10 @@ namespace onnxruntime { // Version 3 - add `graph_doc_string` to Model // Version 4 - update kernel def hashing to not depend on ordering of type constraint types (NOT BACKWARDS COMPATIBLE) // Version 5 - deprecate kernel def hashes and add KernelTypeStrResolver info to replace them (NOT BACKWARDS COMPATIBLE) -constexpr const char* kOrtModelVersion = "5"; +constexpr const int kOrtModelVersion = 5; // Check if the given ort model version is supported in this build -inline bool IsOrtModelVersionSupported(std::string_view ort_model_version) { +inline bool IsOrtModelVersionSupported(const int ort_model_version) { // The ort model versions we will support in this build // This may contain more versions than the kOrtModelVersion, based on the compatibilities constexpr std::array kSupportedOrtModelVersions{ diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 75b3749997..7b046dcfcb 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -79,13 +79,12 @@ Model::Model(const std::string& graph_name, p_domain_to_version = &domain_to_version_static; } - for (const auto& domain : *p_domain_to_version) { - model_load_utils::ValidateOpsetForDomain( - domain_to_version_static, logger, allow_released_opsets_only_final, - domain.first, domain.second); + for (const auto& [domain, version] : *p_domain_to_version) { + model_load_utils::ValidateOpsetForDomain(domain_to_version_static, logger, allow_released_opsets_only_final, + domain, version); const gsl::not_null opset_id_proto{model_proto_.add_opset_import()}; - opset_id_proto->set_domain(domain.first); - opset_id_proto->set_version(domain.second); + opset_id_proto->set_domain(domain); + opset_id_proto->set_version(version); } for (auto& func : model_local_functions) { @@ -202,12 +201,12 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, auto domain_map = allow_official_onnx_release_only_final ? schema_registry->GetLastReleasedOpsetVersions(false) : schema_registry->GetLatestOpsetVersions(false); - for (const auto& domain : domain_map) { - if (domain_to_version.find(domain.first) == domain_to_version.end()) { - domain_to_version[domain.first] = domain.second; + for (const auto& [domain, version] : domain_map) { + if (domain_to_version.find(domain) == domain_to_version.end()) { + domain_to_version[domain] = version; const gsl::not_null opset_id_proto{model_proto_.add_opset_import()}; - opset_id_proto->set_domain(domain.first); - opset_id_proto->set_version(domain.second); + opset_id_proto->set_domain(domain); + opset_id_proto->set_version(version); } } @@ -230,12 +229,14 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, func_template_ptr->op_schema_ = std::move(func_schema_ptr); func_template_ptr->onnx_func_proto_ = &func; model_local_function_templates_.push_back(std::move(func_template_ptr)); - model_local_function_templates_maps_[function_utils::GetFunctionIdentifier(func.domain(), func.name())] = model_local_function_templates_.back().get(); + model_local_function_templates_maps_[function_utils::GetFunctionIdentifier(func.domain(), func.name())] = + model_local_function_templates_.back().get(); } // create instance. need to call private ctor so can't use make_unique GSL_SUPPRESS(r .11) - graph_.reset(new Graph(*this, model_proto_.mutable_graph(), domain_to_version, IrVersion(), schema_registry, logger, options.strict_shape_type_inference)); + graph_.reset(new Graph(*this, model_proto_.mutable_graph(), domain_to_version, IrVersion(), schema_registry, + logger, options.strict_shape_type_inference)); } const InlinedHashMap& Model::GetModelLocalFunctionTemplates() const { @@ -828,6 +829,14 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model, ORT_RETURN_IF(nullptr == fbs_graph, "Graph is null. Invalid ORT format model."); #if !defined(ORT_MINIMAL_BUILD) + // add the opset imports to the model_proto in case we're updating an ORT format model and need those to be + // included when SaveToOrtFormat is called later + for (const auto& [domain, version] : domain_to_version) { + const gsl::not_null opset_id_proto{model->model_proto_.add_opset_import()}; + opset_id_proto->set_domain(domain); + opset_id_proto->set_version(version); + } + ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version, schema_registry, can_use_flatbuffer_for_initializers, logger, model->graph_)); #else diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 89315c0e7c..097a812b3a 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -604,7 +604,7 @@ common::Status InferenceSession::SaveToOrtFormat(const PathString& filepath) con fbs_buffer_size = ((fbs_buffer_size + m_bytes - 1) / m_bytes) * m_bytes; flatbuffers::FlatBufferBuilder builder(fbs_buffer_size); - auto ort_model_version = builder.CreateString(kOrtModelVersion); + auto ort_model_version = builder.CreateString(std::to_string(kOrtModelVersion)); flatbuffers::Offset fbs_model; ORT_RETURN_IF_ERROR( model_->SaveToOrtFormat(builder, fbs_model)); @@ -1024,16 +1024,30 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function load_ort const auto* fbs_ort_model_version = fbs_session->ort_version(); ORT_RETURN_IF(fbs_ort_model_version == nullptr, "Serialized version info is null. Invalid ORT format model."); + auto model_version = std::stoi(fbs_ort_model_version->str()); + bool is_supported = IsOrtModelVersionSupported(model_version); + +#if defined(ORT_MINIMAL_BUILD) // Note about the ORT format version 5 breaking change. // TODO This change was introduced in 1.13. Remove this note a few releases later, e.g., 1.15. constexpr auto* kOrtFormatVersion5BreakingChangeNote = "This build doesn't support ORT format models older than version 5. " "See: https://github.com/microsoft/onnxruntime/blob/rel-1.13.0/docs/ORT_Format_Update_in_1.13.md"; - ORT_RETURN_IF_NOT(IsOrtModelVersionSupported(fbs_ort_model_version->string_view()), + ORT_RETURN_IF(!is_supported, + "The ORT format model version [", fbs_ort_model_version->string_view(), + "] is not supported in this build ", ORT_VERSION, ". ", + kOrtFormatVersion5BreakingChangeNote); +#else + // models prior to v5 can be handled by inserting the kernel constraints in a full build + bool is_supported_with_update = !is_supported && model_version < 5; + + // currently this means the model is using a future version + // i.e. attempted load of model created with future version of ORT + ORT_RETURN_IF_NOT(is_supported || is_supported_with_update, "The ORT format model version [", fbs_ort_model_version->string_view(), - "] is not supported in this build ", ORT_VERSION, ". ", - kOrtFormatVersion5BreakingChangeNote); + "] is not supported in this build ", ORT_VERSION, ". "); +#endif const auto* fbs_model = fbs_session->model(); ORT_RETURN_IF(nullptr == fbs_model, "Missing Model. Invalid ORT format model."); @@ -1066,7 +1080,15 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function load_ort if (const auto* fbs_kernel_type_str_resolver = fbs_session->kernel_type_str_resolver(); fbs_kernel_type_str_resolver != nullptr) { ORT_RETURN_IF_ERROR(kernel_type_str_resolver.LoadFromOrtFormat(*fbs_kernel_type_str_resolver)); + } else { +#if !defined(ORT_MINIMAL_BUILD) + // insert the kernel type constraints if we're updating an old model that had kernel hashes. + if (is_supported_with_update) { + ORT_RETURN_IF_ERROR(kernel_type_str_resolver.RegisterGraphNodeOpSchemas(model_->MainGraph())); + } +#endif } + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) ORT_RETURN_IF_ERROR( kernel_type_str_resolver_utils::AddLayoutTransformationRequiredOpsToKernelTypeStrResolver( diff --git a/onnxruntime/test/framework/ort_model_only_test.cc b/onnxruntime/test/framework/ort_model_only_test.cc index 0898895be9..595f9fb821 100644 --- a/onnxruntime/test/framework/ort_model_only_test.cc +++ b/onnxruntime/test/framework/ort_model_only_test.cc @@ -1,24 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/flatbuffers/schema/ort.fbs.h" #include "core/framework/data_types.h" #include "core/framework/tensorprotoutils.h" +#include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" +#include "core/session/onnxruntime_cxx_api.h" #include "core/session/inference_session.h" #include "core/session/onnxruntime_session_options_config_keys.h" -#include "core/graph/model.h" -#include "test/test_environment.h" #include "test_utils.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/test_environment.h" #include "test/util/include/asserts.h" #include "test/util/include/inference_session_wrapper.h" -#include "core/flatbuffers/schema/ort.fbs.h" + #include "flatbuffers/idl.h" #include "flatbuffers/util.h" -#include "core/session/onnxruntime_cxx_api.h" - #include "gtest/gtest.h" -#include "test/util/include/asserts.h" using namespace std; using namespace ONNX_NAMESPACE; @@ -37,6 +37,7 @@ struct OrtModelTestInfo { bool run_use_buffer{false}; bool disable_copy_ort_buffer{false}; bool use_buffer_for_initializers{false}; + TransformerLevel optimization_level = TransformerLevel::Level3; }; static void RunOrtModel(const OrtModelTestInfo& test_info) { @@ -54,6 +55,8 @@ static void RunOrtModel(const OrtModelTestInfo& test_info) { } } + so.graph_optimization_level = test_info.optimization_level; + std::vector model_data; InferenceSessionWrapper session_object{so, GetEnvironment()}; if (test_info.run_use_buffer) { @@ -94,6 +97,7 @@ static void CompareTensors(const OrtValue& left_value, const OrtValue& right_val EXPECT_EQ(left_strings[i], right_strings[i]) << "Mismatch index:" << i; } } else { + ASSERT_EQ(memcmp(left.DataRaw(), right.DataRaw(), left.SizeInBytes()), 0); } } @@ -227,16 +231,20 @@ static void CompareSessionMetadata(const InferenceSessionWrapper& session_object ASSERT_EQ(model_1.ProducerVersion(), model_2.ProducerVersion()); } -static void SaveAndCompareModels(const std::string& onnx_file, const std::basic_string& ort_file) { +static void SaveAndCompareModels(const std::basic_string& orig_file, + const std::basic_string& ort_file, + TransformerLevel optimization_level = TransformerLevel::Level3) { SessionOptions so; so.session_logid = "SerializeToOrtFormat"; so.optimized_model_filepath = ort_file; + so.graph_optimization_level = optimization_level; + // not strictly necessary - type should be inferred from the filename ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsConfigSaveModelFormat, "ORT")); InferenceSessionWrapper session_object{so, GetEnvironment()}; // create .ort file during Initialize due to values in SessionOptions - ASSERT_STATUS_OK(session_object.Load(onnx_file)); + ASSERT_STATUS_OK(session_object.Load(orig_file)); ASSERT_STATUS_OK(session_object.Initialize()); SessionOptions so2; @@ -282,8 +290,8 @@ We could take steps to handle this scenario in a full build, but for consistency on any ORT format model. */ TEST(OrtModelOnlyTests, ValidateOrtFormatModelDoesNotRunOptimizersInFullBuild) { - const std::basic_string ort_file = ORT_TSTR("testdata/mnist.onnx.test_output.ort"); - SaveAndCompareModels("testdata/mnist.onnx", ort_file); + const auto ort_file = ORT_TSTR("testdata/mnist.onnx.test_output.ort"); + SaveAndCompareModels(ORT_TSTR("testdata/mnist.onnx"), ort_file); // DumpOrtModelAsJson(ToUTF8String(ort_file)); @@ -309,8 +317,8 @@ TEST(OrtModelOnlyTests, ValidateOrtFormatModelDoesNotRunOptimizersInFullBuild) { } TEST(OrtModelOnlyTests, SerializeToOrtFormat) { - const std::basic_string ort_file = ORT_TSTR("testdata/ort_github_issue_4031.onnx.test_output.ort"); - SaveAndCompareModels("testdata/ort_github_issue_4031.onnx", ort_file); + const auto ort_file = ORT_TSTR("testdata/ort_github_issue_4031.onnx.test_output.ort"); + SaveAndCompareModels(ORT_TSTR("testdata/ort_github_issue_4031.onnx"), ort_file); // DumpOrtModelAsJson(ToUTF8String(ort_file)); @@ -336,9 +344,8 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormat) { } TEST(OrtModelOnlyTests, SparseInitializerHandling) { - const std::basic_string ort_file = - ORT_TSTR("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx.test_output.ort"); - SaveAndCompareModels("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx", ort_file); + const auto ort_file = ORT_TSTR("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx.test_output.ort"); + SaveAndCompareModels(ORT_TSTR("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx"), ort_file); SessionOptions so; so.session_logid = "SparseInitializerHandling"; @@ -355,22 +362,75 @@ TEST(OrtModelOnlyTests, SparseInitializerHandling) { // regression test to make sure the model path is correctly passed through when serializing a tensor attribute TEST(OrtModelOnlyTests, TensorAttributeSerialization) { - const std::basic_string ort_file = - ORT_TSTR("testdata/ort_minimal_test_models/tensor_attribute.onnx.test_output.ort"); - SaveAndCompareModels("testdata/ort_minimal_test_models/tensor_attribute.onnx", ort_file); + const auto ort_file = ORT_TSTR("testdata/ort_minimal_test_models/tensor_attribute.onnx.test_output.ort"); + SaveAndCompareModels(ORT_TSTR("testdata/ort_minimal_test_models/tensor_attribute.onnx"), ort_file); } TEST(OrtModelOnlyTests, MetadataSerialization) { - const std::basic_string ort_file = - ORT_TSTR("testdata/model_with_metadata.onnx.test_output.ort"); - SaveAndCompareModels("testdata/model_with_metadata.onnx", ort_file); + const auto ort_file = ORT_TSTR("testdata/model_with_metadata.onnx.test_output.ort"); + SaveAndCompareModels(ORT_TSTR("testdata/model_with_metadata.onnx"), ort_file); +} + +// test we can load an old ORT format model and run it in a full build. +// we changed from using kernel hashes to kernel type constraints in v5, so any old model should be able to be loaded +// in a full build if we add the kernel type constraints during loading. this also means we can save the updated +// ORT format model to effectively upgrade it to v5. +TEST(OrtModelOnlyTests, UpdateOrtModelVersion) { + // input is ORT format model using v4 where we used kernel hashes instead of constraints + const auto onnx_file = ORT_TSTR("testdata/mnist.onnx"); + const auto ort_file_v4 = ORT_TSTR("testdata/mnist.basic.v4.ort"); + const auto ort_file_v5 = ORT_TSTR("testdata/mnist.basic.v5.test_output.ort"); + + // update v4 model and save as v5. do not run optimizations in order to preserve the model as-is. + SaveAndCompareModels(ort_file_v4, ort_file_v5, TransformerLevel::Default); + + // run the original, v4 and v5 models and check the output is the same + RandomValueGenerator random{}; + std::vector input_dims{1, 1, 28, 28}; + std::vector input_data = random.Gaussian(input_dims, 0.0f, 0.9f); + + OrtValue ml_value; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), + input_dims, input_data, &ml_value); + + OrtModelTestInfo test_info; + + // keep the onnx and ort models to the same optimization level + test_info.optimization_level = TransformerLevel::Level1; + + test_info.inputs.insert(std::make_pair("Input3", ml_value)); + test_info.output_names = {"Plus214_Output_0"}; + + OrtValue orig_out, v4_out, v5_out; + + test_info.model_filename = onnx_file; + test_info.output_verifier = [&orig_out](const std::vector& fetches) { + orig_out = fetches[0]; + }; + RunOrtModel(test_info); + + // run with v4 as input. this should also update to v5 prior to execution. + test_info.model_filename = ort_file_v4; + test_info.output_verifier = [&v4_out](const std::vector& fetches) { + v4_out = fetches[0]; + }; + RunOrtModel(test_info); + + // validate the model saved as v5 also works + test_info.model_filename = ort_file_v5; + test_info.output_verifier = [&v5_out](const std::vector& fetches) { + v5_out = fetches[0]; + }; + RunOrtModel(test_info); + + CompareTensors(orig_out, v4_out); + CompareTensors(v4_out, v5_out); } #if !defined(DISABLE_ML_OPS) TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) { - const std::basic_string ort_file = - ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft.onnx.test_output.ort"); - SaveAndCompareModels("testdata/sklearn_bin_voting_classifier_soft.onnx", ort_file); + const auto ort_file = ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft.onnx.test_output.ort"); + SaveAndCompareModels(ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft.onnx"), ort_file); OrtModelTestInfo test_info; test_info.model_filename = ort_file; @@ -414,7 +474,7 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) { // test loading ORT format model with sparse initializers TEST(OrtModelOnlyTests, LoadSparseInitializersOrtFormat) { - const std::basic_string ort_file = ORT_TSTR("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx.ort"); + const auto ort_file = ORT_TSTR("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx.ort"); SessionOptions so; so.session_logid = "LoadOrtFormat"; ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsConfigLoadModelFormat, "ORT")); @@ -537,5 +597,6 @@ TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOpsFromBufferNoCopy) { } #endif // !defined(DISABLE_ML_OPS) + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc b/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc index 037108b3eb..7558d49773 100644 --- a/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc +++ b/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "gtest/gtest.h" @@ -128,7 +129,7 @@ TEST(GraphRuntimeOptimizationTest, SaveRuntimeOptimizationToOrtFormat) { flatbuffers::Offset fbs_session_offset = fbs::CreateInferenceSessionDirect(builder, - kOrtModelVersion, + std::to_string(kOrtModelVersion).c_str(), fbs_model_offset, 0); diff --git a/onnxruntime/test/testdata/mnist.basic.v4.ort b/onnxruntime/test/testdata/mnist.basic.v4.ort new file mode 100644 index 0000000000..066e14dcab Binary files /dev/null and b/onnxruntime/test/testdata/mnist.basic.v4.ort differ