Make 'env' argument to Session const (#13362)

### Description
<!-- Describe your changes. -->
The Env argument does not need to be mutable to call the underlying C
API. Update the Ort::Session ctor to have a const Env.

All other changes are from clang-format running. 

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Cleanup
This commit is contained in:
Scott McKay 2022-10-19 14:23:24 +10:00 committed by GitHub
parent 9efa8e20bb
commit 565da71275
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 68 additions and 42 deletions

View file

@ -273,7 +273,7 @@ struct Base<const T>;
/// <summary>
/// Covers unowned pointers owned by either the ORT
/// or some other instance of CPP wrappers.
/// or some other instance of CPP wrappers.
/// Used for ConstXXX and UnownedXXXX types that are copyable.
/// Also convenient to wrap raw OrtXX pointers .
/// </summary>
@ -332,7 +332,6 @@ using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
* constructors to construct an instance of a Status object from exceptions.
*/
struct Status : detail::Base<OrtStatus> {
explicit Status(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used
explicit Status(OrtStatus* status); ///< Takes ownership of OrtStatus instance returned from the C API. Must be non-null
explicit Status(const Exception&); ///< Creates status instance out of exception
@ -391,14 +390,14 @@ struct RunOptions : detail::Base<OrtRunOptions> {
explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used
RunOptions(); ///< Wraps OrtApi::CreateRunOptions
RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel
int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel
RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel
int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel
RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel
int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel
RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel
int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel
RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag
const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag
RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag
const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag
RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry
@ -642,7 +641,7 @@ struct SessionImpl : ConstSessionImpl<T> {
* \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector)
*/
std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, size_t output_count);
const char* const* output_names, size_t output_count);
/** \brief Run the model returning results in user provided outputs
* Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t)
@ -662,11 +661,12 @@ using UnownedSession = detail::SessionImpl<detail::Unowned<OrtSession>>;
*
*/
struct Session : detail::SessionImpl<OrtSession> {
explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used
Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession
Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray
Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used
Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession
Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray
Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer
ConstSession GetConst() const { return ConstSession{this->p_}; }

View file

@ -54,31 +54,57 @@ inline OrtErrorCode Status::GetErrorCode() const {
template <typename T>
struct TypeToTensorType;
template <>
struct TypeToTensorType<float> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
struct TypeToTensorType<float> {
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
};
template <>
struct TypeToTensorType<Float16_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; };
struct TypeToTensorType<Float16_t> {
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
};
template <>
struct TypeToTensorType<BFloat16_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; };
struct TypeToTensorType<BFloat16_t> {
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
};
template <>
struct TypeToTensorType<double> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; };
struct TypeToTensorType<double> {
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
};
template <>
struct TypeToTensorType<int8_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; };
struct TypeToTensorType<int8_t> {
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
};
template <>
struct TypeToTensorType<int16_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; };
struct TypeToTensorType<int16_t> {
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
};
template <>
struct TypeToTensorType<int32_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; };
struct TypeToTensorType<int32_t> {
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
};
template <>
struct TypeToTensorType<int64_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; };
struct TypeToTensorType<int64_t> {
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
template <>
struct TypeToTensorType<uint8_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; };
struct TypeToTensorType<uint8_t> {
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
};
template <>
struct TypeToTensorType<uint16_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; };
struct TypeToTensorType<uint16_t> {
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
};
template <>
struct TypeToTensorType<uint32_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; };
struct TypeToTensorType<uint32_t> {
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
};
template <>
struct TypeToTensorType<uint64_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; };
struct TypeToTensorType<uint64_t> {
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
};
template <>
struct TypeToTensorType<bool> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; };
struct TypeToTensorType<bool> {
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
};
inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
: allocator_(allocator), p_(p), size_(size) {
@ -234,7 +260,7 @@ template <typename T>
inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues(OrtAllocator* allocator) const {
return binding_utils::GetOutputValuesHelper(this->p_, allocator);
}
template <typename T>
inline void IoBindingImpl<T>::BindInput(const char* name, const Value& value) {
ThrowOnError(GetApi().BindInput(this->p_, name, value));
@ -626,7 +652,7 @@ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_MIG
return *this;
}
template<typename T>
template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) {
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
return *this;
@ -794,20 +820,20 @@ inline SessionOptions::SessionOptions() {
ThrowOnError(GetApi().CreateSessionOptions(&this->p_));
}
inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_));
}
inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
OrtPrepackedWeightsContainer* prepacked_weights_container) {
ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_));
}
inline Session::Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_));
}
inline Session::Session(Env& env, const void* model_data, size_t model_data_length,
inline Session::Session(const Env& env, const void* model_data, size_t model_data_length,
const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
prepacked_weights_container, &this->p_));
@ -1167,7 +1193,7 @@ void ValueImpl<T>::UseBlockSparseIndices(const Shape& indices_shape, int32_t* in
template <typename T>
void ValueImpl<T>::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param,
const int64_t* indices_data, size_t indices_num) {
const int64_t* indices_data, size_t indices_num) {
ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape,
values_param.values_shape_len, values_param.data.p_data,
indices_data, indices_num));
@ -1175,9 +1201,9 @@ void ValueImpl<T>::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtS
template <typename T>
void ValueImpl<T>::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
const OrtSparseValuesParam& values,
const int64_t* inner_indices_data, size_t inner_indices_num,
const int64_t* outer_indices_data, size_t outer_indices_num) {
const OrtSparseValuesParam& values,
const int64_t* inner_indices_data, size_t inner_indices_num,
const int64_t* outer_indices_data, size_t outer_indices_num) {
ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
inner_indices_data, inner_indices_num,
outer_indices_data, outer_indices_num));
@ -1185,9 +1211,9 @@ void ValueImpl<T>::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
template <typename T>
void ValueImpl<T>::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
const OrtSparseValuesParam& values,
const Shape& indices_shape,
const int32_t* indices_data) {
const OrtSparseValuesParam& values,
const Shape& indices_shape,
const int32_t* indices_data) {
ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
indices_shape.shape, indices_shape.shape_len,
indices_data));
@ -1317,7 +1343,7 @@ inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType
}
namespace detail {
template<typename T>
template <typename T>
inline KernelInfo KernelInfoImpl<T>::Copy() const {
OrtKernelInfo* info_copy;
Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy));
@ -1398,7 +1424,7 @@ inline void Op::Invoke(const OrtKernelContext* context,
size_t output_count) {
static_assert(sizeof(Value) == sizeof(OrtValue*),
"Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
auto ort_input_values = reinterpret_cast<const OrtValue* const *>(input_values);
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast<int>(input_count),
ort_output_values, static_cast<int>(output_count)));