diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index e2d57cfda4..71268400aa 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -941,6 +941,15 @@ class Graph { const ONNX_NAMESPACE::GraphProto& ToGraphProto(); ONNX_NAMESPACE::GraphProto ToGraphProto() const; + /** Gets the GraphProto representation of this Graph + @params external_file_name name of the binary file to use for initializers + @param initializer_size_threshold initializers larger or equal to this threshold (in bytes) are saved + in the external file. Initializer smaller than this threshold are included in the onnx file. + @returns GraphProto serialization of the graph. + */ + ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::string& external_file_name, + size_t initializer_size_threshold) const; + /** Gets the ISchemaRegistry instances being used with this Graph. */ IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 9cd8d9d5a4..70eddde7cf 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3030,6 +3030,66 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const { return result; } +ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std::string& external_file_name, + size_t initializer_size_threshold) const { + GraphProto result; + ToGraphProtoInternal(result); + const auto& model_path = ModelPath(); + + std::ofstream external_stream(external_file_name, std::ofstream::out | std::ofstream::binary); + ORT_ENFORCE(external_stream.is_open()); + int64_t external_offset = 0; + + // Add the initializers to the result graph. + const auto sparse_end = sparse_tensor_names_.end(); + for (const auto& initializer : graph_proto_->initializer()) { + if (sparse_end != sparse_tensor_names_.find(initializer.name())) { + // Sparse tensors are added to the ONNX file. + auto& sparse_initializer = *result.add_sparse_initializer(); + auto status = utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer); + ORT_ENFORCE(status.IsOK(), "Failed to convert dense initializer to sparse"); + } else { + // Dense tensors larger than the threshold are added to the external file. + TensorProto* output_proto = result.add_initializer(); + + size_t tensor_bytes_size = 0; + std::unique_ptr raw_data; + ORT_THROW_IF_ERROR(utils::UnpackInitializerData(initializer, Path(), raw_data, tensor_bytes_size)); + + if (tensor_bytes_size < initializer_size_threshold) { + *output_proto = initializer; + continue; + } + + for (size_t index = 0; index != tensor_bytes_size; ++index) { + external_stream << raw_data[index]; + } + + output_proto->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); + ONNX_NAMESPACE::StringStringEntryProto* location = output_proto->add_external_data(); + location->set_key("location"); + location->set_value(external_file_name); + ONNX_NAMESPACE::StringStringEntryProto* offset = output_proto->add_external_data(); + offset->set_key("offset"); + offset->set_value(std::to_string(external_offset)); + ONNX_NAMESPACE::StringStringEntryProto* length = output_proto->add_external_data(); + length->set_key("length"); + length->set_value(std::to_string(tensor_bytes_size)); + + output_proto->set_name(initializer.name()); + output_proto->set_data_type(initializer.data_type()); + for (int i = 0; i != initializer.dims_size(); ++i) { + output_proto->add_dims(initializer.dims(i)); + } + output_proto->set_doc_string(initializer.doc_string()); + + external_offset += tensor_bytes_size; + } + } + + return result; +} + void Graph::ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const { graph_proto_->clear_node(); graph_proto_->clear_input(); diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 7666d8390c..302b37bac4 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -271,6 +271,15 @@ ModelProto Model::ToProto() { return result; } +ModelProto Model::ToGraphProtoWithExternalInitializers(const std::string& external_file_name, + size_t initializer_size_threshold) { + ModelProto result(model_proto_); + const auto& graph = *graph_; + *(result.mutable_graph()) = graph.ToGraphProtoWithExternalInitializers(external_file_name, + initializer_size_threshold); + return result; +} + Status Model::Load(std::istream& model_istream, ModelProto* p_model_proto) { if (!model_istream.good()) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid istream object."); @@ -445,6 +454,32 @@ Status Model::Save(Model& model, const std::wstring& file_path) { } #endif +template +static Status SaveModelWithExternalInitializers(Model& model, + const T& file_path, + const std::string& external_file_name, + size_t initializer_size_threshold) { + int fd = 0; + Status status = Env::Default().FileOpenWr(file_path, fd); + ORT_RETURN_IF_ERROR(status); + + ORT_TRY { + status = Model::SaveWithExternalInitializers(model, fd, external_file_name, + initializer_size_threshold); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = Status(ONNXRUNTIME, FAIL, ex.what()); + }); + } + if (!status.IsOK()) { + GSL_SUPPRESS(es .84) + ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd)); + return status; + } + return Env::Default().FileClose(fd); +} + Status Model::Load(const PathString& file_path, ONNX_NAMESPACE::ModelProto& model_proto) { return LoadModel(file_path, model_proto); @@ -462,6 +497,12 @@ Status Model::Save(Model& model, const std::string& file_path) { return SaveModel(model, file_path); } +Status Model::SaveWithExternalInitializers(Model& model, const PathString& file_path, + const std::string& external_file_name, + size_t initializer_size_threshold) { + return SaveModelWithExternalInitializers(model, file_path, external_file_name, initializer_size_threshold); +} + Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) { const bool result = model_proto.ParseFromArray(p_bytes, count); if (!result) { @@ -569,6 +610,25 @@ Status Model::Save(Model& model, int p_fd) { return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf serialization failed."); } +Status Model::SaveWithExternalInitializers(Model& model, + int fd, + const std::string& external_file_name, + size_t initializer_size_threshold) { + if (fd < 0) { + return Status(ONNXRUNTIME, INVALID_ARGUMENT, " is less than 0."); + } + + ORT_RETURN_IF_ERROR(model.MainGraph().Resolve()); + + auto model_proto = model.ToGraphProtoWithExternalInitializers(external_file_name, initializer_size_threshold); + google::protobuf::io::FileOutputStream output(fd); + const bool result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush(); + if (result) { + return Status::OK(); + } + return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf serialization failed."); +} + common::Status Model::SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, flatbuffers::Offset& fbs_model) const { auto producer_name = experimental::utils::SaveStringToOrtFormat( diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 8083ff9f14..37ca815991 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -165,6 +165,12 @@ class Model { // Get model's serialization proto data. ONNX_NAMESPACE::ModelProto ToProto(); + // Get model's serialization proto data. + // Save initializer larger than the given threshold (in bytes) into an external binary file + // with the given name. This function is useful to avoid hitting the size limit of protobuf files. + ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::string& external_file_name, + size_t initializer_size_threshold); + #ifdef _WIN32 static common::Status Save(Model& model, const std::wstring& file_path); #endif @@ -172,6 +178,26 @@ class Model { static common::Status Save(Model& model, int fd); + // Save the model to file using an external file for initializers larger than the given threshold (in bytes). + // Notice that when on Windows the external_file_name is a plain string. + // This is because the string is saved inside the output protobuf as a plain string, where wchar is not supported. +#ifdef _WIN32 + static common::Status SaveWithExternalInitializers(Model& model, + const std::wstring& file_path, + const std::string& external_file_name, + size_t initializer_size_threshold); +#else + static common::Status SaveWithExternalInitializers(Model& model, + const std::string& file_path, + const std::string& external_file_name, + size_t initializer_size_threshold); +#endif + + static common::Status SaveWithExternalInitializers(Model& model, + int fd, + const std::string& external_file_name, + size_t initializer_size_threshold); + static common::Status Load(std::istream& model_istream, ONNX_NAMESPACE::ModelProto* p_model_proto); static common::Status Load(const PathString& file_path, diff --git a/onnxruntime/test/framework/save_model_with_external_initializers.cc b/onnxruntime/test/framework/save_model_with_external_initializers.cc new file mode 100644 index 0000000000..6ac48646d5 --- /dev/null +++ b/onnxruntime/test/framework/save_model_with_external_initializers.cc @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/data_types.h" +#include "core/graph/model.h" +#include "core/framework/tensorprotoutils.h" +#include "test/test_environment.h" +#include "test_utils.h" +#include "test/util/include/asserts.h" + +#include "gtest/gtest.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime; + +namespace onnxruntime { +namespace test { + +void LoadSaveAndCompareModel(const std::string& input_onnx, + const std::string& output_onnx, + const std::string& external_init_file, + size_t initializer_size_threshold) { + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(ToPathString(input_onnx), model, nullptr, DefaultLoggingManager().DefaultLogger())); + std::remove(output_onnx.c_str()); + std::remove(external_init_file.c_str()); + ASSERT_STATUS_OK(Model::SaveWithExternalInitializers(*model, ToPathString(output_onnx), external_init_file, initializer_size_threshold)); + + std::shared_ptr model_from_external; + ASSERT_STATUS_OK(Model::Load(ToPathString(output_onnx), model_from_external, nullptr, DefaultLoggingManager().DefaultLogger())); + + Graph& graph = model->MainGraph(); + // Perform shape inference on the graph, if this succeeds then it means that we could correctly read the + // integer initializers used by reshape and transpose. + ASSERT_STATUS_OK(graph.Resolve()); + Graph& graph_from_external = model_from_external->MainGraph(); + + InitializedTensorSet initializers = graph.GetAllInitializedTensors(); + InitializedTensorSet initializers_from_external = graph_from_external.GetAllInitializedTensors(); + + ASSERT_EQ(initializers.size(), initializers_from_external.size()); + + // Compare the initializers of the two versions. + for (auto i : initializers) { + const std::string kInitName = i.first; + const ONNX_NAMESPACE::TensorProto* tensor_proto = i.second; + const ONNX_NAMESPACE::TensorProto* from_external_tensor_proto = initializers_from_external[kInitName]; + + size_t tensor_proto_size = 0; + std::unique_ptr tensor_proto_data; + ORT_THROW_IF_ERROR(utils::UnpackInitializerData(*tensor_proto, Path(), tensor_proto_data, tensor_proto_size)); + + size_t from_external_tensor_proto_size = 0; + std::unique_ptr from_external_tensor_proto_data; + ORT_THROW_IF_ERROR(utils::UnpackInitializerData(*from_external_tensor_proto, Path(), from_external_tensor_proto_data, from_external_tensor_proto_size)); + + if (from_external_tensor_proto_size < initializer_size_threshold) { + // 'Small' tensors should be embedded in the onnx file. + EXPECT_EQ(from_external_tensor_proto->data_location(), ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_DEFAULT); + } else { + // 'Large' tensors should be added to the external binary file. + EXPECT_EQ(from_external_tensor_proto->data_location(), ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); + } + + ASSERT_EQ(tensor_proto_size, from_external_tensor_proto_size); + EXPECT_EQ(memcmp(tensor_proto_data.get(), from_external_tensor_proto_data.get(), tensor_proto_size), 0); + } + // Cleanup. + ASSERT_EQ(std::remove(output_onnx.c_str()), 0); + ASSERT_EQ(std::remove(external_init_file.c_str()), 0); +} + +TEST(SaveWithExternalInitializers, Mnist) { + LoadSaveAndCompareModel("testdata/mnist.onnx", "testdata/mnist_with_external_initializers.onnx", "mnist_external_initializers.bin", 100); +} + +} // namespace test +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 564557f856..2f3d15d8e6 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -1002,6 +1002,16 @@ static Status UpdateWeightsBeforeSaving( return Status::OK(); } +Status TrainingSession::SaveWithExternalInitializers(const PathString& model_uri, + const std::string& external_file_name, + size_t initializer_size_threshold) { + // Delete the old files before saving. + std::remove(ToMBString(model_uri).c_str()); + std::remove(external_file_name.c_str()); + + return Model::SaveWithExternalInitializers(*model_, model_uri, external_file_name, initializer_size_threshold); +} + Status TrainingSession::Save(const PathString& model_uri, TrainingSession::SaveOption opt) { // Delete the old file before saving. std::remove(ToMBString(model_uri).c_str()); // TODO would be good to have something like RemoveFile(PathString) diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index 7b94818fe2..8b1dc501cf 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -337,6 +337,19 @@ class TrainingSession : public InferenceSession { */ common::Status Save(const PathString& model_uri, SaveOption opt); + /** Save the model using an external file for initializers larger than the threshold (in bytes). + This function is useful to avoid hitting the size limit for protobufs when using large models, in + particular after auto-diff. + @param model_uri the path for the new model. + @param external_file_uri the name for the external initializers file. This is a plain string because + it needs to be saved into the onnx protobuf, where wchar is not supported. + @param initializer_size_threshold initializers larger or equal to this threshold (in bytes) are saved + in the external file. Initializer smaller than this threshold are included in the onnx file. + */ + common::Status SaveWithExternalInitializers(const PathString& model_uri, + const std::string& external_file_name, + size_t initializer_size_threshold); + /** Update the session initializers with passed-in state tensors * @param state_tensors A map of state tensors to set, usually loaded from a checkpoint. * @param strict Whether entries in state_tensors which are unknown or not present in the model are treated as an error or ignored.