Add env to the TrainingSession constructor (#15635)

This commit is contained in:
Baiju Meswani 2023-04-21 21:05:46 -07:00 committed by GitHub
parent fab3e33105
commit fd6ecc3909
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 10 deletions

View file

@ -19,8 +19,9 @@ namespace onnxruntime::training::test {
TEST(TrainingCApiTest, SaveCheckpoint) {
auto model_uri = MODEL_FOLDER "training_model.onnx";
Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
Ort::TrainingSession training_session = Ort::TrainingSession(Ort::SessionOptions(), checkpoint_state, model_uri);
Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
auto test_dir = ORT_TSTR("save_checkpoint_dir");
if (Env::Default().FolderExists(test_dir)) {
@ -33,7 +34,8 @@ TEST(TrainingCApiTest, SaveCheckpoint) {
Ort::CheckpointState::SaveCheckpoint(checkpoint_state, checkpoint_path);
Ort::CheckpointState new_checkpoint_state = Ort::CheckpointState::LoadCheckpoint(checkpoint_path);
Ort::TrainingSession new_training_session = Ort::TrainingSession(Ort::SessionOptions(), new_checkpoint_state, model_uri);
Ort::TrainingSession new_training_session = Ort::TrainingSession(env, Ort::SessionOptions(),
new_checkpoint_state, model_uri);
}
TEST(TrainingCApiTest, AddIntProperty) {
@ -75,8 +77,9 @@ TEST(TrainingCApiTest, AddStringProperty) {
TEST(TrainingCApiTest, InputNames) {
auto model_uri = MODEL_FOLDER "training_model.onnx";
Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
Ort::TrainingSession training_session = Ort::TrainingSession(Ort::SessionOptions(), checkpoint_state, model_uri);
Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
const auto input_names = training_session.InputNames(true);
ASSERT_EQ(input_names.size(), 2U);
@ -87,8 +90,9 @@ TEST(TrainingCApiTest, InputNames) {
TEST(TrainingCApiTest, OutputNames) {
auto model_uri = MODEL_FOLDER "training_model.onnx";
Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
Ort::TrainingSession training_session = Ort::TrainingSession(Ort::SessionOptions(), checkpoint_state, model_uri);
Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
const auto output_names = training_session.OutputNames(true);
ASSERT_EQ(output_names.size(), 1U);
@ -98,8 +102,9 @@ TEST(TrainingCApiTest, OutputNames) {
TEST(TrainingCApiTest, ToBuffer) {
auto model_uri = MODEL_FOLDER "training_model.onnx";
Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
Ort::TrainingSession training_session = Ort::TrainingSession(Ort::SessionOptions(), checkpoint_state, model_uri);
Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
Ort::Value buffer = training_session.ToBuffer(true);
@ -121,8 +126,9 @@ TEST(TrainingCApiTest, ToBuffer) {
TEST(TrainingCApiTest, FromBuffer) {
auto model_uri = MODEL_FOLDER "training_model.onnx";
Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
Ort::TrainingSession training_session = Ort::TrainingSession(Ort::SessionOptions(), checkpoint_state, model_uri);
Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
OrtValue* buffer_impl = std::make_unique<OrtValue>().release();
GenerateRandomInput(std::array<int64_t, 1>{397510}, *buffer_impl);

View file

@ -233,7 +233,7 @@ int RunTraining(const TestRunnerParameters& params) {
#endif
bool do_eval = params.model_evaluation_graph_path.has_value();
Ort::TrainingSession session(soptions, checkpoint_state, params.model_training_graph_path,
Ort::TrainingSession session(env, soptions, checkpoint_state, params.model_training_graph_path,
params.model_evaluation_graph_path,
params.optimizer_training_graph_path.size() > 0
? std::optional<PathString>(params.optimizer_training_graph_path)

View file

@ -101,7 +101,7 @@ class TrainingSession : public detail::Base<OrtTrainingSession> {
size_t training_model_output_count_, eval_model_output_count_;
public:
TrainingSession(const SessionOptions& session_options, CheckpointState& checkpoint_state,
TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state,
const std::basic_string<ORTCHAR_T>& train_model_path,
const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path = std::nullopt,
const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path = std::nullopt);

View file

@ -7,12 +7,11 @@
namespace Ort {
inline TrainingSession::TrainingSession(const SessionOptions& session_options,
inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options,
CheckpointState& checkpoint_state,
const std::basic_string<ORTCHAR_T>& train_model_path,
const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path,
const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path) {
Env env = Env();
ThrowOnError(GetTrainingApi().CreateTrainingSession(
env, session_options, checkpoint_state,
train_model_path.c_str(),