mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-27 03:11:28 +00:00
fix pybind issue introduced by merge
This commit is contained in:
parent
9dbc50c438
commit
8f7bd51f7a
6 changed files with 64 additions and 57 deletions
|
|
@ -255,28 +255,6 @@ void AddTensorAsPyObj(OrtValue& val, std::vector<py::object>& pyobjs) {
|
|||
GetPyObjFromTensor(rtensor, obj);
|
||||
pyobjs.push_back(obj);
|
||||
}
|
||||
class SessionObjectInitializer {
|
||||
public:
|
||||
typedef const SessionOptions& Arg1;
|
||||
// typedef logging::LoggingManager* Arg2;
|
||||
static const std::string default_logger_id;
|
||||
operator Arg1() {
|
||||
return GetDefaultCPUSessionOptions();
|
||||
}
|
||||
|
||||
// operator Arg2() {
|
||||
// static LoggingManager default_logging_manager{std::unique_ptr<ISink>{new CErrSink{}},
|
||||
// Severity::kWARNING, false, LoggingManager::InstanceType::Default,
|
||||
// &default_logger_id};
|
||||
// return &default_logging_manager;
|
||||
// }
|
||||
|
||||
static SessionObjectInitializer Get() {
|
||||
return SessionObjectInitializer();
|
||||
}
|
||||
};
|
||||
|
||||
const std::string SessionObjectInitializer::default_logger_id = "Default";
|
||||
|
||||
inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime::IExecutionProviderFactory& f) {
|
||||
auto p = f.CreateProvider();
|
||||
|
|
@ -968,30 +946,16 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
|
|||
|
||||
#endif
|
||||
|
||||
static std::unique_ptr<Environment> env;
|
||||
auto initialize = [&]() {
|
||||
// Initialization of the module
|
||||
([]() -> void {
|
||||
// import_array1() forces a void return value.
|
||||
import_array1();
|
||||
})();
|
||||
// Initialization of the module
|
||||
([]() -> void {
|
||||
// import_array1() forces a void return value.
|
||||
import_array1();
|
||||
})();
|
||||
|
||||
OrtPybindThrowIfError(Environment::Create(onnxruntime::make_unique<LoggingManager>(
|
||||
std::unique_ptr<ISink>{new CLogSink{}},
|
||||
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
|
||||
&SessionObjectInitializer::default_logger_id),
|
||||
env));
|
||||
Environment& env = get_env();
|
||||
|
||||
static bool initialized = false;
|
||||
if (initialized) {
|
||||
return;
|
||||
}
|
||||
initialized = true;
|
||||
};
|
||||
initialize();
|
||||
|
||||
addGlobalMethods(m, *env);
|
||||
addObjectMethods(m, *env);
|
||||
addGlobalMethods(m, env);
|
||||
addObjectMethods(m, env);
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
addObjectMethodsForTraining(m);
|
||||
|
|
@ -1003,5 +967,36 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
|
|||
#endif
|
||||
}
|
||||
|
||||
|
||||
void initialize_env(){
|
||||
auto initialize = [&]() {
|
||||
// Initialization of the module
|
||||
([]() -> void {
|
||||
// import_array1() forces a void return value.
|
||||
import_array1();
|
||||
})();
|
||||
|
||||
OrtPybindThrowIfError(Environment::Create(onnxruntime::make_unique<LoggingManager>(
|
||||
std::unique_ptr<ISink>{new CLogSink{}},
|
||||
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
|
||||
&SessionObjectInitializer::default_logger_id),
|
||||
session_env));
|
||||
|
||||
static bool initialized = false;
|
||||
if (initialized) {
|
||||
return;
|
||||
}
|
||||
initialized = true;
|
||||
};
|
||||
initialize();
|
||||
}
|
||||
|
||||
onnxruntime::Environment& get_env(){
|
||||
if (!session_env){
|
||||
initialize_env();
|
||||
}
|
||||
return *session_env;
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
} // namespace onnxruntime
|
||||
3
onnxruntime/python/onnxruntime_pybind_state_common.cc
Normal file
3
onnxruntime/python/onnxruntime_pybind_state_common.cc
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#include "onnxruntime_pybind_state_common.h"
|
||||
|
||||
const std::string onnxruntime::python::SessionObjectInitializer::default_logger_id = "Default";
|
||||
|
|
@ -6,6 +6,7 @@
|
|||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/session_options.h"
|
||||
|
||||
#include "core/session/environment.h"
|
||||
namespace onnxruntime {
|
||||
namespace python {
|
||||
|
||||
|
|
@ -25,23 +26,29 @@ inline AllocatorPtr& GetAllocator() {
|
|||
class SessionObjectInitializer {
|
||||
public:
|
||||
typedef const SessionOptions& Arg1;
|
||||
typedef logging::LoggingManager* Arg2;
|
||||
// typedef logging::LoggingManager* Arg2;
|
||||
static const std::string default_logger_id;
|
||||
operator Arg1() {
|
||||
return GetDefaultCPUSessionOptions();
|
||||
}
|
||||
|
||||
operator Arg2() {
|
||||
static std::string default_logger_id{"Default"};
|
||||
static LoggingManager default_logging_manager{std::unique_ptr<ISink>{new CErrSink{}},
|
||||
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
|
||||
&default_logger_id};
|
||||
return &default_logging_manager;
|
||||
}
|
||||
// operator Arg2() {
|
||||
// static LoggingManager default_logging_manager{std::unique_ptr<ISink>{new CErrSink{}},
|
||||
// Severity::kWARNING, false, LoggingManager::InstanceType::Default,
|
||||
// &default_logger_id};
|
||||
// return &default_logging_manager;
|
||||
// }
|
||||
|
||||
static SessionObjectInitializer Get() {
|
||||
return SessionObjectInitializer();
|
||||
}
|
||||
};
|
||||
|
||||
// static variable used to create inference session and training session.
|
||||
static std::unique_ptr<Environment> session_env;
|
||||
void initialize_env();
|
||||
Environment& get_env();
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -20,8 +20,7 @@ class TrainingSession : public InferenceSession {
|
|||
std::vector<std::pair<size_t /*InputIndex*/, float /*value*/>>>
|
||||
ImmutableWeights;
|
||||
|
||||
TrainingSession(const SessionOptions& session_options,
|
||||
const Environment& env)
|
||||
TrainingSession(const SessionOptions& session_options, const Environment& env)
|
||||
: InferenceSession(session_options, env) {}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ int main(int /*argc*/, char* /*args*/ []) {
|
|||
&default_logger_id};
|
||||
|
||||
std::unique_ptr<Environment> env;
|
||||
TERMINATE_IF_FAILED(Environment::Create(env));
|
||||
TERMINATE_IF_FAILED(Environment::Create(nullptr, env));
|
||||
|
||||
// Step 1: Load the model and generate gradient graph in a training session.
|
||||
SessionOptions so;
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
#include "core/framework/random_seed.h"
|
||||
#include "core/framework/session_options.h"
|
||||
#include "core/session/environment.h"
|
||||
#include "orttraining/core/session/training_session.h"
|
||||
#include "orttraining/core/graph/optimizer_config.h"
|
||||
#include "orttraining/core/framework/mpi_setup.h"
|
||||
|
|
@ -169,8 +170,10 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
.def_readwrite("partition_optimizer", &TrainingParameters::partition_optimizer);
|
||||
|
||||
py::class_<onnxruntime::training::TrainingSession, InferenceSession> training_session(m, "TrainingSession");
|
||||
training_session.def(py::init<SessionOptions, SessionObjectInitializer>())
|
||||
.def(py::init<SessionObjectInitializer, SessionObjectInitializer>())
|
||||
training_session.def(py::init([](const SessionOptions& so) {
|
||||
Environment& env = get_env();
|
||||
return onnxruntime::make_unique<onnxruntime::training::TrainingSession>(so, env);
|
||||
}))
|
||||
.def("finalize", [](py::object) {
|
||||
#ifdef USE_HOROVOD
|
||||
training::shutdown_horovod();
|
||||
|
|
|
|||
Loading…
Reference in a new issue