mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
parent
6412c6a362
commit
2439ced3ec
6 changed files with 4935 additions and 1157 deletions
2683
docs/Doxyfile
Normal file
2683
docs/Doxyfile
Normal file
File diff suppressed because it is too large
Load diff
BIN
docs/images/ONNX_Runtime_logo - Docs.png
Normal file
BIN
docs/images/ONNX_Runtime_logo - Docs.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5 KiB |
File diff suppressed because it is too large
Load diff
|
|
@ -27,9 +27,15 @@
|
|||
#include <iostream>
|
||||
#endif
|
||||
|
||||
/** \brief All C++ Onnxruntime APIs are defined inside this namespace
|
||||
*
|
||||
*/
|
||||
namespace Ort {
|
||||
|
||||
// All C++ methods that can fail will throw an exception of this type
|
||||
/** \brief All C++ methods that can fail will throw an exception of this type
|
||||
*
|
||||
* If <tt>ORT_NO_EXCEPTIONS</tt> is defined, then any error will result in a call to abort()
|
||||
*/
|
||||
struct Exception : std::exception {
|
||||
Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
|
||||
|
||||
|
|
@ -62,7 +68,6 @@ struct Global {
|
|||
};
|
||||
|
||||
// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
|
||||
|
||||
template <typename T>
|
||||
#ifdef ORT_API_MANUAL_INIT
|
||||
const OrtApi* Global<T>::api_{};
|
||||
|
|
@ -71,11 +76,10 @@ inline void InitApi() { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VER
|
|||
const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
||||
#endif
|
||||
|
||||
// This returns a reference to the OrtApi interface in use, in case someone wants to use the C API functions
|
||||
/// This returns a reference to the OrtApi interface in use
|
||||
inline const OrtApi& GetApi() { return *Global<void>::api_; }
|
||||
|
||||
// This is a C++ wrapper for GetAvailableProviders() C API and returns
|
||||
// a vector of strings representing the available execution providers.
|
||||
/// This is a C++ wrapper for OrtApi::GetAvailableProviders() and returns a vector of strings representing the available execution providers.
|
||||
std::vector<std::string> GetAvailableProviders();
|
||||
|
||||
// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
|
||||
|
|
@ -100,8 +104,9 @@ ORT_DEFINE_RELEASE(ThreadingOptions);
|
|||
ORT_DEFINE_RELEASE(IoBinding);
|
||||
ORT_DEFINE_RELEASE(ArenaCfg);
|
||||
|
||||
/*! \class Ort::Float16_t
|
||||
* \brief it is a structure that represents float16 data.
|
||||
#undef ORT_DEFINE_RELEASE
|
||||
|
||||
/** \brief IEEE 754 half-precision floating point data type
|
||||
* \details It is necessary for type dispatching to make use of C++ API
|
||||
* The type is implicitly convertible to/from uint16_t.
|
||||
* The size of the structure should align with uint16_t and one can freely cast
|
||||
|
|
@ -151,8 +156,7 @@ struct Float16_t {
|
|||
|
||||
static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
|
||||
|
||||
/*! \class Ort::BFloat16_t
|
||||
* \brief is a structure that represents bfloat16 data.
|
||||
/** \brief bfloat16 (Brain Floating Point) data type
|
||||
* \details It is necessary for type dispatching to make use of C++ API
|
||||
* The type is implicitly convertible to/from uint16_t.
|
||||
* The size of the structure should align with uint16_t and one can freely cast
|
||||
|
|
@ -171,7 +175,13 @@ struct BFloat16_t {
|
|||
|
||||
static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
|
||||
|
||||
// This is used internally by the C++ API. This is the common base class used by the wrapper objects.
|
||||
/** \brief Used internally by the C++ API. C++ wrapper types inherit from this
|
||||
*
|
||||
* This is a zero cost abstraction to wrap the C API objects and delete them on destruction.
|
||||
* There is a secondary class 'Unowned<T>' that is used to prevent deletion on destruction (Used for return types that are
|
||||
* not owned by the caller)
|
||||
*
|
||||
*/
|
||||
template <typename T>
|
||||
struct Base {
|
||||
using contained_type = T;
|
||||
|
|
@ -186,6 +196,7 @@ struct Base {
|
|||
operator T*() { return p_; }
|
||||
operator const T*() const { return p_; }
|
||||
|
||||
/// \brief Releases ownership of the contained pointer
|
||||
T* release() {
|
||||
T* p = p_;
|
||||
p_ = nullptr;
|
||||
|
|
@ -196,46 +207,20 @@ struct Base {
|
|||
Base(const Base&) = delete;
|
||||
Base& operator=(const Base&) = delete;
|
||||
Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
|
||||
void operator=(Base&& v) noexcept {
|
||||
OrtRelease(p_);
|
||||
p_ = v.p_;
|
||||
v.p_ = nullptr;
|
||||
}
|
||||
void operator=(Base&& v) noexcept { OrtRelease(p_); p_ = v.release(); }
|
||||
|
||||
T* p_{};
|
||||
|
||||
template <typename>
|
||||
friend struct Unowned; // This friend line is needed to keep the centos C++ compiler from giving an error
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Base<const T> {
|
||||
using contained_type = const T;
|
||||
|
||||
Base() = default;
|
||||
Base(const T* p) : p_{p} {
|
||||
if (!p)
|
||||
ORT_CXX_API_THROW("Invalid instance ptr", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
~Base() = default;
|
||||
|
||||
operator const T*() const { return p_; }
|
||||
|
||||
protected:
|
||||
Base(const Base&) = delete;
|
||||
Base& operator=(const Base&) = delete;
|
||||
Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
|
||||
void operator=(Base&& v) noexcept {
|
||||
p_ = v.p_;
|
||||
v.p_ = nullptr;
|
||||
}
|
||||
|
||||
const T* p_{};
|
||||
template <typename> friend struct Unowned; // This friend line is needed to keep the centos C++ compiler from giving an error
|
||||
};
|
||||
|
||||
/** \brief Wraps an object that inherits from Ort::Base and stops it from deleting the contained pointer on destruction
|
||||
*
|
||||
* This has the effect of making it not own the memory held by Ort::Base.
|
||||
*/
|
||||
template <typename T>
|
||||
struct Unowned : T {
|
||||
Unowned(decltype(T::p_) p) : T{p} {}
|
||||
Unowned(typename T::contained_type* p) : T{p} {}
|
||||
Unowned(Unowned&& v) : T{v.p_} {}
|
||||
~Unowned() { this->release(); }
|
||||
};
|
||||
|
|
@ -247,174 +232,238 @@ struct TypeInfo;
|
|||
struct Value;
|
||||
struct ModelMetadata;
|
||||
|
||||
/** \brief The Env (Environment)
|
||||
*
|
||||
* The Env holds the logging state used by all other objects.
|
||||
* <b>Note:</b> One Env must be created before using any other Onnxruntime functionality
|
||||
*/
|
||||
struct Env : Base<OrtEnv> {
|
||||
Env(std::nullptr_t) {}
|
||||
explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used
|
||||
|
||||
/// \brief Wraps OrtApi::CreateEnv
|
||||
Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
|
||||
Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
|
||||
|
||||
/// \brief Wraps OrtApi::CreateEnvWithCustomLogger
|
||||
Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
|
||||
|
||||
/// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools
|
||||
Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
|
||||
|
||||
/// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools
|
||||
Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
|
||||
OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
|
||||
|
||||
/// \brief C Interop Helper
|
||||
explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
|
||||
|
||||
Env& EnableTelemetryEvents();
|
||||
Env& DisableTelemetryEvents();
|
||||
Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents
|
||||
Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents
|
||||
|
||||
Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg);
|
||||
|
||||
static const OrtApi* s_api;
|
||||
Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator
|
||||
};
|
||||
|
||||
/** \brief Custom Op Domain
|
||||
*
|
||||
*/
|
||||
struct CustomOpDomain : Base<OrtCustomOpDomain> {
|
||||
explicit CustomOpDomain(std::nullptr_t) {}
|
||||
explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used
|
||||
|
||||
/// \brief Wraps OrtApi::CreateCustomOpDomain
|
||||
explicit CustomOpDomain(const char* domain);
|
||||
|
||||
void Add(OrtCustomOp* op);
|
||||
void Add(OrtCustomOp* op); ///< Wraps CustomOpDomain_Add
|
||||
};
|
||||
|
||||
struct RunOptions : Base<OrtRunOptions> {
|
||||
RunOptions(std::nullptr_t) {}
|
||||
RunOptions();
|
||||
explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used
|
||||
RunOptions(); ///< Wraps OrtApi::CreateRunOptions
|
||||
|
||||
RunOptions& SetRunLogVerbosityLevel(int);
|
||||
int GetRunLogVerbosityLevel() const;
|
||||
RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel
|
||||
int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel
|
||||
|
||||
RunOptions& SetRunLogSeverityLevel(int);
|
||||
int GetRunLogSeverityLevel() const;
|
||||
RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel
|
||||
int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel
|
||||
|
||||
RunOptions& SetRunTag(const char* run_tag);
|
||||
const char* GetRunTag() const;
|
||||
RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag
|
||||
const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag
|
||||
|
||||
RunOptions& AddConfigEntry(const char* config_key, const char* config_value);
|
||||
RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry
|
||||
|
||||
// terminate ALL currently executing Session::Run calls that were made using this RunOptions instance
|
||||
/** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance
|
||||
*
|
||||
* If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error
|
||||
* Wraps OrtApi::RunOptionsSetTerminate
|
||||
*/
|
||||
RunOptions& SetTerminate();
|
||||
// unset the terminate flag so this RunOptions instance can be used in a new Session::Run call
|
||||
|
||||
/** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating
|
||||
*
|
||||
* Wraps OrtApi::RunOptionsUnsetTerminate
|
||||
*/
|
||||
RunOptions& UnsetTerminate();
|
||||
};
|
||||
|
||||
/** \brief Options object used when creating a new Session object
|
||||
*
|
||||
* Wraps ::OrtSessionOptions object and methods
|
||||
*/
|
||||
struct SessionOptions : Base<OrtSessionOptions> {
|
||||
explicit SessionOptions(std::nullptr_t) {}
|
||||
SessionOptions();
|
||||
explicit SessionOptions(OrtSessionOptions* p) : Base<OrtSessionOptions>{p} {}
|
||||
explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used
|
||||
SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions
|
||||
explicit SessionOptions(OrtSessionOptions* p) : Base<OrtSessionOptions>{p} {} ///< Used for interop with the C API
|
||||
|
||||
SessionOptions Clone() const;
|
||||
SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions
|
||||
|
||||
SessionOptions& SetIntraOpNumThreads(int intra_op_num_threads);
|
||||
SessionOptions& SetInterOpNumThreads(int inter_op_num_threads);
|
||||
SessionOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level);
|
||||
SessionOptions& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads
|
||||
SessionOptions& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads
|
||||
SessionOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel
|
||||
|
||||
SessionOptions& EnableCpuMemArena();
|
||||
SessionOptions& DisableCpuMemArena();
|
||||
SessionOptions& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena
|
||||
SessionOptions& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena
|
||||
|
||||
SessionOptions& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file);
|
||||
SessionOptions& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath
|
||||
|
||||
SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix);
|
||||
SessionOptions& DisableProfiling();
|
||||
SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling
|
||||
SessionOptions& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling
|
||||
|
||||
SessionOptions& EnableOrtCustomOps();
|
||||
SessionOptions& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps
|
||||
|
||||
SessionOptions& EnableMemPattern();
|
||||
SessionOptions& DisableMemPattern();
|
||||
SessionOptions& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern
|
||||
SessionOptions& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern
|
||||
|
||||
SessionOptions& SetExecutionMode(ExecutionMode execution_mode);
|
||||
SessionOptions& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode
|
||||
|
||||
SessionOptions& SetLogId(const char* logid);
|
||||
SessionOptions& SetLogSeverityLevel(int level);
|
||||
SessionOptions& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId
|
||||
SessionOptions& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel
|
||||
|
||||
SessionOptions& Add(OrtCustomOpDomain* custom_op_domain);
|
||||
SessionOptions& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain
|
||||
|
||||
SessionOptions& DisablePerSessionThreads();
|
||||
SessionOptions& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads
|
||||
|
||||
SessionOptions& AddConfigEntry(const char* config_key, const char* config_value);
|
||||
SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val);
|
||||
SessionOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry
|
||||
SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer
|
||||
|
||||
SessionOptions& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options);
|
||||
SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options);
|
||||
SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options);
|
||||
SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options);
|
||||
SessionOptions& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA
|
||||
SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
|
||||
SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
|
||||
SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
|
||||
};
|
||||
|
||||
/** \brief Wrapper around ::OrtModelMetadata
|
||||
*
|
||||
*/
|
||||
struct ModelMetadata : Base<OrtModelMetadata> {
|
||||
explicit ModelMetadata(std::nullptr_t) {}
|
||||
explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {}
|
||||
explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used
|
||||
explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} ///< Used for interop with the C API
|
||||
|
||||
char* GetProducerName(OrtAllocator* allocator) const;
|
||||
char* GetGraphName(OrtAllocator* allocator) const;
|
||||
char* GetDomain(OrtAllocator* allocator) const;
|
||||
char* GetDescription(OrtAllocator* allocator) const;
|
||||
char* GetGraphDescription(OrtAllocator* allocator) const;
|
||||
char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const;
|
||||
char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const;
|
||||
int64_t GetVersion() const;
|
||||
char* GetProducerName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName
|
||||
char* GetGraphName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName
|
||||
char* GetDomain(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain
|
||||
char* GetDescription(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription
|
||||
char* GetGraphDescription(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription
|
||||
char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys
|
||||
char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap
|
||||
int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion
|
||||
};
|
||||
|
||||
/** \brief Wrapper around ::OrtSession
|
||||
*
|
||||
*/
|
||||
struct Session : Base<OrtSession> {
|
||||
explicit Session(std::nullptr_t) {}
|
||||
Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
|
||||
Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container);
|
||||
Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options);
|
||||
explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used
|
||||
Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession
|
||||
Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
|
||||
Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray
|
||||
|
||||
// Run that will allocate the output values
|
||||
/** \brief Run the model returning results in an Ort allocated vector.
|
||||
*
|
||||
* Wraps OrtApi::Run
|
||||
*
|
||||
* The caller provides a list of inputs and a list of the desired outputs to return.
|
||||
*
|
||||
* See the output logs for more information on warnings/errors that occur while processing the model.
|
||||
* Common errors are.. (TODO)
|
||||
*
|
||||
* \param[in] run_options
|
||||
* \param[in] input_names Array of null terminated strings of length input_count that is the list of input names
|
||||
* \param[in] input_values Array of Value objects of length input_count that is the list of input values
|
||||
* \param[in] input_count Number of inputs (the size of the input_names & input_values arrays)
|
||||
* \param[in] output_names Array of C style strings of length output_count that is the list of output names
|
||||
* \param[in] output_count Number of outputs (the size of the output_names array)
|
||||
* \return A std::vector of Value objects that directly maps to the output_count (eg. output_name[0] is the first entry of the returned vector)
|
||||
*/
|
||||
std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
||||
const char* const* output_names, size_t output_count);
|
||||
// Run for when there is a list of preallocated outputs
|
||||
|
||||
/** \brief Run the model returning results in user provided outputs
|
||||
* Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t)
|
||||
*/
|
||||
void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
||||
const char* const* output_names, Value* output_values, size_t output_count);
|
||||
|
||||
void Run(const RunOptions& run_options, const struct IoBinding&);
|
||||
void Run(const RunOptions& run_options, const struct IoBinding&); ///< Wraps OrtApi::RunWithBinding
|
||||
|
||||
size_t GetInputCount() const;
|
||||
size_t GetOutputCount() const;
|
||||
size_t GetOverridableInitializerCount() const;
|
||||
size_t GetInputCount() const; ///< Returns the number of model inputs
|
||||
size_t GetOutputCount() const; ///< Returns the number of model outputs
|
||||
size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden
|
||||
|
||||
char* GetInputName(size_t index, OrtAllocator* allocator) const;
|
||||
char* GetOutputName(size_t index, OrtAllocator* allocator) const;
|
||||
char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const;
|
||||
char* EndProfiling(OrtAllocator* allocator) const;
|
||||
uint64_t GetProfilingStartTimeNs() const;
|
||||
ModelMetadata GetModelMetadata() const;
|
||||
char* GetInputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetInputName
|
||||
char* GetOutputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOutputName
|
||||
char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName
|
||||
char* EndProfiling(OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionEndProfiling
|
||||
uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
|
||||
ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata
|
||||
|
||||
TypeInfo GetInputTypeInfo(size_t index) const;
|
||||
TypeInfo GetOutputTypeInfo(size_t index) const;
|
||||
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const;
|
||||
TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo
|
||||
TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo
|
||||
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo
|
||||
};
|
||||
|
||||
/** \brief Wrapper around ::OrtTensorTypeAndShapeInfo
|
||||
*
|
||||
*/
|
||||
struct TensorTypeAndShapeInfo : Base<OrtTensorTypeAndShapeInfo> {
|
||||
explicit TensorTypeAndShapeInfo(std::nullptr_t) {}
|
||||
explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : Base<OrtTensorTypeAndShapeInfo>{p} {}
|
||||
explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used
|
||||
explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : Base<OrtTensorTypeAndShapeInfo>{p} {} ///< Used for interop with the C API
|
||||
|
||||
ONNXTensorElementDataType GetElementType() const;
|
||||
size_t GetElementCount() const;
|
||||
ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType
|
||||
size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount
|
||||
|
||||
size_t GetDimensionsCount() const;
|
||||
void GetDimensions(int64_t* values, size_t values_count) const;
|
||||
void GetSymbolicDimensions(const char** values, size_t values_count) const;
|
||||
size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount
|
||||
void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions
|
||||
void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions
|
||||
|
||||
std::vector<int64_t> GetShape() const;
|
||||
std::vector<int64_t> GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape
|
||||
};
|
||||
|
||||
/** \brief Wrapper around ::OrtSequenceTypeInfo
|
||||
*
|
||||
*/
|
||||
struct SequenceTypeInfo : Base<OrtSequenceTypeInfo> {
|
||||
explicit SequenceTypeInfo(std::nullptr_t) {}
|
||||
explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : Base<OrtSequenceTypeInfo>{p} {}
|
||||
explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used
|
||||
explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : Base<OrtSequenceTypeInfo>{p} {} ///< Used for interop with the C API
|
||||
|
||||
TypeInfo GetSequenceElementType() const;
|
||||
TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType
|
||||
};
|
||||
|
||||
/** \brief Wrapper around ::OrtMapTypeInfo
|
||||
*
|
||||
*/
|
||||
struct MapTypeInfo : Base<OrtMapTypeInfo> {
|
||||
explicit MapTypeInfo(std::nullptr_t) {}
|
||||
explicit MapTypeInfo(OrtMapTypeInfo* p) : Base<OrtMapTypeInfo>{p} {}
|
||||
explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used
|
||||
explicit MapTypeInfo(OrtMapTypeInfo* p) : Base<OrtMapTypeInfo>{p} {} ///< Used for interop with the C API
|
||||
|
||||
ONNXTensorElementDataType GetMapKeyType() const;
|
||||
TypeInfo GetMapValueType() const;
|
||||
ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType
|
||||
TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType
|
||||
};
|
||||
|
||||
struct TypeInfo : Base<OrtTypeInfo> {
|
||||
explicit TypeInfo(std::nullptr_t) {}
|
||||
explicit TypeInfo(OrtTypeInfo* p) : Base<OrtTypeInfo>{p} {}
|
||||
explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used
|
||||
explicit TypeInfo(OrtTypeInfo* p) : Base<OrtTypeInfo>{p} {} ///< C API Interop
|
||||
|
||||
Unowned<TensorTypeAndShapeInfo> GetTensorTypeAndShapeInfo() const;
|
||||
Unowned<SequenceTypeInfo> GetSequenceTypeInfo() const;
|
||||
Unowned<MapTypeInfo> GetMapTypeInfo() const;
|
||||
Unowned<TensorTypeAndShapeInfo> GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo
|
||||
Unowned<SequenceTypeInfo> GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo
|
||||
Unowned<MapTypeInfo> GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo
|
||||
|
||||
ONNXType GetONNXType() const;
|
||||
};
|
||||
|
|
@ -444,8 +493,10 @@ struct Value : Base<OrtValue> {
|
|||
size_t shape_len;
|
||||
};
|
||||
|
||||
/// \brief Wraps OrtApi::CreateTensorWithDataAsOrtValue
|
||||
template <typename T>
|
||||
static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
|
||||
/// \brief Wraps OrtApi::CreateTensorWithDataAsOrtValue
|
||||
static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
|
||||
ONNXTensorElementDataType type);
|
||||
|
||||
|
|
@ -517,8 +568,10 @@ struct Value : Base<OrtValue> {
|
|||
|
||||
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
||||
|
||||
// \brief Wraps OrtApi::CreateTensorAsOrtValue
|
||||
template <typename T>
|
||||
static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
|
||||
// \brief Wraps OrtApi::CreateTensorAsOrtValue
|
||||
static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
|
||||
|
||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
|
|
@ -565,7 +618,7 @@ struct Value : Base<OrtValue> {
|
|||
/// at difference device than the allocator, a X-device copy will be performed if possible.
|
||||
/// </summary>
|
||||
/// <param name="data_mem_info">specified buffer memory description</param>
|
||||
/// <param name="values_param">values buffer information</param>
|
||||
/// <param name="values">values buffer information</param>
|
||||
/// <param name="inner_indices_data">csr inner indices pointer or nullptr for fully sparse tensors</param>
|
||||
/// <param name="inner_indices_num">number of csr inner indices or 0 for fully sparse tensors</param>
|
||||
/// <param name="outer_indices_data">pointer to csr indices data or nullptr for fully sparse tensors</param>
|
||||
|
|
@ -581,7 +634,7 @@ struct Value : Base<OrtValue> {
|
|||
/// at difference device than the allocator, a X-device copy will be performed if possible.
|
||||
/// </summary>
|
||||
/// <param name="data_mem_info">specified buffer memory description</param>
|
||||
/// <param name="values_param">values buffer information</param>
|
||||
/// <param name="values">values buffer information</param>
|
||||
/// <param name="indices_shape">indices shape. use {0} for fully sparse tensors</param>
|
||||
/// <param name="indices_data">pointer to indices data or nullptr for fully sparse tensors</param>
|
||||
void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
|
||||
|
|
@ -609,9 +662,9 @@ struct Value : Base<OrtValue> {
|
|||
/// indices have their own enum values even if a give format has more than one kind of indices.
|
||||
/// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer.
|
||||
/// </summary>
|
||||
/// <param name="">enum requested</param>
|
||||
/// <param name="format">enum requested</param>
|
||||
/// <returns>type and shape information</returns>
|
||||
TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat) const;
|
||||
TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const;
|
||||
|
||||
/// <summary>
|
||||
/// The API retrieves a pointer to the internal indices buffer. The API merely performs
|
||||
|
|
@ -627,21 +680,21 @@ struct Value : Base<OrtValue> {
|
|||
|
||||
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
||||
|
||||
static Value CreateMap(Value& keys, Value& values);
|
||||
static Value CreateSequence(std::vector<Value>& values);
|
||||
static Value CreateMap(Value& keys, Value& values); ///< Wraps OrtApi::CreateValue
|
||||
static Value CreateSequence(std::vector<Value>& values); ///< Wraps OrtApi::CreateValue
|
||||
|
||||
template <typename T>
|
||||
static Value CreateOpaque(const char* domain, const char* type_name, const T&);
|
||||
static Value CreateOpaque(const char* domain, const char* type_name, const T&); ///< Wraps OrtApi::CreateOpaqueValue
|
||||
|
||||
template <typename T>
|
||||
void GetOpaqueData(const char* domain, const char* type_name, T&) const;
|
||||
void GetOpaqueData(const char* domain, const char* type_name, T&) const; ///< Wraps OrtApi::GetOpaqueValue
|
||||
|
||||
explicit Value(std::nullptr_t) {}
|
||||
explicit Value(OrtValue* p) : Base<OrtValue>{p} {}
|
||||
explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used
|
||||
explicit Value(OrtValue* p) : Base<OrtValue>{p} {} ///< Used for interop with the C API
|
||||
Value(Value&&) = default;
|
||||
Value& operator=(Value&&) = default;
|
||||
|
||||
bool IsTensor() const;
|
||||
bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc
|
||||
|
||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
/// <summary>
|
||||
|
|
@ -679,10 +732,10 @@ struct Value : Base<OrtValue> {
|
|||
void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
|
||||
|
||||
template <typename T>
|
||||
T* GetTensorMutableData();
|
||||
T* GetTensorMutableData(); ///< Wraps OrtApi::GetTensorMutableData
|
||||
|
||||
template <typename T>
|
||||
const T* GetTensorData() const;
|
||||
const T* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData
|
||||
|
||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
/// <summary>
|
||||
|
|
@ -773,33 +826,19 @@ struct AllocatorWithDefaultOptions {
|
|||
OrtAllocator* p_{};
|
||||
};
|
||||
|
||||
template <typename B>
|
||||
struct BaseMemoryInfo : B {
|
||||
BaseMemoryInfo() = default;
|
||||
explicit BaseMemoryInfo(typename B::contained_type* p) : B(p) {}
|
||||
~BaseMemoryInfo() = default;
|
||||
BaseMemoryInfo(BaseMemoryInfo&&) = default;
|
||||
BaseMemoryInfo& operator=(BaseMemoryInfo&&) = default;
|
||||
struct MemoryInfo : Base<OrtMemoryInfo> {
|
||||
static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1);
|
||||
|
||||
explicit MemoryInfo(std::nullptr_t) {}
|
||||
explicit MemoryInfo(OrtMemoryInfo* p) : Base<OrtMemoryInfo>{p} {} ///< Used for interop with the C API
|
||||
MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
|
||||
|
||||
std::string GetAllocatorName() const;
|
||||
OrtAllocatorType GetAllocatorType() const;
|
||||
int GetDeviceId() const;
|
||||
OrtMemType GetMemoryType() const;
|
||||
template <typename U>
|
||||
bool operator==(const BaseMemoryInfo<U>& o) const;
|
||||
};
|
||||
|
||||
struct UnownedMemoryInfo : BaseMemoryInfo<Base<const OrtMemoryInfo> > {
|
||||
explicit UnownedMemoryInfo(std::nullptr_t) {}
|
||||
explicit UnownedMemoryInfo(const OrtMemoryInfo* p) : BaseMemoryInfo(p) {}
|
||||
};
|
||||
|
||||
struct MemoryInfo : BaseMemoryInfo<Base<OrtMemoryInfo> > {
|
||||
static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1);
|
||||
|
||||
explicit MemoryInfo(std::nullptr_t) {}
|
||||
explicit MemoryInfo(OrtMemoryInfo* p) : BaseMemoryInfo(p) {}
|
||||
MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
|
||||
bool operator==(const MemoryInfo& o) const;
|
||||
};
|
||||
|
||||
struct Allocator : public Base<OrtAllocator> {
|
||||
|
|
@ -809,15 +848,10 @@ struct Allocator : public Base<OrtAllocator> {
|
|||
// The return value will own the allocation
|
||||
MemoryAllocation GetAllocation(size_t size);
|
||||
void Free(void* p) const;
|
||||
UnownedMemoryInfo GetInfo() const;
|
||||
Unowned<const MemoryInfo> GetInfo() const;
|
||||
};
|
||||
|
||||
struct IoBinding : public Base<OrtIoBinding> {
|
||||
private:
|
||||
std::vector<std::string> GetOutputNamesHelper(OrtAllocator*) const;
|
||||
std::vector<Value> GetOutputValuesHelper(OrtAllocator*) const;
|
||||
|
||||
public:
|
||||
explicit IoBinding(Session& session);
|
||||
void BindInput(const char* name, const Value&);
|
||||
void BindOutput(const char* name, const Value&);
|
||||
|
|
@ -828,6 +862,10 @@ struct IoBinding : public Base<OrtIoBinding> {
|
|||
std::vector<Value> GetOutputValues(Allocator&) const;
|
||||
void ClearBoundInputs();
|
||||
void ClearBoundOutputs();
|
||||
|
||||
private:
|
||||
std::vector<std::string> GetOutputNamesHelper(OrtAllocator*) const;
|
||||
std::vector<Value> GetOutputValuesHelper(OrtAllocator*) const;
|
||||
};
|
||||
|
||||
/*! \struct Ort::ArenaCfg
|
||||
|
|
@ -835,8 +873,9 @@ struct IoBinding : public Base<OrtIoBinding> {
|
|||
* \details Please see docs/C_API.md for details
|
||||
*/
|
||||
struct ArenaCfg : Base<OrtArenaCfg> {
|
||||
explicit ArenaCfg(std::nullptr_t) {}
|
||||
explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used
|
||||
/**
|
||||
* Wraps OrtApi::CreateArenaCfg
|
||||
* \param max_mem - use 0 to allow ORT to choose the default
|
||||
* \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
|
||||
* \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
|
||||
|
|
|
|||
|
|
@ -115,37 +115,31 @@ inline const OrtMemoryInfo* AllocatorWithDefaultOptions::GetInfo() const {
|
|||
return out;
|
||||
}
|
||||
|
||||
template <typename B>
|
||||
inline std::string BaseMemoryInfo<B>::GetAllocatorName() const {
|
||||
inline std::string MemoryInfo::GetAllocatorName() const {
|
||||
const char* name = nullptr;
|
||||
ThrowOnError(GetApi().MemoryInfoGetName(*this, &name));
|
||||
return std::string(name);
|
||||
}
|
||||
|
||||
template <typename B>
|
||||
inline OrtAllocatorType BaseMemoryInfo<B>::GetAllocatorType() const {
|
||||
inline OrtAllocatorType MemoryInfo::GetAllocatorType() const {
|
||||
OrtAllocatorType type;
|
||||
ThrowOnError(GetApi().MemoryInfoGetType(*this, &type));
|
||||
return type;
|
||||
}
|
||||
|
||||
template <typename B>
|
||||
int BaseMemoryInfo<B>::GetDeviceId() const {
|
||||
inline int MemoryInfo::GetDeviceId() const {
|
||||
int id = 0;
|
||||
ThrowOnError(GetApi().MemoryInfoGetId(*this, &id));
|
||||
return id;
|
||||
}
|
||||
|
||||
template <typename B>
|
||||
inline OrtMemType BaseMemoryInfo<B>::GetMemoryType() const {
|
||||
inline OrtMemType MemoryInfo::GetMemoryType() const {
|
||||
OrtMemType type;
|
||||
ThrowOnError(GetApi().MemoryInfoGetMemType(*this, &type));
|
||||
return type;
|
||||
}
|
||||
|
||||
template <typename B>
|
||||
template <typename U>
|
||||
inline bool BaseMemoryInfo<B>::operator==(const BaseMemoryInfo<U>& o) const {
|
||||
inline bool MemoryInfo::operator==(const MemoryInfo& o) const {
|
||||
int comp_result = 0;
|
||||
ThrowOnError(Ort::GetApi().CompareMemoryInfo(*this, o, &comp_result));
|
||||
return comp_result == 0;
|
||||
|
|
@ -182,10 +176,10 @@ inline void Allocator::Free(void* p) const {
|
|||
ThrowOnError(GetApi().AllocatorFree(p_, p));
|
||||
}
|
||||
|
||||
inline UnownedMemoryInfo Allocator::GetInfo() const {
|
||||
inline Unowned<const MemoryInfo> Allocator::GetInfo() const {
|
||||
const OrtMemoryInfo* out = nullptr;
|
||||
ThrowOnError(GetApi().AllocatorGetInfo(p_, &out));
|
||||
return UnownedMemoryInfo(out);
|
||||
return Unowned<const MemoryInfo>(const_cast<OrtMemoryInfo*>(out));
|
||||
}
|
||||
|
||||
inline IoBinding::IoBinding(Session& session) {
|
||||
|
|
@ -374,6 +368,12 @@ inline int RunOptions::GetRunLogVerbosityLevel() const {
|
|||
return out;
|
||||
}
|
||||
|
||||
inline int RunOptions::GetRunLogSeverityLevel() const {
|
||||
int out;
|
||||
ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
|
||||
ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
|
||||
return *this;
|
||||
|
|
|
|||
3
include/onnxruntime/core/session/snippets.dox
Normal file
3
include/onnxruntime/core/session/snippets.dox
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
[OrtStatus Return Value]
|
||||
<returns>If no error, nullptr will be returned. If there is an error, a pointer to an ::OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.</returns>
|
||||
[OrtStatus Return Value]
|
||||
Loading…
Reference in a new issue