mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
136 lines
4.7 KiB
C++
136 lines
4.7 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "core/common/logging/logging.h"
|
|
#include "core/common/logging/sinks/cerr_sink.h"
|
|
#include "core/framework/allocator.h"
|
|
#include "core/framework/session_options.h"
|
|
#include "core/session/environment.h"
|
|
#include "core/session/inference_session.h"
|
|
|
|
namespace onnxruntime {
|
|
namespace python {
|
|
|
|
using namespace onnxruntime;
|
|
using namespace onnxruntime::logging;
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
struct CustomOpLibrary {
|
|
CustomOpLibrary(const char* library_path, OrtSessionOptions& ort_so);
|
|
|
|
~CustomOpLibrary();
|
|
|
|
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomOpLibrary);
|
|
|
|
private:
|
|
void UnloadLibrary();
|
|
|
|
std::string library_path_;
|
|
void* library_handle_ = nullptr;
|
|
};
|
|
#endif
|
|
|
|
// Thin wrapper over internal C++ SessionOptions to accommodate custom op library management for the Python user
|
|
struct PySessionOptions : public SessionOptions {
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
// `PySessionOptions` has a vector of shared_ptrs to CustomOpLibrary, because so that it can be re-used for all
|
|
// `PyInferenceSession`s using the same `PySessionOptions` and that each `PyInferenceSession` need not construct
|
|
// duplicate CustomOpLibrary instances.
|
|
std::vector<std::shared_ptr<CustomOpLibrary>> custom_op_libraries_;
|
|
|
|
// Hold raw `OrtCustomOpDomain` pointers - it is upto the shared library to release the OrtCustomOpDomains
|
|
// that was created when the library is unloaded
|
|
std::vector<OrtCustomOpDomain*> custom_op_domains_;
|
|
#endif
|
|
};
|
|
|
|
// Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user
|
|
struct PyInferenceSession {
|
|
PyInferenceSession(Environment& env, const PySessionOptions& so) {
|
|
sess_ = onnxruntime::make_unique<InferenceSession>(so, env);
|
|
}
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
PyInferenceSession(Environment& env, const PySessionOptions& so, const std::string& arg, bool is_arg_file_name) {
|
|
if (is_arg_file_name) {
|
|
// Given arg is the file path. Invoke the corresponding ctor().
|
|
sess_ = onnxruntime::make_unique<InferenceSession>(so, env, arg);
|
|
} else {
|
|
// Given arg is the model content as bytes. Invoke the corresponding ctor().
|
|
std::istringstream buffer(arg);
|
|
sess_ = onnxruntime::make_unique<InferenceSession>(so, env, buffer);
|
|
}
|
|
}
|
|
|
|
void AddCustomOpLibraries(const std::vector<std::shared_ptr<CustomOpLibrary>>& custom_op_libraries) {
|
|
if (!custom_op_libraries.empty()) {
|
|
custom_op_libraries_.reserve(custom_op_libraries.size());
|
|
for (size_t i = 0; i < custom_op_libraries.size(); ++i) {
|
|
custom_op_libraries_.push_back(custom_op_libraries[i]);
|
|
}
|
|
}
|
|
}
|
|
#endif
|
|
|
|
InferenceSession* GetSessionHandle() const { return sess_.get(); }
|
|
|
|
virtual ~PyInferenceSession() {}
|
|
|
|
protected:
|
|
PyInferenceSession(std::unique_ptr<InferenceSession> sess) {
|
|
sess_ = std::move(sess);
|
|
}
|
|
|
|
private:
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
// Hold CustomOpLibrary resources so as to tie it to the life cycle of the InferenceSession needing it.
|
|
// NOTE: Define this above `sess_` so that this is destructed AFTER the InferenceSession instance -
|
|
// this is so that the custom ops held by the InferenceSession gets destroyed prior to the library getting unloaded
|
|
// (if ref count of the shared_ptr reaches 0)
|
|
std::vector<std::shared_ptr<CustomOpLibrary>> custom_op_libraries_;
|
|
#endif
|
|
|
|
std::unique_ptr<InferenceSession> sess_;
|
|
};
|
|
|
|
inline const PySessionOptions& GetDefaultCPUSessionOptions() {
|
|
static PySessionOptions so;
|
|
return so;
|
|
}
|
|
|
|
inline AllocatorPtr& GetAllocator() {
|
|
static AllocatorPtr alloc = std::make_shared<TAllocator>();
|
|
return alloc;
|
|
}
|
|
|
|
class SessionObjectInitializer {
|
|
public:
|
|
typedef const PySessionOptions& Arg1;
|
|
// typedef logging::LoggingManager* Arg2;
|
|
static const std::string default_logger_id;
|
|
operator Arg1() {
|
|
return GetDefaultCPUSessionOptions();
|
|
}
|
|
|
|
// operator Arg2() {
|
|
// static LoggingManager default_logging_manager{std::unique_ptr<ISink>{new CErrSink{}},
|
|
// Severity::kWARNING, false, LoggingManager::InstanceType::Default,
|
|
// &default_logger_id};
|
|
// return &default_logging_manager;
|
|
// }
|
|
|
|
static SessionObjectInitializer Get() {
|
|
return SessionObjectInitializer();
|
|
}
|
|
};
|
|
|
|
Environment& GetEnv();
|
|
|
|
// Initialize an InferenceSession.
|
|
// Any provider_options should have entries in matching order to provider_types.
|
|
void InitializeSession(InferenceSession* sess,
|
|
const std::vector<std::string>& provider_types = {},
|
|
const ProviderOptionsVector& provider_options = {});
|
|
|
|
} // namespace python
|
|
} // namespace onnxruntime
|