diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index dbc96096cb..0d541dbfa5 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -42,9 +42,6 @@ struct PySessionOptions : public SessionOptions { // Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user struct PyInferenceSession { - // Default ctor is present only to be invoked by the PyTrainingSession class - PyInferenceSession() {} - PyInferenceSession(Environment& env, const PySessionOptions& so, const std::string& arg, bool is_arg_file_name) { if (is_arg_file_name) { // Given arg is the file path. Invoke the corresponding ctor(). @@ -70,6 +67,10 @@ struct PyInferenceSession { virtual ~PyInferenceSession() {} protected: + PyInferenceSession(std::unique_ptr sess) { + sess_ = std::move(sess); + } + // Hold CustomOpLibrary resources so as to tie it to the life cycle of the InferenceSession needing it. // NOTE: Declare this above `sess_` so that this is destructed AFTER the InferenceSession instance - // this is so that the custom ops held by the InferenceSession gets destroyed prior to the library getting unloaded diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index e110e6b950..a818d37564 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -206,9 +206,8 @@ void addObjectMethodsForTraining(py::module& m) { // Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user struct PyTrainingSession : public PyInferenceSession { - PyTrainingSession(Environment& env, const PySessionOptions& so) { - // `sess_` is inherited from PyinferenceSession - sess_ = onnxruntime::make_unique(so, env); + PyTrainingSession(Environment& env, const PySessionOptions& so) + : PyInferenceSession(onnxruntime::make_unique(so, env)) { } };