mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Training API to export the eval model to an inference model (#13345)
This commit is contained in:
parent
8827c4bdbc
commit
a46c599a40
19 changed files with 463 additions and 129 deletions
|
|
@ -32,6 +32,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
public IntPtr CopyBufferToParameters;
|
||||
public IntPtr ReleaseTrainingSession;
|
||||
public IntPtr ReleaseCheckpointState;
|
||||
public IntPtr ExportModelForInferencing;
|
||||
}
|
||||
|
||||
internal static class NativeTrainingMethods
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/training_api/eval_model.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/training_api/eval_model.onnx
vendored
Normal file
Binary file not shown.
Binary file not shown.
|
|
@ -846,24 +846,36 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
const std::vector<OrtValue>& user_inputs, std::vector<OrtValue>& user_outputs) -> void {
|
||||
ORT_THROW_IF_ERROR(model->EvalStep(user_inputs, user_outputs));
|
||||
})
|
||||
.def("reset_grad", [](onnxruntime::training::api::Module* model) -> void {
|
||||
ORT_THROW_IF_ERROR(model->ResetGrad());
|
||||
})
|
||||
.def("copy_parameters_to_buffer", [](onnxruntime::training::api::Module* model, OrtValue& output) -> void {
|
||||
ORT_THROW_IF_ERROR(model->CopyParametersToBuffer(output));
|
||||
})
|
||||
.def("copy_buffer_to_parameters", [](onnxruntime::training::api::Module* model, OrtValue& input) -> void {
|
||||
ORT_THROW_IF_ERROR(model->CopyBufferToParameters(input));
|
||||
})
|
||||
.def("get_parameters_size", [](onnxruntime::training::api::Module* model, bool trainable_only) -> size_t {
|
||||
return model->GetParametersSize(trainable_only);
|
||||
})
|
||||
.def("save_checkpoint", [](onnxruntime::training::api::Module* model, const std::string& checkpoint_path) -> void {
|
||||
onnxruntime::training::api::CheckpointState state;
|
||||
ORT_THROW_IF_ERROR(model->GetStateDict(state.module_checkpoint_state));
|
||||
ORT_THROW_IF_ERROR(onnxruntime::training::api::SaveCheckpoint(state,
|
||||
ToPathString(checkpoint_path)));
|
||||
});
|
||||
.def("reset_grad",
|
||||
[](onnxruntime::training::api::Module* model) -> void {
|
||||
ORT_THROW_IF_ERROR(model->ResetGrad());
|
||||
})
|
||||
.def("copy_parameters_to_buffer",
|
||||
[](onnxruntime::training::api::Module* model, OrtValue& output) -> void {
|
||||
ORT_THROW_IF_ERROR(model->CopyParametersToBuffer(output));
|
||||
})
|
||||
.def("copy_buffer_to_parameters",
|
||||
[](onnxruntime::training::api::Module* model, OrtValue& input) -> void {
|
||||
ORT_THROW_IF_ERROR(model->CopyBufferToParameters(input));
|
||||
})
|
||||
.def("get_parameters_size",
|
||||
[](onnxruntime::training::api::Module* model, bool trainable_only) -> size_t {
|
||||
return model->GetParametersSize(trainable_only);
|
||||
})
|
||||
.def("save_checkpoint",
|
||||
[](onnxruntime::training::api::Module* model, const std::string& checkpoint_path) -> void {
|
||||
onnxruntime::training::api::CheckpointState state;
|
||||
ORT_THROW_IF_ERROR(model->GetStateDict(state.module_checkpoint_state));
|
||||
ORT_THROW_IF_ERROR(onnxruntime::training::api::SaveCheckpoint(state,
|
||||
ToPathString(checkpoint_path)));
|
||||
})
|
||||
.def("export_model_for_inferencing",
|
||||
[](onnxruntime::training::api::Module* model, const std::string& inference_model_path,
|
||||
const std::vector<std::string>& graph_output_names) -> void {
|
||||
ORT_ENFORCE(model, "Received a nullptr for expected pointer to class training::api::Module");
|
||||
ORT_THROW_IF_ERROR(model->ExportModelForInferencing(inference_model_path,
|
||||
graph_output_names));
|
||||
});
|
||||
|
||||
py::class_<onnxruntime::training::api::CheckpointState>
|
||||
checkpoint_state(m, "CheckpointState", R"pbdoc(CheckpointState.)pbdoc");
|
||||
|
|
|
|||
|
|
@ -112,3 +112,9 @@ class Module:
|
|||
Copies buffer to parameters.
|
||||
"""
|
||||
self._model.copy_buffer_to_parameters(buffer)
|
||||
|
||||
def export_model_for_inferencing(self, inference_model_uri: str, graph_output_names: list[str]) -> None:
|
||||
"""
|
||||
Exports the model for inferencing.
|
||||
"""
|
||||
self._model.export_model_for_inferencing(inference_model_uri, graph_output_names)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class SimpleModelWithCrossEntropyLoss(onnxblock.TrainingModel):
|
|||
|
||||
def _create_training_models():
|
||||
# Given
|
||||
device = "cuda"
|
||||
device = "cpu"
|
||||
batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10
|
||||
pt_model, onnx_model = _get_models(device, batch_size, input_size, hidden_size, output_size)
|
||||
|
||||
|
|
@ -82,11 +82,11 @@ def test_train_step():
|
|||
fetches = model(forward_inputs)
|
||||
|
||||
# Calculate loss using pytorch model to compare it with Module's output.
|
||||
pt_outputs = pt_model(torch.from_numpy(inputs).to("cuda"))
|
||||
pt_outputs = pt_model(torch.from_numpy(inputs))
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
pt_loss = loss_fn(pt_outputs, torch.from_numpy(labels).to("cuda").long())
|
||||
pt_loss = loss_fn(pt_outputs, torch.from_numpy(labels).long())
|
||||
|
||||
assert fetches[0] == pt_loss.item()
|
||||
assert np.allclose(fetches[0], pt_loss.detach().numpy())
|
||||
|
||||
|
||||
def test_eval_step():
|
||||
|
|
@ -211,3 +211,25 @@ def test_copy_buffer_to_parameters():
|
|||
|
||||
# Make sure the saved parameters are the same as the old parameters.
|
||||
assert np.array_equal(old_output_params.numpy(), saved_params.numpy())
|
||||
|
||||
|
||||
def test_export_model_for_inferencing():
|
||||
# Initialize Models
|
||||
simple_model, onnx_model, _, eval_model, _ = _create_training_models()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Save models & checkpoint files to load them later.
|
||||
checkpoint_file_path, model_file_path, eval_model_file_path = _get_test_models_path(
|
||||
temp_dir, simple_model, onnx_model, eval_model=eval_model
|
||||
)
|
||||
|
||||
# Create Checkpoint State.
|
||||
state = CheckpointState(checkpoint_file_path)
|
||||
|
||||
# Create a Module.
|
||||
model = Module(model_file_path, state, eval_model_file_path)
|
||||
|
||||
# Export inference model
|
||||
inference_model_file_path = os.path.join(temp_dir, "inference_model.onnx")
|
||||
model.export_model_for_inferencing(inference_model_file_path, ["output-0"])
|
||||
assert os.path.exists(inference_model_file_path)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
#include "orttraining/training_api/include/checkpoint.h"
|
||||
#include "orttraining/training_api/include/lr_scheduler.h"
|
||||
#include "orttraining/test/training_api/core/data_utils.h"
|
||||
#include "test/util/include/temp_dir.h"
|
||||
#include "default_providers.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
|
@ -47,6 +48,153 @@ void GenerateRandomInput(gsl::span<const int64_t> dims, OrtValue& input) {
|
|||
onnxruntime::training::api::utils::CreateInputOrtValue<float>(dims, data, &input);
|
||||
}
|
||||
|
||||
void TestModuleExport(const std::vector<std::shared_ptr<IExecutionProvider>>& providers) {
|
||||
auto training_model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
auto eval_model_uri = MODEL_FOLDER "eval_model.onnx";
|
||||
|
||||
onnxruntime::training::api::CheckpointState state;
|
||||
auto checkpoint_to_load_path = MODEL_FOLDER "checkpoint.ckpt";
|
||||
ASSERT_STATUS_OK(onnxruntime::training::api::LoadCheckpoint(checkpoint_to_load_path, state));
|
||||
|
||||
std::unique_ptr<Environment> env;
|
||||
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
|
||||
auto model = std::make_unique<onnxruntime::training::api::Module>(
|
||||
ToUTF8String(training_model_uri), state.module_checkpoint_state.named_parameters, onnxruntime::SessionOptions(),
|
||||
*env, providers, ToUTF8String(eval_model_uri));
|
||||
|
||||
auto test_dir = ORT_TSTR("export_model_for_inferencing_test_dir");
|
||||
if (Env::Default().FolderExists(test_dir)) {
|
||||
ORT_ENFORCE(Env::Default().DeleteFolder(test_dir).IsOK());
|
||||
}
|
||||
onnxruntime::test::TemporaryDirectory tmp_dir{test_dir};
|
||||
PathString inference_model_path{
|
||||
ConcatPathComponent<PathChar>(tmp_dir.Path(), ORT_TSTR("inference_model.onnx"))};
|
||||
|
||||
std::vector<std::string> graph_output_names({"output-0"});
|
||||
ASSERT_STATUS_OK(model->ExportModelForInferencing(ToUTF8String(inference_model_path), graph_output_names));
|
||||
|
||||
// Load model
|
||||
ONNX_NAMESPACE::ModelProto eval_model;
|
||||
ONNX_NAMESPACE::ModelProto inference_model;
|
||||
ORT_THROW_IF_ERROR(Model::Load(eval_model_uri, eval_model));
|
||||
ORT_THROW_IF_ERROR(Model::Load(inference_model_path, inference_model));
|
||||
|
||||
// Check it has only one graph input
|
||||
ASSERT_EQ(eval_model.graph().input().size(), 6);
|
||||
ASSERT_EQ(inference_model.graph().input().size(), 1);
|
||||
ASSERT_EQ(inference_model.graph().input()[0].name(), "input-0");
|
||||
|
||||
// Check that it does not have any node which has op type SoftmaxCrossEntropyLoss
|
||||
auto softmaxceloss_node_found = [](auto& model) -> bool {
|
||||
for (auto& node : model.graph().node()) {
|
||||
if (node.op_type() == "SoftmaxCrossEntropyLoss") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
ASSERT_EQ(softmaxceloss_node_found(eval_model), true);
|
||||
ASSERT_EQ(softmaxceloss_node_found(inference_model), false);
|
||||
|
||||
// Try running an inference session
|
||||
auto inference_session = std::make_unique<onnxruntime::InferenceSession>(onnxruntime::SessionOptions(), *env);
|
||||
ASSERT_STATUS_OK(inference_session->Load(inference_model_path));
|
||||
ASSERT_STATUS_OK(inference_session->Initialize());
|
||||
std::vector<std::string> input_names({"input-0"});
|
||||
OrtValue graph_input;
|
||||
GenerateRandomInput(std::array<int64_t, 2>{2, 784}, graph_input);
|
||||
std::vector<OrtValue> feeds;
|
||||
feeds.emplace_back(graph_input);
|
||||
std::vector<std::string> output_names({"output-0"});
|
||||
std::vector<OrtValue> outputs;
|
||||
ASSERT_STATUS_OK(inference_session->Run(RunOptions(), input_names, feeds, output_names, &outputs));
|
||||
ASSERT_EQ(outputs.size(), 1U);
|
||||
}
|
||||
|
||||
void CompareValue(float expected, float output, float rtol = 1e-4, float atol = 1e-5) {
|
||||
ASSERT_NEAR(expected, output, atol);
|
||||
ASSERT_NEAR(expected, output, rtol * std::abs(expected));
|
||||
}
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
|
||||
const int64_t total_step_count = 100;
|
||||
const float initial_lr = 1e-3f;
|
||||
const int64_t resume_step = total_step_count / 2;
|
||||
|
||||
void TestLRSchduler(const std::string& test_file_name, float initial_lr, int64_t total_step_count,
|
||||
int64_t warmup_step_count) {
|
||||
/// Load model and optimizer graph, create Module, Optimizer and LRScheduler instances.
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
auto optim_uri = MODEL_FOLDER "adamw.onnx";
|
||||
|
||||
onnxruntime::training::api::CheckpointState state;
|
||||
auto checkpoint_to_load_path = MODEL_FOLDER "checkpoint.ckpt";
|
||||
ASSERT_STATUS_OK(LoadCheckpoint(checkpoint_to_load_path, state));
|
||||
|
||||
onnxruntime::SessionOptions session_option;
|
||||
std::unique_ptr<Environment> env;
|
||||
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
|
||||
const std::vector<std::shared_ptr<IExecutionProvider>> providers{onnxruntime::test::DefaultCudaExecutionProvider()};
|
||||
auto model = std::make_unique<onnxruntime::training::api::Module>(
|
||||
ToUTF8String(model_uri), state.module_checkpoint_state.named_parameters,
|
||||
session_option, *env, providers);
|
||||
auto optim = std::make_shared<onnxruntime::training::api::Optimizer>(
|
||||
ToUTF8String(optim_uri), model->NamedParameters(), session_option,
|
||||
*env, providers);
|
||||
|
||||
OrtValue input, target;
|
||||
GenerateRandomInput(std::array<int64_t, 2>{2, 784}, input);
|
||||
onnxruntime::training::api::utils::CreateInputOrtValue<int32_t>(
|
||||
std::array<int64_t, 1>{2}, std::vector<int32_t>(2, 1), &target);
|
||||
|
||||
/// Load test data for learning rate schedulers.
|
||||
auto data_uri = ORT_TSTR("testdata/test_data_generation/lr_scheduler/" + test_file_name);
|
||||
std::ifstream in{data_uri};
|
||||
// Element of vector represent a pair of <step_count, list of learning rates>>
|
||||
typedef std::vector<std::pair<int64_t, std::vector<float>>> TestDataDictType;
|
||||
TestDataDictType test_data;
|
||||
const json j = json::parse(in);
|
||||
j.get_to<TestDataDictType>(test_data);
|
||||
|
||||
int64_t resume_step = (*test_data.begin()).first;
|
||||
ASSERT_EQ(total_step_count, static_cast<int64_t>(test_data.size()) + resume_step);
|
||||
|
||||
if (resume_step != 0) {
|
||||
/// Reset optimizer states to match the initial state we want to test.
|
||||
onnxruntime::training::api::OptimizerCheckpointState optimizer_checkpoint_states;
|
||||
auto group_opt_state =
|
||||
optimizer_checkpoint_states.group_named_optimizer_states["group0"] =
|
||||
std::make_shared<onnxruntime::training::api::GroupOptimizerState>();
|
||||
group_opt_state->step = resume_step;
|
||||
group_opt_state->initial_lr = initial_lr;
|
||||
ASSERT_STATUS_OK(optim->LoadStateDict(optimizer_checkpoint_states));
|
||||
}
|
||||
|
||||
// KNOWN ISSUE: LinearLRScheduler by default use optim's states to calculate the first step's learning rate.
|
||||
// If we restored it after creation, it will only affect the learning rate from the second step.
|
||||
auto scheduler = std::make_unique<onnxruntime::training::api::LinearLRScheduler>(
|
||||
optim, warmup_step_count, total_step_count);
|
||||
|
||||
for (auto it = test_data.begin(); it != test_data.end(); ++it) {
|
||||
onnxruntime::training::api::OptimizerCheckpointState optimizer_states;
|
||||
ASSERT_STATUS_OK(optim->GetStateDict(optimizer_states));
|
||||
auto group_optimizer_state = optimizer_states.group_named_optimizer_states["group0"];
|
||||
CompareValue(it->second[0], group_optimizer_state->learning_rate);
|
||||
ASSERT_EQ(it->first, group_optimizer_state->step);
|
||||
|
||||
std::vector<OrtValue> inputs{input, target};
|
||||
std::vector<OrtValue> fetches;
|
||||
ASSERT_STATUS_OK(model->TrainStep(inputs, fetches));
|
||||
ASSERT_STATUS_OK(optim->Step());
|
||||
ASSERT_STATUS_OK(scheduler->Step());
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(TrainingApiTest, ModuleParametersSize) {
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
|
||||
|
|
@ -164,8 +312,18 @@ TEST(TrainingApiTest, ModuleTrainStep) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(TrainingApiTest, ModuleExportModelForInferencingCPU) {
|
||||
std::vector<std::shared_ptr<IExecutionProvider>> providers{onnxruntime::test::DefaultCpuExecutionProvider()};
|
||||
TestModuleExport(providers);
|
||||
}
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
|
||||
TEST(TrainingApiTest, ModuleExportModelForInferencingCUDA) {
|
||||
std::vector<std::shared_ptr<IExecutionProvider>> providers{onnxruntime::test::DefaultCudaExecutionProvider()};
|
||||
TestModuleExport(providers);
|
||||
}
|
||||
|
||||
TEST(TrainingApiTest, OptimStep) {
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
auto optim_uri = MODEL_FOLDER "adamw.onnx";
|
||||
|
|
@ -241,83 +399,6 @@ TEST(TrainingApiTest, OptimStep) {
|
|||
}
|
||||
}
|
||||
|
||||
void CompareValue(float expected, float output, float rtol = 1e-4, float atol = 1e-5) {
|
||||
ASSERT_NEAR(expected, output, atol);
|
||||
ASSERT_NEAR(expected, output, rtol * std::abs(expected));
|
||||
}
|
||||
|
||||
void TestLRSchduler(const std::string& test_file_name, float initial_lr, int64_t total_step_count,
|
||||
int64_t warmup_step_count) {
|
||||
/// Load model and optimizer graph, create Module, Optimizer and LRScheduler instances.
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
auto optim_uri = MODEL_FOLDER "adamw.onnx";
|
||||
|
||||
onnxruntime::training::api::CheckpointState state;
|
||||
auto checkpoint_to_load_path = MODEL_FOLDER "checkpoint.ckpt";
|
||||
ASSERT_STATUS_OK(LoadCheckpoint(checkpoint_to_load_path, state));
|
||||
|
||||
onnxruntime::SessionOptions session_option;
|
||||
std::unique_ptr<Environment> env;
|
||||
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
|
||||
const std::vector<std::shared_ptr<IExecutionProvider>> providers{onnxruntime::test::DefaultCudaExecutionProvider()};
|
||||
auto model = std::make_unique<onnxruntime::training::api::Module>(
|
||||
ToUTF8String(model_uri), state.module_checkpoint_state.named_parameters,
|
||||
session_option, *env, providers);
|
||||
auto optim = std::make_shared<onnxruntime::training::api::Optimizer>(
|
||||
ToUTF8String(optim_uri), model->NamedParameters(), session_option,
|
||||
*env, providers);
|
||||
|
||||
OrtValue input, target;
|
||||
GenerateRandomInput(std::array<int64_t, 2>{2, 784}, input);
|
||||
onnxruntime::training::api::utils::CreateInputOrtValue<int32_t>(
|
||||
std::array<int64_t, 1>{2}, std::vector<int32_t>(2, 1), &target);
|
||||
|
||||
/// Load test data for learning rate schedulers.
|
||||
auto data_uri = ORT_TSTR("testdata/test_data_generation/lr_scheduler/" + test_file_name);
|
||||
std::ifstream in{data_uri};
|
||||
// Element of vector represent a pair of <step_count, list of learning rates>>
|
||||
typedef std::vector<std::pair<int64_t, std::vector<float>>> TestDataDictType;
|
||||
TestDataDictType test_data;
|
||||
const json j = json::parse(in);
|
||||
j.get_to<TestDataDictType>(test_data);
|
||||
|
||||
int64_t resume_step = (*test_data.begin()).first;
|
||||
ASSERT_EQ(total_step_count, static_cast<int64_t>(test_data.size()) + resume_step);
|
||||
|
||||
if (resume_step != 0) {
|
||||
/// Reset optimizer states to match the initial state we want to test.
|
||||
onnxruntime::training::api::OptimizerCheckpointState optimizer_checkpoint_states;
|
||||
auto group_opt_state =
|
||||
optimizer_checkpoint_states.group_named_optimizer_states["group0"] =
|
||||
std::make_shared<onnxruntime::training::api::GroupOptimizerState>();
|
||||
group_opt_state->step = resume_step;
|
||||
group_opt_state->initial_lr = initial_lr;
|
||||
ASSERT_STATUS_OK(optim->LoadStateDict(optimizer_checkpoint_states));
|
||||
}
|
||||
|
||||
// KNOWN ISSUE: LinearLRScheduler by default use optim's states to calculate the first step's learning rate.
|
||||
// If we restored it after creation, it will only affect the learning rate from the second step.
|
||||
auto scheduler = std::make_unique<onnxruntime::training::api::LinearLRScheduler>(
|
||||
optim, warmup_step_count, total_step_count);
|
||||
|
||||
for (auto it = test_data.begin(); it != test_data.end(); ++it) {
|
||||
onnxruntime::training::api::OptimizerCheckpointState optimizer_states;
|
||||
ASSERT_STATUS_OK(optim->GetStateDict(optimizer_states));
|
||||
auto group_optimizer_state = optimizer_states.group_named_optimizer_states["group0"];
|
||||
CompareValue(it->second[0], group_optimizer_state->learning_rate);
|
||||
ASSERT_EQ(it->first, group_optimizer_state->step);
|
||||
|
||||
std::vector<OrtValue> inputs{input, target};
|
||||
std::vector<OrtValue> fetches;
|
||||
ASSERT_STATUS_OK(model->TrainStep(inputs, fetches));
|
||||
ASSERT_STATUS_OK(optim->Step());
|
||||
ASSERT_STATUS_OK(scheduler->Step());
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t total_step_count = 100;
|
||||
const float initial_lr = 1e-3f;
|
||||
const int64_t resume_step = total_step_count / 2;
|
||||
TEST(TrainingApiTest, LinearLRScheduler_NoWarmUp_Test) {
|
||||
// No warm up.
|
||||
TestLRSchduler("warmup_linear_scheduler_warmupstep-0.json", initial_lr, total_step_count, 0);
|
||||
|
|
@ -360,7 +441,6 @@ TEST(TrainingApiTest, LinearLRScheduler_WarmUp200Step_ResumeFromCheckpoint_Test)
|
|||
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
} // namespace test
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
#include "core/common/logging/sinks/clog_sink.h"
|
||||
#include "core/common/path.h"
|
||||
#include "core/framework/framework_common.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/platform/env.h"
|
||||
|
|
@ -16,6 +15,7 @@
|
|||
#include "orttraining/core/framework/checkpoint_common.h"
|
||||
#include "orttraining/core/framework/protobuf_message_sequence.h"
|
||||
#include "orttraining/training_api/include/checkpoint.h"
|
||||
#include "orttraining/training_api/include/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
|
@ -53,10 +53,6 @@ Status CreateTensorProtosFromOrtValues(
|
|||
[](const NameMLValMap::value_type& v) { return v.first; });
|
||||
std::sort(ordered_tensor_names.begin(), ordered_tensor_names.end());
|
||||
|
||||
// Copy the tensor data and create TensorProto storing the data.
|
||||
InlinedVector<char> tensor_data_buffer{};
|
||||
static const OrtMemoryInfo cpu_alloc_info{onnxruntime::CPU, OrtDeviceAllocator};
|
||||
|
||||
saved_tensor_protos.reserve(ordered_tensor_names.size());
|
||||
|
||||
uint64_t total_bytes = 0;
|
||||
|
|
@ -65,7 +61,6 @@ Status CreateTensorProtosFromOrtValues(
|
|||
const OrtValue& ort_value = name_to_ort_value.at(tensor_name);
|
||||
ORT_RETURN_IF_NOT(ort_value.IsTensor(), "ort_value.IsTensor() was false");
|
||||
const Tensor& src_tensor = ort_value.Get<Tensor>();
|
||||
tensor_data_buffer.resize(src_tensor.SizeInBytes());
|
||||
|
||||
// Currently large model size not considered, so exception thrown here
|
||||
// when protobuf upper limit hit.
|
||||
|
|
@ -74,23 +69,8 @@ Status CreateTensorProtosFromOrtValues(
|
|||
ORT_THROW("checkpoint file size hit upper limit.");
|
||||
}
|
||||
|
||||
auto& tensor_location = src_tensor.Location();
|
||||
if (tensor_location.device.Type() == OrtDevice::CPU ||
|
||||
tensor_location.mem_type == OrtMemTypeCPUInput ||
|
||||
tensor_location.mem_type == OrtMemTypeCPUOutput ||
|
||||
tensor_location.device.Type() == OrtDevice::GPU) {
|
||||
gsl::span<char> dst_span = gsl::make_span(tensor_data_buffer);
|
||||
ORT_RETURN_IF_NOT(src_tensor.SizeInBytes() == static_cast<size_t>(dst_span.size_bytes()), "src size != dst size");
|
||||
Tensor dst_tensor{src_tensor.DataType(), src_tensor.Shape(), dst_span.data(), cpu_alloc_info};
|
||||
ORT_RETURN_IF_ERROR(data_transfer_manager.CopyTensor(src_tensor, dst_tensor));
|
||||
|
||||
// Convert Tensor to TensorProto.
|
||||
ONNX_NAMESPACE::TensorProto tensor_proto;
|
||||
tensor_proto = utils::TensorToTensorProto(dst_tensor, tensor_name);
|
||||
saved_tensor_protos.emplace_back(tensor_proto);
|
||||
} else {
|
||||
ORT_THROW("Unsupported device type for saving tensors");
|
||||
}
|
||||
saved_tensor_protos.emplace_back(utils::CopyTensorToTensorProto(
|
||||
src_tensor, tensor_name, data_transfer_manager));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -106,6 +106,11 @@ struct Module {
|
|||
// Copy parameter values from contiguous buffer held by parameters_buffer onto parameters
|
||||
Status CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only = true);
|
||||
|
||||
// Load the eval model from eval_model_path_or_bytes and transform it for the purpose of
|
||||
// inferencing, and serialize to given path
|
||||
Status ExportModelForInferencing(const std::string& inference_model_path,
|
||||
gsl::span<const std::string> graph_output_names) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<onnxruntime::InferenceSession> train_sess_{nullptr};
|
||||
std::unique_ptr<onnxruntime::InferenceSession> eval_sess_{nullptr};
|
||||
|
|
@ -118,6 +123,7 @@ struct Module {
|
|||
std::vector<OrtValue> gradients_;
|
||||
bool accumulate_gradient_ = false;
|
||||
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& named_parameters_;
|
||||
std::string eval_model_path_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
|
|
|
|||
|
|
@ -296,6 +296,25 @@ struct OrtTrainingApi {
|
|||
*
|
||||
*/
|
||||
ORT_CLASS_RELEASE(CheckpointState);
|
||||
|
||||
/** \brief Export a model that can be used for inferencing.
|
||||
*
|
||||
* If the training session was provided with an eval model, the training session can generate
|
||||
* an inference model if it knows the inference graph outputs. The input inference graph outputs
|
||||
* are used to prune the eval model so that the output model's outputs align with the provided outputs.
|
||||
* The exported model is saved at the path provided and can be used for inferencing with InferenceSession.
|
||||
* Note that the function re-loads the eval model from the path provided to CreateTrainingSession and expects
|
||||
* that this path still be valid.
|
||||
*
|
||||
* \param[in] sess The training session.
|
||||
* \param[in] inference_model_path Path where the inference model should be serialized to.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess,
|
||||
_In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len,
|
||||
_In_reads_(graph_outputs_len) const char* const* graph_output_names);
|
||||
};
|
||||
|
||||
typedef struct OrtTrainingApi OrtTrainingApi;
|
||||
|
|
|
|||
|
|
@ -143,6 +143,14 @@ class TrainingSession : public detail::Base<OrtTrainingSession> {
|
|||
*
|
||||
*/
|
||||
void OptimizerStep();
|
||||
|
||||
/** \brief Exports a model that can be used for inferencing with the inference session.
|
||||
*
|
||||
* Wraps OrtTrainingApi::ExportModelForInferencing
|
||||
*
|
||||
*/
|
||||
void ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
|
||||
const std::vector<std::string>& graph_output_names);
|
||||
};
|
||||
|
||||
} // namespace Ort
|
||||
|
|
|
|||
|
|
@ -94,4 +94,14 @@ inline void CheckpointState::SaveCheckpoint(const TrainingSession& session,
|
|||
ThrowOnError(GetTrainingApi().SaveCheckpoint(path_to_checkpoint.c_str(), session, include_optimizer_states));
|
||||
}
|
||||
|
||||
inline void TrainingSession::ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
|
||||
const std::vector<std::string>& graph_output_names) {
|
||||
std::vector<const char*> output_names(graph_output_names.size(), nullptr);
|
||||
for (auto& output_name : graph_output_names) {
|
||||
output_names.push_back(output_name.c_str());
|
||||
}
|
||||
ThrowOnError(GetTrainingApi().ExportModelForInferencing(
|
||||
p_, inference_model_path.c_str(), graph_output_names.size(), output_names.data()));
|
||||
}
|
||||
|
||||
} // namespace Ort
|
||||
|
|
|
|||
|
|
@ -60,4 +60,8 @@ ORT_API(void, ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckpointState* sessio
|
|||
|
||||
ORT_API(void, ReleaseTrainingSession, _Frees_ptr_opt_ OrtTrainingSession* session);
|
||||
|
||||
ORT_API_STATUS_IMPL(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess,
|
||||
_In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len,
|
||||
_In_reads_(graph_outputs_len) const char* const* graph_output_names);
|
||||
|
||||
} // namespace OrtTrainingApis
|
||||
|
|
|
|||
|
|
@ -63,11 +63,14 @@ class TrainingSession {
|
|||
|
||||
Status CreateCheckpointState(CheckpointState& chkpt_state, bool save_optimizer_state) const;
|
||||
|
||||
size_t GetParametersSize(const bool trainable_only=true) const;
|
||||
size_t GetParametersSize(const bool trainable_only = true) const;
|
||||
|
||||
Status CopyParametersToBuffer(OrtValue& parameters_buffer, const bool trainable_only=true);
|
||||
|
||||
Status CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only=true);
|
||||
Status CopyParametersToBuffer(OrtValue& parameters_buffer, const bool trainable_only = true);
|
||||
|
||||
Status CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only = true);
|
||||
|
||||
Status ExportModelForInferencing(const std::string& inference_model_path,
|
||||
gsl::span<const std::string> graph_output_names) const;
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TrainingSession);
|
||||
|
|
|
|||
|
|
@ -85,6 +85,9 @@ T GetValue(OrtValue& ort_value) {
|
|||
return val;
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::TensorProto CopyTensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name,
|
||||
const DataTransferManager& data_transfer_manager);
|
||||
|
||||
} // namespace utils
|
||||
} // namespace api
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/common/string_utils.h"
|
||||
#include "core/framework/execution_provider.h"
|
||||
#include "core/session/inference_session.h"
|
||||
#include "core/session/environment.h"
|
||||
|
|
@ -20,6 +21,94 @@ namespace {
|
|||
// TODO: consolidate with frontend tooling
|
||||
const std::string ACCUMULATE_GRAD_CONTROL_INPUT_NAME{"lazy_reset_grad"};
|
||||
|
||||
std::unordered_set<const Node*> GetReverseReachableNodes(Graph& inference_graph,
|
||||
InlinedVector<const NodeArg*>& output_node_args) {
|
||||
// Perform a graph traversal from the graph outputs to collect all reachable nodes from the outputs
|
||||
InlinedVector<NodeIndex> nodes;
|
||||
nodes.reserve((output_node_args.size()));
|
||||
std::unordered_set<const Node*> visited_nodes;
|
||||
for (auto node_arg : output_node_args) {
|
||||
auto* node = inference_graph.GetProducerNode(node_arg->Name());
|
||||
if (!node || std::find(nodes.begin(), nodes.end(), node->Index()) != nodes.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
nodes.push_back(node->Index());
|
||||
}
|
||||
|
||||
inference_graph.ReverseDFSFrom(nodes, [&visited_nodes](const Node* node) { visited_nodes.insert(node); }, {});
|
||||
|
||||
return visited_nodes;
|
||||
}
|
||||
|
||||
Status RemoveUnusedNodes(Graph& inference_graph, InlinedVector<const NodeArg*>& output_node_args) {
|
||||
auto reachable_nodes = GetReverseReachableNodes(inference_graph, output_node_args);
|
||||
|
||||
// Get all graph nodes and remove those that are not in the reachable nodes.
|
||||
GraphViewer graph_viewer(inference_graph);
|
||||
for (auto& node : graph_viewer.Nodes()) {
|
||||
if (!reachable_nodes.count(&node)) {
|
||||
inference_graph.RemoveNode(node.Index());
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TransformModelOutputsForInference(Graph& inference_graph,
|
||||
gsl::span<const std::string> inference_graph_outputs) {
|
||||
// Model is updated to remove any outputs that are not defined in inference_graph_outputs. Nodes
|
||||
// producing these unused model outputs are also subsequently removed.
|
||||
|
||||
ORT_RETURN_IF(inference_graph_outputs.empty(),
|
||||
"Expected a non empty vector of graph output names. Got an empty vector.");
|
||||
|
||||
InlinedVector<const NodeArg*> inference_graph_output_node_args;
|
||||
inference_graph_output_node_args.reserve(inference_graph_outputs.size());
|
||||
for (const auto& output_name : inference_graph_outputs) {
|
||||
const NodeArg* output_node_arg = inference_graph.GetNodeArg(std::string(output_name));
|
||||
ORT_RETURN_IF_NOT(output_node_arg, "Expected graph output for inference graph " + std::string(output_name) +
|
||||
" could not be found. Please regenerate the eval graph.");
|
||||
inference_graph_output_node_args.push_back(output_node_arg);
|
||||
}
|
||||
|
||||
// Set the inference graph outputs, and remove any unused nodes.
|
||||
inference_graph.SetOutputs(inference_graph_output_node_args);
|
||||
ORT_RETURN_IF_ERROR(RemoveUnusedNodes(inference_graph, inference_graph_output_node_args));
|
||||
|
||||
ORT_RETURN_IF_ERROR(inference_graph.Resolve());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TransformModelInputsForInference(Graph& inference_graph,
|
||||
const std::unordered_map<
|
||||
std::string, std::shared_ptr<Parameter>>& named_parameters,
|
||||
const DataTransferManager& data_transfer_manager) {
|
||||
std::vector<const NodeArg*> user_graph_inputs;
|
||||
for (auto& graph_input_node_arg : inference_graph.GetInputs()) {
|
||||
auto named_parameter_it = named_parameters.find(graph_input_node_arg->Name());
|
||||
if (named_parameter_it == named_parameters.end()) {
|
||||
if (inference_graph.GetConsumerNodes(graph_input_node_arg->Name()).empty()) {
|
||||
continue;
|
||||
}
|
||||
user_graph_inputs.emplace_back(graph_input_node_arg);
|
||||
} else {
|
||||
ORT_ENFORCE(!inference_graph.IsInitializedTensor(named_parameter_it->first),
|
||||
"The eval graph is invalid. Expected model parameter ",
|
||||
named_parameter_it->first, " to be a graph input, not a graph initializer.");
|
||||
inference_graph.AddInitializedTensor(utils::CopyTensorToTensorProto(
|
||||
named_parameter_it->second->Data().Get<onnxruntime::Tensor>(),
|
||||
named_parameter_it->first, data_transfer_manager));
|
||||
}
|
||||
}
|
||||
|
||||
inference_graph.SetInputs(user_graph_inputs);
|
||||
ORT_RETURN_IF_ERROR(inference_graph.Resolve());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status Parameter::SetGrad(const std::string& gradient_name, const OrtValue& param_grad) {
|
||||
|
|
@ -162,6 +251,9 @@ Module::Module(const std::string& train_model_path_or_bytes,
|
|||
if (eval_model_path_or_bytes.has_value()) {
|
||||
eval_sess_ = std::make_unique<onnxruntime::InferenceSession>(session_options, env);
|
||||
ORT_THROW_IF_ERROR(eval_sess_->Load(eval_model_path_or_bytes.value()));
|
||||
for (const auto& provider : providers) {
|
||||
ORT_THROW_IF_ERROR(eval_sess_->RegisterExecutionProvider(provider));
|
||||
}
|
||||
ORT_THROW_IF_ERROR(eval_sess_->Initialize());
|
||||
utils::GetGraphInputOutputNames(eval_sess_, eval_input_names_, eval_output_names_);
|
||||
|
||||
|
|
@ -185,6 +277,10 @@ Module::Module(const std::string& train_model_path_or_bytes,
|
|||
}
|
||||
eval_input_names_ = eval_user_input_names;
|
||||
eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end());
|
||||
|
||||
// Keep a copy of the eval model path to be able to later export the model for inferencing.
|
||||
// The inference model will be reconstructed from the eval model.
|
||||
eval_model_path_ = eval_model_path_or_bytes.value();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -232,7 +328,7 @@ Status Module::CopyParametersToBuffer(OrtValue& parameters_buffer, const bool tr
|
|||
ORT_ENFORCE(nullptr != init_tensor);
|
||||
auto expected_buffer_size = static_cast<int64_t>(GetParametersSize(trainable_only));
|
||||
ORT_ENFORCE(init_tensor->Shape().Size() == expected_buffer_size,
|
||||
"Parameters buffer size incorrect. Expected:",expected_buffer_size,
|
||||
"Parameters buffer size incorrect. Expected:", expected_buffer_size,
|
||||
", Actual:", init_tensor->Shape().Size());
|
||||
|
||||
const DataTransferManager& sess_data_transfer_manager = train_sess_->GetDataTransferManager();
|
||||
|
|
@ -275,7 +371,7 @@ Status Module::CopyBufferToParameters(OrtValue& parameters_buffer, const bool tr
|
|||
ORT_ENFORCE(nullptr != init_tensor);
|
||||
auto expected_buffer_size = static_cast<int64_t>(GetParametersSize(trainable_only));
|
||||
ORT_ENFORCE(init_tensor->Shape().Size() == expected_buffer_size,
|
||||
"Parameters buffer size incorrect. Expected:",expected_buffer_size,
|
||||
"Parameters buffer size incorrect. Expected:", expected_buffer_size,
|
||||
", Actual:", init_tensor->Shape().Size());
|
||||
|
||||
const DataTransferManager& sess_data_transfer_manager = train_sess_->GetDataTransferManager();
|
||||
|
|
@ -356,6 +452,33 @@ Status Module::GetStateDict(ModuleCheckpointState& module_checkpoint_state) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Module::ExportModelForInferencing(const std::string& inference_model_path,
|
||||
gsl::span<const std::string> graph_output_names) const {
|
||||
ORT_RETURN_IF(!eval_sess_ || eval_model_path_.empty(),
|
||||
"Eval model was not provided. Cannot export a model for inferencing.");
|
||||
|
||||
ONNX_NAMESPACE::ModelProto eval_model;
|
||||
ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_), eval_model));
|
||||
|
||||
// Clone the eval mode into an inference onnxruntime::Model.
|
||||
std::shared_ptr<Model> inference_model;
|
||||
ORT_RETURN_IF_ERROR(Model::Load(eval_model, inference_model, nullptr, logging::LoggingManager::DefaultLogger()));
|
||||
|
||||
// The cloned model's outputs are transformed such that the model has outputs as defined by graph_output_names
|
||||
// Any nodes not contributing to the inference outputs will be pruned.
|
||||
ORT_THROW_IF_ERROR(TransformModelOutputsForInference(inference_model->MainGraph(), graph_output_names));
|
||||
|
||||
// The cloned model's inputs are transformed such that the model has only user defined inputs. All parameters
|
||||
// are moved to be constant initializers for the model.
|
||||
ORT_RETURN_IF_ERROR(TransformModelInputsForInference(inference_model->MainGraph(), named_parameters_,
|
||||
eval_sess_->GetDataTransferManager()));
|
||||
|
||||
// Save the model at desired location.
|
||||
ORT_THROW_IF_ERROR(Model::Save(*inference_model, inference_model_path));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -339,6 +339,31 @@ ORT_API(void, OrtTrainingApis::ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckp
|
|||
delete reinterpret_cast<onnxruntime::training::api::CheckpointState*>(checkpoint_state);
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::ExportModelForInferencing, _Inout_ OrtTrainingSession* sess,
|
||||
_In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len,
|
||||
_In_reads_(graph_outputs_len) const char* const* graph_output_names) {
|
||||
API_IMPL_BEGIN
|
||||
|
||||
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
|
||||
|
||||
onnxruntime::InlinedVector<std::string> output_names(graph_outputs_len);
|
||||
|
||||
for (size_t i = 0; i != graph_outputs_len; ++i) {
|
||||
if (graph_output_names[i] == nullptr || graph_output_names[i][0] == '\0') {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
|
||||
"Name of graph output cannot be empty. Please provide valid graph names");
|
||||
}
|
||||
|
||||
output_names[i] = graph_output_names[i];
|
||||
}
|
||||
|
||||
ORT_API_RETURN_IF_STATUS_NOT_OK(
|
||||
session->ExportModelForInferencing(onnxruntime::ToUTF8String(inference_model_path), output_names));
|
||||
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
static constexpr OrtTrainingApi ort_training_api = {
|
||||
// NOTE: The C# bindings depend on the API order within this struct. Since Training APIs are not officially
|
||||
// released, it is OK to change the order here, however a corresponding matching change should also be done in the
|
||||
|
|
@ -363,6 +388,7 @@ static constexpr OrtTrainingApi ort_training_api = {
|
|||
&OrtTrainingApis::CopyBufferToParameters,
|
||||
&OrtTrainingApis::ReleaseTrainingSession,
|
||||
&OrtTrainingApis::ReleaseCheckpointState,
|
||||
&OrtTrainingApis::ExportModelForInferencing,
|
||||
};
|
||||
|
||||
ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) {
|
||||
|
|
|
|||
|
|
@ -110,6 +110,11 @@ Status TrainingSession::CopyBufferToParameters(OrtValue& parameters_buffer, cons
|
|||
return module_->CopyBufferToParameters(parameters_buffer, trainable_only);
|
||||
}
|
||||
|
||||
Status TrainingSession::ExportModelForInferencing(const std::string& inference_model_path,
|
||||
gsl::span<const std::string> graph_output_names) const {
|
||||
return module_->ExportModelForInferencing(inference_model_path, graph_output_names);
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include "core/framework/ort_value.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
|
||||
#include "orttraining/training_api/include/utils.h"
|
||||
|
||||
|
|
@ -87,6 +88,31 @@ Status OrtValueLike(const SessionState& sess_state, const OrtValue& input_val, O
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::TensorProto CopyTensorToTensorProto(const Tensor& src_tensor, const std::string& tensor_proto_name,
|
||||
const DataTransferManager& data_transfer_manager) {
|
||||
auto& tensor_location = src_tensor.Location();
|
||||
if (tensor_location.device.Type() != OrtDevice::CPU &&
|
||||
tensor_location.mem_type != OrtMemTypeCPUInput &&
|
||||
tensor_location.mem_type != OrtMemTypeCPUOutput &&
|
||||
tensor_location.device.Type() != OrtDevice::GPU) {
|
||||
ORT_THROW("Unsupported device type for saving tensors");
|
||||
}
|
||||
|
||||
// Copy the tensor data and create TensorProto storing the data.
|
||||
InlinedVector<char> tensor_data_buffer{};
|
||||
tensor_data_buffer.resize(src_tensor.SizeInBytes());
|
||||
static const OrtMemoryInfo cpu_alloc_info{onnxruntime::CPU, OrtDeviceAllocator};
|
||||
|
||||
gsl::span<char> dst_span = gsl::make_span(tensor_data_buffer);
|
||||
ORT_ENFORCE(src_tensor.SizeInBytes() == static_cast<size_t>(dst_span.size_bytes()), "src size != dst size");
|
||||
Tensor dst_tensor{src_tensor.DataType(), src_tensor.Shape(), dst_span.data(), cpu_alloc_info};
|
||||
ORT_THROW_IF_ERROR(data_transfer_manager.CopyTensor(src_tensor, dst_tensor));
|
||||
|
||||
// Convert Tensor to TensorProto.
|
||||
ONNX_NAMESPACE::TensorProto tensor_proto;
|
||||
return onnxruntime::utils::TensorToTensorProto(dst_tensor, tensor_proto_name);
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace api
|
||||
} // namespace training
|
||||
|
|
|
|||
Loading…
Reference in a new issue