Revert "Add flatbuffers verifier for ORT format buffer (#5378)"

This reverts commit e8bf3ba2bb383055b5974e8904324c33e0f4cbb1.
This commit is contained in:
Tianlei Wu 2020-10-09 11:55:51 -07:00
parent 6782866529
commit fc0fc80db2
4 changed files with 10 additions and 56 deletions

View file

@ -38,7 +38,7 @@
#include "core/platform/threadpool.h"
#include "core/providers/cpu/controlflow/utils.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/flatbuffers/flatbuffers_utils.h"
#include "core/flatbuffers/ort.fbs.h"
#ifdef USE_DML // TODO: This is necessary for the workaround in TransformGraph
#include "core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h"
#endif
@ -486,7 +486,7 @@ common::Status InferenceSession::SaveToOrtFormat(const std::basic_string<ORTCHAR
sb.add_model(model);
sb.add_session_state(session_state);
auto session = sb.Finish();
builder.Finish(session, fbs::InferenceSessionIdentifier());
builder.Finish(session);
// TODO: Do we need to catch any std::exceptions from creating/writing to disk and convert to Status codes?
{
@ -576,7 +576,7 @@ common::Status InferenceSession::Load(const std::string& model_uri) {
bool has_explicit_type = !model_type.empty();
if ((has_explicit_type && model_type == "ORT") ||
(!has_explicit_type && experimental::utils::IsOrtFormatModel(model_uri))) {
(!has_explicit_type && inference_session_utils::IsOrtFormatModel(model_uri))) {
#if defined(ENABLE_ORT_FORMAT_LOAD)
return LoadOrtModel(model_uri);
#else
@ -603,7 +603,7 @@ common::Status InferenceSession::Load(const std::wstring& model_uri) {
bool has_explicit_type = !model_type.empty();
if ((has_explicit_type && model_type == "ORT") ||
(!has_explicit_type && experimental::utils::IsOrtFormatModel(model_uri))) {
(!has_explicit_type && inference_session_utils::IsOrtFormatModel(model_uri))) {
#if defined(ENABLE_ORT_FORMAT_LOAD)
return LoadOrtModel(model_uri);
#else
@ -631,7 +631,7 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len
if ((has_explicit_type && model_type == "ORT") ||
(!has_explicit_type &&
experimental::utils::IsOrtFormatModelBytes(model_data, model_data_len))) {
inference_session_utils::IsOrtFormatModelBytes(model_data, model_data_len))) {
#if defined(ENABLE_ORT_FORMAT_LOAD)
return LoadOrtModel(model_data, model_data_len);
#else
@ -939,11 +939,6 @@ Status InferenceSession::LoadOrtModel(std::function<Status()> load_ort_format_mo
}
ORT_RETURN_IF_ERROR(load_ort_format_model_bytes());
// Verify the ort_format_model_bytes_ is a valid InferenceSessionBuffer before we access the data
flatbuffers::Verifier verifier(ort_format_model_bytes_.data(), ort_format_model_bytes_.size());
ORT_RETURN_IF_NOT(fbs::VerifyInferenceSessionBuffer(verifier));
const auto* fbs_session = fbs::GetInferenceSession(ort_format_model_bytes_.data());
ORT_RETURN_IF(nullptr == fbs_session, "InferenceSession is null. Invalid ORT format model.");
@ -1151,7 +1146,7 @@ common::Status InferenceSession::Initialize() {
if ((has_explicit_type && model_type == "ORT") ||
(!has_explicit_type &&
experimental::utils::IsOrtFormatModel(session_options_.optimized_model_filepath))) {
inference_session_utils::IsOrtFormatModel(session_options_.optimized_model_filepath))) {
ORT_RETURN_IF_ERROR_SESSIONID_(SaveToOrtFormat(session_options_.optimized_model_filepath));
} else {
ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, session_options_.optimized_model_filepath));

View file

@ -50,7 +50,6 @@ struct OrtModelTestInfo {
std::vector<std::string> output_names;
std::function<void(const std::vector<OrtValue>&)> output_verifier;
std::vector<std::pair<std::string, std::string>> configs;
bool run_use_buffer{false};
};
static void RunOrtModel(const OrtModelTestInfo& test_info) {
@ -59,21 +58,8 @@ static void RunOrtModel(const OrtModelTestInfo& test_info) {
for (const auto& config : test_info.configs)
so.AddConfigEntry(config.first.c_str(), config.second.c_str());
std::vector<char> model_data;
InferenceSessionGetGraphWrapper session_object{so, GetEnvironment()};
if (test_info.run_use_buffer) {
// Load the file into a buffer and use the buffer to create inference session
size_t num_bytes = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(test_info.model_filename.c_str(), num_bytes));
model_data.resize(num_bytes);
std::ifstream bytes_stream(test_info.model_filename, std::ifstream::in | std::ifstream::binary);
bytes_stream.read(model_data.data(), num_bytes);
bytes_stream.close();
ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast<int>(num_bytes)));
} else {
ASSERT_STATUS_OK(session_object.Load(test_info.model_filename)); // infer type from filename
}
ASSERT_STATUS_OK(session_object.Load(test_info.model_filename)); // infer type from filename
ASSERT_STATUS_OK(session_object.Initialize());
std::vector<OrtValue> fetches;
@ -318,7 +304,8 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) {
#endif // #if !defined(DISABLE_ML_OPS)
#endif // #if !defined(ORT_MINIMAL_BUILD)
OrtModelTestInfo GetTestInfoForLoadOrtFormatModel() {
// test that we can deserialize and run a previously saved ORT format model
TEST(OrtModelOnlyTests, LoadOrtFormatModel) {
OrtModelTestInfo test_info;
test_info.model_filename = ORT_TSTR("testdata/ort_github_issue_4031.onnx.ort");
test_info.logid = "LoadOrtFormatModel";
@ -336,26 +323,13 @@ OrtModelTestInfo GetTestInfoForLoadOrtFormatModel() {
ASSERT_TRUE(output.Data<float>()[0] == 125.f);
};
return test_info;
}
// test that we can deserialize and run a previously saved ORT format model
TEST(OrtModelOnlyTests, LoadOrtFormatModel) {
OrtModelTestInfo test_info = GetTestInfoForLoadOrtFormatModel();
RunOrtModel(test_info);
}
// Load the model from a buffer instead of a file path
TEST(OrtModelOnlyTests, LoadOrtFormatModelFromBuffer) {
OrtModelTestInfo test_info = GetTestInfoForLoadOrtFormatModel();
test_info.run_use_buffer = true;
RunOrtModel(test_info);
}
#if !defined(DISABLE_ML_OPS)
// test that we can deserialize and run a previously saved ORT format model
// for a model with sequence and map outputs
OrtModelTestInfo GetTestInfoForLoadOrtFormatModelMLOps() {
TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOps) {
OrtModelTestInfo test_info;
test_info.model_filename = ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft.ort");
test_info.logid = "LoadOrtFormatModelMLOps";
@ -389,23 +363,8 @@ OrtModelTestInfo GetTestInfoForLoadOrtFormatModelMLOps() {
}
};
return test_info;
}
// test that we can deserialize and run a previously saved ORT format model
// for a model with sequence and map outputs
TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOps) {
OrtModelTestInfo test_info = GetTestInfoForLoadOrtFormatModelMLOps();
RunOrtModel(test_info);
}
// Load the model from a buffer instead of a file path
TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOpsFromBuffer) {
OrtModelTestInfo test_info = GetTestInfoForLoadOrtFormatModelMLOps();
test_info.run_use_buffer = true;
RunOrtModel(test_info);
}
#endif // !defined(DISABLE_ML_OPS)
} // namespace test