diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 7369925e9d..ae902322e5 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -38,7 +38,7 @@ #include "core/platform/threadpool.h" #include "core/providers/cpu/controlflow/utils.h" #include "core/providers/cpu/cpu_execution_provider.h" -#include "core/flatbuffers/flatbuffers_utils.h" +#include "core/flatbuffers/ort.fbs.h" #ifdef USE_DML // TODO: This is necessary for the workaround in TransformGraph #include "core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h" #endif @@ -486,7 +486,7 @@ common::Status InferenceSession::SaveToOrtFormat(const std::basic_string load_ort_format_mo } ORT_RETURN_IF_ERROR(load_ort_format_model_bytes()); - - // Verify the ort_format_model_bytes_ is a valid InferenceSessionBuffer before we access the data - flatbuffers::Verifier verifier(ort_format_model_bytes_.data(), ort_format_model_bytes_.size()); - ORT_RETURN_IF_NOT(fbs::VerifyInferenceSessionBuffer(verifier)); - const auto* fbs_session = fbs::GetInferenceSession(ort_format_model_bytes_.data()); ORT_RETURN_IF(nullptr == fbs_session, "InferenceSession is null. Invalid ORT format model."); @@ -1151,7 +1146,7 @@ common::Status InferenceSession::Initialize() { if ((has_explicit_type && model_type == "ORT") || (!has_explicit_type && - experimental::utils::IsOrtFormatModel(session_options_.optimized_model_filepath))) { + inference_session_utils::IsOrtFormatModel(session_options_.optimized_model_filepath))) { ORT_RETURN_IF_ERROR_SESSIONID_(SaveToOrtFormat(session_options_.optimized_model_filepath)); } else { ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, session_options_.optimized_model_filepath)); diff --git a/onnxruntime/test/framework/ort_model_only_test.cc b/onnxruntime/test/framework/ort_model_only_test.cc index be81f8e9a7..c7000cd046 100644 --- a/onnxruntime/test/framework/ort_model_only_test.cc +++ b/onnxruntime/test/framework/ort_model_only_test.cc @@ -50,7 +50,6 @@ struct OrtModelTestInfo { std::vector output_names; std::function&)> output_verifier; std::vector> configs; - bool run_use_buffer{false}; }; static void RunOrtModel(const OrtModelTestInfo& test_info) { @@ -59,21 +58,8 @@ static void RunOrtModel(const OrtModelTestInfo& test_info) { for (const auto& config : test_info.configs) so.AddConfigEntry(config.first.c_str(), config.second.c_str()); - std::vector model_data; InferenceSessionGetGraphWrapper session_object{so, GetEnvironment()}; - if (test_info.run_use_buffer) { - // Load the file into a buffer and use the buffer to create inference session - size_t num_bytes = 0; - ASSERT_STATUS_OK(Env::Default().GetFileLength(test_info.model_filename.c_str(), num_bytes)); - model_data.resize(num_bytes); - std::ifstream bytes_stream(test_info.model_filename, std::ifstream::in | std::ifstream::binary); - bytes_stream.read(model_data.data(), num_bytes); - bytes_stream.close(); - ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(num_bytes))); - } else { - ASSERT_STATUS_OK(session_object.Load(test_info.model_filename)); // infer type from filename - } - + ASSERT_STATUS_OK(session_object.Load(test_info.model_filename)); // infer type from filename ASSERT_STATUS_OK(session_object.Initialize()); std::vector fetches; @@ -318,7 +304,8 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) { #endif // #if !defined(DISABLE_ML_OPS) #endif // #if !defined(ORT_MINIMAL_BUILD) -OrtModelTestInfo GetTestInfoForLoadOrtFormatModel() { +// test that we can deserialize and run a previously saved ORT format model +TEST(OrtModelOnlyTests, LoadOrtFormatModel) { OrtModelTestInfo test_info; test_info.model_filename = ORT_TSTR("testdata/ort_github_issue_4031.onnx.ort"); test_info.logid = "LoadOrtFormatModel"; @@ -336,26 +323,13 @@ OrtModelTestInfo GetTestInfoForLoadOrtFormatModel() { ASSERT_TRUE(output.Data()[0] == 125.f); }; - return test_info; -} - -// test that we can deserialize and run a previously saved ORT format model -TEST(OrtModelOnlyTests, LoadOrtFormatModel) { - OrtModelTestInfo test_info = GetTestInfoForLoadOrtFormatModel(); - RunOrtModel(test_info); -} - -// Load the model from a buffer instead of a file path -TEST(OrtModelOnlyTests, LoadOrtFormatModelFromBuffer) { - OrtModelTestInfo test_info = GetTestInfoForLoadOrtFormatModel(); - test_info.run_use_buffer = true; RunOrtModel(test_info); } #if !defined(DISABLE_ML_OPS) // test that we can deserialize and run a previously saved ORT format model // for a model with sequence and map outputs -OrtModelTestInfo GetTestInfoForLoadOrtFormatModelMLOps() { +TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOps) { OrtModelTestInfo test_info; test_info.model_filename = ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft.ort"); test_info.logid = "LoadOrtFormatModelMLOps"; @@ -389,23 +363,8 @@ OrtModelTestInfo GetTestInfoForLoadOrtFormatModelMLOps() { } }; - return test_info; -} - -// test that we can deserialize and run a previously saved ORT format model -// for a model with sequence and map outputs -TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOps) { - OrtModelTestInfo test_info = GetTestInfoForLoadOrtFormatModelMLOps(); RunOrtModel(test_info); } - -// Load the model from a buffer instead of a file path -TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOpsFromBuffer) { - OrtModelTestInfo test_info = GetTestInfoForLoadOrtFormatModelMLOps(); - test_info.run_use_buffer = true; - RunOrtModel(test_info); -} - #endif // !defined(DISABLE_ML_OPS) } // namespace test diff --git a/onnxruntime/test/testdata/ort_github_issue_4031.onnx.ort b/onnxruntime/test/testdata/ort_github_issue_4031.onnx.ort index 6d10929180..0fe711ee71 100644 Binary files a/onnxruntime/test/testdata/ort_github_issue_4031.onnx.ort and b/onnxruntime/test/testdata/ort_github_issue_4031.onnx.ort differ diff --git a/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.ort b/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.ort index eaef3784e7..7eae709f43 100644 Binary files a/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.ort and b/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.ort differ