Load CheckpointState from a buffer (#16457)

This commit is contained in:
Baiju Meswani 2023-06-26 09:18:38 -07:00 committed by GitHub
parent efe0af3720
commit 1f60414bc2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 243 additions and 27 deletions

View file

@ -40,6 +40,7 @@ namespace Microsoft.ML.OnnxRuntime
public IntPtr TrainingSessionGetEvalModelInputName;
public IntPtr AddProperty;
public IntPtr GetProperty;
public IntPtr LoadCheckpointFromBuffer;
}
internal static class NativeTrainingMethods

View file

@ -138,6 +138,113 @@ TEST(CheckpointApiTest, SaveOnnxModelAsCheckpoint_ThenLoad_CPU) {
}
}
/**
* Load ONNX model from file path, save into ORT checkpoint files,
* Then load it into a bytes buffer and then load the buffer to a checkpoint, compare with the initial parameter values.
*/
TEST(CheckpointApiTest, SaveOnnxModelAsCheckpointThenLoadFromBufferCPU) {
/// Phase 1 - Test Preparation
/// Prepare the data and dest folder for saving checkpoint.
/// Also cooked the data for test result comparison.
// Model path and trainable parameter name definitions.
auto model_uri = MODEL_FOLDER "transform/computation_reduction/gathernd/e2e.onnx";
std::vector<std::string> expected_trainable_param_names{
"bert.encoder.layer.2.output.LayerNorm.weight",
"bert.encoder.layer.2.output.LayerNorm.bias",
"add1_initializerr",
"cls.predictions.transform.LayerNorm.weight",
"cls.predictions.transform.LayerNorm.bias",
"bert.embeddings.word_embeddings.weight_transposed",
"cls.predictions.bias",
};
// Extract a weight value baseline to compare.
// expected_trainable_param_name_to_ort_value is used to compare with the values after restoring from checkpoint.
auto logger_ptr = std::make_unique<logging::Logger>(logging::LoggingManager::DefaultLogger());
std::shared_ptr<Model> p_model;
ORT_ENFORCE(Model::Load(model_uri, p_model, nullptr, *logger_ptr).IsOK());
Graph& graph = p_model->MainGraph();
std::vector<ONNX_NAMESPACE::TensorProto> trainable_param_values;
trainable_param_values.reserve(expected_trainable_param_names.size());
std::vector<ONNX_NAMESPACE::TensorProto> non_trainable_param_values;
const auto& initializer_tensors = graph.GetAllInitializedTensors();
for (const auto& [initializer_name, tensor_proto] : initializer_tensors) {
if (std::find(expected_trainable_param_names.begin(), expected_trainable_param_names.end(), initializer_name) !=
expected_trainable_param_names.end()) {
trainable_param_values.emplace_back(static_cast<ONNX_NAMESPACE::TensorProto>(*tensor_proto));
} else {
non_trainable_param_values.emplace_back(static_cast<ONNX_NAMESPACE::TensorProto>(*tensor_proto));
}
}
std::unordered_map<std::string, OrtValue> expected_trainable_param_name_to_ort_value;
ORT_ENFORCE(CreateOrtValuesFromTensorProtos(trainable_param_values, expected_trainable_param_name_to_ort_value)
.IsOK());
// Remove the temporary directory if it already exists.
auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir");
TemporaryDirectory tmp_dir{ckpt_test_root_dir};
/// Phase 2 - Run save checkpoint APIs.
/// And check the result checkpoint files.
// Call Save APIs.
PathString checkpoint_path{
ConcatPathComponent<PathChar>(tmp_dir.Path(), ORT_TSTR("e2e_ckpt_save_cpu"))};
ASSERT_STATUS_OK(SaveCheckpoint(trainable_param_values, non_trainable_param_values, checkpoint_path));
/// Phase 3 - Run load checkpoint APIs.
/// And check the result comparable with initial parameter values.
// Call Load APIs
size_t num_bytes = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(checkpoint_path.c_str(), num_bytes));
std::vector<uint8_t> checkpoint_bytes(num_bytes);
std::ifstream bytes_stream(checkpoint_path, std::ifstream::in | std::ifstream::binary);
bytes_stream.read(reinterpret_cast<char*>(checkpoint_bytes.data()), num_bytes);
ASSERT_TRUE(bytes_stream);
CheckpointState checkpoint_state_to_load;
ASSERT_STATUS_OK(LoadCheckpointFromBuffer(checkpoint_bytes, checkpoint_state_to_load));
ModuleCheckpointState module_state = checkpoint_state_to_load.module_checkpoint_state;
const auto& param_states = module_state.named_parameters;
std::unordered_map<std::string, OrtValue> restored_param_name_to_ort_values;
std::vector<std::string> restored_trainable_param_names;
for (auto it = param_states.begin(); it != param_states.end(); ++it) {
restored_param_name_to_ort_values.insert({it->first, it->second->Data()});
if (it->second->RequiresGrad()) {
restored_trainable_param_names.emplace_back(it->first);
}
}
// Check loaded parameter's values are same with original ones.
ASSERT_EQ(expected_trainable_param_name_to_ort_value.size(), restored_trainable_param_names.size());
ASSERT_EQ(expected_trainable_param_name_to_ort_value.size(), 7);
ASSERT_EQ(restored_param_name_to_ort_values.size(), 9);
std::sort(expected_trainable_param_names.begin(), expected_trainable_param_names.end());
std::sort(restored_trainable_param_names.begin(), restored_trainable_param_names.end());
ASSERT_EQ(expected_trainable_param_names, restored_trainable_param_names);
for (const auto& name : restored_trainable_param_names) {
const auto& restored_ort_value = restored_param_name_to_ort_values[name];
const auto& expected_ort_value = expected_trainable_param_name_to_ort_value.at(name);
ASSERT_TRUE(restored_ort_value.IsTensor() && expected_ort_value.IsTensor());
const Tensor& restored_tensor = restored_ort_value.Get<Tensor>();
const Tensor& expected_tensor = expected_ort_value.Get<Tensor>();
ASSERT_EQ(expected_tensor.DataType(), restored_tensor.DataType());
ASSERT_EQ(expected_tensor.SizeInBytes(), restored_tensor.SizeInBytes());
ASSERT_EQ(expected_tensor.DataType(), restored_tensor.DataType());
ASSERT_EQ(std::memcmp(expected_tensor.DataRaw(), restored_tensor.DataRaw(), expected_tensor.SizeInBytes()), 0);
}
}
/**
* Load ONNX model with parameters set to 0 from file path, Load Checkpoint weights into the Model,
* Then compare the new weights to 0 to make sure they were changed after loading checkpoint to model.

View file

@ -10,6 +10,7 @@
#include "orttraining/training_api/checkpoint.h"
#include "orttraining/test/training_api/core/data_utils.h"
#include "test/util/include/asserts.h"
#include "test/util/include/temp_dir.h"
namespace onnxruntime::training::test {
@ -38,6 +39,33 @@ TEST(TrainingCApiTest, SaveCheckpoint) {
new_checkpoint_state, model_uri);
}
TEST(TrainingCApiTest, LoadCheckpointFromBuffer) {
Ort::Env env;
size_t num_bytes = 0;
PathString checkpoint_path = MODEL_FOLDER "checkpoint.ckpt";
ASSERT_STATUS_OK(Env::Default().GetFileLength(checkpoint_path.c_str(), num_bytes));
std::vector<uint8_t> checkpoint_bytes(num_bytes);
std::ifstream bytes_stream(checkpoint_path, std::ifstream::in | std::ifstream::binary);
bytes_stream.read(reinterpret_cast<char*>(checkpoint_bytes.data()), num_bytes);
ASSERT_TRUE(bytes_stream);
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpointFromBuffer(checkpoint_bytes);
auto test_dir = ORT_TSTR("save_checkpoint_dir");
if (Env::Default().FolderExists(test_dir)) {
ORT_ENFORCE(Env::Default().DeleteFolder(test_dir).IsOK());
}
onnxruntime::test::TemporaryDirectory tmp_dir{test_dir};
PathString new_checkpoint_path{
ConcatPathComponent<PathChar>(tmp_dir.Path(), ORT_TSTR("new_checkpoint.ckpt"))};
Ort::CheckpointState::SaveCheckpoint(checkpoint_state, new_checkpoint_path);
Ort::CheckpointState new_checkpoint_state = Ort::CheckpointState::LoadCheckpoint(new_checkpoint_path);
}
TEST(TrainingCApiTest, AddIntProperty) {
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");

View file

@ -473,9 +473,6 @@ Status FromFile(const PathString& checkpoint_path, InlinedVector<uint8_t>& check
ORT_RETURN_IF_NOT(bytes_stream, "Loading checkpoint from ", ToUTF8String(checkpoint_path), " failed. Only ",
bytes_stream.gcount(), "/", num_bytes, " bytes could be read.");
flatbuffers::Verifier verifier(checkpoint_bytes.data(), checkpoint_bytes.size());
ORT_RETURN_IF_NOT(fbs::VerifyCheckpointBuffer(verifier), "Checkpoint verification failed.");
return Status::OK();
}
@ -622,10 +619,10 @@ Status ToPropertyBag(const onnxruntime::fbs::PropertyBag& fbs_property_bag,
* @param model_proto Model proto to be populated.
* @return Status of the operation.
*/
Status ToModelProto(const PathString& checkpoint_path,
Status ToModelProto(gsl::span<const uint8_t> checkpoint_bytes,
ONNX_NAMESPACE::ModelProto& model_proto) {
InlinedVector<uint8_t> checkpoint_bytes;
ORT_RETURN_IF_ERROR(load::FromFile(checkpoint_path, checkpoint_bytes));
flatbuffers::Verifier verifier(checkpoint_bytes.data(), checkpoint_bytes.size());
ORT_RETURN_IF_NOT(fbs::VerifyCheckpointBuffer(verifier), "Checkpoint verification failed.");
const auto* fbs_checkpoint = fbs::GetCheckpoint(checkpoint_bytes.data());
ORT_RETURN_IF_NOT(fbs_checkpoint, "Checkpoint is invalid. Expected: Valid checkpoint flatbuffer. Actual: nullptr.");
@ -687,9 +684,9 @@ Status ToModelProto(const PathString& checkpoint_path,
* @param state Checkpoint state to be populated.
* @return Status of the operation.
*/
Status ToCheckpointState(const PathString& checkpoint_path, CheckpointState& state) {
InlinedVector<uint8_t> checkpoint_bytes;
ORT_RETURN_IF_ERROR(load::FromFile(checkpoint_path, checkpoint_bytes));
Status ToCheckpointState(gsl::span<const uint8_t> checkpoint_bytes, CheckpointState& state) {
flatbuffers::Verifier verifier(checkpoint_bytes.data(), checkpoint_bytes.size());
ORT_RETURN_IF_NOT(fbs::VerifyCheckpointBuffer(verifier), "Checkpoint verification failed.");
const auto* fbs_checkpoint = fbs::GetCheckpoint(checkpoint_bytes.data());
ORT_RETURN_IF_NOT(fbs_checkpoint, "Checkpoint is invalid. Expected: Valid checkpoint flatbuffer. Actual: nullptr.");
@ -737,14 +734,26 @@ Status SaveCheckpoint(const CheckpointState& states, const PathString& checkpoin
Status LoadCheckpoint(const PathString& checkpoint_path, CheckpointState& checkpoint_states) {
ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ORT training checkpoint format only supports little-endian machines");
return load::ToCheckpointState(checkpoint_path, checkpoint_states);
InlinedVector<uint8_t> checkpoint_bytes;
ORT_RETURN_IF_ERROR(load::FromFile(checkpoint_path, checkpoint_bytes));
return load::ToCheckpointState(checkpoint_bytes, checkpoint_states);
}
Status LoadCheckpointFromBuffer(gsl::span<const uint8_t> checkpoint_bytes, CheckpointState& checkpoint_state) {
ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ORT training checkpoint format only supports little-endian machines");
return load::ToCheckpointState(checkpoint_bytes, checkpoint_state);
}
#if !defined(ORT_MINIMAL_BUILD)
Status LoadCheckpointToModel(const PathString& checkpoint_path,
ONNX_NAMESPACE::ModelProto& model_proto) {
ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ORT training checkpoint format only supports little-endian machines");
return load::ToModelProto(checkpoint_path, model_proto);
InlinedVector<uint8_t> checkpoint_bytes;
ORT_RETURN_IF_ERROR(load::FromFile(checkpoint_path, checkpoint_bytes));
return load::ToModelProto(checkpoint_bytes, model_proto);
}
#endif

View file

@ -66,6 +66,15 @@ Status SaveCheckpoint(gsl::span<const ONNX_NAMESPACE::TensorProto> trainable_ten
Status LoadCheckpoint(const PathString& checkpoint_path,
CheckpointState& checkpoint_state);
/**
* @brief Load training states from ORT checkpoint bytes buffer.
* @param checkpoint_bytes bytes buffer of the checkpoint.
* @param checkpoint_state parameter/optimizer and other user defined training states.
* @return Status
*/
Status LoadCheckpointFromBuffer(gsl::span<const uint8_t> checkpoint_bytes,
CheckpointState& checkpoint_state);
#if !defined(ORT_MINIMAL_BUILD)
/**
* @brief Load training states from ORT checkpoint into a ModelProto.

View file

@ -13,7 +13,7 @@
*
* In order to train a model with onnxruntime, the following training artifacts must be generated:
* - The training onnx model
* - The checkpoint directory
* - The checkpoint file
* - The optimizer onnx model
* - The eval onnx model model (optional)
*
@ -123,9 +123,9 @@ struct OrtTrainingApi {
/// \name Accessing The Training Session State
/// @{
/** \brief Load a checkpoint state from directory on disk into checkpoint_state.
/** \brief Load a checkpoint state from a file on disk into checkpoint_state.
*
* This function will parse a checkpoint directory, pull relevant files and load the training
* This function will parse a checkpoint file, pull relevant data and load the training
* state into the checkpoint_state. This checkpoint state can then be used to create the
* training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training
* session will resume training from the given checkpoint state.
@ -133,7 +133,7 @@ struct OrtTrainingApi {
* training state (including model parameters, its gradients, the optimizer states and the properties).
* As a result, it is required that the checkpoint state outlive the lifetime of the training session.
*
* \param[in] checkpoint_path Path to the checkpoint directory
* \param[in] checkpoint_path Path to the checkpoint file
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
@ -142,14 +142,14 @@ struct OrtTrainingApi {
ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
_Outptr_ OrtCheckpointState** checkpoint_state);
/** \brief Save the given state to a checkpoint directory on disk.
/** \brief Save the given state to a checkpoint file on disk.
*
* This function serializes the provided checkpoint state to a directory on disk.
* This function serializes the provided checkpoint state to a file on disk.
* This checkpoint can later be loaded by invoking OrtTrainingApi::LoadCheckpoint to resume
* training from this snapshot of the state.
*
* \param[in] checkpoint_state The checkpoint state to save.
* \param[in] checkpoint_path Path to the checkpoint directory.
* \param[in] checkpoint_path Path to the checkpoint file.
* \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
@ -172,7 +172,7 @@ struct OrtTrainingApi {
* - The training onnx model
* - The evaluation onnx model (optional)
* - The optimizer onnx model
* - The checkpoint directory
* - The checkpoint file
*
* These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md).
*
@ -623,6 +623,30 @@ struct OrtTrainingApi {
_Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value);
/// @}
/// \name Accessing The Training Session State
/// @{
/** \brief Load a checkpoint state from a buffer into checkpoint_state.
*
* This function will parse a checkpoint bytes buffer, pull relevant data and load the training
* state into the checkpoint_state. This checkpoint state can then be used to create the
* training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training
* session will resume training from the given checkpoint state.
* \note Note that the training session created with a checkpoint state uses this state to store the entire
* training state (including model parameters, its gradients, the optimizer states and the properties).
* As a result, it is required that the checkpoint state outlive the lifetime of the training session.
*
* \param[in] checkpoint_buffer Path to the checkpoint bytes buffer.
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
/// @}
};
typedef struct OrtTrainingApi OrtTrainingApi;

View file

@ -71,27 +71,40 @@ class CheckpointState : public detail::Base<OrtCheckpointState> {
/// \name Accessing The Training Session State
/// @{
/** \brief Load a checkpoint state from directory on disk into checkpoint_state.
/** \brief Load a checkpoint state from a file on disk into checkpoint_state.
*
* This function will parse a checkpoint directory, pull relevant files and load the training
* This function will parse a checkpoint file, pull relevant data and load the training
* state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the
* training session by instantiating Ort::TrainingSession. By doing so, the training session will resume
* training from the given checkpoint state.
*
* \param[in] path_to_checkpoint Path to the checkpoint directory
* \param[in] path_to_checkpoint Path to the checkpoint file
* \return Ort::CheckpointState object which holds the state of the training session parameters.
*
*/
static CheckpointState LoadCheckpoint(const std::basic_string<ORTCHAR_T>& path_to_checkpoint);
/** \brief Save the given state to a checkpoint directory on disk.
/** \brief Load a checkpoint state from a buffer.
*
* This function serializes the provided checkpoint state to a directory on disk.
* This function will parse a checkpoint buffer, pull relevant data and load the training
* state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the
* training session by instantiating Ort::TrainingSession. By doing so, the training session will resume
* training from the given checkpoint state.
*
* \param[in] buffer Buffer containing the checkpoint data.
* \return Ort::CheckpointState object which holds the state of the training session parameters.
*
*/
static CheckpointState LoadCheckpointFromBuffer(const std::vector<uint8_t>& buffer);
/** \brief Save the given state to a checkpoint file on disk.
*
* This function serializes the provided checkpoint state to a file on disk.
* This checkpoint can later be loaded by invoking Ort::CheckpointState::LoadCheckpoint to resume
* training from this snapshot of the state.
*
* \param[in] checkpoint_state The checkpoint state to save.
* \param[in] path_to_checkpoint Path to the checkpoint directory.
* \param[in] path_to_checkpoint Path to the checkpoint file.
* \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not.
*
*/
@ -131,7 +144,7 @@ class CheckpointState : public detail::Base<OrtCheckpointState> {
* - The training onnx model
* - The evaluation onnx model (optional)
* - The optimizer onnx model
* - The checkpoint directory
* - The checkpoint file
*
* These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md).
*

View file

@ -175,6 +175,12 @@ inline CheckpointState CheckpointState::LoadCheckpoint(const std::basic_string<O
return CheckpointState(checkpoint_state);
}
inline CheckpointState CheckpointState::LoadCheckpointFromBuffer(const std::vector<uint8_t>& buffer) {
OrtCheckpointState* checkpoint_state;
ThrowOnError(GetTrainingApi().LoadCheckpointFromBuffer(buffer.data(), buffer.size(), &checkpoint_state));
return CheckpointState(checkpoint_state);
}
inline void CheckpointState::SaveCheckpoint(const CheckpointState& checkpoint_states,
const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
const bool include_optimizer_state) {

View file

@ -281,6 +281,22 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::SaveCheckpoint, _In_ OrtCheckpointState* ch
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state) {
API_IMPL_BEGIN
*checkpoint_state = nullptr;
auto chkpt_state = std::make_unique<onnxruntime::training::api::CheckpointState>();
const auto* checkpoint_bytes = reinterpret_cast<const uint8_t*>(checkpoint_buffer);
gsl::span checkpoint_span(checkpoint_bytes, num_bytes);
ORT_API_RETURN_IF_STATUS_NOT_OK(
onnxruntime::training::api::LoadCheckpointFromBuffer(checkpoint_span, *chkpt_state));
*checkpoint_state = reinterpret_cast<OrtCheckpointState*>(chkpt_state.release());
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtTrainingApis::GetParametersSize, _Inout_ OrtTrainingSession* sess,
_Out_ size_t* out, bool trainable_only) {
API_IMPL_BEGIN
@ -527,7 +543,7 @@ static constexpr OrtTrainingApi ort_training_api = {
&OrtTrainingApis::TrainingSessionGetEvalModelInputName,
&OrtTrainingApis::AddProperty,
&OrtTrainingApis::GetProperty,
};
&OrtTrainingApis::LoadCheckpointFromBuffer};
ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) {
// No constraints on the API version yet.

View file

@ -84,4 +84,7 @@ ORT_API_STATUS_IMPL(GetProperty, _In_ const OrtCheckpointState* checkpoint_state
_In_ const char* property_name, _Inout_ OrtAllocator* allocator,
_Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value);
ORT_API_STATUS_IMPL(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
} // namespace OrtTrainingApis