diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc index d87d2ad8ae..9fb1fefa4a 100644 --- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc @@ -19,8 +19,9 @@ namespace onnxruntime::training::test { TEST(TrainingCApiTest, SaveCheckpoint) { auto model_uri = MODEL_FOLDER "training_model.onnx"; + Ort::Env env; Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); - Ort::TrainingSession training_session = Ort::TrainingSession(Ort::SessionOptions(), checkpoint_state, model_uri); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri); auto test_dir = ORT_TSTR("save_checkpoint_dir"); if (Env::Default().FolderExists(test_dir)) { @@ -33,7 +34,8 @@ TEST(TrainingCApiTest, SaveCheckpoint) { Ort::CheckpointState::SaveCheckpoint(checkpoint_state, checkpoint_path); Ort::CheckpointState new_checkpoint_state = Ort::CheckpointState::LoadCheckpoint(checkpoint_path); - Ort::TrainingSession new_training_session = Ort::TrainingSession(Ort::SessionOptions(), new_checkpoint_state, model_uri); + Ort::TrainingSession new_training_session = Ort::TrainingSession(env, Ort::SessionOptions(), + new_checkpoint_state, model_uri); } TEST(TrainingCApiTest, AddIntProperty) { @@ -75,8 +77,9 @@ TEST(TrainingCApiTest, AddStringProperty) { TEST(TrainingCApiTest, InputNames) { auto model_uri = MODEL_FOLDER "training_model.onnx"; + Ort::Env env; Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); - Ort::TrainingSession training_session = Ort::TrainingSession(Ort::SessionOptions(), checkpoint_state, model_uri); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri); const auto input_names = training_session.InputNames(true); ASSERT_EQ(input_names.size(), 2U); @@ -87,8 +90,9 @@ TEST(TrainingCApiTest, InputNames) { TEST(TrainingCApiTest, OutputNames) { auto model_uri = MODEL_FOLDER "training_model.onnx"; + Ort::Env env; Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); - Ort::TrainingSession training_session = Ort::TrainingSession(Ort::SessionOptions(), checkpoint_state, model_uri); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri); const auto output_names = training_session.OutputNames(true); ASSERT_EQ(output_names.size(), 1U); @@ -98,8 +102,9 @@ TEST(TrainingCApiTest, OutputNames) { TEST(TrainingCApiTest, ToBuffer) { auto model_uri = MODEL_FOLDER "training_model.onnx"; + Ort::Env env; Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); - Ort::TrainingSession training_session = Ort::TrainingSession(Ort::SessionOptions(), checkpoint_state, model_uri); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri); Ort::Value buffer = training_session.ToBuffer(true); @@ -121,8 +126,9 @@ TEST(TrainingCApiTest, ToBuffer) { TEST(TrainingCApiTest, FromBuffer) { auto model_uri = MODEL_FOLDER "training_model.onnx"; + Ort::Env env; Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); - Ort::TrainingSession training_session = Ort::TrainingSession(Ort::SessionOptions(), checkpoint_state, model_uri); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri); OrtValue* buffer_impl = std::make_unique().release(); GenerateRandomInput(std::array{397510}, *buffer_impl); diff --git a/orttraining/orttraining/test/training_api/trainer/trainer.cc b/orttraining/orttraining/test/training_api/trainer/trainer.cc index e330a89524..ff4a824374 100644 --- a/orttraining/orttraining/test/training_api/trainer/trainer.cc +++ b/orttraining/orttraining/test/training_api/trainer/trainer.cc @@ -233,7 +233,7 @@ int RunTraining(const TestRunnerParameters& params) { #endif bool do_eval = params.model_evaluation_graph_path.has_value(); - Ort::TrainingSession session(soptions, checkpoint_state, params.model_training_graph_path, + Ort::TrainingSession session(env, soptions, checkpoint_state, params.model_training_graph_path, params.model_evaluation_graph_path, params.optimizer_training_graph_path.size() > 0 ? std::optional(params.optimizer_training_graph_path) diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h index ffaf579e6a..05eb162410 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h @@ -101,7 +101,7 @@ class TrainingSession : public detail::Base { size_t training_model_output_count_, eval_model_output_count_; public: - TrainingSession(const SessionOptions& session_options, CheckpointState& checkpoint_state, + TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state, const std::basic_string& train_model_path, const std::optional>& eval_model_path = std::nullopt, const std::optional>& optimizer_model_path = std::nullopt); diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h index f947293722..8929e0d185 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h @@ -7,12 +7,11 @@ namespace Ort { -inline TrainingSession::TrainingSession(const SessionOptions& session_options, +inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state, const std::basic_string& train_model_path, const std::optional>& eval_model_path, const std::optional>& optimizer_model_path) { - Env env = Env(); ThrowOnError(GetTrainingApi().CreateTrainingSession( env, session_options, checkpoint_state, train_model_path.c_str(),