From 8f7bd51f7af19388ccfa5ecfdf7d2478bb0f54aa Mon Sep 17 00:00:00 2001 From: Xueyun Zhu Date: Mon, 23 Mar 2020 23:23:34 +0000 Subject: [PATCH] fix pybind issue introduced by merge --- .../python/onnxruntime_pybind_state.cc | 83 +++++++++---------- .../python/onnxruntime_pybind_state_common.cc | 3 + .../python/onnxruntime_pybind_state_common.h | 23 +++-- .../core/session/training_session.h | 3 +- .../models/mnist/test_grad_graph_builder.cc | 2 +- .../python/orttraining_pybind_state.cc | 7 +- 6 files changed, 64 insertions(+), 57 deletions(-) create mode 100644 onnxruntime/python/onnxruntime_pybind_state_common.cc diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 0343320c4d..272c59d63d 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -255,28 +255,6 @@ void AddTensorAsPyObj(OrtValue& val, std::vector& pyobjs) { GetPyObjFromTensor(rtensor, obj); pyobjs.push_back(obj); } -class SessionObjectInitializer { - public: - typedef const SessionOptions& Arg1; - // typedef logging::LoggingManager* Arg2; - static const std::string default_logger_id; - operator Arg1() { - return GetDefaultCPUSessionOptions(); - } - - // operator Arg2() { - // static LoggingManager default_logging_manager{std::unique_ptr{new CErrSink{}}, - // Severity::kWARNING, false, LoggingManager::InstanceType::Default, - // &default_logger_id}; - // return &default_logging_manager; - // } - - static SessionObjectInitializer Get() { - return SessionObjectInitializer(); - } -}; - -const std::string SessionObjectInitializer::default_logger_id = "Default"; inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime::IExecutionProviderFactory& f) { auto p = f.CreateProvider(); @@ -968,30 +946,16 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { #endif - static std::unique_ptr env; - auto initialize = [&]() { - // Initialization of the module - ([]() -> void { - // import_array1() forces a void return value. - import_array1(); - })(); + // Initialization of the module + ([]() -> void { + // import_array1() forces a void return value. + import_array1(); + })(); - OrtPybindThrowIfError(Environment::Create(onnxruntime::make_unique( - std::unique_ptr{new CLogSink{}}, - Severity::kWARNING, false, LoggingManager::InstanceType::Default, - &SessionObjectInitializer::default_logger_id), - env)); + Environment& env = get_env(); - static bool initialized = false; - if (initialized) { - return; - } - initialized = true; - }; - initialize(); - - addGlobalMethods(m, *env); - addObjectMethods(m, *env); + addGlobalMethods(m, env); + addObjectMethods(m, env); #ifdef ENABLE_TRAINING addObjectMethodsForTraining(m); @@ -1003,5 +967,36 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { #endif } + +void initialize_env(){ + auto initialize = [&]() { + // Initialization of the module + ([]() -> void { + // import_array1() forces a void return value. + import_array1(); + })(); + + OrtPybindThrowIfError(Environment::Create(onnxruntime::make_unique( + std::unique_ptr{new CLogSink{}}, + Severity::kWARNING, false, LoggingManager::InstanceType::Default, + &SessionObjectInitializer::default_logger_id), + session_env)); + + static bool initialized = false; + if (initialized) { + return; + } + initialized = true; + }; + initialize(); +} + +onnxruntime::Environment& get_env(){ + if (!session_env){ + initialize_env(); + } + return *session_env; +} + } // namespace python } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.cc b/onnxruntime/python/onnxruntime_pybind_state_common.cc new file mode 100644 index 0000000000..460e5fbd75 --- /dev/null +++ b/onnxruntime/python/onnxruntime_pybind_state_common.cc @@ -0,0 +1,3 @@ +#include "onnxruntime_pybind_state_common.h" + +const std::string onnxruntime::python::SessionObjectInitializer::default_logger_id = "Default"; \ No newline at end of file diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 4068c2b051..c8c39a0913 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -6,6 +6,7 @@ #include "core/framework/allocator.h" #include "core/framework/session_options.h" +#include "core/session/environment.h" namespace onnxruntime { namespace python { @@ -25,23 +26,29 @@ inline AllocatorPtr& GetAllocator() { class SessionObjectInitializer { public: typedef const SessionOptions& Arg1; - typedef logging::LoggingManager* Arg2; + // typedef logging::LoggingManager* Arg2; + static const std::string default_logger_id; operator Arg1() { return GetDefaultCPUSessionOptions(); } - operator Arg2() { - static std::string default_logger_id{"Default"}; - static LoggingManager default_logging_manager{std::unique_ptr{new CErrSink{}}, - Severity::kWARNING, false, LoggingManager::InstanceType::Default, - &default_logger_id}; - return &default_logging_manager; - } + // operator Arg2() { + // static LoggingManager default_logging_manager{std::unique_ptr{new CErrSink{}}, + // Severity::kWARNING, false, LoggingManager::InstanceType::Default, + // &default_logger_id}; + // return &default_logging_manager; + // } static SessionObjectInitializer Get() { return SessionObjectInitializer(); } }; + +// static variable used to create inference session and training session. +static std::unique_ptr session_env; +void initialize_env(); +Environment& get_env(); + } } diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index 4367b37828..0231a1620e 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -20,8 +20,7 @@ class TrainingSession : public InferenceSession { std::vector>> ImmutableWeights; - TrainingSession(const SessionOptions& session_options, - const Environment& env) + TrainingSession(const SessionOptions& session_options, const Environment& env) : InferenceSession(session_options, env) {} /** diff --git a/orttraining/orttraining/models/mnist/test_grad_graph_builder.cc b/orttraining/orttraining/models/mnist/test_grad_graph_builder.cc index 79be0b79e0..b112b9ab62 100644 --- a/orttraining/orttraining/models/mnist/test_grad_graph_builder.cc +++ b/orttraining/orttraining/models/mnist/test_grad_graph_builder.cc @@ -60,7 +60,7 @@ int main(int /*argc*/, char* /*args*/ []) { &default_logger_id}; std::unique_ptr env; - TERMINATE_IF_FAILED(Environment::Create(env)); + TERMINATE_IF_FAILED(Environment::Create(nullptr, env)); // Step 1: Load the model and generate gradient graph in a training session. SessionOptions so; diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index eafbe9f616..7bb7497abc 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -9,6 +9,7 @@ #include "core/framework/random_seed.h" #include "core/framework/session_options.h" +#include "core/session/environment.h" #include "orttraining/core/session/training_session.h" #include "orttraining/core/graph/optimizer_config.h" #include "orttraining/core/framework/mpi_setup.h" @@ -169,8 +170,10 @@ void addObjectMethodsForTraining(py::module& m) { .def_readwrite("partition_optimizer", &TrainingParameters::partition_optimizer); py::class_ training_session(m, "TrainingSession"); - training_session.def(py::init()) - .def(py::init()) + training_session.def(py::init([](const SessionOptions& so) { + Environment& env = get_env(); + return onnxruntime::make_unique(so, env); + })) .def("finalize", [](py::object) { #ifdef USE_HOROVOD training::shutdown_horovod();