Training API to export the eval model to an inference model (#13345)

This commit is contained in:
Baiju Meswani 2022-10-27 09:34:01 -07:00 committed by GitHub
parent 8827c4bdbc
commit a46c599a40
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 463 additions and 129 deletions

View file

@ -32,6 +32,7 @@ namespace Microsoft.ML.OnnxRuntime
public IntPtr CopyBufferToParameters;
public IntPtr ReleaseTrainingSession;
public IntPtr ReleaseCheckpointState;
public IntPtr ExportModelForInferencing;
}
internal static class NativeTrainingMethods

Binary file not shown.

View file

@ -846,24 +846,36 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
const std::vector<OrtValue>& user_inputs, std::vector<OrtValue>& user_outputs) -> void {
ORT_THROW_IF_ERROR(model->EvalStep(user_inputs, user_outputs));
})
.def("reset_grad", [](onnxruntime::training::api::Module* model) -> void {
ORT_THROW_IF_ERROR(model->ResetGrad());
})
.def("copy_parameters_to_buffer", [](onnxruntime::training::api::Module* model, OrtValue& output) -> void {
ORT_THROW_IF_ERROR(model->CopyParametersToBuffer(output));
})
.def("copy_buffer_to_parameters", [](onnxruntime::training::api::Module* model, OrtValue& input) -> void {
ORT_THROW_IF_ERROR(model->CopyBufferToParameters(input));
})
.def("get_parameters_size", [](onnxruntime::training::api::Module* model, bool trainable_only) -> size_t {
return model->GetParametersSize(trainable_only);
})
.def("save_checkpoint", [](onnxruntime::training::api::Module* model, const std::string& checkpoint_path) -> void {
onnxruntime::training::api::CheckpointState state;
ORT_THROW_IF_ERROR(model->GetStateDict(state.module_checkpoint_state));
ORT_THROW_IF_ERROR(onnxruntime::training::api::SaveCheckpoint(state,
ToPathString(checkpoint_path)));
});
.def("reset_grad",
[](onnxruntime::training::api::Module* model) -> void {
ORT_THROW_IF_ERROR(model->ResetGrad());
})
.def("copy_parameters_to_buffer",
[](onnxruntime::training::api::Module* model, OrtValue& output) -> void {
ORT_THROW_IF_ERROR(model->CopyParametersToBuffer(output));
})
.def("copy_buffer_to_parameters",
[](onnxruntime::training::api::Module* model, OrtValue& input) -> void {
ORT_THROW_IF_ERROR(model->CopyBufferToParameters(input));
})
.def("get_parameters_size",
[](onnxruntime::training::api::Module* model, bool trainable_only) -> size_t {
return model->GetParametersSize(trainable_only);
})
.def("save_checkpoint",
[](onnxruntime::training::api::Module* model, const std::string& checkpoint_path) -> void {
onnxruntime::training::api::CheckpointState state;
ORT_THROW_IF_ERROR(model->GetStateDict(state.module_checkpoint_state));
ORT_THROW_IF_ERROR(onnxruntime::training::api::SaveCheckpoint(state,
ToPathString(checkpoint_path)));
})
.def("export_model_for_inferencing",
[](onnxruntime::training::api::Module* model, const std::string& inference_model_path,
const std::vector<std::string>& graph_output_names) -> void {
ORT_ENFORCE(model, "Received a nullptr for expected pointer to class training::api::Module");
ORT_THROW_IF_ERROR(model->ExportModelForInferencing(inference_model_path,
graph_output_names));
});
py::class_<onnxruntime::training::api::CheckpointState>
checkpoint_state(m, "CheckpointState", R"pbdoc(CheckpointState.)pbdoc");

View file

