mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
introduce macro ORT_API_MANUAL_INIT in C++ API (#4536)
* introduce macro ORT_API_MANUAL_INIT in C++ API * resolve comments
This commit is contained in:
parent
21d2728974
commit
fdc5c308c4
4 changed files with 92 additions and 91 deletions
|
|
@ -41,20 +41,21 @@ struct Exception : std::exception {
|
|||
// it transparent to the users of the API.
|
||||
template <typename T>
|
||||
struct Global {
|
||||
static const OrtApi& api_;
|
||||
static const OrtApi* api_;
|
||||
};
|
||||
|
||||
#ifdef EXCLUDE_REFERENCE_TO_ORT_DLL
|
||||
OrtApi stub_api;
|
||||
// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
|
||||
|
||||
template <typename T>
|
||||
const OrtApi& Global<T>::api_ = stub_api;
|
||||
#ifdef ORT_API_MANUAL_INIT
|
||||
const OrtApi* Global<T>::api_{};
|
||||
inline void InitApi() { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
|
||||
#else
|
||||
template <typename T>
|
||||
const OrtApi& Global<T>::api_ = *OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
||||
const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
||||
#endif
|
||||
|
||||
// This returns a reference to the OrtApi interface in use, in case someone wants to use the C API functions
|
||||
inline const OrtApi& GetApi() { return Global<void>::api_; }
|
||||
inline const OrtApi& GetApi() { return *Global<void>::api_; }
|
||||
|
||||
// This is a C++ wrapper for GetAvailableProviders() C API and returns
|
||||
// a vector of strings representing the available execution providers.
|
||||
|
|
@ -63,7 +64,7 @@ std::vector<std::string> GetAvailableProviders();
|
|||
// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
|
||||
// This can't be done in the C API since C doesn't have function overloading.
|
||||
#define ORT_DEFINE_RELEASE(NAME) \
|
||||
inline void OrtRelease(Ort##NAME* ptr) { Global<void>::api_.Release##NAME(ptr); }
|
||||
inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
|
||||
|
||||
ORT_DEFINE_RELEASE(MemoryInfo);
|
||||
ORT_DEFINE_RELEASE(CustomOpDomain);
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ inline void ThrowOnError(const OrtApi& ort, OrtStatus* status) {
|
|||
}
|
||||
|
||||
inline void ThrowOnError(OrtStatus* status) {
|
||||
ThrowOnError(Global<void>::api_, status);
|
||||
ThrowOnError(GetApi(), status);
|
||||
}
|
||||
|
||||
// This template converts a C++ type into it's ONNXTensorElementDataType
|
||||
|
|
@ -49,186 +49,186 @@ template <>
|
|||
struct TypeToTensorType<bool> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; };
|
||||
|
||||
inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() {
|
||||
ThrowOnError(Global<void>::api_.GetAllocatorWithDefaultOptions(&p_));
|
||||
ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&p_));
|
||||
}
|
||||
|
||||
inline void* AllocatorWithDefaultOptions::Alloc(size_t size) {
|
||||
void* out;
|
||||
ThrowOnError(Global<void>::api_.AllocatorAlloc(p_, size, &out));
|
||||
ThrowOnError(GetApi().AllocatorAlloc(p_, size, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline void AllocatorWithDefaultOptions::Free(void* p) {
|
||||
ThrowOnError(Global<void>::api_.AllocatorFree(p_, p));
|
||||
ThrowOnError(GetApi().AllocatorFree(p_, p));
|
||||
}
|
||||
|
||||
inline const OrtMemoryInfo* AllocatorWithDefaultOptions::GetInfo() const {
|
||||
const OrtMemoryInfo* out;
|
||||
ThrowOnError(Global<void>::api_.AllocatorGetInfo(p_, &out));
|
||||
ThrowOnError(GetApi().AllocatorGetInfo(p_, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) {
|
||||
OrtMemoryInfo* p;
|
||||
ThrowOnError(Global<void>::api_.CreateCpuMemoryInfo(type, mem_type, &p));
|
||||
ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
|
||||
return MemoryInfo(p);
|
||||
}
|
||||
|
||||
inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
|
||||
ThrowOnError(Global<void>::api_.CreateMemoryInfo(name, type, id, mem_type, &p_));
|
||||
ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &p_));
|
||||
}
|
||||
|
||||
inline Env::Env(OrtLoggingLevel default_warning_level, _In_ const char* logid) {
|
||||
ThrowOnError(Global<void>::api_.CreateEnv(default_warning_level, logid, &p_));
|
||||
ThrowOnError(GetApi().CreateEnv(default_warning_level, logid, &p_));
|
||||
}
|
||||
|
||||
inline Env::Env(OrtLoggingLevel default_warning_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
|
||||
ThrowOnError(Global<void>::api_.CreateEnvWithCustomLogger(logging_function, logger_param, default_warning_level, logid, &p_));
|
||||
ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, default_warning_level, logid, &p_));
|
||||
}
|
||||
|
||||
inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel default_warning_level, _In_ const char* logid) {
|
||||
ThrowOnError(Global<void>::api_.CreateEnvWithGlobalThreadPools(default_warning_level, logid, tp_options, &p_));
|
||||
ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(default_warning_level, logid, tp_options, &p_));
|
||||
}
|
||||
|
||||
inline Env& Env::EnableTelemetryEvents() {
|
||||
ThrowOnError(Global<void>::api_.EnableTelemetryEvents(p_));
|
||||
ThrowOnError(GetApi().EnableTelemetryEvents(p_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline Env& Env::DisableTelemetryEvents() {
|
||||
ThrowOnError(Global<void>::api_.DisableTelemetryEvents(p_));
|
||||
ThrowOnError(GetApi().DisableTelemetryEvents(p_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline CustomOpDomain::CustomOpDomain(const char* domain) {
|
||||
ThrowOnError(Global<void>::api_.CreateCustomOpDomain(domain, &p_));
|
||||
ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
|
||||
}
|
||||
|
||||
inline void CustomOpDomain::Add(OrtCustomOp* op) {
|
||||
ThrowOnError(Global<void>::api_.CustomOpDomain_Add(p_, op));
|
||||
ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
|
||||
}
|
||||
|
||||
inline RunOptions::RunOptions() {
|
||||
ThrowOnError(Global<void>::api_.CreateRunOptions(&p_));
|
||||
ThrowOnError(GetApi().CreateRunOptions(&p_));
|
||||
}
|
||||
|
||||
inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) {
|
||||
ThrowOnError(Global<void>::api_.RunOptionsSetRunLogVerbosityLevel(p_, level));
|
||||
ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) {
|
||||
ThrowOnError(Global<void>::api_.RunOptionsSetRunLogSeverityLevel(p_, level));
|
||||
ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline int RunOptions::GetRunLogVerbosityLevel() const {
|
||||
int out;
|
||||
ThrowOnError(Global<void>::api_.RunOptionsGetRunLogVerbosityLevel(p_, &out));
|
||||
ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
|
||||
ThrowOnError(Global<void>::api_.RunOptionsSetRunTag(p_, run_tag));
|
||||
ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline const char* RunOptions::GetRunTag() const {
|
||||
const char* out;
|
||||
ThrowOnError(Global<void>::api_.RunOptionsGetRunTag(p_, &out));
|
||||
ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline RunOptions& RunOptions::SetTerminate() {
|
||||
ThrowOnError(Global<void>::api_.RunOptionsSetTerminate(p_));
|
||||
ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline RunOptions& RunOptions::UnsetTerminate() {
|
||||
ThrowOnError(Global<void>::api_.RunOptionsUnsetTerminate(p_));
|
||||
ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions::SessionOptions() {
|
||||
ThrowOnError(Global<void>::api_.CreateSessionOptions(&p_));
|
||||
ThrowOnError(GetApi().CreateSessionOptions(&p_));
|
||||
}
|
||||
|
||||
inline SessionOptions SessionOptions::Clone() const {
|
||||
OrtSessionOptions* out;
|
||||
ThrowOnError(Global<void>::api_.CloneSessionOptions(p_, &out));
|
||||
ThrowOnError(GetApi().CloneSessionOptions(p_, &out));
|
||||
return SessionOptions{out};
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::SetIntraOpNumThreads(int intra_op_num_threads) {
|
||||
ThrowOnError(Global<void>::api_.SetIntraOpNumThreads(p_, intra_op_num_threads));
|
||||
ThrowOnError(GetApi().SetIntraOpNumThreads(p_, intra_op_num_threads));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::SetInterOpNumThreads(int inter_op_num_threads) {
|
||||
ThrowOnError(Global<void>::api_.SetInterOpNumThreads(p_, inter_op_num_threads));
|
||||
ThrowOnError(GetApi().SetInterOpNumThreads(p_, inter_op_num_threads));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
|
||||
ThrowOnError(Global<void>::api_.SetSessionGraphOptimizationLevel(p_, graph_optimization_level));
|
||||
ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(p_, graph_optimization_level));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
|
||||
ThrowOnError(Global<void>::api_.SetOptimizedModelFilePath(p_, optimized_model_filepath));
|
||||
ThrowOnError(GetApi().SetOptimizedModelFilePath(p_, optimized_model_filepath));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
|
||||
ThrowOnError(Global<void>::api_.EnableProfiling(p_, profile_file_prefix));
|
||||
ThrowOnError(GetApi().EnableProfiling(p_, profile_file_prefix));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::DisableProfiling() {
|
||||
ThrowOnError(Global<void>::api_.DisableProfiling(p_));
|
||||
ThrowOnError(GetApi().DisableProfiling(p_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::EnableMemPattern() {
|
||||
ThrowOnError(Global<void>::api_.EnableMemPattern(p_));
|
||||
ThrowOnError(GetApi().EnableMemPattern(p_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::DisableMemPattern() {
|
||||
ThrowOnError(Global<void>::api_.DisableMemPattern(p_));
|
||||
ThrowOnError(GetApi().DisableMemPattern(p_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::EnableCpuMemArena() {
|
||||
ThrowOnError(Global<void>::api_.EnableCpuMemArena(p_));
|
||||
ThrowOnError(GetApi().EnableCpuMemArena(p_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::DisableCpuMemArena() {
|
||||
ThrowOnError(Global<void>::api_.DisableCpuMemArena(p_));
|
||||
ThrowOnError(GetApi().DisableCpuMemArena(p_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::SetExecutionMode(ExecutionMode execution_mode) {
|
||||
ThrowOnError(Global<void>::api_.SetSessionExecutionMode(p_, execution_mode));
|
||||
ThrowOnError(GetApi().SetSessionExecutionMode(p_, execution_mode));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::SetLogId(const char* logid) {
|
||||
ThrowOnError(Global<void>::api_.SetSessionLogId(p_, logid));
|
||||
ThrowOnError(GetApi().SetSessionLogId(p_, logid));
|
||||
return *this;
|
||||
}
|
||||
inline SessionOptions& SessionOptions::Add(OrtCustomOpDomain* custom_op_domain) {
|
||||
ThrowOnError(Global<void>::api_.AddCustomOpDomain(p_, custom_op_domain));
|
||||
ThrowOnError(GetApi().AddCustomOpDomain(p_, custom_op_domain));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
|
||||
ThrowOnError(Global<void>::api_.CreateSession(env, model_path, options, &p_));
|
||||
ThrowOnError(GetApi().CreateSession(env, model_path, options, &p_));
|
||||
}
|
||||
|
||||
inline Session::Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
|
||||
ThrowOnError(Global<void>::api_.CreateSessionFromArray(env, model_data, model_data_length, options, &p_));
|
||||
ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &p_));
|
||||
}
|
||||
|
||||
inline std::vector<Value> Session::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
||||
|
|
@ -245,141 +245,141 @@ inline void Session::Run(const RunOptions& run_options, const char* const* input
|
|||
static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
|
||||
auto ort_input_values = reinterpret_cast<const OrtValue**>(const_cast<Value*>(input_values));
|
||||
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
|
||||
ThrowOnError(Global<void>::api_.Run(p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
|
||||
ThrowOnError(GetApi().Run(p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
|
||||
}
|
||||
|
||||
inline size_t Session::GetInputCount() const {
|
||||
size_t out;
|
||||
ThrowOnError(Global<void>::api_.SessionGetInputCount(p_, &out));
|
||||
ThrowOnError(GetApi().SessionGetInputCount(p_, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline size_t Session::GetOutputCount() const {
|
||||
size_t out;
|
||||
ThrowOnError(Global<void>::api_.SessionGetOutputCount(p_, &out));
|
||||
ThrowOnError(GetApi().SessionGetOutputCount(p_, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline size_t Session::GetOverridableInitializerCount() const {
|
||||
size_t out;
|
||||
ThrowOnError(Global<void>::api_.SessionGetOverridableInitializerCount(p_, &out));
|
||||
ThrowOnError(GetApi().SessionGetOverridableInitializerCount(p_, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline char* Session::GetInputName(size_t index, OrtAllocator* allocator) const {
|
||||
char* out;
|
||||
ThrowOnError(Global<void>::api_.SessionGetInputName(p_, index, allocator, &out));
|
||||
ThrowOnError(GetApi().SessionGetInputName(p_, index, allocator, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline char* Session::GetOutputName(size_t index, OrtAllocator* allocator) const {
|
||||
char* out;
|
||||
ThrowOnError(Global<void>::api_.SessionGetOutputName(p_, index, allocator, &out));
|
||||
ThrowOnError(GetApi().SessionGetOutputName(p_, index, allocator, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline char* Session::GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const {
|
||||
char* out;
|
||||
ThrowOnError(Global<void>::api_.SessionGetOverridableInitializerName(p_, index, allocator, &out));
|
||||
ThrowOnError(GetApi().SessionGetOverridableInitializerName(p_, index, allocator, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline char* Session::EndProfiling(OrtAllocator* allocator) const {
|
||||
char* out;
|
||||
ThrowOnError(Global<void>::api_.SessionEndProfiling(p_, allocator, &out));
|
||||
ThrowOnError(GetApi().SessionEndProfiling(p_, allocator, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline ModelMetadata Session::GetModelMetadata() const {
|
||||
OrtModelMetadata* out;
|
||||
ThrowOnError(Global<void>::api_.SessionGetModelMetadata(p_, &out));
|
||||
ThrowOnError(GetApi().SessionGetModelMetadata(p_, &out));
|
||||
return ModelMetadata{out};
|
||||
}
|
||||
|
||||
inline char* ModelMetadata::GetProducerName(OrtAllocator* allocator) const {
|
||||
char* out;
|
||||
ThrowOnError(Global<void>::api_.ModelMetadataGetProducerName(p_, allocator, &out));
|
||||
ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline char* ModelMetadata::GetGraphName(OrtAllocator* allocator) const {
|
||||
char* out;
|
||||
ThrowOnError(Global<void>::api_.ModelMetadataGetGraphName(p_, allocator, &out));
|
||||
ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline char* ModelMetadata::GetDomain(OrtAllocator* allocator) const {
|
||||
char* out;
|
||||
ThrowOnError(Global<void>::api_.ModelMetadataGetDomain(p_, allocator, &out));
|
||||
ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline char* ModelMetadata::GetDescription(OrtAllocator* allocator) const {
|
||||
char* out;
|
||||
ThrowOnError(Global<void>::api_.ModelMetadataGetDescription(p_, allocator, &out));
|
||||
ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline char* ModelMetadata::LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const {
|
||||
char* out;
|
||||
ThrowOnError(Global<void>::api_.ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
|
||||
ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline char** ModelMetadata::GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const {
|
||||
char** out;
|
||||
ThrowOnError(Global<void>::api_.ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
|
||||
ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline int64_t ModelMetadata::GetVersion() const {
|
||||
int64_t out;
|
||||
ThrowOnError(Global<void>::api_.ModelMetadataGetVersion(p_, &out));
|
||||
ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline TypeInfo Session::GetInputTypeInfo(size_t index) const {
|
||||
OrtTypeInfo* out;
|
||||
ThrowOnError(Global<void>::api_.SessionGetInputTypeInfo(p_, index, &out));
|
||||
ThrowOnError(GetApi().SessionGetInputTypeInfo(p_, index, &out));
|
||||
return TypeInfo{out};
|
||||
}
|
||||
|
||||
inline TypeInfo Session::GetOutputTypeInfo(size_t index) const {
|
||||
OrtTypeInfo* out;
|
||||
ThrowOnError(Global<void>::api_.SessionGetOutputTypeInfo(p_, index, &out));
|
||||
ThrowOnError(GetApi().SessionGetOutputTypeInfo(p_, index, &out));
|
||||
return TypeInfo{out};
|
||||
}
|
||||
|
||||
inline TypeInfo Session::GetOverridableInitializerTypeInfo(size_t index) const {
|
||||
OrtTypeInfo* out;
|
||||
ThrowOnError(Global<void>::api_.SessionGetOverridableInitializerTypeInfo(p_, index, &out));
|
||||
ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(p_, index, &out));
|
||||
return TypeInfo{out};
|
||||
}
|
||||
|
||||
inline ONNXTensorElementDataType TensorTypeAndShapeInfo::GetElementType() const {
|
||||
ONNXTensorElementDataType out;
|
||||
ThrowOnError(Global<void>::api_.GetTensorElementType(p_, &out));
|
||||
ThrowOnError(GetApi().GetTensorElementType(p_, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline size_t TensorTypeAndShapeInfo::GetElementCount() const {
|
||||
size_t out;
|
||||
ThrowOnError(Global<void>::api_.GetTensorShapeElementCount(p_, &out));
|
||||
ThrowOnError(GetApi().GetTensorShapeElementCount(p_, &out));
|
||||
return static_cast<size_t>(out);
|
||||
}
|
||||
|
||||
inline size_t TensorTypeAndShapeInfo::GetDimensionsCount() const {
|
||||
size_t out;
|
||||
ThrowOnError(Global<void>::api_.GetDimensionsCount(p_, &out));
|
||||
ThrowOnError(GetApi().GetDimensionsCount(p_, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline void TensorTypeAndShapeInfo::GetDimensions(int64_t* values, size_t values_count) const {
|
||||
ThrowOnError(Global<void>::api_.GetDimensions(p_, values, values_count));
|
||||
ThrowOnError(GetApi().GetDimensions(p_, values, values_count));
|
||||
}
|
||||
|
||||
inline void TensorTypeAndShapeInfo::GetSymbolicDimensions(const char** values, size_t values_count) const {
|
||||
ThrowOnError(Global<void>::api_.GetSymbolicDimensions(p_, values, values_count));
|
||||
ThrowOnError(GetApi().GetSymbolicDimensions(p_, values, values_count));
|
||||
}
|
||||
|
||||
inline std::vector<int64_t> TensorTypeAndShapeInfo::GetShape() const {
|
||||
|
|
@ -390,13 +390,13 @@ inline std::vector<int64_t> TensorTypeAndShapeInfo::GetShape() const {
|
|||
|
||||
inline Unowned<TensorTypeAndShapeInfo> TypeInfo::GetTensorTypeAndShapeInfo() const {
|
||||
const OrtTensorTypeAndShapeInfo* out;
|
||||
ThrowOnError(Global<void>::api_.CastTypeInfoToTensorInfo(p_, &out));
|
||||
ThrowOnError(GetApi().CastTypeInfoToTensorInfo(p_, &out));
|
||||
return Unowned<TensorTypeAndShapeInfo>{const_cast<OrtTensorTypeAndShapeInfo*>(out)};
|
||||
}
|
||||
|
||||
inline ONNXType TypeInfo::GetONNXType() const {
|
||||
ONNXType out;
|
||||
ThrowOnError(Global<void>::api_.GetOnnxTypeFromTypeInfo(p_, &out));
|
||||
ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(p_, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
|
|
@ -408,7 +408,7 @@ inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_
|
|||
inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
|
||||
ONNXTensorElementDataType type) {
|
||||
OrtValue* out;
|
||||
ThrowOnError(Global<void>::api_.CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
|
||||
ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
|
||||
return Value{out};
|
||||
}
|
||||
|
||||
|
|
@ -419,80 +419,80 @@ inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape,
|
|||
|
||||
inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
|
||||
OrtValue* out;
|
||||
ThrowOnError(Global<void>::api_.CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
|
||||
ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
|
||||
return Value{out};
|
||||
}
|
||||
|
||||
inline Value Value::CreateMap(Value& keys, Value& values) {
|
||||
OrtValue* out;
|
||||
OrtValue* inputs[2] = {keys, values};
|
||||
ThrowOnError(Global<void>::api_.CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
|
||||
ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
|
||||
return Value{out};
|
||||
}
|
||||
|
||||
inline Value Value::CreateSequence(std::vector<Value>& values) {
|
||||
OrtValue* out;
|
||||
std::vector<OrtValue*> values_ort{values.data(), values.data() + values.size()};
|
||||
ThrowOnError(Global<void>::api_.CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
|
||||
ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
|
||||
return Value{out};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
|
||||
OrtValue* out;
|
||||
ThrowOnError(Global<void>::api_.CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
|
||||
ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
|
||||
return Value{out};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void Value::GetOpaqueData(const char* domain, const char* type_name, T& out) {
|
||||
ThrowOnError(Global<void>::api_.GetOpaqueValue(domain, type_name, p_, &out, sizeof(T)));
|
||||
ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, p_, &out, sizeof(T)));
|
||||
}
|
||||
|
||||
inline bool Value::IsTensor() const {
|
||||
int out;
|
||||
ThrowOnError(Global<void>::api_.IsTensor(p_, &out));
|
||||
ThrowOnError(GetApi().IsTensor(p_, &out));
|
||||
return out != 0;
|
||||
}
|
||||
|
||||
inline size_t Value::GetCount() const {
|
||||
size_t out;
|
||||
ThrowOnError(Global<void>::api_.GetValueCount(p_, &out));
|
||||
ThrowOnError(GetApi().GetValueCount(p_, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline Value Value::GetValue(int index, OrtAllocator* allocator) const {
|
||||
OrtValue* out;
|
||||
ThrowOnError(Global<void>::api_.GetValue(p_, index, allocator, &out));
|
||||
ThrowOnError(GetApi().GetValue(p_, index, allocator, &out));
|
||||
return Value{out};
|
||||
}
|
||||
|
||||
inline size_t Value::GetStringTensorDataLength() const {
|
||||
size_t out;
|
||||
ThrowOnError(Global<void>::api_.GetStringTensorDataLength(p_, &out));
|
||||
ThrowOnError(GetApi().GetStringTensorDataLength(p_, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline void Value::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
|
||||
ThrowOnError(Global<void>::api_.GetStringTensorContent(p_, buffer, buffer_length, offsets, offsets_count));
|
||||
ThrowOnError(GetApi().GetStringTensorContent(p_, buffer, buffer_length, offsets, offsets_count));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* Value::GetTensorMutableData() {
|
||||
T* out;
|
||||
ThrowOnError(Global<void>::api_.GetTensorMutableData(p_, (void**)&out));
|
||||
ThrowOnError(GetApi().GetTensorMutableData(p_, (void**)&out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline TypeInfo Value::GetTypeInfo() const {
|
||||
OrtTypeInfo* output;
|
||||
ThrowOnError(Global<void>::api_.GetTypeInfo(p_, &output));
|
||||
ThrowOnError(GetApi().GetTypeInfo(p_, &output));
|
||||
return TypeInfo{output};
|
||||
}
|
||||
|
||||
inline TensorTypeAndShapeInfo Value::GetTensorTypeAndShapeInfo() const {
|
||||
OrtTensorTypeAndShapeInfo* output;
|
||||
ThrowOnError(Global<void>::api_.GetTensorTypeAndShape(p_, &output));
|
||||
ThrowOnError(GetApi().GetTensorTypeAndShape(p_, &output));
|
||||
return TensorTypeAndShapeInfo{output};
|
||||
}
|
||||
|
||||
|
|
@ -614,7 +614,7 @@ inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context,
|
|||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::DisablePerSessionThreads() {
|
||||
ThrowOnError(Global<void>::api_.DisablePerSessionThreads(p_));
|
||||
ThrowOnError(GetApi().DisablePerSessionThreads(p_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
#endif
|
||||
|
||||
#include "dnnl_func_kernel.h"
|
||||
#define EXCLUDE_REFERENCE_TO_ORT_DLL
|
||||
#define ORT_API_MANUAL_INIT
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "core/providers/dnnl/dnnl_common.h"
|
||||
#include "core/providers/dnnl/subgraph/dnnl_conv.h"
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
#include "custom_op_library.h"
|
||||
|
||||
#define EXCLUDE_REFERENCE_TO_ORT_DLL
|
||||
#define ORT_API_MANUAL_INIT
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
#undef EXCLUDE_REFERENCE_TO_ORT_DLL
|
||||
#undef ORT_API_MANUAL_INIT
|
||||
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
|
|
|
|||
Loading…
Reference in a new issue