mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
Add env to the TrainingSession constructor (#15635)
This commit is contained in:
parent
fab3e33105
commit
fd6ecc3909
4 changed files with 15 additions and 10 deletions
|
|
@ -19,8 +19,9 @@ namespace onnxruntime::training::test {
|
|||
TEST(TrainingCApiTest, SaveCheckpoint) {
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
|
||||
Ort::Env env;
|
||||
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
|
||||
Ort::TrainingSession training_session = Ort::TrainingSession(Ort::SessionOptions(), checkpoint_state, model_uri);
|
||||
Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
|
||||
|
||||
auto test_dir = ORT_TSTR("save_checkpoint_dir");
|
||||
if (Env::Default().FolderExists(test_dir)) {
|
||||
|
|
@ -33,7 +34,8 @@ TEST(TrainingCApiTest, SaveCheckpoint) {
|
|||
Ort::CheckpointState::SaveCheckpoint(checkpoint_state, checkpoint_path);
|
||||
|
||||
Ort::CheckpointState new_checkpoint_state = Ort::CheckpointState::LoadCheckpoint(checkpoint_path);
|
||||
Ort::TrainingSession new_training_session = Ort::TrainingSession(Ort::SessionOptions(), new_checkpoint_state, model_uri);
|
||||
Ort::TrainingSession new_training_session = Ort::TrainingSession(env, Ort::SessionOptions(),
|
||||
new_checkpoint_state, model_uri);
|
||||
}
|
||||
|
||||
TEST(TrainingCApiTest, AddIntProperty) {
|
||||
|
|
@ -75,8 +77,9 @@ TEST(TrainingCApiTest, AddStringProperty) {
|
|||
TEST(TrainingCApiTest, InputNames) {
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
|
||||
Ort::Env env;
|
||||
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
|
||||
Ort::TrainingSession training_session = Ort::TrainingSession(Ort::SessionOptions(), checkpoint_state, model_uri);
|
||||
Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
|
||||
|
||||
const auto input_names = training_session.InputNames(true);
|
||||
ASSERT_EQ(input_names.size(), 2U);
|
||||
|
|
@ -87,8 +90,9 @@ TEST(TrainingCApiTest, InputNames) {
|
|||
TEST(TrainingCApiTest, OutputNames) {
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
|
||||
Ort::Env env;
|
||||
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
|
||||
Ort::TrainingSession training_session = Ort::TrainingSession(Ort::SessionOptions(), checkpoint_state, model_uri);
|
||||
Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
|
||||
|
||||
const auto output_names = training_session.OutputNames(true);
|
||||
ASSERT_EQ(output_names.size(), 1U);
|
||||
|
|
@ -98,8 +102,9 @@ TEST(TrainingCApiTest, OutputNames) {
|
|||
TEST(TrainingCApiTest, ToBuffer) {
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
|
||||
Ort::Env env;
|
||||
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
|
||||
Ort::TrainingSession training_session = Ort::TrainingSession(Ort::SessionOptions(), checkpoint_state, model_uri);
|
||||
Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
|
||||
|
||||
Ort::Value buffer = training_session.ToBuffer(true);
|
||||
|
||||
|
|
@ -121,8 +126,9 @@ TEST(TrainingCApiTest, ToBuffer) {
|
|||
TEST(TrainingCApiTest, FromBuffer) {
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
|
||||
Ort::Env env;
|
||||
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
|
||||
Ort::TrainingSession training_session = Ort::TrainingSession(Ort::SessionOptions(), checkpoint_state, model_uri);
|
||||
Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
|
||||
|
||||
OrtValue* buffer_impl = std::make_unique<OrtValue>().release();
|
||||
GenerateRandomInput(std::array<int64_t, 1>{397510}, *buffer_impl);
|
||||
|
|
|
|||
|
|
@ -233,7 +233,7 @@ int RunTraining(const TestRunnerParameters& params) {
|
|||
#endif
|
||||
|
||||
bool do_eval = params.model_evaluation_graph_path.has_value();
|
||||
Ort::TrainingSession session(soptions, checkpoint_state, params.model_training_graph_path,
|
||||
Ort::TrainingSession session(env, soptions, checkpoint_state, params.model_training_graph_path,
|
||||
params.model_evaluation_graph_path,
|
||||
params.optimizer_training_graph_path.size() > 0
|
||||
? std::optional<PathString>(params.optimizer_training_graph_path)
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ class TrainingSession : public detail::Base<OrtTrainingSession> {
|
|||
size_t training_model_output_count_, eval_model_output_count_;
|
||||
|
||||
public:
|
||||
TrainingSession(const SessionOptions& session_options, CheckpointState& checkpoint_state,
|
||||
TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state,
|
||||
const std::basic_string<ORTCHAR_T>& train_model_path,
|
||||
const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path = std::nullopt,
|
||||
const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path = std::nullopt);
|
||||
|
|
|
|||
|
|
@ -7,12 +7,11 @@
|
|||
|
||||
namespace Ort {
|
||||
|
||||
inline TrainingSession::TrainingSession(const SessionOptions& session_options,
|
||||
inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options,
|
||||
CheckpointState& checkpoint_state,
|
||||
const std::basic_string<ORTCHAR_T>& train_model_path,
|
||||
const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path,
|
||||
const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path) {
|
||||
Env env = Env();
|
||||
ThrowOnError(GetTrainingApi().CreateTrainingSession(
|
||||
env, session_options, checkpoint_state,
|
||||
train_model_path.c_str(),
|
||||
|
|
|
|||
Loading…
Reference in a new issue