diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 260817f712..c52ca4d1a4 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -40,6 +40,7 @@ namespace Microsoft.ML.OnnxRuntime public IntPtr TrainingSessionGetEvalModelInputName; public IntPtr AddProperty; public IntPtr GetProperty; + public IntPtr LoadCheckpointFromBuffer; } internal static class NativeTrainingMethods diff --git a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc index ba4e643feb..82a8acde33 100644 --- a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc +++ b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc @@ -138,6 +138,113 @@ TEST(CheckpointApiTest, SaveOnnxModelAsCheckpoint_ThenLoad_CPU) { } } +/** + * Load ONNX model from file path, save into ORT checkpoint files, + * Then load it into a bytes buffer and then load the buffer to a checkpoint, compare with the initial parameter values. + */ +TEST(CheckpointApiTest, SaveOnnxModelAsCheckpointThenLoadFromBufferCPU) { + /// Phase 1 - Test Preparation + /// Prepare the data and dest folder for saving checkpoint. + /// Also cooked the data for test result comparison. + + // Model path and trainable parameter name definitions. + auto model_uri = MODEL_FOLDER "transform/computation_reduction/gathernd/e2e.onnx"; + std::vector expected_trainable_param_names{ + "bert.encoder.layer.2.output.LayerNorm.weight", + "bert.encoder.layer.2.output.LayerNorm.bias", + "add1_initializerr", + "cls.predictions.transform.LayerNorm.weight", + "cls.predictions.transform.LayerNorm.bias", + "bert.embeddings.word_embeddings.weight_transposed", + "cls.predictions.bias", + }; + + // Extract a weight value baseline to compare. + // expected_trainable_param_name_to_ort_value is used to compare with the values after restoring from checkpoint. + auto logger_ptr = std::make_unique(logging::LoggingManager::DefaultLogger()); + std::shared_ptr p_model; + ORT_ENFORCE(Model::Load(model_uri, p_model, nullptr, *logger_ptr).IsOK()); + Graph& graph = p_model->MainGraph(); + + std::vector trainable_param_values; + trainable_param_values.reserve(expected_trainable_param_names.size()); + std::vector non_trainable_param_values; + const auto& initializer_tensors = graph.GetAllInitializedTensors(); + for (const auto& [initializer_name, tensor_proto] : initializer_tensors) { + if (std::find(expected_trainable_param_names.begin(), expected_trainable_param_names.end(), initializer_name) != + expected_trainable_param_names.end()) { + trainable_param_values.emplace_back(static_cast(*tensor_proto)); + } else { + non_trainable_param_values.emplace_back(static_cast(*tensor_proto)); + } + } + + std::unordered_map expected_trainable_param_name_to_ort_value; + ORT_ENFORCE(CreateOrtValuesFromTensorProtos(trainable_param_values, expected_trainable_param_name_to_ort_value) + .IsOK()); + + // Remove the temporary directory if it already exists. + auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir"); + TemporaryDirectory tmp_dir{ckpt_test_root_dir}; + + /// Phase 2 - Run save checkpoint APIs. + /// And check the result checkpoint files. + + // Call Save APIs. + PathString checkpoint_path{ + ConcatPathComponent(tmp_dir.Path(), ORT_TSTR("e2e_ckpt_save_cpu"))}; + ASSERT_STATUS_OK(SaveCheckpoint(trainable_param_values, non_trainable_param_values, checkpoint_path)); + + /// Phase 3 - Run load checkpoint APIs. + /// And check the result comparable with initial parameter values. + + // Call Load APIs + size_t num_bytes = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(checkpoint_path.c_str(), num_bytes)); + std::vector checkpoint_bytes(num_bytes); + + std::ifstream bytes_stream(checkpoint_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(checkpoint_bytes.data()), num_bytes); + + ASSERT_TRUE(bytes_stream); + + CheckpointState checkpoint_state_to_load; + ASSERT_STATUS_OK(LoadCheckpointFromBuffer(checkpoint_bytes, checkpoint_state_to_load)); + ModuleCheckpointState module_state = checkpoint_state_to_load.module_checkpoint_state; + const auto& param_states = module_state.named_parameters; + std::unordered_map restored_param_name_to_ort_values; + std::vector restored_trainable_param_names; + for (auto it = param_states.begin(); it != param_states.end(); ++it) { + restored_param_name_to_ort_values.insert({it->first, it->second->Data()}); + if (it->second->RequiresGrad()) { + restored_trainable_param_names.emplace_back(it->first); + } + } + + // Check loaded parameter's values are same with original ones. + ASSERT_EQ(expected_trainable_param_name_to_ort_value.size(), restored_trainable_param_names.size()); + ASSERT_EQ(expected_trainable_param_name_to_ort_value.size(), 7); + ASSERT_EQ(restored_param_name_to_ort_values.size(), 9); + + std::sort(expected_trainable_param_names.begin(), expected_trainable_param_names.end()); + std::sort(restored_trainable_param_names.begin(), restored_trainable_param_names.end()); + ASSERT_EQ(expected_trainable_param_names, restored_trainable_param_names); + + for (const auto& name : restored_trainable_param_names) { + const auto& restored_ort_value = restored_param_name_to_ort_values[name]; + const auto& expected_ort_value = expected_trainable_param_name_to_ort_value.at(name); + + ASSERT_TRUE(restored_ort_value.IsTensor() && expected_ort_value.IsTensor()); + const Tensor& restored_tensor = restored_ort_value.Get(); + const Tensor& expected_tensor = expected_ort_value.Get(); + ASSERT_EQ(expected_tensor.DataType(), restored_tensor.DataType()); + ASSERT_EQ(expected_tensor.SizeInBytes(), restored_tensor.SizeInBytes()); + ASSERT_EQ(expected_tensor.DataType(), restored_tensor.DataType()); + + ASSERT_EQ(std::memcmp(expected_tensor.DataRaw(), restored_tensor.DataRaw(), expected_tensor.SizeInBytes()), 0); + } +} + /** * Load ONNX model with parameters set to 0 from file path, Load Checkpoint weights into the Model, * Then compare the new weights to 0 to make sure they were changed after loading checkpoint to model. diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc index 9fb1fefa4a..8b44fc65a1 100644 --- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc @@ -10,6 +10,7 @@ #include "orttraining/training_api/checkpoint.h" #include "orttraining/test/training_api/core/data_utils.h" +#include "test/util/include/asserts.h" #include "test/util/include/temp_dir.h" namespace onnxruntime::training::test { @@ -38,6 +39,33 @@ TEST(TrainingCApiTest, SaveCheckpoint) { new_checkpoint_state, model_uri); } +TEST(TrainingCApiTest, LoadCheckpointFromBuffer) { + Ort::Env env; + size_t num_bytes = 0; + PathString checkpoint_path = MODEL_FOLDER "checkpoint.ckpt"; + ASSERT_STATUS_OK(Env::Default().GetFileLength(checkpoint_path.c_str(), num_bytes)); + std::vector checkpoint_bytes(num_bytes); + + std::ifstream bytes_stream(checkpoint_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(checkpoint_bytes.data()), num_bytes); + + ASSERT_TRUE(bytes_stream); + + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpointFromBuffer(checkpoint_bytes); + + auto test_dir = ORT_TSTR("save_checkpoint_dir"); + if (Env::Default().FolderExists(test_dir)) { + ORT_ENFORCE(Env::Default().DeleteFolder(test_dir).IsOK()); + } + onnxruntime::test::TemporaryDirectory tmp_dir{test_dir}; + PathString new_checkpoint_path{ + ConcatPathComponent(tmp_dir.Path(), ORT_TSTR("new_checkpoint.ckpt"))}; + + Ort::CheckpointState::SaveCheckpoint(checkpoint_state, new_checkpoint_path); + + Ort::CheckpointState new_checkpoint_state = Ort::CheckpointState::LoadCheckpoint(new_checkpoint_path); +} + TEST(TrainingCApiTest, AddIntProperty) { Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); diff --git a/orttraining/orttraining/training_api/checkpoint.cc b/orttraining/orttraining/training_api/checkpoint.cc index 838e062006..7133450d22 100644 --- a/orttraining/orttraining/training_api/checkpoint.cc +++ b/orttraining/orttraining/training_api/checkpoint.cc @@ -473,9 +473,6 @@ Status FromFile(const PathString& checkpoint_path, InlinedVector& check ORT_RETURN_IF_NOT(bytes_stream, "Loading checkpoint from ", ToUTF8String(checkpoint_path), " failed. Only ", bytes_stream.gcount(), "/", num_bytes, " bytes could be read."); - flatbuffers::Verifier verifier(checkpoint_bytes.data(), checkpoint_bytes.size()); - ORT_RETURN_IF_NOT(fbs::VerifyCheckpointBuffer(verifier), "Checkpoint verification failed."); - return Status::OK(); } @@ -622,10 +619,10 @@ Status ToPropertyBag(const onnxruntime::fbs::PropertyBag& fbs_property_bag, * @param model_proto Model proto to be populated. * @return Status of the operation. */ -Status ToModelProto(const PathString& checkpoint_path, +Status ToModelProto(gsl::span checkpoint_bytes, ONNX_NAMESPACE::ModelProto& model_proto) { - InlinedVector checkpoint_bytes; - ORT_RETURN_IF_ERROR(load::FromFile(checkpoint_path, checkpoint_bytes)); + flatbuffers::Verifier verifier(checkpoint_bytes.data(), checkpoint_bytes.size()); + ORT_RETURN_IF_NOT(fbs::VerifyCheckpointBuffer(verifier), "Checkpoint verification failed."); const auto* fbs_checkpoint = fbs::GetCheckpoint(checkpoint_bytes.data()); ORT_RETURN_IF_NOT(fbs_checkpoint, "Checkpoint is invalid. Expected: Valid checkpoint flatbuffer. Actual: nullptr."); @@ -687,9 +684,9 @@ Status ToModelProto(const PathString& checkpoint_path, * @param state Checkpoint state to be populated. * @return Status of the operation. */ -Status ToCheckpointState(const PathString& checkpoint_path, CheckpointState& state) { - InlinedVector checkpoint_bytes; - ORT_RETURN_IF_ERROR(load::FromFile(checkpoint_path, checkpoint_bytes)); +Status ToCheckpointState(gsl::span checkpoint_bytes, CheckpointState& state) { + flatbuffers::Verifier verifier(checkpoint_bytes.data(), checkpoint_bytes.size()); + ORT_RETURN_IF_NOT(fbs::VerifyCheckpointBuffer(verifier), "Checkpoint verification failed."); const auto* fbs_checkpoint = fbs::GetCheckpoint(checkpoint_bytes.data()); ORT_RETURN_IF_NOT(fbs_checkpoint, "Checkpoint is invalid. Expected: Valid checkpoint flatbuffer. Actual: nullptr."); @@ -737,14 +734,26 @@ Status SaveCheckpoint(const CheckpointState& states, const PathString& checkpoin Status LoadCheckpoint(const PathString& checkpoint_path, CheckpointState& checkpoint_states) { ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ORT training checkpoint format only supports little-endian machines"); - return load::ToCheckpointState(checkpoint_path, checkpoint_states); + + InlinedVector checkpoint_bytes; + ORT_RETURN_IF_ERROR(load::FromFile(checkpoint_path, checkpoint_bytes)); + return load::ToCheckpointState(checkpoint_bytes, checkpoint_states); +} + +Status LoadCheckpointFromBuffer(gsl::span checkpoint_bytes, CheckpointState& checkpoint_state) { + ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ORT training checkpoint format only supports little-endian machines"); + + return load::ToCheckpointState(checkpoint_bytes, checkpoint_state); } #if !defined(ORT_MINIMAL_BUILD) Status LoadCheckpointToModel(const PathString& checkpoint_path, ONNX_NAMESPACE::ModelProto& model_proto) { ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ORT training checkpoint format only supports little-endian machines"); - return load::ToModelProto(checkpoint_path, model_proto); + + InlinedVector checkpoint_bytes; + ORT_RETURN_IF_ERROR(load::FromFile(checkpoint_path, checkpoint_bytes)); + return load::ToModelProto(checkpoint_bytes, model_proto); } #endif diff --git a/orttraining/orttraining/training_api/checkpoint.h b/orttraining/orttraining/training_api/checkpoint.h index dbb2a9dbf8..5d8554662f 100644 --- a/orttraining/orttraining/training_api/checkpoint.h +++ b/orttraining/orttraining/training_api/checkpoint.h @@ -66,6 +66,15 @@ Status SaveCheckpoint(gsl::span trainable_ten Status LoadCheckpoint(const PathString& checkpoint_path, CheckpointState& checkpoint_state); +/** + * @brief Load training states from ORT checkpoint bytes buffer. + * @param checkpoint_bytes bytes buffer of the checkpoint. + * @param checkpoint_state parameter/optimizer and other user defined training states. + * @return Status + */ +Status LoadCheckpointFromBuffer(gsl::span checkpoint_bytes, + CheckpointState& checkpoint_state); + #if !defined(ORT_MINIMAL_BUILD) /** * @brief Load training states from ORT checkpoint into a ModelProto. diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index 2d8aafd44f..71cdeebeb2 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -13,7 +13,7 @@ * * In order to train a model with onnxruntime, the following training artifacts must be generated: * - The training onnx model - * - The checkpoint directory + * - The checkpoint file * - The optimizer onnx model * - The eval onnx model model (optional) * @@ -123,9 +123,9 @@ struct OrtTrainingApi { /// \name Accessing The Training Session State /// @{ - /** \brief Load a checkpoint state from directory on disk into checkpoint_state. + /** \brief Load a checkpoint state from a file on disk into checkpoint_state. * - * This function will parse a checkpoint directory, pull relevant files and load the training + * This function will parse a checkpoint file, pull relevant data and load the training * state into the checkpoint_state. This checkpoint state can then be used to create the * training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training * session will resume training from the given checkpoint state. @@ -133,7 +133,7 @@ struct OrtTrainingApi { * training state (including model parameters, its gradients, the optimizer states and the properties). * As a result, it is required that the checkpoint state outlive the lifetime of the training session. * - * \param[in] checkpoint_path Path to the checkpoint directory + * \param[in] checkpoint_path Path to the checkpoint file * \param[out] checkpoint_state Checkpoint state that contains the states of the training session. * * \snippet{doc} snippets.dox OrtStatus Return Value @@ -142,14 +142,14 @@ struct OrtTrainingApi { ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state); - /** \brief Save the given state to a checkpoint directory on disk. + /** \brief Save the given state to a checkpoint file on disk. * - * This function serializes the provided checkpoint state to a directory on disk. + * This function serializes the provided checkpoint state to a file on disk. * This checkpoint can later be loaded by invoking OrtTrainingApi::LoadCheckpoint to resume * training from this snapshot of the state. * * \param[in] checkpoint_state The checkpoint state to save. - * \param[in] checkpoint_path Path to the checkpoint directory. + * \param[in] checkpoint_path Path to the checkpoint file. * \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not. * * \snippet{doc} snippets.dox OrtStatus Return Value @@ -172,7 +172,7 @@ struct OrtTrainingApi { * - The training onnx model * - The evaluation onnx model (optional) * - The optimizer onnx model - * - The checkpoint directory + * - The checkpoint file * * These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md). * @@ -623,6 +623,30 @@ struct OrtTrainingApi { _Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value); /// @} + + /// \name Accessing The Training Session State + /// @{ + + /** \brief Load a checkpoint state from a buffer into checkpoint_state. + * + * This function will parse a checkpoint bytes buffer, pull relevant data and load the training + * state into the checkpoint_state. This checkpoint state can then be used to create the + * training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training + * session will resume training from the given checkpoint state. + * \note Note that the training session created with a checkpoint state uses this state to store the entire + * training state (including model parameters, its gradients, the optimizer states and the properties). + * As a result, it is required that the checkpoint state outlive the lifetime of the training session. + * + * \param[in] checkpoint_buffer Path to the checkpoint bytes buffer. + * \param[out] checkpoint_state Checkpoint state that contains the states of the training session. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, + _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state); + + /// @} }; typedef struct OrtTrainingApi OrtTrainingApi; diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h index 96e9013818..8653244844 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h @@ -71,27 +71,40 @@ class CheckpointState : public detail::Base { /// \name Accessing The Training Session State /// @{ - /** \brief Load a checkpoint state from directory on disk into checkpoint_state. + /** \brief Load a checkpoint state from a file on disk into checkpoint_state. * - * This function will parse a checkpoint directory, pull relevant files and load the training + * This function will parse a checkpoint file, pull relevant data and load the training * state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the * training session by instantiating Ort::TrainingSession. By doing so, the training session will resume * training from the given checkpoint state. * - * \param[in] path_to_checkpoint Path to the checkpoint directory + * \param[in] path_to_checkpoint Path to the checkpoint file * \return Ort::CheckpointState object which holds the state of the training session parameters. * */ static CheckpointState LoadCheckpoint(const std::basic_string& path_to_checkpoint); - /** \brief Save the given state to a checkpoint directory on disk. + /** \brief Load a checkpoint state from a buffer. * - * This function serializes the provided checkpoint state to a directory on disk. + * This function will parse a checkpoint buffer, pull relevant data and load the training + * state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the + * training session by instantiating Ort::TrainingSession. By doing so, the training session will resume + * training from the given checkpoint state. + * + * \param[in] buffer Buffer containing the checkpoint data. + * \return Ort::CheckpointState object which holds the state of the training session parameters. + * + */ + static CheckpointState LoadCheckpointFromBuffer(const std::vector& buffer); + + /** \brief Save the given state to a checkpoint file on disk. + * + * This function serializes the provided checkpoint state to a file on disk. * This checkpoint can later be loaded by invoking Ort::CheckpointState::LoadCheckpoint to resume * training from this snapshot of the state. * * \param[in] checkpoint_state The checkpoint state to save. - * \param[in] path_to_checkpoint Path to the checkpoint directory. + * \param[in] path_to_checkpoint Path to the checkpoint file. * \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not. * */ @@ -131,7 +144,7 @@ class CheckpointState : public detail::Base { * - The training onnx model * - The evaluation onnx model (optional) * - The optimizer onnx model - * - The checkpoint directory + * - The checkpoint file * * These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md). * diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h index 313235545f..393e5b01f7 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h @@ -175,6 +175,12 @@ inline CheckpointState CheckpointState::LoadCheckpoint(const std::basic_string& buffer) { + OrtCheckpointState* checkpoint_state; + ThrowOnError(GetTrainingApi().LoadCheckpointFromBuffer(buffer.data(), buffer.size(), &checkpoint_state)); + return CheckpointState(checkpoint_state); +} + inline void CheckpointState::SaveCheckpoint(const CheckpointState& checkpoint_states, const std::basic_string& path_to_checkpoint, const bool include_optimizer_state) { diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index d43cfe6b22..773ca93648 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -281,6 +281,22 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::SaveCheckpoint, _In_ OrtCheckpointState* ch API_IMPL_END } +ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, + _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state) { + API_IMPL_BEGIN + + *checkpoint_state = nullptr; + auto chkpt_state = std::make_unique(); + const auto* checkpoint_bytes = reinterpret_cast(checkpoint_buffer); + gsl::span checkpoint_span(checkpoint_bytes, num_bytes); + ORT_API_RETURN_IF_STATUS_NOT_OK( + onnxruntime::training::api::LoadCheckpointFromBuffer(checkpoint_span, *chkpt_state)); + *checkpoint_state = reinterpret_cast(chkpt_state.release()); + + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtTrainingApis::GetParametersSize, _Inout_ OrtTrainingSession* sess, _Out_ size_t* out, bool trainable_only) { API_IMPL_BEGIN @@ -527,7 +543,7 @@ static constexpr OrtTrainingApi ort_training_api = { &OrtTrainingApis::TrainingSessionGetEvalModelInputName, &OrtTrainingApis::AddProperty, &OrtTrainingApis::GetProperty, -}; + &OrtTrainingApis::LoadCheckpointFromBuffer}; ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) { // No constraints on the API version yet. diff --git a/orttraining/orttraining/training_api/ort_training_apis.h b/orttraining/orttraining/training_api/ort_training_apis.h index 3f8edede36..2b383f3b97 100644 --- a/orttraining/orttraining/training_api/ort_training_apis.h +++ b/orttraining/orttraining/training_api/ort_training_apis.h @@ -84,4 +84,7 @@ ORT_API_STATUS_IMPL(GetProperty, _In_ const OrtCheckpointState* checkpoint_state _In_ const char* property_name, _Inout_ OrtAllocator* allocator, _Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value); +ORT_API_STATUS_IMPL(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, + _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state); + } // namespace OrtTrainingApis