@ -112,3 +112,9 @@ class Module:
Copies buffer to parameters.
"""
self._model.copy_buffer_to_parameters(buffer)
def export_model_for_inferencing(self, inference_model_uri: str, graph_output_names: list[str]) -> None:
"""
Exports the model for inferencing.
"""
self._model.export_model_for_inferencing(inference_model_uri, graph_output_names)

View file

@ -21,7 +21,7 @@ class SimpleModelWithCrossEntropyLoss(onnxblock.TrainingModel):
def _create_training_models():
# Given
device = "cuda"
device = "cpu"
batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10
pt_model, onnx_model = _get_models(device, batch_size, input_size, hidden_size, output_size)
@ -82,11 +82,11 @@ def test_train_step():
fetches = model(forward_inputs)
# Calculate loss using pytorch model to compare it with Module's output.
pt_outputs = pt_model(torch.from_numpy(inputs).to("cuda"))
pt_outputs = pt_model(torch.from_numpy(inputs))
loss_fn = torch.nn.CrossEntropyLoss()
pt_loss = loss_fn(pt_outputs, torch.from_numpy(labels).to("cuda").long())
pt_loss = loss_fn(pt_outputs, torch.from_numpy(labels).long())
assert fetches[0] == pt_loss.item()
assert np.allclose(fetches[0], pt_loss.detach().numpy())
def test_eval_step():
@ -211,3 +211,25 @@ def test_copy_buffer_to_parameters():
# Make sure the saved parameters are the same as the old parameters.
assert np.array_equal(old_output_params.numpy(), saved_params.numpy())
def test_export_model_for_inferencing():
# Initialize Models
simple_model, onnx_model, _, eval_model, _ = _create_training_models()
with tempfile.TemporaryDirectory() as temp_dir:
# Save models & checkpoint files to load them later.
checkpoint_file_path, model_file_path, eval_model_file_path = _get_test_models_path(
temp_dir, simple_model, onnx_model, eval_model=eval_model
)
# Create Checkpoint State.
state = CheckpointState(checkpoint_file_path)
# Create a Module.
model = Module(model_file_path, state, eval_model_file_path)
# Export inference model
inference_model_file_path = os.path.join(temp_dir, "inference_model.onnx")
model.export_model_for_inferencing(inference_model_file_path, ["output-0"])
assert os.path.exists(inference_model_file_path)

View file

@ -17,6 +17,7 @@
#include "orttraining/training_api/include/checkpoint.h"
#include "orttraining/training_api/include/lr_scheduler.h"
#include "orttraining/test/training_api/core/data_utils.h"
#include "test/util/include/temp_dir.h"
#include "default_providers.h"
using json = nlohmann::json;
@ -47,6 +48,153 @@ void GenerateRandomInput(gsl::span<const int64_t> dims, OrtValue& input) {
onnxruntime::training::api::utils::CreateInputOrtValue<float>(dims, data, &input);
}
void TestModuleExport(const std::vector<std::shared_ptr<IExecutionProvider>>& providers) {
auto training_model_uri = MODEL_FOLDER "training_model.onnx";
auto eval_model_uri = MODEL_FOLDER "eval_model.onnx";
onnxruntime::training::api::CheckpointState state;
auto checkpoint_to_load_path = MODEL_FOLDER "checkpoint.ckpt";
ASSERT_STATUS_OK(onnxruntime::training::api::LoadCheckpoint(checkpoint_to_load_path, state));
std::unique_ptr<Environment> env;
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
auto model = std::make_unique<onnxruntime::training::api::Module>(
ToUTF8String(training_model_uri), state.module_checkpoint_state.named_parameters, onnxruntime::SessionOptions(),
*env, providers, ToUTF8String(eval_model_uri));
auto test_dir = ORT_TSTR("export_model_for_inferencing_test_dir");
if (Env::Default().FolderExists(test_dir)) {
ORT_ENFORCE(Env::Default().DeleteFolder(test_dir).IsOK());
}
onnxruntime::test::TemporaryDirectory tmp_dir{test_dir};
PathString inference_model_path{
ConcatPathComponent<PathChar>(tmp_dir.Path(), ORT_TSTR("inference_model.onnx"))};
std::vector<std::string> graph_output_names({"output-0"});
ASSERT_STATUS_OK(model->ExportModelForInferencing(ToUTF8String(inference_model_path), graph_output_names));
// Load model
ONNX_NAMESPACE::ModelProto eval_model;
ONNX_NAMESPACE::ModelProto inference_model;
ORT_THROW_IF_ERROR(Model::Load(eval_model_uri, eval_model));
ORT_THROW_IF_ERROR(Model::Load(inference_model_path, inference_model));
// Check it has only one graph input
ASSERT_EQ(eval_model.graph().input().size(), 6);
ASSERT_EQ(inference_model.graph().input().size(), 1);
ASSERT_EQ(inference_model.graph().input()[0].name(), "input-0");
// Check that it does not have any node which has op type SoftmaxCrossEntropyLoss
auto softmaxceloss_node_found = [](auto& model) -> bool {
for (auto& node : model.graph().node()) {
if (node.op_type() == "SoftmaxCrossEntropyLoss") {
return true;
}
}
return false;
};
ASSERT_EQ(softmaxceloss_node_found(eval_model), true);
ASSERT_EQ(softmaxceloss_node_found(inference_model), false);
// Try running an inference session
auto inference_session = std::make_unique<onnxruntime::InferenceSession>(onnxruntime::SessionOptions(), *env);
ASSERT_STATUS_OK(inference_session->Load(inference_model_path));
ASSERT_STATUS_OK(inference_session->Initialize());
std::vector<std::string> input_names({"input-0"});
OrtValue graph_input;
GenerateRandomInput(std::array<int64_t, 2>{2, 784}, graph_input);
std::vector<OrtValue> feeds;
feeds.emplace_back(graph_input);
std::vector<std::string> output_names({"output-0"});
std::vector<OrtValue> outputs;
ASSERT_STATUS_OK(inference_session->Run(RunOptions(), input_names, feeds, output_names, &outputs));
ASSERT_EQ(outputs.size(), 1U);
}
void CompareValue(float expected, float output, float rtol = 1e-4, float atol = 1e-5) {
ASSERT_NEAR(expected, output, atol);
ASSERT_NEAR(expected, output, rtol * std::abs(expected));
}
#if defined(USE_CUDA) || defined(USE_ROCM)
const int64_t total_step_count = 100;
const float initial_lr = 1e-3f;
const int64_t resume_step = total_step_count / 2;
void TestLRSchduler(const std::string& test_file_name, float initial_lr, int64_t total_step_count,
int64_t warmup_step_count) {
/// Load model and optimizer graph, create Module, Optimizer and LRScheduler instances.
auto model_uri = MODEL_FOLDER "training_model.onnx";
auto optim_uri = MODEL_FOLDER "adamw.onnx";
onnxruntime::training::api::CheckpointState state;
auto checkpoint_to_load_path = MODEL_FOLDER "checkpoint.ckpt";
ASSERT_STATUS_OK(LoadCheckpoint(checkpoint_to_load_path, state));
onnxruntime::SessionOptions session_option;
std::unique_ptr<Environment> env;
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
const std::vector<std::shared_ptr<IExecutionProvider>> providers{onnxruntime::test::DefaultCudaExecutionProvider()};
auto model = std::make_unique<onnxruntime::training::api::Module>(
ToUTF8String(model_uri), state.module_checkpoint_state.named_parameters,
session_option, *env, providers);
auto optim = std::make_shared<onnxruntime::training::api::Optimizer>(
ToUTF8String(optim_uri), model->NamedParameters(), session_option,
*env, providers);
OrtValue input, target;
GenerateRandomInput(std::array<int64_t, 2>{2, 784}, input);
onnxruntime::training::api::utils::CreateInputOrtValue<int32_t>(
std::array<int64_t, 1>{2}, std::vector<int32_t>(2, 1), &target);
/// Load test data for learning rate schedulers.
auto data_uri = ORT_TSTR("testdata/test_data_generation/lr_scheduler/" + test_file_name);
std::ifstream in{data_uri};
// Element of vector represent a pair of <step_count, list of learning rates>>
typedef std::vector<std::pair<int64_t, std::vector<float>>> TestDataDictType;
TestDataDictType test_data;
const json j = json::parse(in);
j.get_to<TestDataDictType>(test_data);
int64_t resume_step = (*test_data.begin()).first;
ASSERT_EQ(total_step_count, static_cast<int64_t>(test_data.size()) + resume_step);
if (resume_step != 0) {
/// Reset optimizer states to match the initial state we want to test.
onnxruntime::training::api::OptimizerCheckpointState optimizer_checkpoint_states;
auto group_opt_state =
optimizer_checkpoint_states.group_named_optimizer_states["group0"] =
std::make_shared<onnxruntime::training::api::GroupOptimizerState>();
group_opt_state->step = resume_step;
group_opt_state->initial_lr = initial_lr;
ASSERT_STATUS_OK(optim->LoadStateDict(optimizer_checkpoint_states));
}
// KNOWN ISSUE: LinearLRScheduler by default use optim's states to calculate the first step's learning rate.
// If we restored it after creation, it will only affect the learning rate from the second step.
auto scheduler = std::make_unique<onnxruntime::training::api::LinearLRScheduler>(
optim, warmup_step_count, total_step_count);
for (auto it = test_data.begin(); it != test_data.end(); ++it) {
onnxruntime::training::api::OptimizerCheckpointState optimizer_states;
ASSERT_STATUS_OK(optim->GetStateDict(optimizer_states));
auto group_optimizer_state = optimizer_states.group_named_optimizer_states["group0"];
CompareValue(it->second[0], group_optimizer_state->learning_rate);
ASSERT_EQ(it->first, group_optimizer_state->step);
std::vector<OrtValue> inputs{input, target};
std::vector<OrtValue> fetches;
ASSERT_STATUS_OK(model->TrainStep(inputs, fetches));
ASSERT_STATUS_OK(optim->Step());
ASSERT_STATUS_OK(scheduler->Step());
}
}
#endif
} // namespace
TEST(TrainingApiTest, ModuleParametersSize) {
auto model_uri = MODEL_FOLDER "training_model.onnx";
@ -164,8 +312,18 @@ TEST(TrainingApiTest, ModuleTrainStep) {
}
}
TEST(TrainingApiTest, ModuleExportModelForInferencingCPU) {
std::vector<std::shared_ptr<IExecutionProvider>> providers{onnxruntime::test::DefaultCpuExecutionProvider()};
TestModuleExport(providers);
}
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(TrainingApiTest, ModuleExportModelForInferencingCUDA) {
std::vector<std::shared_ptr<IExecutionProvider>> providers{onnxruntime::test::DefaultCudaExecutionProvider()};
TestModuleExport(providers);
}
TEST(TrainingApiTest, OptimStep) {
auto model_uri = MODEL_FOLDER "training_model.onnx";
auto optim_uri = MODEL_FOLDER "adamw.onnx";
@ -241,83 +399,6 @@ TEST(TrainingApiTest, OptimStep) {
}
}
void CompareValue(float expected, float output, float rtol = 1e-4, float atol = 1e-5) {
ASSERT_NEAR(expected, output, atol);
ASSERT_NEAR(expected, output, rtol * std::abs(expected));
}
void TestLRSchduler(const std::string& test_file_name, float initial_lr, int64_t total_step_count,
int64_t warmup_step_count) {
/// Load model and optimizer graph, create Module, Optimizer and LRScheduler instances.
auto model_uri = MODEL_FOLDER "training_model.onnx";
auto optim_uri = MODEL_FOLDER "adamw.onnx";
onnxruntime::training::api::CheckpointState state;
auto checkpoint_to_load_path = MODEL_FOLDER "checkpoint.ckpt";
ASSERT_STATUS_OK(LoadCheckpoint(checkpoint_to_load_path, state));
onnxruntime::SessionOptions session_option;
std::unique_ptr<Environment> env;
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
const std::vector<std::shared_ptr<IExecutionProvider>> providers{onnxruntime::test::DefaultCudaExecutionProvider()};
auto model = std::make_unique<onnxruntime::training::api::Module>(
ToUTF8String(model_uri), state.module_checkpoint_state.named_parameters,
session_option, *env, providers);
auto optim = std::make_shared<onnxruntime::training::api::Optimizer>(
ToUTF8String(optim_uri), model->NamedParameters(), session_option,
*env, providers);
OrtValue input, target;
GenerateRandomInput(std::array<int64_t, 2>{2, 784}, input);
onnxruntime::training::api::utils::CreateInputOrtValue<int32_t>(
std::array<int64_t, 1>{2}, std::vector<int32_t>(2, 1), &target);
/// Load test data for learning rate schedulers.
auto data_uri = ORT_TSTR("testdata/test_data_generation/lr_scheduler/" + test_file_name);
std::ifstream in{data_uri};
// Element of vector represent a pair of <step_count, list of learning rates>>
typedef std::vector<std::pair<int64_t, std::vector<float>>> TestDataDictType;
TestDataDictType test_data;
const json j = json::parse(in);
j.get_to<TestDataDictType>(test_data);
int64_t resume_step = (*test_data.begin()).first;
ASSERT_EQ(total_step_count, static_cast<int64_t>(test_data.size()) + resume_step);
if (resume_step != 0) {
/// Reset optimizer states to match the initial state we want to test.
onnxruntime::training::api::OptimizerCheckpointState optimizer_checkpoint_states;
auto group_opt_state =
optimizer_checkpoint_states.group_named_optimizer_states["group0"] =
std::make_shared<onnxruntime::training::api::GroupOptimizerState>();
group_opt_state->step = resume_step;
group_opt_state->initial_lr = initial_lr;
ASSERT_STATUS_OK(optim->LoadStateDict(optimizer_checkpoint_states));
}
// KNOWN ISSUE: LinearLRScheduler by default use optim's states to calculate the first step's learning rate.
// If we restored it after creation, it will only affect the learning rate from the second step.
auto scheduler = std::make_unique<onnxruntime::training::api::LinearLRScheduler>(
optim, warmup_step_count, total_step_count);
for (auto it = test_data.begin(); it != test_data.end(); ++it) {
onnxruntime::training::api::OptimizerCheckpointState optimizer_states;
ASSERT_STATUS_OK(optim->GetStateDict(optimizer_states));
auto group_optimizer_state = optimizer_states.group_named_optimizer_states["group0"];
CompareValue(it->second[0], group_optimizer_state->learning_rate);
ASSERT_EQ(it->first, group_optimizer_state->step);
std::vector<OrtValue> inputs{input, target};
std::vector<OrtValue> fetches;
ASSERT_STATUS_OK(model->TrainStep(inputs, fetches));
ASSERT_STATUS_OK(optim->Step());
ASSERT_STATUS_OK(scheduler->Step());
}
}
const int64_t total_step_count = 100;
const float initial_lr = 1e-3f;
const int64_t resume_step = total_step_count / 2;
TEST(TrainingApiTest, LinearLRScheduler_NoWarmUp_Test) {
// No warm up.
TestLRSchduler("warmup_linear_scheduler_warmupstep-0.json", initial_lr, total_step_count, 0);
@ -360,7 +441,6 @@ TEST(TrainingApiTest, LinearLRScheduler_WarmUp200Step_ResumeFromCheckpoint_Test)
#endif
} // namespace
} // namespace test
} // namespace training
} // namespace onnxruntime

View file

@ -6,7 +6,6 @@
#include "core/common/logging/sinks/clog_sink.h"
#include "core/common/path.h"
#include "core/framework/framework_common.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"
#include "core/platform/env.h"
@ -16,6 +15,7 @@
#include "orttraining/core/framework/checkpoint_common.h"
#include "orttraining/core/framework/protobuf_message_sequence.h"
#include "orttraining/training_api/include/checkpoint.h"
#include "orttraining/training_api/include/utils.h"
namespace onnxruntime {
namespace training {
@ -53,10 +53,6 @@ Status CreateTensorProtosFromOrtValues(
[](const NameMLValMap::value_type& v) { return v.first; });
std::sort(ordered_tensor_names.begin(), ordered_tensor_names.end());
// Copy the tensor data and create TensorProto storing the data.
InlinedVector<char> tensor_data_buffer{};
static const OrtMemoryInfo cpu_alloc_info{onnxruntime::CPU, OrtDeviceAllocator};
saved_tensor_protos.reserve(ordered_tensor_names.size());
uint64_t total_bytes = 0;
@ -65,7 +61,6 @@ Status CreateTensorProtosFromOrtValues(
const OrtValue& ort_value = name_to_ort_value.at(tensor_name);
ORT_RETURN_IF_NOT(ort_value.IsTensor(), "ort_value.IsTensor() was false");
const Tensor& src_tensor = ort_value.Get<Tensor>();
tensor_data_buffer.resize(src_tensor.SizeInBytes());
// Currently large model size not considered, so exception thrown here
// when protobuf upper limit hit.
@ -74,23 +69,8 @@ Status CreateTensorProtosFromOrtValues(
ORT_THROW("checkpoint file size hit upper limit.");
}
auto& tensor_location = src_tensor.Location();
if (tensor_location.device.Type() == OrtDevice::CPU ||
tensor_location.mem_type == OrtMemTypeCPUInput ||
tensor_location.mem_type == OrtMemTypeCPUOutput ||
tensor_location.device.Type() == OrtDevice::GPU) {
gsl::span<char> dst_span = gsl::make_span(tensor_data_buffer);
ORT_RETURN_IF_NOT(src_tensor.SizeInBytes() == static_cast<size_t>(dst_span.size_bytes()), "src size != dst size");
Tensor dst_tensor{src_tensor.DataType(), src_tensor.Shape(), dst_span.data(), cpu_alloc_info};
ORT_RETURN_IF_ERROR(data_transfer_manager.CopyTensor(src_tensor, dst_tensor));
// Convert Tensor to TensorProto.
ONNX_NAMESPACE::TensorProto tensor_proto;
tensor_proto = utils::TensorToTensorProto(dst_tensor, tensor_name);
saved_tensor_protos.emplace_back(tensor_proto);
} else {
ORT_THROW("Unsupported device type for saving tensors");
}
saved_tensor_protos.emplace_back(utils::CopyTensorToTensorProto(
src_tensor, tensor_name, data_transfer_manager));
}
return Status::OK();

View file

@ -106,6 +106,11 @@ struct Module {
// Copy parameter values from contiguous buffer held by parameters_buffer onto parameters
Status CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only = true);
// Load the eval model from eval_model_path_or_bytes and transform it for the purpose of
// inferencing, and serialize to given path
Status ExportModelForInferencing(const std::string& inference_model_path,
gsl::span<const std::string> graph_output_names) const;
private:
std::unique_ptr<onnxruntime::InferenceSession> train_sess_{nullptr};
std::unique_ptr<onnxruntime::InferenceSession> eval_sess_{nullptr};
@ -118,6 +123,7 @@ struct Module {
std::vector<OrtValue> gradients_;
bool accumulate_gradient_ = false;
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& named_parameters_;
std::string eval_model_path_;
};
} // namespace api

View file

@ -296,6 +296,25 @@ struct OrtTrainingApi {
*
*/
ORT_CLASS_RELEASE(CheckpointState);
/** \brief Export a model that can be used for inferencing.
*
* If the training session was provided with an eval model, the training session can generate
* an inference model if it knows the inference graph outputs. The input inference graph outputs
* are used to prune the eval model so that the output model's outputs align with the provided outputs.
* The exported model is saved at the path provided and can be used for inferencing with InferenceSession.
* Note that the function re-loads the eval model from the path provided to CreateTrainingSession and expects
* that this path still be valid.
*
* \param[in] sess The training session.
* \param[in] inference_model_path Path where the inference model should be serialized to.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess,
_In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len,
_In_reads_(graph_outputs_len) const char* const* graph_output_names);
};
typedef struct OrtTrainingApi OrtTrainingApi;

View file

@ -143,6 +143,14 @@ class TrainingSession : public detail::Base<OrtTrainingSession> {
*
*/
void OptimizerStep();
/** \brief Exports a model that can be used for inferencing with the inference session.
*
* Wraps OrtTrainingApi::ExportModelForInferencing
*
*/
void ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
const std::vector<std::string>& graph_output_names);
};
} // namespace Ort

View file

@ -94,4 +94,14 @@ inline void CheckpointState::SaveCheckpoint(const TrainingSession& session,
ThrowOnError(GetTrainingApi().SaveCheckpoint(path_to_checkpoint.c_str(), session, include_optimizer_states));
}
inline void TrainingSession::ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
const std::vector<std::string>& graph_output_names) {
std::vector<const char*> output_names(graph_output_names.size(), nullptr);
for (auto& output_name : graph_output_names) {
output_names.push_back(output_name.c_str());
}
ThrowOnError(GetTrainingApi().ExportModelForInferencing(
p_, inference_model_path.c_str(), graph_output_names.size(), output_names.data()));
}
} // namespace Ort

View file

@ -60,4 +60,8 @@ ORT_API(void, ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckpointState* sessio
ORT_API(void, ReleaseTrainingSession, _Frees_ptr_opt_ OrtTrainingSession* session);
ORT_API_STATUS_IMPL(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess,
_In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len,
_In_reads_(graph_outputs_len) const char* const* graph_output_names);
} // namespace OrtTrainingApis

View file

@ -63,11 +63,14 @@ class TrainingSession {
Status CreateCheckpointState(CheckpointState& chkpt_state, bool save_optimizer_state) const;
size_t GetParametersSize(const bool trainable_only=true) const;
size_t GetParametersSize(const bool trainable_only = true) const;
Status CopyParametersToBuffer(OrtValue& parameters_buffer, const bool trainable_only=true);
Status CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only=true);
Status CopyParametersToBuffer(OrtValue& parameters_buffer, const bool trainable_only = true);
Status CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only = true);
Status ExportModelForInferencing(const std::string& inference_model_path,
gsl::span<const std::string> graph_output_names) const;
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TrainingSession);

View file

@ -85,6 +85,9 @@ T GetValue(OrtValue& ort_value) {
return val;
}
ONNX_NAMESPACE::TensorProto CopyTensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name,
const DataTransferManager& data_transfer_manager);
} // namespace utils
} // namespace api
} // namespace training

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "core/common/safeint.h"
#include "core/common/string_utils.h"
#include "core/framework/execution_provider.h"
#include "core/session/inference_session.h"
#include "core/session/environment.h"
@ -20,6 +21,94 @@ namespace {
// TODO: consolidate with frontend tooling
const std::string ACCUMULATE_GRAD_CONTROL_INPUT_NAME{"lazy_reset_grad"};
std::unordered_set<const Node*> GetReverseReachableNodes(Graph& inference_graph,
InlinedVector<const NodeArg*>& output_node_args) {
// Perform a graph traversal from the graph outputs to collect all reachable nodes from the outputs
InlinedVector<NodeIndex> nodes;
nodes.reserve((output_node_args.size()));
std::unordered_set<const Node*> visited_nodes;
for (auto node_arg : output_node_args) {
auto* node = inference_graph.GetProducerNode(node_arg->Name());
if (!node || std::find(nodes.begin(), nodes.end(), node->Index()) != nodes.end()) {
continue;
}
nodes.push_back(node->Index());
}
inference_graph.ReverseDFSFrom(nodes, [&visited_nodes](const Node* node) { visited_nodes.insert(node); }, {});
return visited_nodes;
}
Status RemoveUnusedNodes(Graph& inference_graph, InlinedVector<const NodeArg*>& output_node_args) {
auto reachable_nodes = GetReverseReachableNodes(inference_graph, output_node_args);
// Get all graph nodes and remove those that are not in the reachable nodes.
GraphViewer graph_viewer(inference_graph);
for (auto& node : graph_viewer.Nodes()) {
if (!reachable_nodes.count(&node)) {
inference_graph.RemoveNode(node.Index());
}
}
return Status::OK();
}
Status TransformModelOutputsForInference(Graph& inference_graph,
gsl::span<const std::string> inference_graph_outputs) {
// Model is updated to remove any outputs that are not defined in inference_graph_outputs. Nodes
// producing these unused model outputs are also subsequently removed.
ORT_RETURN_IF(inference_graph_outputs.empty(),
"Expected a non empty vector of graph output names. Got an empty vector.");
InlinedVector<const NodeArg*> inference_graph_output_node_args;
inference_graph_output_node_args.reserve(inference_graph_outputs.size());
for (const auto& output_name : inference_graph_outputs) {
const NodeArg* output_node_arg = inference_graph.GetNodeArg(std::string(output_name));
ORT_RETURN_IF_NOT(output_node_arg, "Expected graph output for inference graph " + std::string(output_name) +
" could not be found. Please regenerate the eval graph.");
inference_graph_output_node_args.push_back(output_node_arg);
}
// Set the inference graph outputs, and remove any unused nodes.
inference_graph.SetOutputs(inference_graph_output_node_args);
ORT_RETURN_IF_ERROR(RemoveUnusedNodes(inference_graph, inference_graph_output_node_args));
ORT_RETURN_IF_ERROR(inference_graph.Resolve());
return Status::OK();
}
Status TransformModelInputsForInference(Graph& inference_graph,
const std::unordered_map<
std::string, std::shared_ptr<Parameter>>& named_parameters,
const DataTransferManager& data_transfer_manager) {
std::vector<const NodeArg*> user_graph_inputs;
for (auto& graph_input_node_arg : inference_graph.GetInputs()) {
auto named_parameter_it = named_parameters.find(graph_input_node_arg->Name());
if (named_parameter_it == named_parameters.end()) {
if (inference_graph.GetConsumerNodes(graph_input_node_arg->Name()).empty()) {
continue;
}
user_graph_inputs.emplace_back(graph_input_node_arg);
} else {
ORT_ENFORCE(!inference_graph.IsInitializedTensor(named_parameter_it->first),
"The eval graph is invalid. Expected model parameter ",
named_parameter_it->first, " to be a graph input, not a graph initializer.");
inference_graph.AddInitializedTensor(utils::CopyTensorToTensorProto(
named_parameter_it->second->Data().Get<onnxruntime::Tensor>(),
named_parameter_it->first, data_transfer_manager));
}
}
inference_graph.SetInputs(user_graph_inputs);
ORT_RETURN_IF_ERROR(inference_graph.Resolve());
return Status::OK();
}
} // namespace
Status Parameter::SetGrad(const std::string& gradient_name, const OrtValue& param_grad) {
@ -162,6 +251,9 @@ Module::Module(const std::string& train_model_path_or_bytes,
if (eval_model_path_or_bytes.has_value()) {
eval_sess_ = std::make_unique<onnxruntime::InferenceSession>(session_options, env);
ORT_THROW_IF_ERROR(eval_sess_->Load(eval_model_path_or_bytes.value()));
for (const auto& provider : providers) {
ORT_THROW_IF_ERROR(eval_sess_->RegisterExecutionProvider(provider));
}
ORT_THROW_IF_ERROR(eval_sess_->Initialize());
utils::GetGraphInputOutputNames(eval_sess_, eval_input_names_, eval_output_names_);
@ -185,6 +277,10 @@ Module::Module(const std::string& train_model_path_or_bytes,
}
eval_input_names_ = eval_user_input_names;
eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end());
// Keep a copy of the eval model path to be able to later export the model for inferencing.
// The inference model will be reconstructed from the eval model.
eval_model_path_ = eval_model_path_or_bytes.value();
}
}
@ -232,7 +328,7 @@ Status Module::CopyParametersToBuffer(OrtValue& parameters_buffer, const bool tr
ORT_ENFORCE(nullptr != init_tensor);
auto expected_buffer_size = static_cast<int64_t>(GetParametersSize(trainable_only));
ORT_ENFORCE(init_tensor->Shape().Size() == expected_buffer_size,
"Parameters buffer size incorrect. Expected:",expected_buffer_size,
"Parameters buffer size incorrect. Expected:", expected_buffer_size,
", Actual:", init_tensor->Shape().Size());
const DataTransferManager& sess_data_transfer_manager = train_sess_->GetDataTransferManager();
@ -275,7 +371,7 @@ Status Module::CopyBufferToParameters(OrtValue& parameters_buffer, const bool tr
ORT_ENFORCE(nullptr != init_tensor);
auto expected_buffer_size = static_cast<int64_t>(GetParametersSize(trainable_only));
ORT_ENFORCE(init_tensor->Shape().Size() == expected_buffer_size,
"Parameters buffer size incorrect. Expected:",expected_buffer_size,
"Parameters buffer size incorrect. Expected:", expected_buffer_size,
", Actual:", init_tensor->Shape().Size());
const DataTransferManager& sess_data_transfer_manager = train_sess_->GetDataTransferManager();
@ -356,6 +452,33 @@ Status Module::GetStateDict(ModuleCheckpointState& module_checkpoint_state) {
return Status::OK();
}
Status Module::ExportModelForInferencing(const std::string& inference_model_path,
gsl::span<const std::string> graph_output_names) const {
ORT_RETURN_IF(!eval_sess_ || eval_model_path_.empty(),
"Eval model was not provided. Cannot export a model for inferencing.");
ONNX_NAMESPACE::ModelProto eval_model;
ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_), eval_model));
// Clone the eval mode into an inference onnxruntime::Model.
std::shared_ptr<Model> inference_model;
ORT_RETURN_IF_ERROR(Model::Load(eval_model, inference_model, nullptr, logging::LoggingManager::DefaultLogger()));
// The cloned model's outputs are transformed such that the model has outputs as defined by graph_output_names
// Any nodes not contributing to the inference outputs will be pruned.
ORT_THROW_IF_ERROR(TransformModelOutputsForInference(inference_model->MainGraph(), graph_output_names));
// The cloned model's inputs are transformed such that the model has only user defined inputs. All parameters
// are moved to be constant initializers for the model.
ORT_RETURN_IF_ERROR(TransformModelInputsForInference(inference_model->MainGraph(), named_parameters_,
eval_sess_->GetDataTransferManager()));
// Save the model at desired location.
ORT_THROW_IF_ERROR(Model::Save(*inference_model, inference_model_path));
return Status::OK();
}
} // namespace api
} // namespace training
} // namespace onnxruntime

View file

@ -339,6 +339,31 @@ ORT_API(void, OrtTrainingApis::ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckp
delete reinterpret_cast<onnxruntime::training::api::CheckpointState*>(checkpoint_state);
}
ORT_API_STATUS_IMPL(OrtTrainingApis::ExportModelForInferencing, _Inout_ OrtTrainingSession* sess,
_In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len,
_In_reads_(graph_outputs_len) const char* const* graph_output_names) {
API_IMPL_BEGIN
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
onnxruntime::InlinedVector<std::string> output_names(graph_outputs_len);
for (size_t i = 0; i != graph_outputs_len; ++i) {
if (graph_output_names[i] == nullptr || graph_output_names[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Name of graph output cannot be empty. Please provide valid graph names");
}
output_names[i] = graph_output_names[i];
}
ORT_API_RETURN_IF_STATUS_NOT_OK(
session->ExportModelForInferencing(onnxruntime::ToUTF8String(inference_model_path), output_names));
return nullptr;
API_IMPL_END
}
static constexpr OrtTrainingApi ort_training_api = {
// NOTE: The C# bindings depend on the API order within this struct. Since Training APIs are not officially
// released, it is OK to change the order here, however a corresponding matching change should also be done in the
@ -363,6 +388,7 @@ static constexpr OrtTrainingApi ort_training_api = {
&OrtTrainingApis::CopyBufferToParameters,
&OrtTrainingApis::ReleaseTrainingSession,
&OrtTrainingApis::ReleaseCheckpointState,
&OrtTrainingApis::ExportModelForInferencing,
};
ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) {

View file

@ -110,6 +110,11 @@ Status TrainingSession::CopyBufferToParameters(OrtValue& parameters_buffer, cons
return module_->CopyBufferToParameters(parameters_buffer, trainable_only);
}
Status TrainingSession::ExportModelForInferencing(const std::string& inference_model_path,
gsl::span<const std::string> graph_output_names) const {
return module_->ExportModelForInferencing(inference_model_path, graph_output_names);
}
} // namespace api
} // namespace training
} // namespace onnxruntime

View file

@ -6,6 +6,7 @@
#include "core/framework/ort_value.h"
#include "core/framework/tensor.h"
#include "core/framework/allocator.h"
#include "core/framework/tensorprotoutils.h"
#include "orttraining/training_api/include/utils.h"
@ -87,6 +88,31 @@ Status OrtValueLike(const SessionState& sess_state, const OrtValue& input_val, O
return Status::OK();
}
ONNX_NAMESPACE::TensorProto CopyTensorToTensorProto(const Tensor& src_tensor, const std::string& tensor_proto_name,
const DataTransferManager& data_transfer_manager) {
auto& tensor_location = src_tensor.Location();
if (tensor_location.device.Type() != OrtDevice::CPU &&
tensor_location.mem_type != OrtMemTypeCPUInput &&
tensor_location.mem_type != OrtMemTypeCPUOutput &&
tensor_location.device.Type() != OrtDevice::GPU) {
ORT_THROW("Unsupported device type for saving tensors");
}
// Copy the tensor data and create TensorProto storing the data.
InlinedVector<char> tensor_data_buffer{};
tensor_data_buffer.resize(src_tensor.SizeInBytes());
static const OrtMemoryInfo cpu_alloc_info{onnxruntime::CPU, OrtDeviceAllocator};
gsl::span<char> dst_span = gsl::make_span(tensor_data_buffer);
ORT_ENFORCE(src_tensor.SizeInBytes() == static_cast<size_t>(dst_span.size_bytes()), "src size != dst size");
Tensor dst_tensor{src_tensor.DataType(), src_tensor.Shape(), dst_span.data(), cpu_alloc_info};
ORT_THROW_IF_ERROR(data_transfer_manager.CopyTensor(src_tensor, dst_tensor));
// Convert Tensor to TensorProto.
ONNX_NAMESPACE::TensorProto tensor_proto;
return onnxruntime::utils::TensorToTensorProto(dst_tensor, tensor_proto_name);
}
} // namespace utils
} // namespace api
} // namespace training