diff --git a/orttraining/orttraining/models/runner/training_runner.h b/orttraining/orttraining/models/runner/training_runner.h index 23f7672c2e..94d972edca 100644 --- a/orttraining/orttraining/models/runner/training_runner.h +++ b/orttraining/orttraining/models/runner/training_runner.h @@ -169,6 +169,7 @@ class TrainingRunner { common::Status ResetLossScaler(); size_t GetRound() const { return round_; } + TrainingSession& GetSession() { return session_; } private: Status TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader);