mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Load CheckpointState from a buffer (#16457)
This commit is contained in:
parent
efe0af3720
commit
1f60414bc2
10 changed files with 243 additions and 27 deletions
|
|
@ -40,6 +40,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
public IntPtr TrainingSessionGetEvalModelInputName;
|
||||
public IntPtr AddProperty;
|
||||
public IntPtr GetProperty;
|
||||
public IntPtr LoadCheckpointFromBuffer;
|
||||
}
|
||||
|
||||
internal static class NativeTrainingMethods
|
||||
|
|
|
|||
|
|
@ -138,6 +138,113 @@ TEST(CheckpointApiTest, SaveOnnxModelAsCheckpoint_ThenLoad_CPU) {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load ONNX model from file path, save into ORT checkpoint files,
|
||||
* Then load it into a bytes buffer and then load the buffer to a checkpoint, compare with the initial parameter values.
|
||||
*/
|
||||
TEST(CheckpointApiTest, SaveOnnxModelAsCheckpointThenLoadFromBufferCPU) {
|
||||
/// Phase 1 - Test Preparation
|
||||
/// Prepare the data and dest folder for saving checkpoint.
|
||||
/// Also cooked the data for test result comparison.
|
||||
|
||||
// Model path and trainable parameter name definitions.
|
||||
auto model_uri = MODEL_FOLDER "transform/computation_reduction/gathernd/e2e.onnx";
|
||||
std::vector<std::string> expected_trainable_param_names{
|
||||
"bert.encoder.layer.2.output.LayerNorm.weight",
|
||||
"bert.encoder.layer.2.output.LayerNorm.bias",
|
||||
"add1_initializerr",
|
||||
"cls.predictions.transform.LayerNorm.weight",
|
||||
"cls.predictions.transform.LayerNorm.bias",
|
||||
"bert.embeddings.word_embeddings.weight_transposed",
|
||||
"cls.predictions.bias",
|
||||
};
|
||||
|
||||
// Extract a weight value baseline to compare.
|
||||
// expected_trainable_param_name_to_ort_value is used to compare with the values after restoring from checkpoint.
|
||||
auto logger_ptr = std::make_unique<logging::Logger>(logging::LoggingManager::DefaultLogger());
|
||||
std::shared_ptr<Model> p_model;
|
||||
ORT_ENFORCE(Model::Load(model_uri, p_model, nullptr, *logger_ptr).IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
std::vector<ONNX_NAMESPACE::TensorProto> trainable_param_values;
|
||||
trainable_param_values.reserve(expected_trainable_param_names.size());
|
||||
std::vector<ONNX_NAMESPACE::TensorProto> non_trainable_param_values;
|
||||
const auto& initializer_tensors = graph.GetAllInitializedTensors();
|
||||
for (const auto& [initializer_name, tensor_proto] : initializer_tensors) {
|
||||
if (std::find(expected_trainable_param_names.begin(), expected_trainable_param_names.end(), initializer_name) !=
|
||||
expected_trainable_param_names.end()) {
|
||||
trainable_param_values.emplace_back(static_cast<ONNX_NAMESPACE::TensorProto>(*tensor_proto));
|
||||
} else {
|
||||
non_trainable_param_values.emplace_back(static_cast<ONNX_NAMESPACE::TensorProto>(*tensor_proto));
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, OrtValue> expected_trainable_param_name_to_ort_value;
|
||||
ORT_ENFORCE(CreateOrtValuesFromTensorProtos(trainable_param_values, expected_trainable_param_name_to_ort_value)
|
||||
.IsOK());
|
||||
|
||||
// Remove the temporary directory if it already exists.
|
||||
auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir");
|
||||
TemporaryDirectory tmp_dir{ckpt_test_root_dir};
|
||||
|
||||
/// Phase 2 - Run save checkpoint APIs.
|
||||
/// And check the result checkpoint files.
|
||||
|
||||
// Call Save APIs.
|
||||
PathString checkpoint_path{
|
||||
ConcatPathComponent<PathChar>(tmp_dir.Path(), ORT_TSTR("e2e_ckpt_save_cpu"))};
|
||||
ASSERT_STATUS_OK(SaveCheckpoint(trainable_param_values, non_trainable_param_values, checkpoint_path));
|
||||
|
||||
/// Phase 3 - Run load checkpoint APIs.
|
||||
/// And check the result comparable with initial parameter values.
|
||||
|
||||
// Call Load APIs
|
||||
size_t num_bytes = 0;
|
||||
ASSERT_STATUS_OK(Env::Default().GetFileLength(checkpoint_path.c_str(), num_bytes));
|
||||
std::vector<uint8_t> checkpoint_bytes(num_bytes);
|
||||
|
||||
std::ifstream bytes_stream(checkpoint_path, std::ifstream::in | std::ifstream::binary);
|
||||
bytes_stream.read(reinterpret_cast<char*>(checkpoint_bytes.data()), num_bytes);
|
||||
|
||||
ASSERT_TRUE(bytes_stream);
|
||||
|
||||
CheckpointState checkpoint_state_to_load;
|
||||
ASSERT_STATUS_OK(LoadCheckpointFromBuffer(checkpoint_bytes, checkpoint_state_to_load));
|
||||
ModuleCheckpointState module_state = checkpoint_state_to_load.module_checkpoint_state;
|
||||
const auto& param_states = module_state.named_parameters;
|
||||
std::unordered_map<std::string, OrtValue> restored_param_name_to_ort_values;
|
||||
std::vector<std::string> restored_trainable_param_names;
|
||||
for (auto it = param_states.begin(); it != param_states.end(); ++it) {
|
||||
restored_param_name_to_ort_values.insert({it->first, it->second->Data()});
|
||||
if (it->second->RequiresGrad()) {
|
||||
restored_trainable_param_names.emplace_back(it->first);
|
||||
}
|
||||
}
|
||||
|
||||
// Check loaded parameter's values are same with original ones.
|
||||
ASSERT_EQ(expected_trainable_param_name_to_ort_value.size(), restored_trainable_param_names.size());
|
||||
ASSERT_EQ(expected_trainable_param_name_to_ort_value.size(), 7);
|
||||
ASSERT_EQ(restored_param_name_to_ort_values.size(), 9);
|
||||
|
||||
std::sort(expected_trainable_param_names.begin(), expected_trainable_param_names.end());
|
||||
std::sort(restored_trainable_param_names.begin(), restored_trainable_param_names.end());
|
||||
ASSERT_EQ(expected_trainable_param_names, restored_trainable_param_names);
|
||||
|
||||
for (const auto& name : restored_trainable_param_names) {
|
||||
const auto& restored_ort_value = restored_param_name_to_ort_values[name];
|
||||
const auto& expected_ort_value = expected_trainable_param_name_to_ort_value.at(name);
|
||||
|
||||
ASSERT_TRUE(restored_ort_value.IsTensor() && expected_ort_value.IsTensor());
|
||||
const Tensor& restored_tensor = restored_ort_value.Get<Tensor>();
|
||||
const Tensor& expected_tensor = expected_ort_value.Get<Tensor>();
|
||||
ASSERT_EQ(expected_tensor.DataType(), restored_tensor.DataType());
|
||||
ASSERT_EQ(expected_tensor.SizeInBytes(), restored_tensor.SizeInBytes());
|
||||
ASSERT_EQ(expected_tensor.DataType(), restored_tensor.DataType());
|
||||
|
||||
ASSERT_EQ(std::memcmp(expected_tensor.DataRaw(), restored_tensor.DataRaw(), expected_tensor.SizeInBytes()), 0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load ONNX model with parameters set to 0 from file path, Load Checkpoint weights into the Model,
|
||||
* Then compare the new weights to 0 to make sure they were changed after loading checkpoint to model.
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include "orttraining/training_api/checkpoint.h"
|
||||
|
||||
#include "orttraining/test/training_api/core/data_utils.h"
|
||||
#include "test/util/include/asserts.h"
|
||||
#include "test/util/include/temp_dir.h"
|
||||
|
||||
namespace onnxruntime::training::test {
|
||||
|
|
@ -38,6 +39,33 @@ TEST(TrainingCApiTest, SaveCheckpoint) {
|
|||
new_checkpoint_state, model_uri);
|
||||
}
|
||||
|
||||
TEST(TrainingCApiTest, LoadCheckpointFromBuffer) {
|
||||
Ort::Env env;
|
||||
size_t num_bytes = 0;
|
||||
PathString checkpoint_path = MODEL_FOLDER "checkpoint.ckpt";
|
||||
ASSERT_STATUS_OK(Env::Default().GetFileLength(checkpoint_path.c_str(), num_bytes));
|
||||
std::vector<uint8_t> checkpoint_bytes(num_bytes);
|
||||
|
||||
std::ifstream bytes_stream(checkpoint_path, std::ifstream::in | std::ifstream::binary);
|
||||
bytes_stream.read(reinterpret_cast<char*>(checkpoint_bytes.data()), num_bytes);
|
||||
|
||||
ASSERT_TRUE(bytes_stream);
|
||||
|
||||
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpointFromBuffer(checkpoint_bytes);
|
||||
|
||||
auto test_dir = ORT_TSTR("save_checkpoint_dir");
|
||||
if (Env::Default().FolderExists(test_dir)) {
|
||||
ORT_ENFORCE(Env::Default().DeleteFolder(test_dir).IsOK());
|
||||
}
|
||||
onnxruntime::test::TemporaryDirectory tmp_dir{test_dir};
|
||||
PathString new_checkpoint_path{
|
||||
ConcatPathComponent<PathChar>(tmp_dir.Path(), ORT_TSTR("new_checkpoint.ckpt"))};
|
||||
|
||||
Ort::CheckpointState::SaveCheckpoint(checkpoint_state, new_checkpoint_path);
|
||||
|
||||
Ort::CheckpointState new_checkpoint_state = Ort::CheckpointState::LoadCheckpoint(new_checkpoint_path);
|
||||
}
|
||||
|
||||
TEST(TrainingCApiTest, AddIntProperty) {
|
||||
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
|
||||
|
||||
|
|
|
|||
|
|
@ -473,9 +473,6 @@ Status FromFile(const PathString& checkpoint_path, InlinedVector<uint8_t>& check
|
|||
ORT_RETURN_IF_NOT(bytes_stream, "Loading checkpoint from ", ToUTF8String(checkpoint_path), " failed. Only ",
|
||||
bytes_stream.gcount(), "/", num_bytes, " bytes could be read.");
|
||||
|
||||
flatbuffers::Verifier verifier(checkpoint_bytes.data(), checkpoint_bytes.size());
|
||||
ORT_RETURN_IF_NOT(fbs::VerifyCheckpointBuffer(verifier), "Checkpoint verification failed.");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -622,10 +619,10 @@ Status ToPropertyBag(const onnxruntime::fbs::PropertyBag& fbs_property_bag,
|
|||
* @param model_proto Model proto to be populated.
|
||||
* @return Status of the operation.
|
||||
*/
|
||||
Status ToModelProto(const PathString& checkpoint_path,
|
||||
Status ToModelProto(gsl::span<const uint8_t> checkpoint_bytes,
|
||||
ONNX_NAMESPACE::ModelProto& model_proto) {
|
||||
InlinedVector<uint8_t> checkpoint_bytes;
|
||||
ORT_RETURN_IF_ERROR(load::FromFile(checkpoint_path, checkpoint_bytes));
|
||||
flatbuffers::Verifier verifier(checkpoint_bytes.data(), checkpoint_bytes.size());
|
||||
ORT_RETURN_IF_NOT(fbs::VerifyCheckpointBuffer(verifier), "Checkpoint verification failed.");
|
||||
|
||||
const auto* fbs_checkpoint = fbs::GetCheckpoint(checkpoint_bytes.data());
|
||||
ORT_RETURN_IF_NOT(fbs_checkpoint, "Checkpoint is invalid. Expected: Valid checkpoint flatbuffer. Actual: nullptr.");
|
||||
|
|
@ -687,9 +684,9 @@ Status ToModelProto(const PathString& checkpoint_path,
|
|||
* @param state Checkpoint state to be populated.
|
||||
* @return Status of the operation.
|
||||
*/
|
||||
Status ToCheckpointState(const PathString& checkpoint_path, CheckpointState& state) {
|
||||
InlinedVector<uint8_t> checkpoint_bytes;
|
||||
ORT_RETURN_IF_ERROR(load::FromFile(checkpoint_path, checkpoint_bytes));
|
||||
Status ToCheckpointState(gsl::span<const uint8_t> checkpoint_bytes, CheckpointState& state) {
|
||||
flatbuffers::Verifier verifier(checkpoint_bytes.data(), checkpoint_bytes.size());
|
||||
ORT_RETURN_IF_NOT(fbs::VerifyCheckpointBuffer(verifier), "Checkpoint verification failed.");
|
||||
|
||||
const auto* fbs_checkpoint = fbs::GetCheckpoint(checkpoint_bytes.data());
|
||||
ORT_RETURN_IF_NOT(fbs_checkpoint, "Checkpoint is invalid. Expected: Valid checkpoint flatbuffer. Actual: nullptr.");
|
||||
|
|
@ -737,14 +734,26 @@ Status SaveCheckpoint(const CheckpointState& states, const PathString& checkpoin
|
|||
|
||||
Status LoadCheckpoint(const PathString& checkpoint_path, CheckpointState& checkpoint_states) {
|
||||
ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ORT training checkpoint format only supports little-endian machines");
|
||||
return load::ToCheckpointState(checkpoint_path, checkpoint_states);
|
||||
|
||||
InlinedVector<uint8_t> checkpoint_bytes;
|
||||
ORT_RETURN_IF_ERROR(load::FromFile(checkpoint_path, checkpoint_bytes));
|
||||
return load::ToCheckpointState(checkpoint_bytes, checkpoint_states);
|
||||
}
|
||||
|
||||
Status LoadCheckpointFromBuffer(gsl::span<const uint8_t> checkpoint_bytes, CheckpointState& checkpoint_state) {
|
||||
ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ORT training checkpoint format only supports little-endian machines");
|
||||
|
||||
return load::ToCheckpointState(checkpoint_bytes, checkpoint_state);
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
Status LoadCheckpointToModel(const PathString& checkpoint_path,
|
||||
ONNX_NAMESPACE::ModelProto& model_proto) {
|
||||
ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ORT training checkpoint format only supports little-endian machines");
|
||||
return load::ToModelProto(checkpoint_path, model_proto);
|
||||
|
||||
InlinedVector<uint8_t> checkpoint_bytes;
|
||||
ORT_RETURN_IF_ERROR(load::FromFile(checkpoint_path, checkpoint_bytes));
|
||||
return load::ToModelProto(checkpoint_bytes, model_proto);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -66,6 +66,15 @@ Status SaveCheckpoint(gsl::span<const ONNX_NAMESPACE::TensorProto> trainable_ten
|
|||
Status LoadCheckpoint(const PathString& checkpoint_path,
|
||||
CheckpointState& checkpoint_state);
|
||||
|
||||
/**
|
||||
* @brief Load training states from ORT checkpoint bytes buffer.
|
||||
* @param checkpoint_bytes bytes buffer of the checkpoint.
|
||||
* @param checkpoint_state parameter/optimizer and other user defined training states.
|
||||
* @return Status
|
||||
*/
|
||||
Status LoadCheckpointFromBuffer(gsl::span<const uint8_t> checkpoint_bytes,
|
||||
CheckpointState& checkpoint_state);
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
/**
|
||||
* @brief Load training states from ORT checkpoint into a ModelProto.
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@
|
|||
*
|
||||
* In order to train a model with onnxruntime, the following training artifacts must be generated:
|
||||
* - The training onnx model
|
||||
* - The checkpoint directory
|
||||
* - The checkpoint file
|
||||
* - The optimizer onnx model
|
||||
* - The eval onnx model model (optional)
|
||||
*
|
||||
|
|
@ -123,9 +123,9 @@ struct OrtTrainingApi {
|
|||
/// \name Accessing The Training Session State
|
||||
/// @{
|
||||
|
||||
/** \brief Load a checkpoint state from directory on disk into checkpoint_state.
|
||||
/** \brief Load a checkpoint state from a file on disk into checkpoint_state.
|
||||
*
|
||||
* This function will parse a checkpoint directory, pull relevant files and load the training
|
||||
* This function will parse a checkpoint file, pull relevant data and load the training
|
||||
* state into the checkpoint_state. This checkpoint state can then be used to create the
|
||||
* training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training
|
||||
* session will resume training from the given checkpoint state.
|
||||
|
|
@ -133,7 +133,7 @@ struct OrtTrainingApi {
|
|||
* training state (including model parameters, its gradients, the optimizer states and the properties).
|
||||
* As a result, it is required that the checkpoint state outlive the lifetime of the training session.
|
||||
*
|
||||
* \param[in] checkpoint_path Path to the checkpoint directory
|
||||
* \param[in] checkpoint_path Path to the checkpoint file
|
||||
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
|
|
@ -142,14 +142,14 @@ struct OrtTrainingApi {
|
|||
ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
|
||||
_Outptr_ OrtCheckpointState** checkpoint_state);
|
||||
|
||||
/** \brief Save the given state to a checkpoint directory on disk.
|
||||
/** \brief Save the given state to a checkpoint file on disk.
|
||||
*
|
||||
* This function serializes the provided checkpoint state to a directory on disk.
|
||||
* This function serializes the provided checkpoint state to a file on disk.
|
||||
* This checkpoint can later be loaded by invoking OrtTrainingApi::LoadCheckpoint to resume
|
||||
* training from this snapshot of the state.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state to save.
|
||||
* \param[in] checkpoint_path Path to the checkpoint directory.
|
||||
* \param[in] checkpoint_path Path to the checkpoint file.
|
||||
* \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
|
|
@ -172,7 +172,7 @@ struct OrtTrainingApi {
|
|||
* - The training onnx model
|
||||
* - The evaluation onnx model (optional)
|
||||
* - The optimizer onnx model
|
||||
* - The checkpoint directory
|
||||
* - The checkpoint file
|
||||
*
|
||||
* These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md).
|
||||
*
|
||||
|
|
@ -623,6 +623,30 @@ struct OrtTrainingApi {
|
|||
_Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value);
|
||||
|
||||
/// @}
|
||||
|
||||
/// \name Accessing The Training Session State
|
||||
/// @{
|
||||
|
||||
/** \brief Load a checkpoint state from a buffer into checkpoint_state.
|
||||
*
|
||||
* This function will parse a checkpoint bytes buffer, pull relevant data and load the training
|
||||
* state into the checkpoint_state. This checkpoint state can then be used to create the
|
||||
* training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training
|
||||
* session will resume training from the given checkpoint state.
|
||||
* \note Note that the training session created with a checkpoint state uses this state to store the entire
|
||||
* training state (including model parameters, its gradients, the optimizer states and the properties).
|
||||
* As a result, it is required that the checkpoint state outlive the lifetime of the training session.
|
||||
*
|
||||
* \param[in] checkpoint_buffer Path to the checkpoint bytes buffer.
|
||||
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
|
||||
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
typedef struct OrtTrainingApi OrtTrainingApi;
|
||||
|
|
|
|||
|
|
@ -71,27 +71,40 @@ class CheckpointState : public detail::Base<OrtCheckpointState> {
|
|||
/// \name Accessing The Training Session State
|
||||
/// @{
|
||||
|
||||
/** \brief Load a checkpoint state from directory on disk into checkpoint_state.
|
||||
/** \brief Load a checkpoint state from a file on disk into checkpoint_state.
|
||||
*
|
||||
* This function will parse a checkpoint directory, pull relevant files and load the training
|
||||
* This function will parse a checkpoint file, pull relevant data and load the training
|
||||
* state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the
|
||||
* training session by instantiating Ort::TrainingSession. By doing so, the training session will resume
|
||||
* training from the given checkpoint state.
|
||||
*
|
||||
* \param[in] path_to_checkpoint Path to the checkpoint directory
|
||||
* \param[in] path_to_checkpoint Path to the checkpoint file
|
||||
* \return Ort::CheckpointState object which holds the state of the training session parameters.
|
||||
*
|
||||
*/
|
||||
static CheckpointState LoadCheckpoint(const std::basic_string<ORTCHAR_T>& path_to_checkpoint);
|
||||
|
||||
/** \brief Save the given state to a checkpoint directory on disk.
|
||||
/** \brief Load a checkpoint state from a buffer.
|
||||
*
|
||||
* This function serializes the provided checkpoint state to a directory on disk.
|
||||
* This function will parse a checkpoint buffer, pull relevant data and load the training
|
||||
* state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the
|
||||
* training session by instantiating Ort::TrainingSession. By doing so, the training session will resume
|
||||
* training from the given checkpoint state.
|
||||
*
|
||||
* \param[in] buffer Buffer containing the checkpoint data.
|
||||
* \return Ort::CheckpointState object which holds the state of the training session parameters.
|
||||
*
|
||||
*/
|
||||
static CheckpointState LoadCheckpointFromBuffer(const std::vector<uint8_t>& buffer);
|
||||
|
||||
/** \brief Save the given state to a checkpoint file on disk.
|
||||
*
|
||||
* This function serializes the provided checkpoint state to a file on disk.
|
||||
* This checkpoint can later be loaded by invoking Ort::CheckpointState::LoadCheckpoint to resume
|
||||
* training from this snapshot of the state.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state to save.
|
||||
* \param[in] path_to_checkpoint Path to the checkpoint directory.
|
||||
* \param[in] path_to_checkpoint Path to the checkpoint file.
|
||||
* \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not.
|
||||
*
|
||||
*/
|
||||
|
|
@ -131,7 +144,7 @@ class CheckpointState : public detail::Base<OrtCheckpointState> {
|
|||
* - The training onnx model
|
||||
* - The evaluation onnx model (optional)
|
||||
* - The optimizer onnx model
|
||||
* - The checkpoint directory
|
||||
* - The checkpoint file
|
||||
*
|
||||
* These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md).
|
||||
*
|
||||
|
|
|
|||
|
|
@ -175,6 +175,12 @@ inline CheckpointState CheckpointState::LoadCheckpoint(const std::basic_string<O
|
|||
return CheckpointState(checkpoint_state);
|
||||
}
|
||||
|
||||
inline CheckpointState CheckpointState::LoadCheckpointFromBuffer(const std::vector<uint8_t>& buffer) {
|
||||
OrtCheckpointState* checkpoint_state;
|
||||
ThrowOnError(GetTrainingApi().LoadCheckpointFromBuffer(buffer.data(), buffer.size(), &checkpoint_state));
|
||||
return CheckpointState(checkpoint_state);
|
||||
}
|
||||
|
||||
inline void CheckpointState::SaveCheckpoint(const CheckpointState& checkpoint_states,
|
||||
const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
|
||||
const bool include_optimizer_state) {
|
||||
|
|
|
|||
|
|
@ -281,6 +281,22 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::SaveCheckpoint, _In_ OrtCheckpointState* ch
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
|
||||
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state) {
|
||||
API_IMPL_BEGIN
|
||||
|
||||
*checkpoint_state = nullptr;
|
||||
auto chkpt_state = std::make_unique<onnxruntime::training::api::CheckpointState>();
|
||||
const auto* checkpoint_bytes = reinterpret_cast<const uint8_t*>(checkpoint_buffer);
|
||||
gsl::span checkpoint_span(checkpoint_bytes, num_bytes);
|
||||
ORT_API_RETURN_IF_STATUS_NOT_OK(
|
||||
onnxruntime::training::api::LoadCheckpointFromBuffer(checkpoint_span, *chkpt_state));
|
||||
*checkpoint_state = reinterpret_cast<OrtCheckpointState*>(chkpt_state.release());
|
||||
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::GetParametersSize, _Inout_ OrtTrainingSession* sess,
|
||||
_Out_ size_t* out, bool trainable_only) {
|
||||
API_IMPL_BEGIN
|
||||
|
|
@ -527,7 +543,7 @@ static constexpr OrtTrainingApi ort_training_api = {
|
|||
&OrtTrainingApis::TrainingSessionGetEvalModelInputName,
|
||||
&OrtTrainingApis::AddProperty,
|
||||
&OrtTrainingApis::GetProperty,
|
||||
};
|
||||
&OrtTrainingApis::LoadCheckpointFromBuffer};
|
||||
|
||||
ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) {
|
||||
// No constraints on the API version yet.
|
||||
|
|
|
|||
|
|
@ -84,4 +84,7 @@ ORT_API_STATUS_IMPL(GetProperty, _In_ const OrtCheckpointState* checkpoint_state
|
|||
_In_ const char* property_name, _Inout_ OrtAllocator* allocator,
|
||||
_Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value);
|
||||
|
||||
ORT_API_STATUS_IMPL(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
|
||||
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
|
||||
|
||||
} // namespace OrtTrainingApis
|
||||
|
|
|
|||
Loading…
Reference in a new issue