onnxruntime/onnxruntime/python/onnxruntime_pybind_state.cc
Chi Lo fa4cbcd36b
[TensorRT EP] Add new provider option to exclude nodes from running on TRT (#22681)
Add new provider option `trt_op_types_to_exclude`:
- User can provide op type list to be excluded from running on TRT
- e.g. `trt_op_types_to_exclude="MaxPool"`

There is a known performance issue with the DDS ops (NonMaxSuppression,
NonZero and RoiAlign) from TRT versions 10.0 to 10.7. TRT EP excludes
DDS ops from running on TRT by default, user can override default value
with empty string to include all ops.
2024-11-13 11:34:43 -08:00

2412 lines
118 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
// Licensed under the MIT License.
#include "python/onnxruntime_pybind_exceptions.h"
#include "python/onnxruntime_pybind_mlvalue.h"
#include "python/onnxruntime_pybind_state_common.h"
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#define PY_ARRAY_UNIQUE_SYMBOL onnxruntime_python_ARRAY_API
#include "python/numpy_helper.h"
#include "core/common/inlined_containers.h"
#include "core/common/logging/logging.h"
#include "core/common/logging/severity.h"
#include "core/common/narrow.h"
#include "core/common/optional.h"
#include "core/common/path_string.h"
#include "core/framework/arena_extend_strategy.h"
#include "core/framework/data_transfer_utils.h"
#include "core/framework/data_types_internal.h"
#include "core/framework/provider_options_utils.h"
#include "core/framework/random_seed.h"
#include "core/framework/sparse_tensor.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/TensorSeq.h"
#include "core/graph/graph_viewer.h"
#include "core/platform/env.h"
#include "core/providers/get_execution_providers.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"
#include "core/session/IOBinding.h"
#include "core/session/abi_session_options_impl.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/provider_bridge_ort.h"
#include "core/session/lora_adapters.h"
#ifdef ENABLE_ATEN
#include "contrib_ops/cpu/aten_ops/aten_op_executor.h"
#endif
#ifdef USE_CUDA
#include <cuda.h> // for CUDA_VERSION
#include <cudnn.h> // for CUDNN_MAJOR
#endif
#if defined(USE_COREML)
#include "core/providers/coreml/coreml_provider_factory.h"
#endif
#include <pybind11/functional.h>
// Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct,
// GCC 4.x doesn't seem to define this and it breaks the pipelines based on CentOS as it uses
// GCC 4.x.
// (This static var is referenced in GetCudaToHostMemCpyFunction())
const OrtDevice::DeviceType OrtDevice::GPU;
#if defined(_MSC_VER)
#pragma warning(disable : 4267 4996 4503)
#endif // _MSC_VER
#include <iterator>
#include <algorithm>
namespace onnxruntime {
namespace python {
namespace py = pybind11;
using namespace onnxruntime;
using namespace onnxruntime::logging;
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
// "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
// TODO: we may delay-init this variable
#pragma warning(disable : 26426)
#endif
static Env& platform_env = Env::Default();
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#endif
using PyCallback = std::function<void(std::vector<py::object>, py::object user_data, std::string)>;
struct AsyncResource {
std::vector<OrtValue> feeds;
std::vector<const OrtValue*> feeds_raw;
std::vector<std::string> feed_names;
std::vector<const char*> feed_names_raw;
std::vector<OrtValue*> fetches_raw; // will be released during destruction
std::vector<std::string> fetch_names;
std::vector<const char*> fetch_names_raw;
RunOptions default_run_option;
PyCallback callback;
py::object user_data;
void ReserveFeeds(size_t sz) {
feeds.reserve(sz);
feeds_raw.reserve(sz);
feed_names.reserve(sz);
feed_names_raw.reserve(sz);
}
void ReserveFetches(size_t sz) {
fetches_raw.reserve(sz);
fetch_names.reserve(sz);
fetch_names_raw.reserve(sz);
}
~AsyncResource() {
std::for_each(fetches_raw.begin(), fetches_raw.end(), [](const OrtValue* fetch) {
if (fetch) {
std::unique_ptr<const OrtValue> fetch_recycler(fetch);
}
});
fetches_raw.clear();
}
};
void AsyncCallback(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr ort_status) {
ORT_ENFORCE(user_data, "user data must not be NULL for callback in python");
auto invoke_callback = [&]() {
std::unique_ptr<AsyncResource> async_resource{reinterpret_cast<AsyncResource*>(user_data)};
Ort::Status status(ort_status);
// return on error
if (!status.IsOK()) {
async_resource->callback({}, async_resource->user_data, status.GetErrorMessage());
return;
}
std::vector<py::object> rfetch;
rfetch.reserve(num_outputs);
size_t pos = 0;
for (size_t ith = 0; ith < num_outputs; ++ith) {
const auto& fet = *outputs[ith];
if (fet.IsAllocated()) {
if (fet.IsTensor()) {
rfetch.push_back(AddTensorAsPyObj(fet, nullptr, nullptr));
} else if (fet.IsSparseTensor()) {
rfetch.push_back(GetPyObjectFromSparseTensor(pos, fet, nullptr));
} else {
rfetch.push_back(AddNonTensorAsPyObj(fet, nullptr, nullptr));
}
} else {
rfetch.push_back(py::none());
}
++pos;
}
async_resource->callback(rfetch, async_resource->user_data, "");
};
if (PyGILState_Check()) {
invoke_callback();
} else {
// acquire GIL to safely:
// 1) invoke python callback
// 2) create, manipulate, and destroy python objects
py::gil_scoped_acquire acquire;
invoke_callback();
}
}
void AppendLoraParametersAsInputs(const RunOptions& run_options,
size_t total_entries,
NameMLValMap& feeds) {
for (const auto* adapter : run_options.active_adapters) {
total_entries += adapter->GetParamNum();
}
feeds.reserve(total_entries + feeds.size());
// Append necessary inputs for active adapters
for (const auto* adapter : run_options.active_adapters) {
auto [begin, end] = adapter->GetParamIterators();
for (; begin != end; ++begin) {
const auto& [name, param] = *begin;
feeds.insert(std::make_pair(name, param.GetMapped()));
}
}
}
template <typename T>
static py::object AddNonTensor(const OrtValue& val,
const DataTransferManager* /*data_transfer_manager*/,
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* /*mem_cpy_to_host_functions*/) {
return py::cast(val.Get<T>());
}
// This function is used to return strings from a string tensor to python
// as a numpy array of strings
// Strings are always on CPU and must always be copied to python memory
py::array StringTensorToNumpyArray(const Tensor& tensor) {
// Create the result and allocate memory with the right size
py::array result(py::dtype(NPY_OBJECT), tensor.Shape().GetDims());
const auto span = tensor.DataAsSpan<std::string>();
auto* mutable_data = reinterpret_cast<py::object*>(result.mutable_data());
for (size_t i = 0, lim = span.size(); i < lim; ++i) {
mutable_data[i] = py::cast(span[i]);
}
return result;
}
pybind11::array PrimitiveTensorToNumpyOverOrtValue(const OrtValue& ort_value) {
const Tensor& tensor = ort_value.Get<Tensor>();
// The capsule destructor must be stateless
// We create a copy of OrtValue on the heap.
auto memory_release = [](void* data) {
auto* ort_value = reinterpret_cast<OrtValue*>(data);
delete ort_value;
};
const int numpy_type = OnnxRuntimeTensorToNumpyType(tensor.DataType());
auto ort_value_ptr = std::make_unique<OrtValue>(ort_value);
pybind11::capsule caps(ort_value_ptr.get(), memory_release);
ort_value_ptr.release();
// Not using array_t<T> because it may not handle MLFloat16 properly
pybind11::array result(py::dtype(numpy_type), tensor.Shape().GetDims(),
tensor.DataRaw(),
caps);
return result;
}
pybind11::array PrimitiveTensorToNumpyFromDevice(const OrtValue& ort_value, const DataTransferAlternative& dtm) {
const Tensor& tensor = ort_value.Get<Tensor>();
const int numpy_type = OnnxRuntimeTensorToNumpyType(tensor.DataType());
pybind11::array result(py::dtype(numpy_type), tensor.Shape().GetDims());
void* data = result.mutable_data();
if (std::holds_alternative<const DataTransferManager*>(dtm)) {
const DataTransferManager* data_transfer = std::get<const DataTransferManager*>(dtm);
static const OrtMemoryInfo cpu_alloc_info{onnxruntime::CPU, OrtDeviceAllocator};
const auto span = gsl::make_span<char>(reinterpret_cast<char*>(data), tensor.SizeInBytes());
ORT_THROW_IF_ERROR(CopyTensorDataToByteSpan(*data_transfer, tensor, cpu_alloc_info, span));
} else {
std::get<MemCpyFunc>(dtm)(data, tensor.DataRaw(), tensor.SizeInBytes());
}
return result;
}
// In all cases, we may not have access to a DataTransferManager, hence the user may specify functions that
// pretty much does what a DataTransferManager does - copy data from device(s) to the host
py::object GetPyObjFromTensor(const OrtValue& ort_value,
const DataTransferManager* data_transfer_manager,
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions) {
ORT_ENFORCE(ort_value.IsTensor(), "This function only supports tensors");
const auto& tensor = ort_value.Get<Tensor>();
if (tensor.IsDataTypeString()) {
ORT_ENFORCE(tensor.Location().device.Type() == OrtDevice::CPU, "Strings can only be on CPU");
// Create a numpy array of strings (python objects) by copy/converting them
py::array result = StringTensorToNumpyArray(tensor);
return py::cast<py::object>(result);
}
const auto device_type = tensor.Location().device.Type();
// Create an numpy array on top of the OrtValue memory, no copy
if (device_type == OrtDevice::CPU) {
py::array result = PrimitiveTensorToNumpyOverOrtValue(ort_value);
return py::cast<py::object>(result);
}
if (!data_transfer_manager && !mem_cpy_to_host_functions) {
throw std::runtime_error(
"GetPyObjFromTensor: Either data transfer manager or a "
"function to copy data to the host is needed to convert non-CPU tensor to numpy array");
}
py::array result;
if (data_transfer_manager != nullptr) {
result = PrimitiveTensorToNumpyFromDevice(ort_value, data_transfer_manager);
} else {
auto mem_cpy_to_host = mem_cpy_to_host_functions->find(device_type);
ORT_ENFORCE(mem_cpy_to_host != mem_cpy_to_host_functions->end(),
"Unable to locate a function that can copy data to the host from the device");
result = PrimitiveTensorToNumpyFromDevice(ort_value, mem_cpy_to_host->second);
}
return py::cast<py::object>(result);
}
const char* GetDeviceName(const OrtDevice& device) {
switch (device.Type()) {
case OrtDevice::CPU:
return CPU;
case OrtDevice::GPU:
return CUDA;
case OrtDevice::DML:
return DML;
case OrtDevice::FPGA:
return "FPGA";
case OrtDevice::NPU:
#ifdef USE_CANN
return CANN;
#else
return "NPU";
#endif
default:
ORT_THROW("Unknown device type: ", device.Type());
}
}
py::object GetPyObjectFromSparseTensor(size_t pos, const OrtValue& ort_value, const DataTransferManager* data_transfer_manager) {
#if !defined(DISABLE_SPARSE_TENSORS)
if (!ort_value.IsSparseTensor()) {
ORT_THROW("Must be a sparse tensor");
}
auto& logger = logging::LoggingManager::DefaultLogger();
const SparseTensor& src_sparse_tensor = ort_value.Get<SparseTensor>();
std::unique_ptr<PySparseTensor> py_sparse_tensor;
auto device_type = src_sparse_tensor.Location().device.Type();
if (device_type != OrtDevice::CPU) {
if (!data_transfer_manager) {
LOGS(logger, WARNING) << "Returned OrtValue with sparse tensor at position: " << pos << " is on GPU but no data_transfer_manager provided."
<< " Returned it will have its data on GPU, you can copy it using numpy_array_to_cpu()";
py_sparse_tensor = std::make_unique<PySparseTensor>(ort_value);
} else {
auto dst_sparse_tensor = std::make_unique<SparseTensor>(src_sparse_tensor.DataType(), src_sparse_tensor.DenseShape(), GetAllocator());
auto status = src_sparse_tensor.Copy(*data_transfer_manager, *dst_sparse_tensor);
OrtPybindThrowIfError(status);
py_sparse_tensor = std::make_unique<PySparseTensor>(std::move(dst_sparse_tensor));
}
} else {
py_sparse_tensor = std::make_unique<PySparseTensor>(ort_value);
}
py::object result = py::cast(py_sparse_tensor.get(), py::return_value_policy::take_ownership);
py_sparse_tensor.release();
return result;
#else
ORT_UNUSED_PARAMETER(pos);
ORT_UNUSED_PARAMETER(ort_value);
ORT_UNUSED_PARAMETER(data_transfer_manager);
ORT_THROW("SparseTensor support is disabled in this build.");
#endif // !defined(DISABLE_SPARSE_TENSORS)
}
template <>
py::object AddNonTensor<TensorSeq>(const OrtValue& val,
const DataTransferManager* data_transfer_manager,
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions) {
const auto& seq_tensors = val.Get<TensorSeq>();
py::list py_list;
for (const auto& ort_value : seq_tensors) {
py::object obj = GetPyObjFromTensor(ort_value, data_transfer_manager, mem_cpy_to_host_functions);
py_list.append(std::move(obj));
}
// XToolChain kills the build
// local variable 'py_list' will be copied despite being returned by name [-Werror,-Wreturn-std-move]
// call 'std::move' explicitly to avoid copying
// We choose to cast it to object explicitly
return py::cast<py::object>(py_list);
}
py::object AddNonTensorAsPyObj(const OrtValue& val,
const DataTransferManager* data_transfer_manager,
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions) {
// Should be in sync with core/framework/datatypes.h
auto val_type = val.Type();
if (val_type->IsTensorSequenceType()) {
return AddNonTensor<TensorSeq>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else {
#if !defined(DISABLE_ML_OPS)
utils::ContainerChecker c_checker(val_type);
if (c_checker.IsMap()) {
if (c_checker.IsMapOf<std::string, std::string>()) {
return AddNonTensor<MapStringToString>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else if (c_checker.IsMapOf<std::string, int64_t>()) {
return AddNonTensor<MapStringToInt64>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else if (c_checker.IsMapOf<std::string, float>()) {
return AddNonTensor<MapStringToFloat>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else if (c_checker.IsMapOf<std::string, double>()) {
return AddNonTensor<MapStringToDouble>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else if (c_checker.IsMapOf<int64_t, std::string>()) {
return AddNonTensor<MapInt64ToString>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else if (c_checker.IsMapOf<int64_t, int64_t>()) {
return AddNonTensor<MapInt64ToInt64>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else if (c_checker.IsMapOf<int64_t, float>()) {
return AddNonTensor<MapInt64ToFloat>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else if (c_checker.IsMapOf<int64_t, double>()) {
return AddNonTensor<MapInt64ToDouble>(val, data_transfer_manager, mem_cpy_to_host_functions);
}
} else {
if (c_checker.IsSequenceOf<std::map<std::string, float>>()) {
return AddNonTensor<VectorMapStringToFloat>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else if (c_checker.IsSequenceOf<std::map<int64_t, float>>()) {
return AddNonTensor<VectorMapInt64ToFloat>(val, data_transfer_manager, mem_cpy_to_host_functions);
}
}
#endif
}
ORT_THROW("Non-tensor type is not supported in this build: ", val_type);
}
py::object AddTensorAsPyObj(const OrtValue& val, const DataTransferManager* data_transfer_manager,
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions) {
return GetPyObjFromTensor(val, data_transfer_manager, mem_cpy_to_host_functions);
}
static std::unique_ptr<onnxruntime::IExecutionProvider> LoadExecutionProvider(
const std::string& ep_shared_lib_path,
const ProviderOptions& provider_options = {},
const std::string& entry_symbol_name = "GetProvider") {
void* handle;
const auto path_str = ToPathString(ep_shared_lib_path);
auto error = Env::Default().LoadDynamicLibrary(path_str, false, &handle);
if (!error.IsOK()) {
throw std::runtime_error(error.ErrorMessage());
}
Provider* (*PGetProvider)();
OrtPybindThrowIfError(Env::Default().GetSymbolFromLibrary(handle, entry_symbol_name, (void**)&PGetProvider));
Provider* provider = PGetProvider();
std::shared_ptr<IExecutionProviderFactory> ep_factory = provider->CreateExecutionProviderFactory(&provider_options);
return ep_factory->CreateProvider();
}
#ifdef USE_CUDA
const CUDAExecutionProviderInfo GetCudaExecutionProviderInfo(ProviderInfo_CUDA* cuda_provider_info,
const ProviderOptionsMap& provider_options_map) {
ORT_ENFORCE(cuda_provider_info);
const auto it = provider_options_map.find(kCudaExecutionProvider);
CUDAExecutionProviderInfo info;
if (it != provider_options_map.end())
cuda_provider_info->CUDAExecutionProviderInfo__FromProviderOptions(it->second, info);
else {
info.device_id = cuda_device_id;
info.gpu_mem_limit = gpu_mem_limit;
info.arena_extend_strategy = arena_extend_strategy;
info.cudnn_conv_algo_search = cudnn_conv_algo_search;
info.do_copy_in_default_stream = do_copy_in_default_stream;
info.external_allocator_info = external_allocator_info;
info.tunable_op = tunable_op;
}
return info;
}
#endif
#ifdef USE_CANN
const CANNExecutionProviderInfo GetCannExecutionProviderInfo(ProviderInfo_CANN* cann_provider_info,
const ProviderOptionsMap& provider_options_map) {
ORT_ENFORCE(cann_provider_info);
const auto it = provider_options_map.find(kCannExecutionProvider);
CANNExecutionProviderInfo info;
if (it != provider_options_map.end())
cann_provider_info->CANNExecutionProviderInfo__FromProviderOptions(it->second, info);
return info;
}
#endif
#ifdef USE_ROCM
const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM* rocm_provider_info,
const ProviderOptionsMap& provider_options_map) {
ORT_ENFORCE(rocm_provider_info);
const auto it = provider_options_map.find(kRocmExecutionProvider);
ROCMExecutionProviderInfo info;
if (it != provider_options_map.end())
rocm_provider_info->ROCMExecutionProviderInfo__FromProviderOptions(it->second, info);
else {
info.device_id = cuda_device_id;
info.gpu_mem_limit = gpu_mem_limit;
info.arena_extend_strategy = arena_extend_strategy;
info.miopen_conv_exhaustive_search = miopen_conv_exhaustive_search;
info.do_copy_in_default_stream = do_copy_in_default_stream;
info.external_allocator_info = external_allocator_info;
info.tunable_op = tunable_op;
}
return info;
}
#endif
#ifdef USE_TENSORRT
void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOptions& options) {
if (auto* tensorrt_provider_info = TryGetProviderInfo_TensorRT()) {
auto is_already_in_domains = [&](std::string& domain_name, std::vector<OrtCustomOpDomain*>& domains) {
for (auto ptr : domains) {
if (domain_name == ptr->domain_) {
return true;
}
}
return false;
};
std::string trt_extra_plugin_lib_paths = "";
const auto it = options.find("trt_extra_plugin_lib_paths");
if (it != options.end()) {
trt_extra_plugin_lib_paths = it->second;
}
std::vector<OrtCustomOpDomain*> custom_op_domains;
tensorrt_provider_info->GetTensorRTCustomOpDomainList(custom_op_domains, trt_extra_plugin_lib_paths);
for (auto ptr : custom_op_domains) {
if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) {
so.custom_op_domains_.push_back(ptr);
} else {
LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option.";
}
}
} else {
ORT_THROW("Please install TensorRT libraries as mentioned in the GPU requirements page, make sure they're in the PATH or LD_LIBRARY_PATH, and that your GPU is supported.");
}
}
#endif
std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
const SessionOptions& session_options,
const std::string& type,
const ProviderOptionsMap& provider_options_map) {
if (type == kCpuExecutionProvider) {
return onnxruntime::CPUProviderFactoryCreator::Create(
session_options.enable_cpu_mem_arena)
->CreateProvider();
} else if (type == kTensorrtExecutionProvider) {
#ifdef USE_TENSORRT
// If the environment variable 'ORT_TENSORRT_UNAVAILABLE' exists, then we do not load TensorRT. This is set by _ld_preload for the manylinux case
// as in that case, trying to load the library itself will result in a crash due to the way that auditwheel strips dependencies.
if (Env::Default().GetEnvironmentVar("ORT_TENSORRT_UNAVAILABLE").empty()) {
// provider_options_map is just a reference to the ProviderOptionsMap instance, so it can be released anytime from application.
// So we need these std::string variables defined here as they will be kept alive for the lifetime of TRT EP and we can still access them from OrtTensorRTProviderOptionsV2 instance.
// (The reason is string copy is involved, for example params.trt_engine_cache_path = cache_path.c_str() and those std::string variable is referenced by OrtTensorRTProviderOptionsV2 instance
// and TRT EP instance, so it won't be released.)
std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources,
trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile, ep_context_file_path,
onnx_model_folder_path, trt_op_types_to_exclude{"NonMaxSuppression,NonZero,RoiAlign"};
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtTensorRTProviderOptionsV2 params;
for (auto option : it->second) {
if (option.first == "device_id") {
if (!option.second.empty()) {
params.device_id = std::stoi(option.second);
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'device_id' should be a number i.e. '0'.\n");
}
} else if (option.first == "user_compute_stream") {
if (!option.second.empty()) {
auto stream = std::stoull(option.second, nullptr, 0);
params.user_compute_stream = reinterpret_cast<void*>(stream);
params.has_user_compute_stream = true;
} else {
params.has_user_compute_stream = false;
ORT_THROW("[ERROR] [TensorRT] The value for the key 'user_compute_stream' should be a string to define the compute stream for the inference to run on.\n");
}
} else if (option.first == "trt_max_partition_iterations") {
if (!option.second.empty()) {
params.trt_max_partition_iterations = std::stoi(option.second);
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_max_partition_iterations' should be a positive integer number i.e. '1000'.\n");
}
} else if (option.first == "trt_min_subgraph_size") {
if (!option.second.empty()) {
params.trt_min_subgraph_size = std::stoi(option.second);
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_min_subgraph_size' should be a positive integer number i.e. '1'.\n");
}
} else if (option.first == "trt_max_workspace_size") {
if (!option.second.empty()) {
params.trt_max_workspace_size = std::stoull(option.second);
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_max_workspace_size' should be a number in byte i.e. '1073741824'.\n");
}
} else if (option.first == "trt_fp16_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_fp16_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_fp16_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_fp16_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_int8_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_int8_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_int8_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_int8_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_int8_calibration_table_name") {
if (!option.second.empty()) {
calibration_table = option.second;
params.trt_int8_calibration_table_name = calibration_table.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_int8_calibration_table_name' should be a file name i.e. 'cal_table'.\n");
}
} else if (option.first == "trt_int8_use_native_calibration_table") {
if (option.second == "True" || option.second == "true") {
params.trt_int8_use_native_calibration_table = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_int8_use_native_calibration_table = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_int8_use_native_calibration_table' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_dla_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_dla_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_dla_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dla_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_dla_core") {
if (!option.second.empty()) {
params.trt_dla_core = std::stoi(option.second);
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dla_core' should be a positive integer number i.e. '0'.\n");
}
} else if (option.first == "trt_dump_subgraphs") {
if (option.second == "True" || option.second == "true") {
params.trt_dump_subgraphs = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_dump_subgraphs = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dump_subgraphs' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_engine_cache_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_engine_cache_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_engine_cache_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_engine_cache_path") {
if (!option.second.empty()) {
cache_path = option.second;
params.trt_engine_cache_path = cache_path.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_path' should be a path string i.e. 'engine_cache'.\n");
}
} else if (option.first == "trt_engine_cache_prefix") {
if (!option.second.empty()) {
cache_prefix = option.second;
params.trt_engine_cache_prefix = cache_prefix.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_prefix' should be a string to customize engine cache prefix i.e. 'FRCNN' or 'yolov4'.\n");
}
} else if (option.first == "trt_weight_stripped_engine_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_weight_stripped_engine_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_weight_stripped_engine_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_weight_stripped_engine_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_onnx_model_folder_path") {
if (!option.second.empty()) {
onnx_model_folder_path = option.second;
params.trt_onnx_model_folder_path = onnx_model_folder_path.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_onnx_model_folder_path' should be a path string i.e. 'engine_cache'.\n");
}
} else if (option.first == "trt_engine_decryption_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_engine_decryption_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_engine_decryption_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_decryption_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_engine_decryption_lib_path") {
if (!option.second.empty()) {
lib_path = option.second;
params.trt_engine_decryption_lib_path = lib_path.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_decryption_lib_path' should be a path string i.e. 'decryption_lib'.\n");
}
} else if (option.first == "trt_force_sequential_engine_build") {
if (option.second == "True" || option.second == "true") {
params.trt_force_sequential_engine_build = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_force_sequential_engine_build = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_force_sequential_engine_build' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_context_memory_sharing_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_context_memory_sharing_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_context_memory_sharing_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_context_memory_sharing_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_layer_norm_fp32_fallback") {
if (option.second == "True" || option.second == "true") {
params.trt_layer_norm_fp32_fallback = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_layer_norm_fp32_fallback = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_layer_norm_fp32_fallback' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_timing_cache_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_timing_cache_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_timing_cache_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_timing_cache_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_timing_cache_path") {
if (!option.second.empty()) {
timing_cache_path = option.second;
params.trt_timing_cache_path = timing_cache_path.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_timing_cache_path' should be a path string i.e. 'cache_folder/'.\n");
}
} else if (option.first == "trt_force_timing_cache") {
if (option.second == "True" || option.second == "true") {
params.trt_force_timing_cache = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_force_timing_cache = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_force_timing_cache' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_detailed_build_log") {
if (option.second == "True" || option.second == "true") {
params.trt_detailed_build_log = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_detailed_build_log = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_detailed_build_log' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_build_heuristics_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_build_heuristics_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_build_heuristics_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_build_heuristics_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_sparsity_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_sparsity_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_sparsity_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_sparsity_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_builder_optimization_level") {
if (!option.second.empty()) {
params.trt_builder_optimization_level = std::stoi(option.second);
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_builder_optimization_level' should be a number i.e. '0'.\n");
}
} else if (option.first == "trt_auxiliary_streams") {
if (!option.second.empty()) {
params.trt_auxiliary_streams = std::stoi(option.second);
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_auxiliary_streams' should be a number i.e. '0'.\n");
}
} else if (option.first == "trt_tactic_sources") {
if (!option.second.empty()) {
trt_tactic_sources = option.second;
params.trt_tactic_sources = trt_tactic_sources.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_tactic_sources' should be a string. e.g. \"-CUDNN,+CUBLAS\" available keys: \"CUBLAS\"|\"CUBLAS_LT\"|\"CUDNN\"|\"EDGE_MASK_CONVOLUTIONS\".\n");
}
} else if (option.first == "trt_extra_plugin_lib_paths") {
if (!option.second.empty()) {
trt_extra_plugin_lib_paths = option.second;
params.trt_extra_plugin_lib_paths = trt_extra_plugin_lib_paths.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_extra_plugin_lib_paths' should be a path string.\n");
}
} else if (option.first == "trt_profile_min_shapes") {
if (!option.second.empty()) {
min_profile = option.second;
params.trt_profile_min_shapes = min_profile.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_min_shapes' should be a string of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'.\n");
}
} else if (option.first == "trt_profile_max_shapes") {
if (!option.second.empty()) {
max_profile = option.second;
params.trt_profile_max_shapes = max_profile.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_max_shapes' should be a string of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'.\n");
}
} else if (option.first == "trt_profile_opt_shapes") {
if (!option.second.empty()) {
opt_profile = option.second;
params.trt_profile_opt_shapes = opt_profile.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_opt_shapes' should be a string of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'.\n");
}
} else if (option.first == "trt_cuda_graph_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_cuda_graph_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_cuda_graph_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_cuda_graph_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_dump_ep_context_model") {
if (option.second == "True" || option.second == "true") {
params.trt_dump_ep_context_model = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_dump_ep_context_model = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dump_ep_context_model' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_ep_context_file_path") {
if (!option.second.empty()) {
ep_context_file_path = option.second;
params.trt_ep_context_file_path = ep_context_file_path.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_ep_context_file_path' should be a string.\n");
}
} else if (option.first == "trt_ep_context_embed_mode") {
if (!option.second.empty()) {
params.trt_ep_context_embed_mode = std::stoi(option.second);
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_ep_context_embed_mode' should be a positive integer number i.e. '1'.\n");
}
} else if (option.first == "trt_engine_hw_compatible") {
if (option.second == "True" || option.second == "true") {
params.trt_engine_hw_compatible = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_engine_hw_compatible = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_hw_compatible' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_op_types_to_exclude") {
trt_op_types_to_exclude = option.second;
params.trt_op_types_to_exclude = trt_op_types_to_exclude.c_str();
} else {
ORT_THROW("Invalid TensorRT EP option: ", option.first);
}
}
if (std::shared_ptr<IExecutionProviderFactory> tensorrt_provider_factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&params)) {
return tensorrt_provider_factory->CreateProvider();
}
} else {
if (std::shared_ptr<IExecutionProviderFactory> tensorrt_provider_factory = onnxruntime::TensorrtProviderFactoryCreator::Create(cuda_device_id)) {
return tensorrt_provider_factory->CreateProvider();
}
}
}
LOGS_DEFAULT(WARNING) << "Failed to create "
<< type
<< ". Please reference "
<< "https://onnxruntime.ai/docs/execution-providers/"
<< "TensorRT-ExecutionProvider.html#requirements to ensure all dependencies are met.";
#endif
} else if (type == kMIGraphXExecutionProvider) {
#ifdef USE_MIGRAPHX
std::string calibration_table;
std::string save_model_path;
std::string load_model_path;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtMIGraphXProviderOptions params{
0,
0,
0,
0,
nullptr,
1,
"./compiled_model.mxr",
1,
"./compiled_model.mxr",
1};
for (auto option : it->second) {
if (option.first == "device_id") {
if (!option.second.empty()) {
params.device_id = std::stoi(option.second);
} else {
ORT_THROW("[ERROR] [MIGraphX] The value for the key 'device_id' should be a number i.e. '0'.\n");
}
} else if (option.first == "migraphx_fp16_enable") {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp16_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp16_enable = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_int8_enable") {
if (option.second == "True" || option.second == "true") {
params.migraphx_int8_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_int8_enable = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_int8_enable' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_int8_calibration_table_name") {
if (!option.second.empty()) {
calibration_table = option.second;
params.migraphx_int8_calibration_table_name = calibration_table.c_str();
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_int8_calibration_table_name' should be a "
"file name i.e. 'cal_table'.\n");
}
} else if (option.first == "migraphx_use_native_calibration_table") {
if (option.second == "True" || option.second == "true") {
params.migraphx_use_native_calibration_table = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_use_native_calibration_table = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_save_compiled_model") {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp16_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp16_enable = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_save_compiled_model' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_save_model_path") {
if (!option.second.empty()) {
save_model_path = option.second;
params.migraphx_save_model_path = save_model_path.c_str();
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_save_model_name' should be a "
"file name i.e. 'compiled_model.mxr'.\n");
}
} else if (option.first == "migraphx_load_compiled_model") {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp16_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp16_enable = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_load_compiled_model' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_load_model_path") {
if (!option.second.empty()) {
load_model_path = option.second;
params.migraphx_load_model_path = load_model_path.c_str();
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a "
"file name i.e. 'compiled_model.mxr'.\n");
}
} else if (option.first == "migraphx_exhaustive_tune") {
if (option.second == "True" || option.second == "true") {
params.migraphx_exhaustive_tune = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_exhaustive_tune = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migraphx_exhaustive_tune' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else {
ORT_THROW("Invalid MIGraphX EP option: ", option.first);
}
}
if (std::shared_ptr<IExecutionProviderFactory> migraphx_provider_factory =
onnxruntime::MIGraphXProviderFactoryCreator::Create(&params)) {
return migraphx_provider_factory->CreateProvider();
}
} else {
if (std::shared_ptr<IExecutionProviderFactory> migraphx_provider_factory =
onnxruntime::MIGraphXProviderFactoryCreator::Create(cuda_device_id)) {
return migraphx_provider_factory->CreateProvider();
}
}
#endif
} else if (type == kCudaExecutionProvider) {
#ifdef USE_CUDA
// If the environment variable 'CUDA_UNAVAILABLE' exists, then we do not load cuda.
// This is set by _ld_preload for the manylinux case as in that case,
// trying to load the library itself will result in a crash due to the way that auditwheel strips dependencies.
if (Env::Default().GetEnvironmentVar("ORT_CUDA_UNAVAILABLE").empty()) {
if (auto* cuda_provider_info = TryGetProviderInfo_CUDA()) {
const CUDAExecutionProviderInfo info = GetCudaExecutionProviderInfo(cuda_provider_info,
provider_options_map);
// This variable is never initialized because the APIs by which it should be initialized are deprecated,
// however they still exist are are in-use. Nevertheless, it is used to return CUDAAllocator,
// hence we must try to initialize it here if we can since FromProviderOptions might contain
// external CUDA allocator.
external_allocator_info = info.external_allocator_info;
return cuda_provider_info->CreateExecutionProviderFactory(info)->CreateProvider();
}
}
LOGS_DEFAULT(WARNING) << "Failed to create "
<< type
<< ". Require cuDNN " << CUDNN_MAJOR << ".* and "
<< "CUDA " << (CUDA_VERSION / 1000) << ".*"
#if defined(_MSC_VER)
<< ", and the latest MSVC runtime"
#endif
<< ". Please install all dependencies as mentioned in the GPU requirements page"
" (https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements), "
"make sure they're in the PATH, and that your GPU is supported.";
#endif
} else if (type == kRocmExecutionProvider) {
#ifdef USE_ROCM
if (auto* rocm_provider_info = TryGetProviderInfo_ROCM()) {
const ROCMExecutionProviderInfo info = GetRocmExecutionProviderInfo(rocm_provider_info,
provider_options_map);
// This variable is never initialized because the APIs by which is it should be initialized are deprecated,
// however they still exist and are in-use. Nevertheless, it is used to return ROCMAllocator, hence we must
// try to initialize it here if we can since FromProviderOptions might contain external ROCM allocator.
external_allocator_info = info.external_allocator_info;
return rocm_provider_info->CreateExecutionProviderFactory(info)->CreateProvider();
} else {
if (!Env::Default().GetEnvironmentVar("ROCM_PATH").empty()) {
ORT_THROW(
"ROCM_PATH is set but ROCM wasn't able to be loaded. Please install the correct version "
"of ROCM and MIOpen as mentioned in the GPU requirements page, make sure they're in the PATH, "
"and that your GPU is supported.");
}
}
#endif
} else if (type == kDnnlExecutionProvider) {
#ifdef USE_DNNL
// Generate dnnl_options
OrtDnnlProviderOptions dnnl_options;
// For Eigen and OpenMP
#if defined(DNNL_OPENMP)
int num_threads = 0;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
for (auto option : it->second) {
if (option.first == "num_of_threads") {
num_threads = std::stoi(option.second);
if (num_threads < 0) {
ORT_THROW(
"[ERROR] [OneDNN] Invalid entry for the key 'num_of_threads',"
" set number of threads or use '0' for default\n");
// If the user doesnt define num_threads, auto detect threads later
}
} else {
ORT_THROW("Invalid OneDNN EP option: ", option.first);
}
}
}
dnnl_options.threadpool_args = static_cast<void*>(&num_threads);
#endif // !defined(DNNL_ORT_THREAD)
dnnl_options.use_arena = session_options.enable_cpu_mem_arena;
return onnxruntime::DnnlProviderFactoryCreator::Create(&dnnl_options)->CreateProvider();
#endif
} else if (type == kOpenVINOExecutionProvider) {
#ifdef USE_OPENVINO
ProviderOptions OV_provider_options_map;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
for (auto option : it->second) {
if (option.first == "device_type") {
OV_provider_options_map[option.first] = option.second;
continue;
} else if (option.first == "precision") {
OV_provider_options_map[option.first] = option.second;
continue;
} else if (option.first == "enable_opencl_throttling") {
if (!(option.second == "True" || option.second == "true" ||
option.second == "False" || option.second == "false")) {
ORT_THROW("Invalid value passed for enable_opencl_throttling: ", option.second);
}
OV_provider_options_map[option.first] = option.second;
} else if (option.first == "disable_dynamic_shapes") {
if (!(option.second == "True" || option.second == "true" ||
option.second == "False" || option.second == "false")) {
ORT_THROW("Invalid value passed for disable_dynamic_shapes: ", option.second);
}
OV_provider_options_map[option.first] = option.second;
} else if (option.first == "enable_dynamic_shapes") {
LOGS_DEFAULT(WARNING) << " Deprecation notice - 'enable_dynamic_shapes' is Deprected. Upgrade the API to disable_dynamic_shapes parameter."
"Please refer https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements to ensure all dependencies are met.";
std::string value;
if (!(option.second == "True" || option.second == "true" ||
option.second == "False" || option.second == "false")) {
ORT_THROW("Invalid value passed for enable_dynamic_shapes: ", option.second);
}
if (option.second == "True" || option.second == "true") {
value = "false";
} else {
value = "true";
}
OV_provider_options_map["disable_dynamic_shapes"] = value;
} else if (option.first == "num_of_threads") {
OV_provider_options_map[option.first] = option.second;
continue;
} else if (option.first == "model_priority") {
OV_provider_options_map[option.first] = option.second;
continue;
} else if (option.first == "num_streams") {
OV_provider_options_map[option.first] = option.second;
continue;
} else if (option.first == "load_config") {
OV_provider_options_map[option.first] = option.second;
continue;
} else if (option.first == "cache_dir") {
OV_provider_options_map[option.first] = option.second;
continue;
} else if (option.first == "context") {
OV_provider_options_map[option.first] = option.second;
continue;
} else if (option.first == "enable_qdq_optimizer") {
OV_provider_options_map[option.first] = option.second;
continue;
} else {
ORT_THROW("Invalid OpenVINO EP option: ", option.first);
}
}
}
if (std::shared_ptr<IExecutionProviderFactory> openvino_provider_factory = onnxruntime::OpenVINOProviderFactoryCreator::Create(
&OV_provider_options_map, &session_options)) {
auto p = openvino_provider_factory->CreateProvider();
// Reset global variables config to avoid it being accidentally passed on to the next session
openvino_device_type.clear();
return p;
} else {
if (!Env::Default().GetEnvironmentVar("INTEL_OPENVINO_DIR").empty()) {
ORT_THROW("INTEL_OPENVINO_DIR is set but OpenVINO library wasn't able to be loaded. Please install a supported version of OpenVINO as mentioned in the requirements page (https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements), ensure dependency libraries are in the PATH and your hardware is supported.");
} else {
LOGS_DEFAULT(WARNING) << "Failed to create " << type << ". Please refer https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements to ensure all dependencies are met.";
}
}
#endif
} else if (type == kTvmExecutionProvider) {
#if USE_TVM
onnxruntime::tvm::TvmEPOptions info{};
const auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
info = onnxruntime::tvm::TvmEPOptionsHelper::FromProviderOptions(it->second);
}
return onnxruntime::TVMProviderFactoryCreator::Create(info)->CreateProvider();
#endif
} else if (type == kVitisAIExecutionProvider) {
#ifdef USE_VITISAI
ProviderOptions info{};
const auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
info = it->second;
}
info["session_options"] = std::to_string((uintptr_t)(void*)&session_options);
return onnxruntime::VitisAIProviderFactoryCreator::Create(info)->CreateProvider();
#endif
} else if (type == kAclExecutionProvider) {
#ifdef USE_ACL
bool enable_fast_math = false;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
for (auto option : it->second) {
if (option.first == "enable_fast_math") {
std::set<std::string> supported_values = {"true", "True", "false", "False"};
if (supported_values.find(option.second) != supported_values.end()) {
enable_fast_math = (option.second == "true") || (option.second == "True");
} else {
ORT_THROW(
"Invalid value for enable_fast_math. "
"Select from 'true' or 'false'\n");
}
} else {
ORT_THROW("Unrecognized option: ", option.first);
}
}
}
return onnxruntime::ACLProviderFactoryCreator::Create(enable_fast_math)
->CreateProvider();
#endif
} else if (type == kArmNNExecutionProvider) {
#ifdef USE_ARMNN
return onnxruntime::ArmNNProviderFactoryCreator::Create(
session_options.enable_cpu_mem_arena)
->CreateProvider();
#endif
} else if (type == kDmlExecutionProvider) {
#ifdef USE_DML
auto cit = provider_options_map.find(type);
return onnxruntime::DMLProviderFactoryCreator::CreateFromProviderOptions(
session_options.config_options, cit == provider_options_map.end() ? ProviderOptions{} : cit->second, true)
->CreateProvider();
#endif
} else if (type == kNnapiExecutionProvider) {
#if defined(USE_NNAPI)
#if !defined(__ANDROID__)
LOGS_DEFAULT(WARNING) << "NNAPI execution provider can only be used to generate ORT format model in this build.";
#endif
const auto partitioning_stop_ops_list = session_options.config_options.GetConfigEntry(
kOrtSessionOptionsConfigNnapiEpPartitioningStopOps);
return onnxruntime::NnapiProviderFactoryCreator::Create(0, partitioning_stop_ops_list)->CreateProvider();
#endif
} else if (type == kRknpuExecutionProvider) {
#ifdef USE_RKNPU
return onnxruntime::RknpuProviderFactoryCreator::Create()->CreateProvider();
#endif
} else if (type == kCoreMLExecutionProvider) {
#if defined(USE_COREML)
#if !defined(__APPLE__)
LOGS_DEFAULT(WARNING) << "CoreML execution provider can only be used to generate ORT format model in this build.";
#endif
uint32_t coreml_flags = 0;
const auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
const ProviderOptions& options = it->second;
auto flags = options.find("flags");
if (flags != options.end()) {
const auto& flags_str = flags->second;
if (flags_str.find("COREML_FLAG_USE_CPU_ONLY") != std::string::npos) {
coreml_flags |= COREMLFlags::COREML_FLAG_USE_CPU_ONLY;
} else if (flags_str.find("COREML_FLAG_USE_CPU_AND_GPU") != std::string::npos) {
coreml_flags |= COREMLFlags::COREML_FLAG_USE_CPU_AND_GPU;
}
if (flags_str.find("COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES") != std::string::npos) {
coreml_flags |= COREMLFlags::COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES;
}
if (flags_str.find("COREML_FLAG_CREATE_MLPROGRAM") != std::string::npos) {
coreml_flags |= COREMLFlags::COREML_FLAG_CREATE_MLPROGRAM;
}
}
}
return onnxruntime::CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider();
#endif
} else if (type == kXnnpackExecutionProvider) {
#if defined(USE_XNNPACK)
auto cit = provider_options_map.find(type);
return onnxruntime::XnnpackProviderFactoryCreator::Create(
cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options)
->CreateProvider();
#endif
} else if (type == kWebGpuExecutionProvider) {
#if defined(USE_WEBGPU)
return onnxruntime::WebGpuProviderFactoryCreator::Create(session_options.config_options)->CreateProvider();
#endif
} else if (type == kCannExecutionProvider) {
#ifdef USE_CANN
if (auto* cann_provider_info = TryGetProviderInfo_CANN()) {
const CANNExecutionProviderInfo info = GetCannExecutionProviderInfo(cann_provider_info,
provider_options_map);
return cann_provider_info->CreateExecutionProviderFactory(info)->CreateProvider();
} else {
ORT_THROW("create CANN ExecutionProvider fail");
}
#endif
} else if (type == kAzureExecutionProvider) {
#ifdef USE_AZURE
return onnxruntime::AzureProviderFactoryCreator::Create({})->CreateProvider();
#endif
} else if (type == kQnnExecutionProvider) {
#ifdef USE_QNN
auto cit = provider_options_map.find(type);
return onnxruntime::QNNProviderFactoryCreator::Create(
cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options)
->CreateProvider();
#endif
} else {
// check whether it is a dynamic load EP:
const auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
auto shared_lib_path_it = it->second.find(kExecutionProviderSharedLibraryPath);
if (shared_lib_path_it != it->second.end()) {
// this is an EP with dynamic loading
// construct the provider option
ProviderOptions provider_options;
std::string entry_symbol = kDefaultExecutionProviderEntry;
for (auto option : it->second) {
if (option.first == kExecutionProviderSharedLibraryEntry) {
entry_symbol = option.second;
} else if (option.first != kExecutionProviderSharedLibraryPath) {
provider_options.insert(option);
}
}
return LoadExecutionProvider(shared_lib_path_it->second, provider_options, entry_symbol);
}
}
// unknown provider
throw std::runtime_error("Unknown Provider Type: " + type);
}
return nullptr;
}
/*
* Register execution provider with options.
*/
static void RegisterExecutionProviders(InferenceSession* sess, const std::vector<std::string>& provider_types,
const ProviderOptionsMap& provider_options_map) {
ORT_UNUSED_PARAMETER(provider_options_map);
for (const std::string& type : provider_types) {
auto ep = CreateExecutionProviderInstance(sess->GetSessionOptions(), type, provider_options_map);
if (ep)
OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(ep)));
}
}
/**
* Generate a map for mapping execution provider to excution provider options.
*
* @param providers vector of excution providers. [ep1, ep2, ...]
* @param provider_options_vector vector of excution provider options. [option1, option2 ...]
* @param provider_options_map an unordered map for mapping excution provider to excution provider options.
* {'ep1' -> option1, 'ep2' -> option2 ...}
*
*/
static void GenerateProviderOptionsMap(const std::vector<std::string>& providers,
const ProviderOptionsVector& provider_options_vector,
ProviderOptionsMap& provider_options_map) {
if (provider_options_vector.empty() || providers.empty()) {
return;
}
std::size_t j = 0; // index for provider_options_vector
for (const std::string& type : providers) {
if (j < provider_options_vector.size() && !provider_options_vector[j].empty()) {
provider_options_map[type] = provider_options_vector[j];
}
j += 1;
}
}
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
static void RegisterCustomOpDomains(PyInferenceSession* sess, const PySessionOptions& so) {
if (!so.custom_op_domains_.empty()) {
// Register all custom op domains that will be needed for the session
std::vector<OrtCustomOpDomain*> custom_op_domains;
custom_op_domains.reserve(so.custom_op_domains_.size());
for (size_t i = 0; i < so.custom_op_domains_.size(); ++i) {
custom_op_domains.emplace_back(so.custom_op_domains_[i]);
}
OrtPybindThrowIfError(sess->GetSessionHandle()->AddCustomOpDomains(custom_op_domains));
}
}
#endif
void InitializeSession(InferenceSession* sess,
ExecutionProviderRegistrationFn ep_registration_fn,
const std::vector<std::string>& provider_types,
const ProviderOptionsVector& provider_options,
const std::unordered_set<std::string>& disabled_optimizer_names) {
ProviderOptionsMap provider_options_map;
GenerateProviderOptionsMap(provider_types, provider_options, provider_options_map);
ep_registration_fn(sess, provider_types, provider_options_map);
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
if (!disabled_optimizer_names.empty()) {
OrtPybindThrowIfError(sess->FilterEnabledOptimizers({disabled_optimizer_names.cbegin(), disabled_optimizer_names.cend()}));
}
#else
ORT_UNUSED_PARAMETER(disabled_optimizer_names);
#endif
OrtPybindThrowIfError(sess->Initialize());
}
bool CheckIfTensor(const std::vector<const NodeArg*>& def_list,
const std::string& name,
/*out*/ onnx::TypeProto& type_proto) {
auto ret_it = std::find_if(std::begin(def_list), std::end(def_list),
[&name](const NodeArg* node_arg) { return name == node_arg->Name(); });
if (ret_it == std::end(def_list)) {
throw std::runtime_error("Failed to find NodeArg with name: " + name + " in the def list");
}
const auto* temp = (*ret_it)->TypeAsProto();
if (!temp) {
throw std::runtime_error("Corresponding type_proto is null");
} else {
type_proto = *temp;
}
return type_proto.has_tensor_type();
}
#if defined(USE_OPENVINO) || \
defined(USE_CUDA) || \
defined(USE_ROCM)
static void LogDeprecationWarning(
const std::string& deprecated, const optional<std::string>& alternative = nullopt) {
LOGS_DEFAULT(WARNING) << "This is DEPRECATED and will be removed in the future: " << deprecated;
LOGS_DEFAULT_IF(alternative.has_value(), WARNING) << "As an alternative, use: " << *alternative;
}
#endif
void addGlobalMethods(py::module& m) {
m.def("get_default_session_options", &GetDefaultCPUSessionOptions, "Return a default session_options instance.");
m.def("get_session_initializer", &SessionObjectInitializer::Get, "Return a default session object initializer.");
m.def(
"get_device", []() -> std::string { return BACKEND_DEVICE; },
"Return the device used to compute the prediction (CPU, MKL, ...)");
m.def(
"set_seed", [](const int64_t seed) { utils::SetRandomSeed(seed); },
"Sets the seed used for random number generation in Onnxruntime.");
m.def(
"set_default_logger_severity", [](int severity) {
ORT_ENFORCE(severity >= 0 && severity <= 4,
"Invalid logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal");
auto env = GetEnv();
logging::LoggingManager* default_logging_manager = env->GetLoggingManager();
default_logging_manager->SetDefaultLoggerSeverity(static_cast<logging::Severity>(severity));
},
"Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal");
m.def(
"set_default_logger_verbosity", [](int vlog_level) {
auto env = GetEnv();
logging::LoggingManager* default_logging_manager = env->GetLoggingManager();
default_logging_manager->SetDefaultLoggerVerbosity(vlog_level);
},
"Sets the default logging verbosity level. To activate the verbose log, "
"you need to set the default logging severity to 0:Verbose level.");
m.def(
"get_all_providers", []() -> const std::vector<std::string>& { return GetAllExecutionProviderNames(); },
"Return list of Execution Providers that this version of Onnxruntime can support. "
"The order of elements represents the default priority order of Execution Providers "
"from highest to lowest.");
m.def(
"enable_telemetry_events", []() -> void { platform_env.GetTelemetryProvider().EnableTelemetryEvents(); },
"Enables platform-specific telemetry collection where applicable.");
m.def(
"disable_telemetry_events", []() -> void { platform_env.GetTelemetryProvider().DisableTelemetryEvents(); },
"Disables platform-specific telemetry collection.");
m.def(
"create_and_register_allocator", [](const OrtMemoryInfo& mem_info, const OrtArenaCfg* arena_cfg = nullptr) -> void {
auto env = GetEnv();
auto st = env->CreateAndRegisterAllocator(mem_info, arena_cfg);
if (!st.IsOK()) {
throw std::runtime_error("Error when creating and registering allocator: " + st.ErrorMessage());
}
});
m.def(
"create_and_register_allocator_v2", [](const std::string& provider_type, const OrtMemoryInfo& mem_info, const ProviderOptions& options, const OrtArenaCfg* arena_cfg = nullptr) -> void {
auto env = GetEnv();
auto st = env->CreateAndRegisterAllocatorV2(provider_type, mem_info, options, arena_cfg);
if (!st.IsOK()) {
throw std::runtime_error("Error when creating and registering allocator in create_and_register_allocator_v2: " + st.ErrorMessage());
}
});
#ifdef USE_OPENVINO
m.def(
"get_available_openvino_device_ids", []() -> std::vector<std::string> {
if (auto* info = GetProviderInfo_OpenVINO()) {
return info->GetAvailableDevices();
}
return {};
},
"Lists all OpenVINO device ids available.");
/*
* The following APIs to set config options are deprecated. Use Session.set_providers() instead.
*/
// TODO remove deprecated global config
m.def(
"set_openvino_device", [](const std::string& device_type) {
LogDeprecationWarning("set_openvino_device", "OpenVINO execution provider option \"device_type\"");
openvino_device_type = device_type;
},
"Set the preferred OpenVINO device type to be used. If left unset, "
"the device type selected during build time will be used.");
// TODO remove deprecated global config
m.def(
"get_openvino_device", []() -> std::string {
LogDeprecationWarning("get_openvino_device");
return openvino_device_type;
},
"Gets the dynamically selected OpenVINO device type for inference.");
#endif
#if defined(USE_CUDA) || defined(USE_ROCM)
/*
* The following set_* methods are deprecated.
*
* To achieve same result, please use the following python api:
* InferenceSession.set_providers(list_of_providers, list_of_provider_option_dicts)
*
*/
// TODO remove deprecated global config
m.def("set_cuda_device_id", [](const int id) {
LogDeprecationWarning("set_cuda_device_id", "CUDA/ROCM execution provider option \"device_id\"");
cuda_device_id = static_cast<OrtDevice::DeviceId>(id);
});
// TODO remove deprecated global config
m.def("set_cudnn_conv_algo_search", [](const OrtCudnnConvAlgoSearch algo) {
LogDeprecationWarning("set_cudnn_conv_algo_search", "CUDA execution provider option \"cudnn_conv_algo_search\"");
#ifdef USE_ROCM
ORT_UNUSED_PARAMETER(algo);
ORT_THROW("set_cudnn_conv_algo_search is not supported in ROCM");
#else
cudnn_conv_algo_search = algo;
#endif
});
// TODO remove deprecated global config
m.def("set_do_copy_in_default_stream", [](const bool use_single_stream) {
LogDeprecationWarning(
"set_do_copy_in_default_stream", "CUDA execution provider option \"do_copy_in_default_stream\"");
#ifdef USE_ROCM
ORT_UNUSED_PARAMETER(use_single_stream);
ORT_THROW("set_do_copy_in_default_stream is not supported in ROCM");
#else
do_copy_in_default_stream = use_single_stream;
#endif
});
// TODO remove deprecated global config
m.def("set_gpu_mem_limit", [](const int64_t limit) {
LogDeprecationWarning(
"set_gpu_mem_limit",
"CUDA execution provider option \"gpu_mem_limit\", ROCM execution provider option \"gpu_mem_limit\"");
gpu_mem_limit = gsl::narrow<size_t>(limit);
});
// TODO remove deprecated global config
m.def("set_arena_extend_strategy", [](const onnxruntime::ArenaExtendStrategy strategy) {
LogDeprecationWarning("set_arena_extend_strategy", "CUDA/ROCM execution provider option \"arena_extend_strategy\"");
arena_extend_strategy = strategy;
});
#endif
#ifdef USE_TENSORRT
m.def(
"register_tensorrt_plugins_as_custom_ops", [](PySessionOptions& so, const ProviderOptions& options) { RegisterTensorRTPluginsAsCustomOps(so, options); },
"Register TensorRT plugins as custom ops.");
#endif
#ifdef ENABLE_ATEN
m.def("register_aten_op_executor",
[](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void {
size_t is_tensor_argument_address_int, aten_op_executor_address_int;
ORT_THROW_IF_ERROR(
ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int));
ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int));
void* p_is_tensor_argument = reinterpret_cast<void*>(is_tensor_argument_address_int);
void* p_aten_op_executor = reinterpret_cast<void*>(aten_op_executor_address_int);
contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor);
});
#endif
}
void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn) {
py::enum_<GraphOptimizationLevel>(m, "GraphOptimizationLevel")
.value("ORT_DISABLE_ALL", GraphOptimizationLevel::ORT_DISABLE_ALL)
.value("ORT_ENABLE_BASIC", GraphOptimizationLevel::ORT_ENABLE_BASIC)
.value("ORT_ENABLE_EXTENDED", GraphOptimizationLevel::ORT_ENABLE_EXTENDED)
.value("ORT_ENABLE_ALL", GraphOptimizationLevel::ORT_ENABLE_ALL);
py::enum_<ExecutionMode>(m, "ExecutionMode")
.value("ORT_SEQUENTIAL", ExecutionMode::ORT_SEQUENTIAL)
.value("ORT_PARALLEL", ExecutionMode::ORT_PARALLEL);
py::enum_<ExecutionOrder>(m, "ExecutionOrder")
.value("DEFAULT", ExecutionOrder::DEFAULT)
.value("PRIORITY_BASED", ExecutionOrder::PRIORITY_BASED)
.value("MEMORY_EFFICIENT", ExecutionOrder::MEMORY_EFFICIENT);
py::enum_<OrtAllocatorType>(m, "OrtAllocatorType")
.value("INVALID", OrtInvalidAllocator)
.value("ORT_DEVICE_ALLOCATOR", OrtDeviceAllocator)
.value("ORT_ARENA_ALLOCATOR", OrtArenaAllocator);
py::enum_<OrtMemType>(m, "OrtMemType")
.value("CPU_INPUT", OrtMemTypeCPUInput)
.value("CPU_OUTPUT", OrtMemTypeCPUOutput)
.value("CPU", OrtMemTypeCPU)
.value("DEFAULT", OrtMemTypeDefault);
py::class_<OrtDevice> device(m, "OrtDevice", R"pbdoc(ONNXRuntime device informaion.)pbdoc");
device.def(py::init<OrtDevice::DeviceType, OrtDevice::MemoryType, OrtDevice::DeviceId>())
.def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc")
.def("device_type", &OrtDevice::Type, R"pbdoc(Device Type.)pbdoc")
.def_static("cpu", []() { return OrtDevice::CPU; })
.def_static("cuda", []() { return OrtDevice::GPU; })
.def_static("cann", []() { return OrtDevice::NPU; })
.def_static("fpga", []() { return OrtDevice::FPGA; })
.def_static("npu", []() { return OrtDevice::NPU; })
.def_static("dml", []() { return OrtDevice::DML; })
.def_static("webgpu", []() { return OrtDevice::GPU; })
.def_static("default_memory", []() { return OrtDevice::MemType::DEFAULT; });
py::class_<OrtArenaCfg> ort_arena_cfg_binding(m, "OrtArenaCfg");
// Note: Doesn't expose initial_growth_chunk_sizes_bytes/max_power_of_two_extend_bytes option.
// This constructor kept for backwards compatibility, key-value pair constructor overload exposes all options
// There is a global var: arena_extend_strategy, which means we can't use that var name here
// See docs/C_API.md for details on what the following parameters mean and how to choose these values
ort_arena_cfg_binding.def(py::init([](size_t max_mem, int arena_extend_strategy_local,
int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
auto ort_arena_cfg = std::make_unique<OrtArenaCfg>();
ort_arena_cfg->max_mem = max_mem;
ort_arena_cfg->arena_extend_strategy = arena_extend_strategy_local;
ort_arena_cfg->initial_chunk_size_bytes = initial_chunk_size_bytes;
ort_arena_cfg->max_dead_bytes_per_chunk = max_dead_bytes_per_chunk;
return ort_arena_cfg;
}))
.def(py::init([](const py::dict& feeds) {
auto ort_arena_cfg = std::make_unique<OrtArenaCfg>();
for (const auto kvp : feeds) {
std::string key = kvp.first.cast<std::string>();
if (key == "max_mem") {
ort_arena_cfg->max_mem = kvp.second.cast<size_t>();
} else if (key == "arena_extend_strategy") {
ort_arena_cfg->arena_extend_strategy = kvp.second.cast<int>();
} else if (key == "initial_chunk_size_bytes") {
ort_arena_cfg->initial_chunk_size_bytes = kvp.second.cast<int>();
} else if (key == "max_dead_bytes_per_chunk") {
ort_arena_cfg->max_dead_bytes_per_chunk = kvp.second.cast<int>();
} else if (key == "initial_growth_chunk_size_bytes") {
ort_arena_cfg->initial_growth_chunk_size_bytes = kvp.second.cast<int>();
} else if (key == "max_power_of_two_extend_bytes") {
ort_arena_cfg->max_power_of_two_extend_bytes = kvp.second.cast<int>();
} else {
ORT_THROW("Invalid OrtArenaCfg option: ", key);
}
}
return ort_arena_cfg;
}))
.def_readwrite("max_mem", &OrtArenaCfg::max_mem)
.def_readwrite("arena_extend_strategy", &OrtArenaCfg::arena_extend_strategy)
.def_readwrite("initial_chunk_size_bytes", &OrtArenaCfg::initial_chunk_size_bytes)
.def_readwrite("max_dead_bytes_per_chunk", &OrtArenaCfg::max_dead_bytes_per_chunk)
.def_readwrite("initial_growth_chunk_size_bytes", &OrtArenaCfg::initial_growth_chunk_size_bytes)
.def_readwrite("max_power_of_two_extend_bytes", &OrtArenaCfg::max_power_of_two_extend_bytes);
py::class_<OrtMemoryInfo> ort_memory_info_binding(m, "OrtMemoryInfo");
ort_memory_info_binding.def(py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
if (strcmp(name, onnxruntime::CPU) == 0) {
return std::make_unique<OrtMemoryInfo>(onnxruntime::CPU, type, OrtDevice(), id, mem_type);
} else if (strcmp(name, onnxruntime::CUDA) == 0) {
return std::make_unique<OrtMemoryInfo>(
onnxruntime::CUDA, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id)), id,
mem_type);
} else if (strcmp(name, onnxruntime::CUDA_PINNED) == 0) {
return std::make_unique<OrtMemoryInfo>(
onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast<OrtDevice::DeviceId>(id)),
id, mem_type);
} else {
throw std::runtime_error("Specified device is not supported.");
}
}));
py::class_<PySessionOptions>
sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc");
sess
.def(py::init())
.def_property(
"enable_cpu_mem_arena",
[](const PySessionOptions* options) -> bool { return options->value.enable_cpu_mem_arena; },
[](PySessionOptions* options, bool enable_cpu_mem_arena) -> void {
options->value.enable_cpu_mem_arena = enable_cpu_mem_arena;
},
R"pbdoc(Enables the memory arena on CPU. Arena may pre-allocate memory for future usage.
Set this option to false if you don't want it. Default is True.)pbdoc")
.def_property(
"enable_profiling",
[](const PySessionOptions* options) -> bool { return options->value.enable_profiling; },
[](PySessionOptions* options, bool enable_profiling) -> void {
options->value.enable_profiling = enable_profiling;
},
R"pbdoc(Enable profiling for this session. Default is false.)pbdoc")
.def_property(
"profile_file_prefix",
[](const PySessionOptions* options) -> std::basic_string<ORTCHAR_T> {
return options->value.profile_file_prefix;
},
[](PySessionOptions* options, std::basic_string<ORTCHAR_T> profile_file_prefix) -> void {
options->value.profile_file_prefix = std::move(profile_file_prefix);
},
R"pbdoc(The prefix of the profile file. The current time will be appended to the file name.)pbdoc")
.def_property(
"optimized_model_filepath",
[](const PySessionOptions* options) -> std::basic_string<ORTCHAR_T> {
return options->value.optimized_model_filepath;
},
[](PySessionOptions* options, std::basic_string<ORTCHAR_T> optimized_model_filepath) -> void {
options->value.optimized_model_filepath = std::move(optimized_model_filepath);
},
R"pbdoc(
File path to serialize optimized model to.
Optimized model is not serialized unless optimized_model_filepath is set.
Serialized model format will default to ONNX unless:
- add_session_config_entry is used to set 'session.save_model_format' to 'ORT', or
- there is no 'session.save_model_format' config entry and optimized_model_filepath ends in '.ort' (case insensitive)
)pbdoc")
.def_property(
"enable_cpu_mem_arena",
[](const PySessionOptions* options) -> bool { return options->value.enable_cpu_mem_arena; },
[](PySessionOptions* options, bool enable_cpu_mem_arena) -> void {
options->value.enable_cpu_mem_arena = enable_cpu_mem_arena;
},
R"pbdoc(Enable memory arena on CPU. Default is true.)pbdoc")
.def_property(
"enable_mem_pattern",
[](const PySessionOptions* options) -> bool { return options->value.enable_mem_pattern; },
[](PySessionOptions* options, bool enable_mem_pattern) -> void {
options->value.enable_mem_pattern = enable_mem_pattern;
},
R"pbdoc(Enable the memory pattern optimization. Default is true.)pbdoc")
.def_property(
"enable_mem_reuse",
[](const PySessionOptions* options) -> bool { return options->value.enable_mem_reuse; },
[](PySessionOptions* options, bool enable_mem_reuse) -> void {
options->value.enable_mem_reuse = enable_mem_reuse;
},
R"pbdoc(Enable the memory reuse optimization. Default is true.)pbdoc")
.def_property(
"logid",
[](const PySessionOptions* options) -> std::string {
return options->value.session_logid;
},
[](PySessionOptions* options, std::string logid) -> void {
options->value.session_logid = std::move(logid);
},
R"pbdoc(Logger id to use for session output.)pbdoc")
.def_property(
"log_severity_level",
[](const PySessionOptions* options) -> int { return options->value.session_log_severity_level; },
[](PySessionOptions* options, int log_severity_level) -> void {
options->value.session_log_severity_level = log_severity_level;
},
R"pbdoc(Log severity level. Applies to session load, initialization, etc.
0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.)pbdoc")
.def_property(
"log_verbosity_level",
[](const PySessionOptions* options) -> int { return options->value.session_log_verbosity_level; },
[](PySessionOptions* options, int log_verbosity_level) -> void {
options->value.session_log_verbosity_level = log_verbosity_level;
},
R"pbdoc(VLOG level if DEBUG build and session_log_severity_level is 0.
Applies to session load, initialization, etc. Default is 0.)pbdoc")
.def_property(
"intra_op_num_threads",
[](const PySessionOptions* options) -> int { return options->value.intra_op_param.thread_pool_size; },
[](PySessionOptions* options, int value) -> void { options->value.intra_op_param.thread_pool_size = value; },
R"pbdoc(Sets the number of threads used to parallelize the execution within nodes. Default is 0 to let onnxruntime choose.)pbdoc")
.def_property(
"inter_op_num_threads",
[](const PySessionOptions* options) -> int { return options->value.inter_op_param.thread_pool_size; },
[](PySessionOptions* options, int value) -> void { options->value.inter_op_param.thread_pool_size = value; },
R"pbdoc(Sets the number of threads used to parallelize the execution of the graph (across nodes). Default is 0 to let onnxruntime choose.)pbdoc")
.def_property(
"execution_mode",
[](const PySessionOptions* options) -> ExecutionMode { return options->value.execution_mode; },
[](PySessionOptions* options, ExecutionMode execution_mode) -> void {
options->value.execution_mode = execution_mode;
},
R"pbdoc(Sets the execution mode. Default is sequential.)pbdoc")
.def_property(
"execution_order",
[](const PySessionOptions* options) -> ExecutionOrder { return options->value.execution_order; },
[](PySessionOptions* options, ExecutionOrder execution_order) -> void {
options->value.execution_order = execution_order;
},
R"pbdoc(Sets the execution order. Default is basic topological order.)pbdoc")
.def_property(
"graph_optimization_level",
[](const PySessionOptions* options) -> GraphOptimizationLevel {
GraphOptimizationLevel retval = ORT_ENABLE_ALL;
switch (options->value.graph_optimization_level) {
case onnxruntime::TransformerLevel::Default:
retval = ORT_DISABLE_ALL;
break;
case onnxruntime::TransformerLevel::Level1:
retval = ORT_ENABLE_BASIC;
break;
case onnxruntime::TransformerLevel::Level2:
retval = ORT_ENABLE_EXTENDED;
break;
case onnxruntime::TransformerLevel::Level3:
retval = ORT_ENABLE_ALL;
break;
default:
retval = ORT_ENABLE_ALL;
LOGS_DEFAULT(WARNING) << "Got invalid graph optimization level; defaulting to ORT_ENABLE_ALL";
break;
}
return retval;
},
[](PySessionOptions* options, GraphOptimizationLevel level) -> void {
switch (level) {
case ORT_DISABLE_ALL:
options->value.graph_optimization_level = onnxruntime::TransformerLevel::Default;
break;
case ORT_ENABLE_BASIC:
options->value.graph_optimization_level = onnxruntime::TransformerLevel::Level1;
break;
case ORT_ENABLE_EXTENDED:
options->value.graph_optimization_level = onnxruntime::TransformerLevel::Level2;
break;
case ORT_ENABLE_ALL:
options->value.graph_optimization_level = onnxruntime::TransformerLevel::Level3;
break;
}
},
R"pbdoc(Graph optimization level for this session.)pbdoc")
.def_property(
"use_deterministic_compute",
[](const PySessionOptions* options) -> bool { return options->value.use_deterministic_compute; },
[](PySessionOptions* options, bool use_deterministic_compute) -> void {
options->value.use_deterministic_compute = use_deterministic_compute;
},
R"pbdoc(Whether to use deterministic compute. Default is false.)pbdoc")
.def(
"add_free_dimension_override_by_denotation",
[](PySessionOptions* options, const char* dim_name, int64_t dim_value)
-> void { options->value.free_dimension_overrides.push_back(
onnxruntime::FreeDimensionOverride{
dim_name,
onnxruntime::FreeDimensionOverrideType::Denotation,
dim_value}); },
R"pbdoc(Specify the dimension size for each denotation associated with an input's free dimension.)pbdoc")
.def(
"add_free_dimension_override_by_name",
[](PySessionOptions* options, const char* dim_name, int64_t dim_value)
-> void { options->value.free_dimension_overrides.push_back(
onnxruntime::FreeDimensionOverride{
dim_name,
onnxruntime::FreeDimensionOverrideType::Name,
dim_value}); },
R"pbdoc(Specify values of named dimensions within model inputs.)pbdoc")
.def(
"add_session_config_entry",
[](PySessionOptions* options, const char* config_key, const char* config_value) -> void {
// config_key and config_value will be copied
const Status status = options->value.config_options.AddConfigEntry(config_key, config_value);
if (!status.IsOK())
throw std::runtime_error(status.ErrorMessage());
},
R"pbdoc(Set a single session configuration entry as a pair of strings.)pbdoc")
.def(
"get_session_config_entry",
[](const PySessionOptions* options, const char* config_key) -> std::string {
const std::string key(config_key);
std::string value;
if (!options->value.config_options.TryGetConfigEntry(key, value))
throw std::runtime_error("SessionOptions does not have configuration with key: " + key);
return value;
},
R"pbdoc(Get a single session configuration value using the given configuration key.)pbdoc")
.def(
"register_custom_ops_library",
[](PySessionOptions* options, const char* library_name) -> void {
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
OrtPybindThrowIfError(options->RegisterCustomOpsLibrary(ToPathString(library_name)));
#else
ORT_UNUSED_PARAMETER(options);
ORT_UNUSED_PARAMETER(library_name);
ORT_THROW("Custom Ops are not supported in this build.");
#endif
},
R"pbdoc(Specify the path to the shared library containing the custom op kernels required to run a model.)pbdoc")
.def(
"add_initializer", [](PySessionOptions* options, const char* name, py::object& ml_value_pyobject) -> void {
ORT_ENFORCE(strcmp(Py_TYPE(ml_value_pyobject.ptr())->tp_name, PYTHON_ORTVALUE_OBJECT_NAME) == 0, "The provided Python object must be an OrtValue");
// The user needs to ensure that the python OrtValue being provided as an overriding initializer
// is not destructed as long as any session that uses the provided OrtValue initializer is still in scope
// This is no different than the native APIs
const OrtValue* ml_value = ml_value_pyobject.attr(PYTHON_ORTVALUE_NATIVE_OBJECT_ATTR).cast<OrtValue*>();
ORT_THROW_IF_ERROR(options->value.AddInitializer(name, ml_value));
})
.def("add_external_initializers", [](PySessionOptions* options, py::list& names, const py::list& ort_values) -> void {
#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_EXTERNAL_INITIALIZERS)
const auto init_num = ort_values.size();
ORT_ENFORCE(init_num == names.size(), "Expecting names and ort_values lists to have equal length");
InlinedVector<std::string> names_ptrs;
InlinedVector<OrtValue> values_ptrs;
names_ptrs.reserve(init_num);
values_ptrs.reserve(init_num);
for (size_t i = 0; i < init_num; ++i) {
names_ptrs.emplace_back(py::str(names[i]));
values_ptrs.emplace_back(*ort_values[i].attr(PYTHON_ORTVALUE_NATIVE_OBJECT_ATTR).cast<const OrtValue*>());
}
ORT_THROW_IF_ERROR(options->value.AddExternalInitializers(names_ptrs, values_ptrs));
#else
ORT_UNUSED_PARAMETER(options);
ORT_UNUSED_PARAMETER(names);
ORT_UNUSED_PARAMETER(ort_values);
ORT_THROW("External initializers are not supported in this build.");
#endif
});
py::class_<RunOptions>(m, "RunOptions", R"pbdoc(Configuration information for a single Run.)pbdoc")
.def(py::init())
.def_readwrite("log_severity_level", &RunOptions::run_log_severity_level,
R"pbdoc(Log severity level for a particular Run() invocation. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.)pbdoc")
.def_readwrite("log_verbosity_level", &RunOptions::run_log_verbosity_level,
R"pbdoc(VLOG level if DEBUG build and run_log_severity_level is 0.
Applies to a particular Run() invocation. Default is 0.)pbdoc")
.def_readwrite("logid", &RunOptions::run_tag,
"To identify logs generated by a particular Run() invocation.")
.def_readwrite("terminate", &RunOptions::terminate,
R"pbdoc(Set to True to terminate any currently executing calls that are using this
RunOptions instance. The individual calls will exit gracefully and return an error status.)pbdoc")
#ifdef ENABLE_TRAINING
.def_readwrite("training_mode", &RunOptions::training_mode,
R"pbdoc(Choose to run in training or inferencing mode)pbdoc")
#endif
.def_readwrite("only_execute_path_to_fetches", &RunOptions::only_execute_path_to_fetches,
R"pbdoc(Only execute the nodes needed by fetch list)pbdoc")
.def(
"add_run_config_entry",
[](RunOptions* options, const char* config_key, const char* config_value) -> void {
// config_key and config_value will be copied
const Status status = options->config_options.AddConfigEntry(config_key, config_value);
if (!status.IsOK())
throw std::runtime_error(status.ErrorMessage());
},
R"pbdoc(Set a single run configuration entry as a pair of strings.)pbdoc")
.def(
"get_run_config_entry",
[](const RunOptions* options, const char* config_key) -> std::string {
const std::string key(config_key);
std::string value;
if (!options->config_options.TryGetConfigEntry(key, value))
throw std::runtime_error("RunOptions does not have configuration with key: " + key);
return value;
},
R"pbdoc(Get a single run configuration value using the given configuration key.)pbdoc")
.def(
"add_active_adapter", [](RunOptions* options, lora::LoraAdapter* adapter) {
options->active_adapters.push_back(adapter);
},
R"pbdoc(Adds specified adapter as an active adapter)pbdoc");
py::class_<ModelMetadata>(m, "ModelMetadata", R"pbdoc(Pre-defined and custom metadata about the model.
It is usually used to identify the model used to run the prediction and
facilitate the comparison.)pbdoc")
.def_readwrite("producer_name", &ModelMetadata::producer_name, "producer name")
.def_readwrite("graph_name", &ModelMetadata::graph_name, "graph name")
.def_readwrite("domain", &ModelMetadata::domain, "ONNX domain")
.def_readwrite("description", &ModelMetadata::description, "description of the model")
.def_readwrite("graph_description", &ModelMetadata::graph_description, "description of the graph hosted in the model")
.def_readwrite("version", &ModelMetadata::version, "version of the model")
.def_readwrite("custom_metadata_map", &ModelMetadata::custom_metadata_map, "additional metadata");
py::class_<onnxruntime::NodeArg>(m, "NodeArg", R"pbdoc(Node argument definition, for both input and output,
including arg name, arg type (contains both type and shape).)pbdoc")
.def_property_readonly("name", &onnxruntime::NodeArg::Name, "node name")
.def_property_readonly(
"type", [](const onnxruntime::NodeArg& na) -> std::string {
return *(na.Type());
},
"node type")
.def("__str__", [](const onnxruntime::NodeArg& na) -> std::string {
std::ostringstream res;
res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape=";
auto shape = na.Shape();
std::vector<py::object> arr;
if (shape == nullptr || shape->dim_size() == 0) {
res << "[]";
} else {
res << "[";
for (int i = 0; i < shape->dim_size(); ++i) {
if (utils::HasDimValue(shape->dim(i))) {
res << shape->dim(i).dim_value();
} else if (utils::HasDimParam(shape->dim(i))) {
res << "'" << shape->dim(i).dim_param() << "'";
} else {
res << "None";
}
if (i < shape->dim_size() - 1) {
res << ", ";
}
}
res << "]";
}
res << ")";
return std::string(res.str()); }, "converts the node into a readable string")
.def_property_readonly("shape", [](const onnxruntime::NodeArg& na) -> std::vector<py::object> {
auto shape = na.Shape();
std::vector<py::object> arr;
if (shape == nullptr || shape->dim_size() == 0) {
return arr;
}
arr.resize(shape->dim_size());
for (int i = 0; i < shape->dim_size(); ++i) {
if (utils::HasDimValue(shape->dim(i))) {
arr[i] = py::cast(shape->dim(i).dim_value());
} else if (utils::HasDimParam(shape->dim(i))) {
arr[i] = py::cast(shape->dim(i).dim_param());
} else {
arr[i] = py::none();
}
}
return arr; }, "node shape (assuming the node holds a tensor)");
py::class_<SessionObjectInitializer> sessionObjectInitializer(m, "SessionObjectInitializer");
py::class_<PyInferenceSession>(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc")
// In Python3, a Python bytes object will be passed to C++ functions that accept std::string or char*
// without any conversion. So this init method can be used for model file path (string) and model content (bytes)
.def(py::init([](const PySessionOptions& so, const std::string arg, bool is_arg_file_name,
bool load_config_from_model = false) {
auto env = GetEnv();
std::unique_ptr<PyInferenceSession> sess;
// separate creation of the session from model loading unless we have to read the config from the model.
// in a minimal build we only support load via Load(...) and not at session creation time
if (load_config_from_model) {
#if !defined(ORT_MINIMAL_BUILD)
sess = std::make_unique<PyInferenceSession>(std::move(env), so, arg, is_arg_file_name);
RegisterCustomOpDomains(sess.get(), so);
OrtPybindThrowIfError(sess->GetSessionHandle()->Load());
#else
ORT_THROW("Loading configuration from an ONNX model is not supported in this build.");
#endif
} else {
sess = std::make_unique<PyInferenceSession>(std::move(env), so);
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
RegisterCustomOpDomains(sess.get(), so);
#endif
if (is_arg_file_name) {
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(arg));
} else {
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(arg.data(), narrow<int>(arg.size())));
}
}
return sess;
}))
.def(
"initialize_session",
[ep_registration_fn](PyInferenceSession* sess,
const std::vector<std::string>& provider_types = {},
const ProviderOptionsVector& provider_options = {},
const std::unordered_set<std::string>& disabled_optimizer_names = {}) {
InitializeSession(sess->GetSessionHandle(),
ep_registration_fn,
provider_types,
provider_options,
disabled_optimizer_names);
},
R"pbdoc(Load a model saved in ONNX or ORT format.)pbdoc")
.def("run",
[](PyInferenceSession* sess, const std::vector<std::string>& output_names,
const std::map<std::string, const py::object>& pyfeeds, RunOptions* run_options = nullptr)
-> py::list {
NameMLValMap feeds;
if (run_options != nullptr && !run_options->active_adapters.empty()) {
AppendLoraParametersAsInputs(*run_options, pyfeeds.size(), feeds);
} else {
feeds.reserve(pyfeeds.size());
}
for (const auto& feed : pyfeeds) {
// No need to process 'None's sent in by the user
// to feed Optional inputs in the graph.
// We just won't include anything in the feed and ORT
// will handle such implicit 'None's internally.
if (!feed.second.is(py::none())) {
OrtValue ml_value;
auto px = sess->GetSessionHandle()->GetModelInputs();
if (!px.first.IsOK() || !px.second) {
throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null");
}
CreateGenericMLValue(px.second, GetAllocator(), feed.first, feed.second, &ml_value);
ThrowIfPyErrOccured();
feeds.insert(std::make_pair(feed.first, std::move(ml_value)));
}
}
std::vector<OrtValue> fetches;
fetches.reserve(output_names.size());
common::Status status;
{
// release GIL to allow multiple python threads to invoke Run() in parallel.
py::gil_scoped_release release;
if (run_options != nullptr) {
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(*run_options, feeds, output_names, &fetches));
} else {
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(feeds, output_names, &fetches));
}
}
py::list result;
size_t pos = 0;
for (const auto& fet : fetches) {
if (fet.IsAllocated()) {
if (fet.IsTensor()) {
result.append(AddTensorAsPyObj(fet, nullptr, nullptr));
} else if (fet.IsSparseTensor()) {
result.append(GetPyObjectFromSparseTensor(pos, fet, nullptr));
} else {
result.append(AddNonTensorAsPyObj(fet, nullptr, nullptr));
}
} else { // Send back None because the corresponding OrtValue was empty
result.append(py::none());
}
++pos;
}
return result;
})
.def("run_async",
[](PyInferenceSession* sess,
const std::vector<std::string>& output_names,
const std::map<std::string, py::object>& pyfeeds,
PyCallback callback, py::object user_data = {},
RunOptions* run_options = nullptr)
-> void {
if (run_options != nullptr && !run_options->active_adapters.empty()) {
LOGS(*sess->GetSessionHandle()->GetLogger(), WARNING)
<< "run_async has active adapters specified, but won't have an effect";
}
std::unique_ptr<AsyncResource> async_resource = std::make_unique<AsyncResource>();
async_resource->callback = callback;
async_resource->user_data = user_data;
// prepare feeds
async_resource->ReserveFeeds(pyfeeds.size());
for (const auto& feed : pyfeeds) {
if (!feed.second.is(py::none())) {
OrtValue ml_value;
auto px = sess->GetSessionHandle()->GetModelInputs();
if (!px.first.IsOK() || !px.second) {
throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null");
}
CreateGenericMLValue(px.second, GetAllocator(), feed.first, feed.second, &ml_value);
ThrowIfPyErrOccured();
async_resource->feeds.push_back(ml_value);
async_resource->feeds_raw.push_back(&async_resource->feeds.back());
async_resource->feed_names.push_back(feed.first);
async_resource->feed_names_raw.push_back(async_resource->feed_names.back().c_str());
}
}
// prepare fetches
async_resource->ReserveFetches(output_names.size());
for (const auto& output_name : output_names) {
async_resource->fetch_names.push_back(output_name);
async_resource->fetch_names_raw.push_back(async_resource->fetch_names.back().c_str());
async_resource->fetches_raw.push_back({});
}
const RunOptions* run_async_option = run_options ? run_options : &async_resource->default_run_option;
common::Status status = sess->GetSessionHandle()->RunAsync(run_async_option,
gsl::span(async_resource->feed_names_raw.data(), async_resource->feed_names_raw.size()),
gsl::span(async_resource->feeds_raw.data(), async_resource->feeds_raw.size()),
gsl::span(async_resource->fetch_names_raw.data(), async_resource->fetch_names_raw.size()),
gsl::span(async_resource->fetches_raw.data(), async_resource->fetches_raw.size()),
AsyncCallback,
async_resource.get());
if (status.IsOK()) {
async_resource.release();
}
OrtPybindThrowIfError(status);
})
/// This method accepts a dictionary of feeds (name -> OrtValue) and the list of output_names
/// and returns a list of python objects representing OrtValues. Each name may represent either
/// a Tensor, SparseTensor or a TensorSequence.
.def("run_with_ort_values", [](PyInferenceSession* sess, const py::dict& feeds, const std::vector<std::string>& output_names, RunOptions* run_options = nullptr) -> std::vector<OrtValue> {
NameMLValMap ort_feeds;
if (run_options != nullptr && !run_options->active_adapters.empty()) {
AppendLoraParametersAsInputs(*run_options, feeds.size(), ort_feeds);
} else {
ort_feeds.reserve(feeds.size());
}
// item is always a copy since dict returns a value and not a ref
// and Apple XToolChain barks
for (const auto& item : feeds) {
auto name = item.first.cast<std::string>();
const OrtValue* ort_value = item.second.cast<const OrtValue*>();
ort_feeds.emplace(name, *ort_value);
}
std::vector<OrtValue> fetches;
fetches.reserve(output_names.size());
{
// release GIL to allow multiple python threads to invoke Run() in parallel.
py::gil_scoped_release release;
if (run_options != nullptr) {
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(*run_options, ort_feeds, output_names, &fetches));
} else {
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(ort_feeds, output_names, &fetches));
}
}
return fetches;
})
.def("run_with_ortvaluevector", [](PyInferenceSession* sess, RunOptions run_options, const std::vector<std::string>& feed_names, const std::vector<OrtValue>& feeds, const std::vector<std::string>& fetch_names, std::vector<OrtValue>& fetches, const std::vector<OrtDevice>& fetch_devices) -> void {
if (!run_options.active_adapters.empty()) {
LOGS(*sess->GetSessionHandle()->GetLogger(), WARNING)
<< "run_with_ortvaluevector has active adapters specified, but won't have an effect";
}
// release GIL to allow multiple python threads to invoke Run() in parallel.
py::gil_scoped_release release;
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(run_options, feed_names, feeds, fetch_names, &fetches, &fetch_devices));
})
.def("end_profiling", [](const PyInferenceSession* sess) -> std::string {
return sess->GetSessionHandle()->EndProfiling();
})
.def_property_readonly("get_profiling_start_time_ns", [](const PyInferenceSession* sess) -> uint64_t {
return sess->GetSessionHandle()->GetProfiling().GetStartTimeNs();
})
.def("get_providers", [](const PyInferenceSession* sess) -> const std::vector<std::string>& { return sess->GetSessionHandle()->GetRegisteredProviderTypes(); }, py::return_value_policy::reference_internal)
.def("get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& { return sess->GetSessionHandle()->GetAllProviderOptions(); }, py::return_value_policy::reference_internal)
.def_property_readonly("session_options", [](const PyInferenceSession* sess) -> PySessionOptions* {
auto session_options = std::make_unique<PySessionOptions>();
session_options->value = sess->GetSessionHandle()->GetSessionOptions();
return session_options.release(); }, py::return_value_policy::take_ownership)
.def_property_readonly("inputs_meta", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
auto res = sess->GetSessionHandle()->GetModelInputs();
OrtPybindThrowIfError(res.first);
return *(res.second); }, py::return_value_policy::reference_internal)
.def_property_readonly("outputs_meta", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
auto res = sess->GetSessionHandle()->GetModelOutputs();
OrtPybindThrowIfError(res.first);
return *(res.second); }, py::return_value_policy::reference_internal)
.def_property_readonly("overridable_initializers", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
auto res = sess->GetSessionHandle()->GetOverridableInitializers();
OrtPybindThrowIfError(res.first);
return *(res.second); }, py::return_value_policy::reference_internal)
.def_property_readonly("model_meta", [](const PyInferenceSession* sess) -> const onnxruntime::ModelMetadata& {
auto res = sess->GetSessionHandle()->GetModelMetadata();
OrtPybindThrowIfError(res.first);
return *(res.second); }, py::return_value_policy::reference_internal)
.def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void {
Status status;
if (run_options != nullptr && !run_options->active_adapters.empty()) {
LOGS(*sess->GetSessionHandle()->GetLogger(), WARNING)
<< "run_with_iobinding has active adapters specified, but won't have an effect";
}
// release GIL to allow multiple python threads to invoke Run() in parallel.
py::gil_scoped_release release;
if (!run_options)
status = sess->GetSessionHandle()->Run(*io_binding.Get());
else
status = sess->GetSessionHandle()->Run(*run_options, *io_binding.Get());
if (!status.IsOK())
throw std::runtime_error("Error in execution: " + status.ErrorMessage()); })
.def("get_tuning_results", [](PyInferenceSession* sess) -> py::list {
#if !defined(ORT_MINIMAL_BUILD)
auto results = sess->GetSessionHandle()->GetTuningResults();
py::list ret;
for (const auto& trs : results) {
py::dict py_trs;
py_trs["ep"] = trs.ep;
py_trs["results"] = trs.results;
py_trs["validators"] = trs.validators;
ret.append(std::move(py_trs));
}
return ret;
#else
ORT_UNUSED_PARAMETER(sess);
ORT_THROW("TunableOp and get_tuning_results are not supported in this build.");
#endif
})
.def("set_tuning_results", [](PyInferenceSession* sess, py::list results, bool error_on_invalid) -> void {
#if !defined(ORT_MINIMAL_BUILD)
std::vector<TuningResults> tuning_results;
for (auto handle : results) {
auto py_trs = handle.cast<py::dict>();
TuningResults trs;
trs.ep = py_trs["ep"].cast<py::str>();
for (const auto [py_op_sig, py_kernel_map] : py_trs["results"].cast<py::dict>()) {
KernelMap kernel_map;
for (const auto [py_params_sig, py_kernel_id] : py_kernel_map.cast<py::dict>()) {
kernel_map[py_params_sig.cast<py::str>()] = py_kernel_id.cast<py::int_>();
}
trs.results[py_op_sig.cast<py::str>()] = kernel_map;
}
for (const auto [k, v] : py_trs["validators"].cast<py::dict>()) {
trs.validators[k.cast<py::str>()] = v.cast<py::str>();
}
tuning_results.emplace_back(std::move(trs));
}
Status status = sess->GetSessionHandle()->SetTuningResults(tuning_results, error_on_invalid);
if (!status.IsOK()) {
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
}
#else
ORT_UNUSED_PARAMETER(sess);
ORT_UNUSED_PARAMETER(results);
ORT_UNUSED_PARAMETER(error_on_invalid);
ORT_THROW("TunableOp and set_tuning_results are not supported in this build.");
#endif
});
py::enum_<onnxruntime::ArenaExtendStrategy>(m, "ArenaExtendStrategy", py::arithmetic())
.value("kNextPowerOfTwo", onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo)
.value("kSameAsRequested", onnxruntime::ArenaExtendStrategy::kSameAsRequested)
.export_values();
}
bool CreateInferencePybindStateModule(py::module& m) {
m.doc() = "pybind11 stateful interface to ONNX runtime";
RegisterExceptions(m);
import_array1(false);
auto env = GetEnv();
addGlobalMethods(m);
addObjectMethods(m, RegisterExecutionProviders);
addOrtValueMethods(m);
addSparseTensorMethods(m);
addIoBindingMethods(m);
addAdapterFormatMethods(m);
#if !defined(__APPLE__) && !defined(ORT_MINIMAL_BUILD)
if (!InitProvidersSharedLibrary()) {
const logging::Logger& default_logger = logging::LoggingManager::DefaultLogger();
LOGS(default_logger, WARNING) << "Init provider bridge failed.";
}
#endif
addGlobalSchemaFunctions(m);
addOpSchemaSubmodule(m);
addOpKernelSubmodule(m);
return true;
}
// This function is only used by orttraining module
bool InitArray() {
import_array1(false);
return true;
}
namespace {
// This class provides a static shell for on-demand and thread-safe construction
// of Environment object for both Inference and Training python layers.
// Environment class contains objects such as default logger, that must be available
// for the entire duration of a program that makes use of onnxruntime library.
// Because Python is a garbage collected language and the order of destruction of objects
// is not guaranteed we design this class with the following important features.
// 1) we make this class a singleton that is a function local static. The function local statics
// are constructed when the function is called the very first time. This fact has several important
// properties.
// - First, it is constructed before it is first needed possibly by another static object
// and destroyed after that object is destroyed.
// - Second, it is constructed in a thread safe manner.
// - Last, this order of construction/destruction is enforced across the compilation units, as opposed
// to the static objects that are simply declared in order in a single unit, but their lifespan is
// unconnected to that of in other compilation units. This is achieved automatically by run-time
// by execution atexit() to build a chain.
// 2) We make Environment owned by a shared_ptr. This is done because python objects such as Inference and Training
// sessions depend on this global. We acquire a shared_ptr instance when those objects are instantiated
// and release it automatically when they are garbage collected. Although with this change all of the
// globals seem to have been destroyed after module is unloaded and GC runs before that, it is cheap and gives
// a piece of mind as there were situations when GC was still running in the past after Env was gone.
// TrainingEnv global also holds shared reference to this global.
// 3) We guard against singleton resurrection attempts to detect code runs that when it should
// not and make necessary adjustments.
// For all the related details and why it is needed see "Modern C++ design" by A. Alexandrescu Chapter 6.
class EnvInitializer {
public:
static std::shared_ptr<onnxruntime::Environment> SharedInstance() {
// Guard against attempts to resurrect the singleton
if (EnvInitializer::destroyed) {
ORT_THROW("Detected an attempt to resurrect destroyed Environment");
}
static EnvInitializer env_holder;
return env_holder.Get();
}
private:
EnvInitializer() {
std::unique_ptr<Environment> env_ptr;
Env::Default().GetTelemetryProvider().SetLanguageProjection(OrtLanguageProjection::ORT_PROJECTION_PYTHON);
OrtPybindThrowIfError(Environment::Create(std::make_unique<LoggingManager>(
std::make_unique<CLogSink>(),
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
&SessionObjectInitializer::default_logger_id),
env_ptr));
session_env_ = std::shared_ptr<Environment>(env_ptr.release());
destroyed = false;
}
~EnvInitializer() {
destroyed = true;
}
std::shared_ptr<Environment> Get() const {
return session_env_;
}
std::shared_ptr<Environment> session_env_;
static bool destroyed;
};
bool EnvInitializer::destroyed = false;
} // namespace
std::shared_ptr<onnxruntime::Environment> GetEnv() {
return EnvInitializer::SharedInstance();
}
} // namespace python
} // namespace onnxruntime