mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Make 'env' argument to Session const (#13362)
### Description <!-- Describe your changes. --> The Env argument does not need to be mutable to call the underlying C API. Update the Ort::Session ctor to have a const Env. All other changes are from clang-format running. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Cleanup
This commit is contained in:
parent
9efa8e20bb
commit
565da71275
2 changed files with 68 additions and 42 deletions
|
|
@ -273,7 +273,7 @@ struct Base<const T>;
|
|||
|
||||
/// <summary>
|
||||
/// Covers unowned pointers owned by either the ORT
|
||||
/// or some other instance of CPP wrappers.
|
||||
/// or some other instance of CPP wrappers.
|
||||
/// Used for ConstXXX and UnownedXXXX types that are copyable.
|
||||
/// Also convenient to wrap raw OrtXX pointers .
|
||||
/// </summary>
|
||||
|
|
@ -332,7 +332,6 @@ using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
|
|||
* constructors to construct an instance of a Status object from exceptions.
|
||||
*/
|
||||
struct Status : detail::Base<OrtStatus> {
|
||||
|
||||
explicit Status(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used
|
||||
explicit Status(OrtStatus* status); ///< Takes ownership of OrtStatus instance returned from the C API. Must be non-null
|
||||
explicit Status(const Exception&); ///< Creates status instance out of exception
|
||||
|
|
@ -391,14 +390,14 @@ struct RunOptions : detail::Base<OrtRunOptions> {
|
|||
explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used
|
||||
RunOptions(); ///< Wraps OrtApi::CreateRunOptions
|
||||
|
||||
RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel
|
||||
int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel
|
||||
RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel
|
||||
int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel
|
||||
|
||||
RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel
|
||||
int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel
|
||||
RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel
|
||||
int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel
|
||||
|
||||
RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag
|
||||
const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag
|
||||
RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag
|
||||
const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag
|
||||
|
||||
RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry
|
||||
|
||||
|
|
@ -642,7 +641,7 @@ struct SessionImpl : ConstSessionImpl<T> {
|
|||
* \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector)
|
||||
*/
|
||||
std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
||||
const char* const* output_names, size_t output_count);
|
||||
const char* const* output_names, size_t output_count);
|
||||
|
||||
/** \brief Run the model returning results in user provided outputs
|
||||
* Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t)
|
||||
|
|
@ -662,11 +661,12 @@ using UnownedSession = detail::SessionImpl<detail::Unowned<OrtSession>>;
|
|||
*
|
||||
*/
|
||||
struct Session : detail::SessionImpl<OrtSession> {
|
||||
explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used
|
||||
Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession
|
||||
Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
|
||||
Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray
|
||||
Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
|
||||
explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used
|
||||
Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession
|
||||
Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
|
||||
OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
|
||||
Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray
|
||||
Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
|
||||
OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer
|
||||
|
||||
ConstSession GetConst() const { return ConstSession{this->p_}; }
|
||||
|
|
|
|||
|
|
@ -54,31 +54,57 @@ inline OrtErrorCode Status::GetErrorCode() const {
|
|||
template <typename T>
|
||||
struct TypeToTensorType;
|
||||
template <>
|
||||
struct TypeToTensorType<float> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
struct TypeToTensorType<float> {
|
||||
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
template <>
|
||||
struct TypeToTensorType<Float16_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; };
|
||||
struct TypeToTensorType<Float16_t> {
|
||||
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
|
||||
};
|
||||
template <>
|
||||
struct TypeToTensorType<BFloat16_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; };
|
||||
struct TypeToTensorType<BFloat16_t> {
|
||||
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
|
||||
};
|
||||
template <>
|
||||
struct TypeToTensorType<double> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; };
|
||||
struct TypeToTensorType<double> {
|
||||
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
|
||||
};
|
||||
template <>
|
||||
struct TypeToTensorType<int8_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; };
|
||||
struct TypeToTensorType<int8_t> {
|
||||
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
|
||||
};
|
||||
template <>
|
||||
struct TypeToTensorType<int16_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; };
|
||||
struct TypeToTensorType<int16_t> {
|
||||
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
|
||||
};
|
||||
template <>
|
||||
struct TypeToTensorType<int32_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; };
|
||||
struct TypeToTensorType<int32_t> {
|
||||
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
|
||||
};
|
||||
template <>
|
||||
struct TypeToTensorType<int64_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; };
|
||||
struct TypeToTensorType<int64_t> {
|
||||
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
template <>
|
||||
struct TypeToTensorType<uint8_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; };
|
||||
struct TypeToTensorType<uint8_t> {
|
||||
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
};
|
||||
template <>
|
||||
struct TypeToTensorType<uint16_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; };
|
||||
struct TypeToTensorType<uint16_t> {
|
||||
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
|
||||
};
|
||||
template <>
|
||||
struct TypeToTensorType<uint32_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; };
|
||||
struct TypeToTensorType<uint32_t> {
|
||||
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
|
||||
};
|
||||
template <>
|
||||
struct TypeToTensorType<uint64_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; };
|
||||
struct TypeToTensorType<uint64_t> {
|
||||
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
|
||||
};
|
||||
template <>
|
||||
struct TypeToTensorType<bool> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; };
|
||||
struct TypeToTensorType<bool> {
|
||||
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||
};
|
||||
|
||||
inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
|
||||
: allocator_(allocator), p_(p), size_(size) {
|
||||
|
|
@ -234,7 +260,7 @@ template <typename T>
|
|||
inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues(OrtAllocator* allocator) const {
|
||||
return binding_utils::GetOutputValuesHelper(this->p_, allocator);
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
inline void IoBindingImpl<T>::BindInput(const char* name, const Value& value) {
|
||||
ThrowOnError(GetApi().BindInput(this->p_, name, value));
|
||||
|
|
@ -626,7 +652,7 @@ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_MIG
|
|||
return *this;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) {
|
||||
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
|
||||
return *this;
|
||||
|
|
@ -794,20 +820,20 @@ inline SessionOptions::SessionOptions() {
|
|||
ThrowOnError(GetApi().CreateSessionOptions(&this->p_));
|
||||
}
|
||||
|
||||
inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
|
||||
inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
|
||||
ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_));
|
||||
}
|
||||
|
||||
inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
|
||||
inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
|
||||
OrtPrepackedWeightsContainer* prepacked_weights_container) {
|
||||
ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_));
|
||||
}
|
||||
|
||||
inline Session::Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
|
||||
inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
|
||||
ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_));
|
||||
}
|
||||
|
||||
inline Session::Session(Env& env, const void* model_data, size_t model_data_length,
|
||||
inline Session::Session(const Env& env, const void* model_data, size_t model_data_length,
|
||||
const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
|
||||
ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
|
||||
prepacked_weights_container, &this->p_));
|
||||
|
|
@ -1167,7 +1193,7 @@ void ValueImpl<T>::UseBlockSparseIndices(const Shape& indices_shape, int32_t* in
|
|||
|
||||
template <typename T>
|
||||
void ValueImpl<T>::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param,
|
||||
const int64_t* indices_data, size_t indices_num) {
|
||||
const int64_t* indices_data, size_t indices_num) {
|
||||
ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape,
|
||||
values_param.values_shape_len, values_param.data.p_data,
|
||||
indices_data, indices_num));
|
||||
|
|
@ -1175,9 +1201,9 @@ void ValueImpl<T>::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtS
|
|||
|
||||
template <typename T>
|
||||
void ValueImpl<T>::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
|
||||
const OrtSparseValuesParam& values,
|
||||
const int64_t* inner_indices_data, size_t inner_indices_num,
|
||||
const int64_t* outer_indices_data, size_t outer_indices_num) {
|
||||
const OrtSparseValuesParam& values,
|
||||
const int64_t* inner_indices_data, size_t inner_indices_num,
|
||||
const int64_t* outer_indices_data, size_t outer_indices_num) {
|
||||
ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
|
||||
inner_indices_data, inner_indices_num,
|
||||
outer_indices_data, outer_indices_num));
|
||||
|
|
@ -1185,9 +1211,9 @@ void ValueImpl<T>::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
|
|||
|
||||
template <typename T>
|
||||
void ValueImpl<T>::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
|
||||
const OrtSparseValuesParam& values,
|
||||
const Shape& indices_shape,
|
||||
const int32_t* indices_data) {
|
||||
const OrtSparseValuesParam& values,
|
||||
const Shape& indices_shape,
|
||||
const int32_t* indices_data) {
|
||||
ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
|
||||
indices_shape.shape, indices_shape.shape_len,
|
||||
indices_data));
|
||||
|
|
@ -1317,7 +1343,7 @@ inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType
|
|||
}
|
||||
|
||||
namespace detail {
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
inline KernelInfo KernelInfoImpl<T>::Copy() const {
|
||||
OrtKernelInfo* info_copy;
|
||||
Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy));
|
||||
|
|
@ -1398,7 +1424,7 @@ inline void Op::Invoke(const OrtKernelContext* context,
|
|||
size_t output_count) {
|
||||
static_assert(sizeof(Value) == sizeof(OrtValue*),
|
||||
"Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
|
||||
auto ort_input_values = reinterpret_cast<const OrtValue* const *>(input_values);
|
||||
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
|
||||
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
|
||||
Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast<int>(input_count),
|
||||
ort_output_values, static_cast<int>(output_count)));
|
||||
|
|
|
|||
Loading…
Reference in a new issue