diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index dc102b8904..2cdfb993a4 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -273,7 +273,7 @@ struct Base; /// /// Covers unowned pointers owned by either the ORT -/// or some other instance of CPP wrappers. +/// or some other instance of CPP wrappers. /// Used for ConstXXX and UnownedXXXX types that are copyable. /// Also convenient to wrap raw OrtXX pointers . /// @@ -332,7 +332,6 @@ using AllocatedStringPtr = std::unique_ptr; * constructors to construct an instance of a Status object from exceptions. */ struct Status : detail::Base { - explicit Status(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used explicit Status(OrtStatus* status); ///< Takes ownership of OrtStatus instance returned from the C API. Must be non-null explicit Status(const Exception&); ///< Creates status instance out of exception @@ -391,14 +390,14 @@ struct RunOptions : detail::Base { explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used RunOptions(); ///< Wraps OrtApi::CreateRunOptions - RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel - int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel + RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel + int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel - RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel - int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel + RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel + int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel - RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag - const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag + RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag + const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry @@ -642,7 +641,7 @@ struct SessionImpl : ConstSessionImpl { * \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector) */ std::vector Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, - const char* const* output_names, size_t output_count); + const char* const* output_names, size_t output_count); /** \brief Run the model returning results in user provided outputs * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t) @@ -662,11 +661,12 @@ using UnownedSession = detail::SessionImpl>; * */ struct Session : detail::SessionImpl { - explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used - Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession - Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer - Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray - Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options, + explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used + Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession + Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer + Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray + Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer ConstSession GetConst() const { return ConstSession{this->p_}; } diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 4101d4fb90..d330f6573a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -54,31 +54,57 @@ inline OrtErrorCode Status::GetErrorCode() const { template struct TypeToTensorType; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; +}; inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size) : allocator_(allocator), p_(p), size_(size) { @@ -234,7 +260,7 @@ template inline std::vector ConstIoBindingImpl::GetOutputValues(OrtAllocator* allocator) const { return binding_utils::GetOutputValuesHelper(this->p_, allocator); } - + template inline void IoBindingImpl::BindInput(const char* name, const Value& value) { ThrowOnError(GetApi().BindInput(this->p_, name, value)); @@ -626,7 +652,7 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_MIG return *this; } -template +template inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) { ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options)); return *this; @@ -794,20 +820,20 @@ inline SessionOptions::SessionOptions() { ThrowOnError(GetApi().CreateSessionOptions(&this->p_)); } -inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) { +inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) { ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_)); } -inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, +inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) { ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_)); } -inline Session::Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) { +inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) { ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_)); } -inline Session::Session(Env& env, const void* model_data, size_t model_data_length, +inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) { ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options, prepacked_weights_container, &this->p_)); @@ -1167,7 +1193,7 @@ void ValueImpl::UseBlockSparseIndices(const Shape& indices_shape, int32_t* in template void ValueImpl::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param, - const int64_t* indices_data, size_t indices_num) { + const int64_t* indices_data, size_t indices_num) { ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape, values_param.values_shape_len, values_param.data.p_data, indices_data, indices_num)); @@ -1175,9 +1201,9 @@ void ValueImpl::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtS template void ValueImpl::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info, - const OrtSparseValuesParam& values, - const int64_t* inner_indices_data, size_t inner_indices_num, - const int64_t* outer_indices_data, size_t outer_indices_num) { + const OrtSparseValuesParam& values, + const int64_t* inner_indices_data, size_t inner_indices_num, + const int64_t* outer_indices_data, size_t outer_indices_num) { ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data, inner_indices_data, inner_indices_num, outer_indices_data, outer_indices_num)); @@ -1185,9 +1211,9 @@ void ValueImpl::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info, template void ValueImpl::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info, - const OrtSparseValuesParam& values, - const Shape& indices_shape, - const int32_t* indices_data) { + const OrtSparseValuesParam& values, + const Shape& indices_shape, + const int32_t* indices_data) { ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data, indices_shape.shape, indices_shape.shape_len, indices_data)); @@ -1317,7 +1343,7 @@ inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType } namespace detail { -template +template inline KernelInfo KernelInfoImpl::Copy() const { OrtKernelInfo* info_copy; Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy)); @@ -1398,7 +1424,7 @@ inline void Op::Invoke(const OrtKernelContext* context, size_t output_count) { static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely"); - auto ort_input_values = reinterpret_cast(input_values); + auto ort_input_values = reinterpret_cast(input_values); auto ort_output_values = reinterpret_cast(output_values); Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast(input_count), ort_output_values, static_cast(output_count)));