fix pybind issue introduced by merge

This commit is contained in:
Xueyun Zhu 2020-03-23 23:23:34 +00:00
parent 9dbc50c438
commit 8f7bd51f7a
6 changed files with 64 additions and 57 deletions

View file

@ -255,28 +255,6 @@ void AddTensorAsPyObj(OrtValue& val, std::vector<py::object>& pyobjs) {
GetPyObjFromTensor(rtensor, obj);
pyobjs.push_back(obj);
}
class SessionObjectInitializer {
public:
typedef const SessionOptions& Arg1;
// typedef logging::LoggingManager* Arg2;
static const std::string default_logger_id;
operator Arg1() {
return GetDefaultCPUSessionOptions();
}
// operator Arg2() {
// static LoggingManager default_logging_manager{std::unique_ptr<ISink>{new CErrSink{}},
// Severity::kWARNING, false, LoggingManager::InstanceType::Default,
// &default_logger_id};
// return &default_logging_manager;
// }
static SessionObjectInitializer Get() {
return SessionObjectInitializer();
}
};
const std::string SessionObjectInitializer::default_logger_id = "Default";
inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime::IExecutionProviderFactory& f) {
auto p = f.CreateProvider();
@ -968,30 +946,16 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
#endif
static std::unique_ptr<Environment> env;
auto initialize = [&]() {
// Initialization of the module
([]() -> void {
// import_array1() forces a void return value.
import_array1();
})();
// Initialization of the module
([]() -> void {
// import_array1() forces a void return value.
import_array1();
})();
OrtPybindThrowIfError(Environment::Create(onnxruntime::make_unique<LoggingManager>(
std::unique_ptr<ISink>{new CLogSink{}},
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
&SessionObjectInitializer::default_logger_id),
env));
Environment& env = get_env();
static bool initialized = false;
if (initialized) {
return;
}
initialized = true;
};
initialize();
addGlobalMethods(m, *env);
addObjectMethods(m, *env);
addGlobalMethods(m, env);
addObjectMethods(m, env);
#ifdef ENABLE_TRAINING
addObjectMethodsForTraining(m);
@ -1003,5 +967,36 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
#endif
}
void initialize_env(){
auto initialize = [&]() {
// Initialization of the module
([]() -> void {
// import_array1() forces a void return value.
import_array1();
})();
OrtPybindThrowIfError(Environment::Create(onnxruntime::make_unique<LoggingManager>(
std::unique_ptr<ISink>{new CLogSink{}},
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
&SessionObjectInitializer::default_logger_id),
session_env));
static bool initialized = false;
if (initialized) {
return;
}
initialized = true;
};
initialize();
}
onnxruntime::Environment& get_env(){
if (!session_env){
initialize_env();
}
return *session_env;
}
} // namespace python
} // namespace onnxruntime

View file

@ -0,0 +1,3 @@
#include "onnxruntime_pybind_state_common.h"
const std::string onnxruntime::python::SessionObjectInitializer::default_logger_id = "Default";

View file

@ -6,6 +6,7 @@
#include "core/framework/allocator.h"
#include "core/framework/session_options.h"
#include "core/session/environment.h"
namespace onnxruntime {
namespace python {
@ -25,23 +26,29 @@ inline AllocatorPtr& GetAllocator() {
class SessionObjectInitializer {
public:
typedef const SessionOptions& Arg1;
typedef logging::LoggingManager* Arg2;
// typedef logging::LoggingManager* Arg2;
static const std::string default_logger_id;
operator Arg1() {
return GetDefaultCPUSessionOptions();
}
operator Arg2() {
static std::string default_logger_id{"Default"};
static LoggingManager default_logging_manager{std::unique_ptr<ISink>{new CErrSink{}},
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
&default_logger_id};
return &default_logging_manager;
}
// operator Arg2() {
// static LoggingManager default_logging_manager{std::unique_ptr<ISink>{new CErrSink{}},
// Severity::kWARNING, false, LoggingManager::InstanceType::Default,
// &default_logger_id};
// return &default_logging_manager;
// }
static SessionObjectInitializer Get() {
return SessionObjectInitializer();
}
};
// static variable used to create inference session and training session.
static std::unique_ptr<Environment> session_env;
void initialize_env();
Environment& get_env();
}
}

View file

@ -20,8 +20,7 @@ class TrainingSession : public InferenceSession {
std::vector<std::pair<size_t /*InputIndex*/, float /*value*/>>>
ImmutableWeights;
TrainingSession(const SessionOptions& session_options,
const Environment& env)
TrainingSession(const SessionOptions& session_options, const Environment& env)
: InferenceSession(session_options, env) {}
/**

View file

@ -60,7 +60,7 @@ int main(int /*argc*/, char* /*args*/ []) {
&default_logger_id};
std::unique_ptr<Environment> env;
TERMINATE_IF_FAILED(Environment::Create(env));
TERMINATE_IF_FAILED(Environment::Create(nullptr, env));
// Step 1: Load the model and generate gradient graph in a training session.
SessionOptions so;

View file

@ -9,6 +9,7 @@
#include "core/framework/random_seed.h"
#include "core/framework/session_options.h"
#include "core/session/environment.h"
#include "orttraining/core/session/training_session.h"
#include "orttraining/core/graph/optimizer_config.h"
#include "orttraining/core/framework/mpi_setup.h"
@ -169,8 +170,10 @@ void addObjectMethodsForTraining(py::module& m) {
.def_readwrite("partition_optimizer", &TrainingParameters::partition_optimizer);
py::class_<onnxruntime::training::TrainingSession, InferenceSession> training_session(m, "TrainingSession");
training_session.def(py::init<SessionOptions, SessionObjectInitializer>())
.def(py::init<SessionObjectInitializer, SessionObjectInitializer>())
training_session.def(py::init([](const SessionOptions& so) {
Environment& env = get_env();
return onnxruntime::make_unique<onnxruntime::training::TrainingSession>(so, env);
}))
.def("finalize", [](py::object) {
#ifdef USE_HOROVOD
training::shutdown_horovod();