Remove public default ctor in PyInferenceSession and replace it with a protected ctor (#4990)

This commit is contained in:
Hariharan Seshadri 2020-09-01 17:10:36 -07:00 committed by GitHub
parent c6a3620ba8
commit d30dd41c0e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 6 deletions

View file

@ -42,9 +42,6 @@ struct PySessionOptions : public SessionOptions {
// Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user
struct PyInferenceSession {
// Default ctor is present only to be invoked by the PyTrainingSession class
PyInferenceSession() {}
PyInferenceSession(Environment& env, const PySessionOptions& so, const std::string& arg, bool is_arg_file_name) {
if (is_arg_file_name) {
// Given arg is the file path. Invoke the corresponding ctor().
@ -70,6 +67,10 @@ struct PyInferenceSession {
virtual ~PyInferenceSession() {}
protected:
PyInferenceSession(std::unique_ptr<InferenceSession> sess) {
sess_ = std::move(sess);
}
// Hold CustomOpLibrary resources so as to tie it to the life cycle of the InferenceSession needing it.
// NOTE: Declare this above `sess_` so that this is destructed AFTER the InferenceSession instance -
// this is so that the custom ops held by the InferenceSession gets destroyed prior to the library getting unloaded

View file

@ -206,9 +206,8 @@ void addObjectMethodsForTraining(py::module& m) {
// Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user
struct PyTrainingSession : public PyInferenceSession {
PyTrainingSession(Environment& env, const PySessionOptions& so) {
// `sess_` is inherited from PyinferenceSession
sess_ = onnxruntime::make_unique<onnxruntime::training::TrainingSession>(so, env);
PyTrainingSession(Environment& env, const PySessionOptions& so)
: PyInferenceSession(onnxruntime::make_unique<TrainingSession>(so, env)) {
}
};