onnxruntime/onnxruntime/python/onnxruntime_pybind_state.cc
Chi Lo 56e4fda8a8
[TensorRT EP] Revert "Add new provider option to exclude nodes from running on TRT" (#22878)
- Revert https://github.com/microsoft/onnxruntime/pull/22681
- But still implicitly exclude DDS ops for TRT 10. Will later provide
better PR to add trt_op_types_to_exclude provider option.
2024-11-19 09:08:54 -08:00

2409 lines
117 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
// Licensed under the MIT License.
#include "python/onnxruntime_pybind_exceptions.h"
#include "python/onnxruntime_pybind_mlvalue.h"
#include "python/onnxruntime_pybind_state_common.h"
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#define PY_ARRAY_UNIQUE_SYMBOL onnxruntime_python_ARRAY_API
#include "python/numpy_helper.h"
#include "core/common/inlined_containers.h"
#include "core/common/logging/logging.h"
#include "core/common/logging/severity.h"
#include "core/common/narrow.h"
#include "core/common/optional.h"
#include "core/common/path_string.h"
#include "core/framework/arena_extend_strategy.h"
#include "core/framework/data_transfer_utils.h"
#include "core/framework/data_types_internal.h"
#include "core/framework/provider_options_utils.h"
#include "core/framework/random_seed.h"
#include "core/framework/sparse_tensor.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/TensorSeq.h"
#include "core/graph/graph_viewer.h"
#include "core/platform/env.h"
#include "core/providers/get_execution_providers.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"
#include "core/session/IOBinding.h"
#include "core/session/abi_session_options_impl.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/provider_bridge_ort.h"
#include "core/session/lora_adapters.h"
#ifdef ENABLE_ATEN
#include "contrib_ops/cpu/aten_ops/aten_op_executor.h"
#endif
#ifdef USE_CUDA
#include <cuda.h> // for CUDA_VERSION
#include <cudnn.h> // for CUDNN_MAJOR
#endif
#if defined(USE_COREML)
#include "core/providers/coreml/coreml_provider_factory.h"
#endif
#include <pybind11/functional.h>
// Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct,
// GCC 4.x doesn't seem to define this and it breaks the pipelines based on CentOS as it uses
// GCC 4.x.
// (This static var is referenced in GetCudaToHostMemCpyFunction())
const OrtDevice::DeviceType OrtDevice::GPU;
#if defined(_MSC_VER)
#pragma warning(disable : 4267 4996 4503)
#endif // _MSC_VER
#include <iterator>
#include <algorithm>
namespace onnxruntime {
namespace python {
namespace py = pybind11;
using namespace onnxruntime;
using namespace onnxruntime::logging;
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
// "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
// TODO: we may delay-init this variable
#pragma warning(disable : 26426)
#endif
static Env& platform_env = Env::Default();
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#endif
using PyCallback = std::function<void(std::vector<py::object>, py::object user_data, std::string)>;
struct AsyncResource {
std::vector<OrtValue> feeds;
std::vector<const OrtValue*> feeds_raw;
std::vector<std::string> feed_names;
std::vector<const char*> feed_names_raw;
std::vector<OrtValue*> fetches_raw; // will be released during destruction
std::vector<std::string> fetch_names;
std::vector<const char*> fetch_names_raw;
RunOptions default_run_option;
PyCallback callback;
py::object user_data;
void ReserveFeeds(size_t sz) {
feeds.reserve(sz);
feeds_raw.reserve(sz);
feed_names.reserve(sz);
feed_names_raw.reserve(sz);
}
void ReserveFetches(size_t sz) {
fetches_raw.reserve(sz);
fetch_names.reserve(sz);
fetch_names_raw.reserve(sz);
}
~AsyncResource() {
std::for_each(fetches_raw.begin(), fetches_raw.end(), [](const OrtValue* fetch) {
if (fetch) {
std::unique_ptr<const OrtValue> fetch_recycler(fetch);
}
});
fetches_raw.clear();
}
};
void AsyncCallback(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr ort_status) {
ORT_ENFORCE(user_data, "user data must not be NULL for callback in python");
auto invoke_callback = [&]() {
std::unique_ptr<AsyncResource> async_resource{reinterpret_cast<AsyncResource*>(user_data)};
Ort::Status status(ort_status);
// return on error
if (!status.IsOK()) {
async_resource->callback({}, async_resource->user_data, status.GetErrorMessage());
return;
}
std::vector<py::object> rfetch;
rfetch.reserve(num_outputs);
size_t pos = 0;
for (size_t ith = 0; ith < num_outputs; ++ith) {
const auto& fet = *outputs[ith];
if (fet.IsAllocated()) {
if (fet.IsTensor()) {
rfetch.push_back(AddTensorAsPyObj(fet, nullptr, nullptr));
} else if (fet.IsSparseTensor()) {
rfetch.push_back(GetPyObjectFromSparseTensor(pos, fet, nullptr));
} else {
rfetch.push_back(AddNonTensorAsPyObj(fet, nullptr, nullptr));
}
} else {
rfetch.push_back(py::none());
}
++pos;
}
async_resource->callback(rfetch, async_resource->user_data, "");
};
if (PyGILState_Check()) {
invoke_callback();
} else {
// acquire GIL to safely:
// 1) invoke python callback
// 2) create, manipulate, and destroy python objects
py::gil_scoped_acquire acquire;
invoke_callback();
}
}
void AppendLoraParametersAsInputs(const RunOptions& run_options,
size_t total_entries,
NameMLValMap& feeds) {
for (const auto* adapter : run_options.active_adapters) {
total_entries += adapter->GetParamNum();
}
feeds.reserve(total_entries + feeds.size());
// Append necessary inputs for active adapters
for (const auto* adapter : run_options.active_adapters) {
auto [begin, end] = adapter->GetParamIterators();
for (; begin != end; ++begin) {
const auto& [name, param] = *begin;
feeds.insert(std::make_pair(name, param.GetMapped()));
}
}
}
template <typename T>
static py::object AddNonTensor(const OrtValue& val,
const DataTransferManager* /*data_transfer_manager*/,
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* /*mem_cpy_to_host_functions*/) {
return py::cast(val.Get<T>());
}
// This function is used to return strings from a string tensor to python
// as a numpy array of strings
// Strings are always on CPU and must always be copied to python memory
py::array StringTensorToNumpyArray(const Tensor& tensor) {
// Create the result and allocate memory with the right size
py::array result(py::dtype(NPY_OBJECT), tensor.Shape().GetDims());
const auto span = tensor.DataAsSpan<std::string>();
auto* mutable_data = reinterpret_cast<py::object*>(result.mutable_data());
for (size_t i = 0, lim = span.size(); i < lim; ++i) {
mutable_data[i] = py::cast(span[i]);
}
return result;
}
pybind11::array PrimitiveTensorToNumpyOverOrtValue(const OrtValue& ort_value) {
const Tensor& tensor = ort_value.Get<Tensor>();
// The capsule destructor must be stateless
// We create a copy of OrtValue on the heap.
auto memory_release = [](void* data) {
auto* ort_value = reinterpret_cast<OrtValue*>(data);
delete ort_value;
};
const int numpy_type = OnnxRuntimeTensorToNumpyType(tensor.DataType());
auto ort_value_ptr = std::make_unique<OrtValue>(ort_value);
pybind11::capsule caps(ort_value_ptr.get(), memory_release);
ort_value_ptr.release();
// Not using array_t<T> because it may not handle MLFloat16 properly
pybind11::array result(py::dtype(numpy_type), tensor.Shape().GetDims(),
tensor.DataRaw(),
caps);
return result;
}
pybind11::array PrimitiveTensorToNumpyFromDevice(const OrtValue& ort_value, const DataTransferAlternative& dtm) {
const Tensor& tensor = ort_value.Get<Tensor>();
const int numpy_type = OnnxRuntimeTensorToNumpyType(tensor.DataType());
pybind11::array result(py::dtype(numpy_type), tensor.Shape().GetDims());
void* data = result.mutable_data();
if (std::holds_alternative<const DataTransferManager*>(dtm)) {
const DataTransferManager* data_transfer = std::get<const DataTransferManager*>(dtm);
static const OrtMemoryInfo cpu_alloc_info{onnxruntime::CPU, OrtDeviceAllocator};
const auto span = gsl::make_span<char>(reinterpret_cast<char*>(data), tensor.SizeInBytes());
ORT_THROW_IF_ERROR(CopyTensorDataToByteSpan(*data_transfer, tensor, cpu_alloc_info, span));
} else {
std::get<MemCpyFunc>(dtm)(data, tensor.DataRaw(), tensor.SizeInBytes());
}
return result;
}
// In all cases, we may not have access to a DataTransferManager, hence the user may specify functions that
// pretty much does what a DataTransferManager does - copy data from device(s) to the host
py::object GetPyObjFromTensor(const OrtValue& ort_value,
const DataTransferManager* data_transfer_manager,
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions) {
ORT_ENFORCE(ort_value.IsTensor(), "This function only supports tensors");
const auto& tensor = ort_value.Get<Tensor>();
if (tensor.IsDataTypeString()) {
ORT_ENFORCE(tensor.Location().device.Type() == OrtDevice::CPU, "Strings can only be on CPU");
// Create a numpy array of strings (python objects) by copy/converting them
py::array result = StringTensorToNumpyArray(tensor);
return py::cast<py::object>(result);
}
const auto device_type = tensor.Location().device.Type();
// Create an numpy array on top of the OrtValue memory, no copy
if (device_type == OrtDevice::CPU) {
py::array result = PrimitiveTensorToNumpyOverOrtValue(ort_value);
return py::cast<py::object>(result);
}
if (!data_transfer_manager && !mem_cpy_to_host_functions) {
throw std::runtime_error(
"GetPyObjFromTensor: Either data transfer manager or a "
"function to copy data to the host is needed to convert non-CPU tensor to numpy array");
}
py::array result;
if (data_transfer_manager != nullptr) {
result = PrimitiveTensorToNumpyFromDevice(ort_value, data_transfer_manager);
} else {
auto mem_cpy_to_host = mem_cpy_to_host_functions->find(device_type);
ORT_ENFORCE(mem_cpy_to_host != mem_cpy_to_host_functions->end(),
"Unable to locate a function that can copy data to the host from the device");
result = PrimitiveTensorToNumpyFromDevice(ort_value, mem_cpy_to_host->second);
}
return py::cast<py::object>(result);
}
const char* GetDeviceName(const OrtDevice& device) {
switch (device.Type()) {
case OrtDevice::CPU:
return CPU;
case OrtDevice::GPU:
return CUDA;
case OrtDevice::DML:
return DML;
case OrtDevice::FPGA:
return "FPGA";
case OrtDevice::NPU:
#ifdef USE_CANN
return CANN;
#else
return "NPU";
#endif
default:
ORT_THROW("Unknown device type: ", device.Type());
}
}
py::object GetPyObjectFromSparseTensor(size_t pos, const OrtValue& ort_value, const DataTransferManager* data_transfer_manager) {
#if !defined(DISABLE_SPARSE_TENSORS)
if (!ort_value.IsSparseTensor()) {
ORT_THROW("Must be a sparse tensor");
}
auto& logger = logging::LoggingManager::DefaultLogger();
const SparseTensor& src_sparse_tensor = ort_value.Get<SparseTensor>();
std::unique_ptr<PySparseTensor> py_sparse_tensor;
auto device_type = src_sparse_tensor.Location().device.Type();
if (device_type != OrtDevice::CPU) {
if (!data_transfer_manager) {
LOGS(logger, WARNING) << "Returned OrtValue with sparse tensor at position: " << pos << " is on GPU but no data_transfer_manager provided."
<< " Returned it will have its data on GPU, you can copy it using numpy_array_to_cpu()";
py_sparse_tensor = std::make_unique<PySparseTensor>(ort_value);
} else {
auto dst_sparse_tensor = std::make_unique<SparseTensor>(src_sparse_tensor.DataType(), src_sparse_tensor.DenseShape(), GetAllocator());
auto status = src_sparse_tensor.Copy(*data_transfer_manager, *dst_sparse_tensor);
OrtPybindThrowIfError(status);
py_sparse_tensor = std::make_unique<PySparseTensor>(std::move(dst_sparse_tensor));
}
} else {
py_sparse_tensor = std::make_unique<PySparseTensor>(ort_value);
}
py::object result = py::cast(py_sparse_tensor.get(), py::return_value_policy::take_ownership);
py_sparse_tensor.release();
return result;
#else
ORT_UNUSED_PARAMETER(pos);
ORT_UNUSED_PARAMETER(ort_value);
ORT_UNUSED_PARAMETER(data_transfer_manager);
ORT_THROW("SparseTensor support is disabled in this build.");
#endif // !defined(DISABLE_SPARSE_TENSORS)
}
template <>
py::object AddNonTensor<TensorSeq>(const OrtValue& val,
const DataTransferManager* data_transfer_manager,
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions) {
const auto& seq_tensors = val.Get<TensorSeq>();
py::list py_list;
for (const auto& ort_value : seq_tensors) {
py::object obj = GetPyObjFromTensor(ort_value, data_transfer_manager, mem_cpy_to_host_functions);
py_list.append(std::move(obj));
}
// XToolChain kills the build
// local variable 'py_list' will be copied despite being returned by name [-Werror,-Wreturn-std-move]
// call 'std::move' explicitly to avoid copying
// We choose to cast it to object explicitly
return py::cast<py::object>(py_list);
}
py::object AddNonTensorAsPyObj(const OrtValue& val,
const DataTransferManager* data_transfer_manager,
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions) {
// Should be in sync with core/framework/datatypes.h
auto val_type = val.Type();
if (val_type->IsTensorSequenceType()) {
return AddNonTensor<TensorSeq>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else {
#if !defined(DISABLE_ML_OPS)
utils::ContainerChecker c_checker(val_type);
if (c_checker.IsMap()) {
if (c_checker.IsMapOf<std::string, std::string>()) {
return AddNonTensor<MapStringToString>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else if (c_checker.IsMapOf<std::string, int64_t>()) {
return AddNonTensor<MapStringToInt64>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else if (c_checker.IsMapOf<std::string, float>()) {
return AddNonTensor<MapStringToFloat>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else if (c_checker.IsMapOf<std::string, double>()) {
return AddNonTensor<MapStringToDouble>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else if (c_checker.IsMapOf<int64_t, std::string>()) {
return AddNonTensor<MapInt64ToString>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else if (c_checker.IsMapOf<int64_t, int64_t>()) {
return AddNonTensor<MapInt64ToInt64>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else if (c_checker.IsMapOf<int64_t, float>()) {
return AddNonTensor<MapInt64ToFloat>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else if (c_checker.IsMapOf<int64_t, double>()) {
return AddNonTensor<MapInt64ToDouble>(val, data_transfer_manager, mem_cpy_to_host_functions);
}
} else {
if (c_checker.IsSequenceOf<std::map<std::string, float>>()) {
return AddNonTensor<VectorMapStringToFloat>(val, data_transfer_manager, mem_cpy_to_host_functions);
} else if (c_checker.IsSequenceOf<std::map<int64_t, float>>()) {
return AddNonTensor<VectorMapInt64ToFloat>(val, data_transfer_manager, mem_cpy_to_host_functions);
}
}
#endif
}
ORT_THROW("Non-tensor type is not supported in this build: ", val_type);
}
py::object AddTensorAsPyObj(const OrtValue& val, const DataTransferManager* data_transfer_manager,
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions) {
return GetPyObjFromTensor(val, data_transfer_manager, mem_cpy_to_host_functions);
}
static std::unique_ptr<onnxruntime::IExecutionProvider> LoadExecutionProvider(
const std::string& ep_shared_lib_path,
const ProviderOptions& provider_options = {},
const std::string& entry_symbol_name = "GetProvider") {
void* handle;
const auto path_str = ToPathString(ep_shared_lib_path);
auto error = Env::Default().LoadDynamicLibrary(path_str, false, &handle);
if (!error.IsOK()) {
throw std::runtime_error(error.ErrorMessage());
}
Provider* (*PGetProvider)();
OrtPybindThrowIfError(Env::Default().GetSymbolFromLibrary(handle, entry_symbol_name, (void**)&PGetProvider));
Provider* provider = PGetProvider();
std::shared_ptr<IExecutionProviderFactory> ep_factory = provider->CreateExecutionProviderFactory(&provider_options);
return ep_factory->CreateProvider();
}
#ifdef USE_CUDA
const CUDAExecutionProviderInfo GetCudaExecutionProviderInfo(ProviderInfo_CUDA* cuda_provider_info,
const ProviderOptionsMap& provider_options_map) {
ORT_ENFORCE(cuda_provider_info);
const auto it = provider_options_map.find(kCudaExecutionProvider);
CUDAExecutionProviderInfo info;
if (it != provider_options_map.end())
cuda_provider_info->CUDAExecutionProviderInfo__FromProviderOptions(it->second, info);
else {
info.device_id = cuda_device_id;
info.gpu_mem_limit = gpu_mem_limit;
info.arena_extend_strategy = arena_extend_strategy;
info.cudnn_conv_algo_search = cudnn_conv_algo_search;
info.do_copy_in_default_stream = do_copy_in_default_stream;
info.external_allocator_info = external_allocator_info;
info.tunable_op = tunable_op;
}
return info;
}
#endif
#ifdef USE_CANN
const CANNExecutionProviderInfo GetCannExecutionProviderInfo(ProviderInfo_CANN* cann_provider_info,
const ProviderOptionsMap& provider_options_map) {
ORT_ENFORCE(cann_provider_info);
const auto it = provider_options_map.find(kCannExecutionProvider);
CANNExecutionProviderInfo info;
if (it != provider_options_map.end())
cann_provider_info->CANNExecutionProviderInfo__FromProviderOptions(it->second, info);
return info;
}
#endif
#ifdef USE_ROCM
const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM* rocm_provider_info,
const ProviderOptionsMap& provider_options_map) {
ORT_ENFORCE(rocm_provider_info);
const auto it = provider_options_map.find(kRocmExecutionProvider);
ROCMExecutionProviderInfo info;
if (it != provider_options_map.end())
rocm_provider_info->ROCMExecutionProviderInfo__FromProviderOptions(it->second, info);
else {
info.device_id = cuda_device_id;
info.gpu_mem_limit = gpu_mem_limit;
info.arena_extend_strategy = arena_extend_strategy;
info.miopen_conv_exhaustive_search = miopen_conv_exhaustive_search;
info.do_copy_in_default_stream = do_copy_in_default_stream;
info.external_allocator_info = external_allocator_info;
info.tunable_op = tunable_op;
}
return info;
}
#endif
#ifdef USE_TENSORRT
void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOptions& options) {
if (auto* tensorrt_provider_info = TryGetProviderInfo_TensorRT()) {
auto is_already_in_domains = [&](std::string& domain_name, std::vector<OrtCustomOpDomain*>& domains) {
for (auto ptr : domains) {
if (domain_name == ptr->domain_) {
return true;
}
}
return false;
};
std::string trt_extra_plugin_lib_paths = "";
const auto it = options.find("trt_extra_plugin_lib_paths");
if (it != options.end()) {
trt_extra_plugin_lib_paths = it->second;
}
std::vector<OrtCustomOpDomain*> custom_op_domains;
tensorrt_provider_info->GetTensorRTCustomOpDomainList(custom_op_domains, trt_extra_plugin_lib_paths);
for (auto ptr : custom_op_domains) {
if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) {
so.custom_op_domains_.push_back(ptr);
} else {
LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option.";
}
}
} else {
ORT_THROW("Please install TensorRT libraries as mentioned in the GPU requirements page, make sure they're in the PATH or LD_LIBRARY_PATH, and that your GPU is supported.");
}
}
#endif
std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
const SessionOptions& session_options,
const std::string& type,
const ProviderOptionsMap& provider_options_map) {
if (type == kCpuExecutionProvider) {
return onnxruntime::CPUProviderFactoryCreator::Create(
session_options.enable_cpu_mem_arena)
->CreateProvider();
} else if (type == kTensorrtExecutionProvider) {
#ifdef USE_TENSORRT
// If the environment variable 'ORT_TENSORRT_UNAVAILABLE' exists, then we do not load TensorRT. This is set by _ld_preload for the manylinux case
// as in that case, trying to load the library itself will result in a crash due to the way that auditwheel strips dependencies.
if (Env::Default().GetEnvironmentVar("ORT_TENSORRT_UNAVAILABLE").empty()) {
// provider_options_map is just a reference to the ProviderOptionsMap instance, so it can be released anytime from application.
// So we need these std::string variables defined here as they will be kept alive for the lifetime of TRT EP and we can still access them from OrtTensorRTProviderOptionsV2 instance.
// (The reason is string copy is involved, for example params.trt_engine_cache_path = cache_path.c_str() and those std::string variable is referenced by OrtTensorRTProviderOptionsV2 instance
// and TRT EP instance, so it won't be released.)
std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources,
trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile, ep_context_file_path,
onnx_model_folder_path;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtTensorRTProviderOptionsV2 params;
for (auto option : it->second) {
if (option.first == "device_id") {
if (!option.second.empty()) {
params.device_id = std::stoi(option.second);
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'device_id' should be a number i.e. '0'.\n");
}
} else if (option.first == "user_compute_stream") {
if (!option.second.empty()) {
auto stream = std::stoull(option.second, nullptr, 0);
params.user_compute_stream = reinterpret_cast<void*>(stream);
params.has_user_compute_stream = true;
} else {
params.has_user_compute_stream = false;
ORT_THROW("[ERROR] [TensorRT] The value for the key 'user_compute_stream' should be a string to define the compute stream for the inference to run on.\n");
}
} else if (option.first == "trt_max_partition_iterations") {
if (!option.second.empty()) {
params.trt_max_partition_iterations = std::stoi(option.second);
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_max_partition_iterations' should be a positive integer number i.e. '1000'.\n");
}
} else if (option.first == "trt_min_subgraph_size") {
if (!option.second.empty()) {
params.trt_min_subgraph_size = std::stoi(option.second);
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_min_subgraph_size' should be a positive integer number i.e. '1'.\n");
}
} else if (option.first == "trt_max_workspace_size") {
if (!option.second.empty()) {
params.trt_max_workspace_size = std::stoull(option.second);
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_max_workspace_size' should be a number in byte i.e. '1073741824'.\n");
}
} else if (option.first == "trt_fp16_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_fp16_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_fp16_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_fp16_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_int8_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_int8_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_int8_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_int8_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_int8_calibration_table_name") {
if (!option.second.empty()) {
calibration_table = option.second;
params.trt_int8_calibration_table_name = calibration_table.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_int8_calibration_table_name' should be a file name i.e. 'cal_table'.\n");
}
} else if (option.first == "trt_int8_use_native_calibration_table") {
if (option.second == "True" || option.second == "true") {
params.trt_int8_use_native_calibration_table = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_int8_use_native_calibration_table = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_int8_use_native_calibration_table' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_dla_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_dla_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_dla_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dla_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_dla_core") {
if (!option.second.empty()) {
params.trt_dla_core = std::stoi(option.second);
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dla_core' should be a positive integer number i.e. '0'.\n");
}
} else if (option.first == "trt_dump_subgraphs") {
if (option.second == "True" || option.second == "true") {
params.trt_dump_subgraphs = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_dump_subgraphs = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dump_subgraphs' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_engine_cache_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_engine_cache_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_engine_cache_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_engine_cache_path") {
if (!option.second.empty()) {
cache_path = option.second;
params.trt_engine_cache_path = cache_path.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_path' should be a path string i.e. 'engine_cache'.\n");
}
} else if (option.first == "trt_engine_cache_prefix") {
if (!option.second.empty()) {
cache_prefix = option.second;
params.trt_engine_cache_prefix = cache_prefix.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_prefix' should be a string to customize engine cache prefix i.e. 'FRCNN' or 'yolov4'.\n");
}
} else if (option.first == "trt_weight_stripped_engine_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_weight_stripped_engine_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_weight_stripped_engine_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_weight_stripped_engine_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_onnx_model_folder_path") {
if (!option.second.empty()) {
onnx_model_folder_path = option.second;
params.trt_onnx_model_folder_path = onnx_model_folder_path.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_onnx_model_folder_path' should be a path string i.e. 'engine_cache'.\n");
}
} else if (option.first == "trt_engine_decryption_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_engine_decryption_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_engine_decryption_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_decryption_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_engine_decryption_lib_path") {
if (!option.second.empty()) {
lib_path = option.second;
params.trt_engine_decryption_lib_path = lib_path.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_decryption_lib_path' should be a path string i.e. 'decryption_lib'.\n");
}
} else if (option.first == "trt_force_sequential_engine_build") {
if (option.second == "True" || option.second == "true") {
params.trt_force_sequential_engine_build = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_force_sequential_engine_build = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_force_sequential_engine_build' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_context_memory_sharing_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_context_memory_sharing_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_context_memory_sharing_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_context_memory_sharing_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_layer_norm_fp32_fallback") {
if (option.second == "True" || option.second == "true") {
params.trt_layer_norm_fp32_fallback = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_layer_norm_fp32_fallback = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_layer_norm_fp32_fallback' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_timing_cache_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_timing_cache_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_timing_cache_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_timing_cache_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_timing_cache_path") {
if (!option.second.empty()) {
timing_cache_path = option.second;
params.trt_timing_cache_path = timing_cache_path.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_timing_cache_path' should be a path string i.e. 'cache_folder/'.\n");
}
} else if (option.first == "trt_force_timing_cache") {
if (option.second == "True" || option.second == "true") {
params.trt_force_timing_cache = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_force_timing_cache = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_force_timing_cache' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_detailed_build_log") {
if (option.second == "True" || option.second == "true") {
params.trt_detailed_build_log = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_detailed_build_log = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_detailed_build_log' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_build_heuristics_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_build_heuristics_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_build_heuristics_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_build_heuristics_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_sparsity_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_sparsity_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_sparsity_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_sparsity_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_builder_optimization_level") {
if (!option.second.empty()) {
params.trt_builder_optimization_level = std::stoi(option.second);
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_builder_optimization_level' should be a number i.e. '0'.\n");
}
} else if (option.first == "trt_auxiliary_streams") {
if (!option.second.empty()) {
params.trt_auxiliary_streams = std::stoi(option.second);
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_auxiliary_streams' should be a number i.e. '0'.\n");
}
} else if (option.first == "trt_tactic_sources") {
if (!option.second.empty()) {
trt_tactic_sources = option.second;
params.trt_tactic_sources = trt_tactic_sources.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_tactic_sources' should be a string. e.g. \"-CUDNN,+CUBLAS\" available keys: \"CUBLAS\"|\"CUBLAS_LT\"|\"CUDNN\"|\"EDGE_MASK_CONVOLUTIONS\".\n");
}
} else if (option.first == "trt_extra_plugin_lib_paths") {
if (!option.second.empty()) {
trt_extra_plugin_lib_paths = option.second;
params.trt_extra_plugin_lib_paths = trt_extra_plugin_lib_paths.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_extra_plugin_lib_paths' should be a path string.\n");
}
} else if (option.first == "trt_profile_min_shapes") {
if (!option.second.empty()) {
min_profile = option.second;
params.trt_profile_min_shapes = min_profile.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_min_shapes' should be a string of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'.\n");
}
} else if (option.first == "trt_profile_max_shapes") {
if (!option.second.empty()) {
max_profile = option.second;
params.trt_profile_max_shapes = max_profile.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_max_shapes' should be a string of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'.\n");
}
} else if (option.first == "trt_profile_opt_shapes") {
if (!option.second.empty()) {
opt_profile = option.second;
params.trt_profile_opt_shapes = opt_profile.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_opt_shapes' should be a string of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'.\n");
}
} else if (option.first == "trt_cuda_graph_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_cuda_graph_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_cuda_graph_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_cuda_graph_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_dump_ep_context_model") {
if (option.second == "True" || option.second == "true") {
params.trt_dump_ep_context_model = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_dump_ep_context_model = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dump_ep_context_model' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_ep_context_file_path") {
if (!option.second.empty()) {
ep_context_file_path = option.second;
params.trt_ep_context_file_path = ep_context_file_path.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_ep_context_file_path' should be a string.\n");
}
} else if (option.first == "trt_ep_context_embed_mode") {
if (!option.second.empty()) {
params.trt_ep_context_embed_mode = std::stoi(option.second);
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_ep_context_embed_mode' should be a positive integer number i.e. '1'.\n");
}
} else if (option.first == "trt_engine_hw_compatible") {
if (option.second == "True" || option.second == "true") {
params.trt_engine_hw_compatible = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_engine_hw_compatible = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_hw_compatible' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else {
ORT_THROW("Invalid TensorRT EP option: ", option.first);
}
}
if (std::shared_ptr<IExecutionProviderFactory> tensorrt_provider_factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&params)) {
return tensorrt_provider_factory->CreateProvider();
}
} else {
if (std::shared_ptr<IExecutionProviderFactory> tensorrt_provider_factory = onnxruntime::TensorrtProviderFactoryCreator::Create(cuda_device_id)) {
return tensorrt_provider_factory->CreateProvider();
}
}
}
LOGS_DEFAULT(WARNING) << "Failed to create "
<< type
<< ". Please reference "
<< "https://onnxruntime.ai/docs/execution-providers/"
<< "TensorRT-ExecutionProvider.html#requirements to ensure all dependencies are met.";
#endif
} else if (type == kMIGraphXExecutionProvider) {
#ifdef USE_MIGRAPHX
std::string calibration_table;
std::string save_model_path;
std::string load_model_path;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtMIGraphXProviderOptions params{
0,
0,
0,
0,
nullptr,
1,
"./compiled_model.mxr",
1,
"./compiled_model.mxr",
1};
for (auto option : it->second) {
if (option.first == "device_id") {
if (!option.second.empty()) {
params.device_id = std::stoi(option.second);
} else {
ORT_THROW("[ERROR] [MIGraphX] The value for the key 'device_id' should be a number i.e. '0'.\n");
}
} else if (option.first == "migraphx_fp16_enable") {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp16_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp16_enable = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_int8_enable") {
if (option.second == "True" || option.second == "true") {
params.migraphx_int8_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_int8_enable = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_int8_enable' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_int8_calibration_table_name") {
if (!option.second.empty()) {
calibration_table = option.second;
params.migraphx_int8_calibration_table_name = calibration_table.c_str();
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_int8_calibration_table_name' should be a "
"file name i.e. 'cal_table'.\n");
}
} else if (option.first == "migraphx_use_native_calibration_table") {
if (option.second == "True" || option.second == "true") {
params.migraphx_use_native_calibration_table = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_use_native_calibration_table = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_save_compiled_model") {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp16_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp16_enable = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_save_compiled_model' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_save_model_path") {
if (!option.second.empty()) {
save_model_path = option.second;
params.migraphx_save_model_path = save_model_path.c_str();
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_save_model_name' should be a "
"file name i.e. 'compiled_model.mxr'.\n");
}
} else if (option.first == "migraphx_load_compiled_model") {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp16_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp16_enable = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_load_compiled_model' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_load_model_path") {
if (!option.second.empty()) {
load_model_path = option.second;
params.migraphx_load_model_path = load_model_path.c_str();
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a "
"file name i.e. 'compiled_model.mxr'.\n");
}
} else if (option.first == "migraphx_exhaustive_tune") {
if (option.second == "True" || option.second == "true") {
params.migraphx_exhaustive_tune = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_exhaustive_tune = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migraphx_exhaustive_tune' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else {
ORT_THROW("Invalid MIGraphX EP option: ", option.first);
}
}
if (std::shared_ptr<IExecutionProviderFactory> migraphx_provider_factory =
onnxruntime::MIGraphXProviderFactoryCreator::Create(&params)) {
return migraphx_provider_factory->CreateProvider();
}
} else {
if (std::shared_ptr<IExecutionProviderFactory> migraphx_provider_factory =
onnxruntime::MIGraphXProviderFactoryCreator::Create(cuda_device_id)) {
return migraphx_provider_factory->CreateProvider();
}
}
#endif
} else if (type == kCudaExecutionProvider) {
#ifdef USE_CUDA
// If the environment variable 'CUDA_UNAVAILABLE' exists, then we do not load cuda.
// This is set by _ld_preload for the manylinux case as in that case,
// trying to load the library itself will result in a crash due to the way that auditwheel strips dependencies.
if (Env::Default().GetEnvironmentVar("ORT_CUDA_UNAVAILABLE").empty()) {
if (auto* cuda_provider_info = TryGetProviderInfo_CUDA()) {
const CUDAExecutionProviderInfo info = GetCudaExecutionProviderInfo(cuda_provider_info,
provider_options_map);
// This variable is never initialized because the APIs by which it should be initialized are deprecated,
// however they still exist are are in-use. Nevertheless, it is used to return CUDAAllocator,
// hence we must try to initialize it here if we can since FromProviderOptions might contain
// external CUDA allocator.
external_allocator_info = info.external_allocator_info;
return cuda_provider_info->CreateExecutionProviderFactory(info)->CreateProvider();
}
}
LOGS_DEFAULT(WARNING) << "Failed to create "
<< type
<< ". Require cuDNN " << CUDNN_MAJOR << ".* and "
<< "CUDA " << (CUDA_VERSION / 1000) << ".*"
#if defined(_MSC_VER)
<< ", and the latest MSVC runtime"
#endif
<< ". Please install all dependencies as mentioned in the GPU requirements page"
" (https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements), "
"make sure they're in the PATH, and that your GPU is supported.";
#endif
} else if (type == kRocmExecutionProvider) {
#ifdef USE_ROCM
if (auto* rocm_provider_info = TryGetProviderInfo_ROCM()) {
const ROCMExecutionProviderInfo info = GetRocmExecutionProviderInfo(rocm_provider_info,
provider_options_map);
// This variable is never initialized because the APIs by which is it should be initialized are deprecated,
// however they still exist and are in-use. Nevertheless, it is used to return ROCMAllocator, hence we must
// try to initialize it here if we can since FromProviderOptions might contain external ROCM allocator.
external_allocator_info = info.external_allocator_info;
return rocm_provider_info->CreateExecutionProviderFactory(info)->CreateProvider();
} else {
if (!Env::Default().GetEnvironmentVar("ROCM_PATH").empty()) {
ORT_THROW(
"ROCM_PATH is set but ROCM wasn't able to be loaded. Please install the correct version "
"of ROCM and MIOpen as mentioned in the GPU requirements page, make sure they're in the PATH, "
"and that your GPU is supported.");
}
}
#endif
} else if (type == kDnnlExecutionProvider) {
#ifdef USE_DNNL
// Generate dnnl_options
OrtDnnlProviderOptions dnnl_options;
// For Eigen and OpenMP
#if defined(DNNL_OPENMP)
int num_threads = 0;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
for (auto option : it->second) {
if (option.first == "num_of_threads") {
num_threads = std::stoi(option.second);
if (num_threads < 0) {
ORT_THROW(
"[ERROR] [OneDNN] Invalid entry for the key 'num_of_threads',"
" set number of threads or use '0' for default\n");
// If the user doesnt define num_threads, auto detect threads later
}
} else {
ORT_THROW("Invalid OneDNN EP option: ", option.first);
}
}
}
dnnl_options.threadpool_args = static_cast<void*>(&num_threads);
#endif // !defined(DNNL_ORT_THREAD)
dnnl_options.use_arena = session_options.enable_cpu_mem_arena;
return onnxruntime::DnnlProviderFactoryCreator::Create(&dnnl_options)->CreateProvider();
#endif
} else if (type == kOpenVINOExecutionProvider) {
#ifdef USE_OPENVINO
ProviderOptions OV_provider_options_map;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
for (auto option : it->second) {
if (option.first == "device_type") {
OV_provider_options_map[option.first] = option.second;
continue;
} else if (option.first == "precision") {
OV_provider_options_map[option.first] = option.second;
continue;
} else if (option.first == "enable_opencl_throttling") {
if (!(option.second == "True" || option.second == "true" ||
option.second == "False" || option.second == "false")) {
ORT_THROW("Invalid value passed for enable_opencl_throttling: ", option.second);
}
OV_provider_options_map[option.first] = option.second;
} else if (option.first == "disable_dynamic_shapes") {
if (!(option.second == "True" || option.second == "true" ||
option.second == "False" || option.second == "false")) {
ORT_THROW("Invalid value passed for disable_dynamic_shapes: ", option.second);
}
OV_provider_options_map[option.first] = option.second;
} else if (option.first == "enable_dynamic_shapes") {
LOGS_DEFAULT(WARNING) << " Deprecation notice - 'enable_dynamic_shapes' is Deprected. Upgrade the API to disable_dynamic_shapes parameter."
"Please refer https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements to ensure all dependencies are met.";
std::string value;
if (!(option.second == "True" || option.second == "true" ||
option.second == "False" || option.second == "false")) {
ORT_THROW("Invalid value passed for enable_dynamic_shapes: ", option.second);
}
if (option.second == "True" || option.second == "true") {
value = "false";
} else {
value = "true";
}
OV_provider_options_map["disable_dynamic_shapes"] = value;
} else if (option.first == "num_of_threads") {
OV_provider_options_map[option.first] = option.second;
continue;
} else if (option.first == "model_priority") {
OV_provider_options_map[option.first] = option.second;
continue;
} else if (option.first == "num_streams") {
OV_provider_options_map[option.first] = option.second;
continue;
} else if (option.first == "load_config") {
OV_provider_options_map[option.first] = option.second;
continue;
} else if (option.first == "cache_dir") {
OV_provider_options_map[option.first] = option.second;
continue;
} else if (option.first == "context") {
OV_provider_options_map[option.first] = option.second;
continue;
} else if (option.first == "enable_qdq_optimizer") {
OV_provider_options_map[option.first] = option.second;
continue;
} else {
ORT_THROW("Invalid OpenVINO EP option: ", option.first);
}
}
}
if (std::shared_ptr<IExecutionProviderFactory> openvino_provider_factory = onnxruntime::OpenVINOProviderFactoryCreator::Create(
&OV_provider_options_map, &session_options)) {
auto p = openvino_provider_factory->CreateProvider();
// Reset global variables config to avoid it being accidentally passed on to the next session
openvino_device_type.clear();
return p;
} else {
if (!Env::Default().GetEnvironmentVar("INTEL_OPENVINO_DIR").empty()) {
ORT_THROW("INTEL_OPENVINO_DIR is set but OpenVINO library wasn't able to be loaded. Please install a supported version of OpenVINO as mentioned in the requirements page (https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements), ensure dependency libraries are in the PATH and your hardware is supported.");
} else {
LOGS_DEFAULT(WARNING) << "Failed to create " << type << ". Please refer https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements to ensure all dependencies are met.";
}
}
#endif
} else if (type == kTvmExecutionProvider) {
#if USE_TVM
onnxruntime::tvm::TvmEPOptions info{};
const auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
info = onnxruntime::tvm::TvmEPOptionsHelper::FromProviderOptions(it->second);
}
return onnxruntime::TVMProviderFactoryCreator::Create(info)->CreateProvider();
#endif
} else if (type == kVitisAIExecutionProvider) {
#ifdef USE_VITISAI
ProviderOptions info{};
const auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
info = it->second;
}
info["session_options"] = std::to_string((uintptr_t)(void*)&session_options);
return onnxruntime::VitisAIProviderFactoryCreator::Create(info)->CreateProvider();
#endif
} else if (type == kAclExecutionProvider) {
#ifdef USE_ACL
bool enable_fast_math = false;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
for (auto option : it->second) {
if (option.first == "enable_fast_math") {
std::set<std::string> supported_values = {"true", "True", "false", "False"};
if (supported_values.find(option.second) != supported_values.end()) {
enable_fast_math = (option.second == "true") || (option.second == "True");
} else {
ORT_THROW(
"Invalid value for enable_fast_math. "
"Select from 'true' or 'false'\n");
}
} else {
ORT_THROW("Unrecognized option: ", option.first);
}
}
}
return onnxruntime::ACLProviderFactoryCreator::Create(enable_fast_math)
->CreateProvider();
#endif
} else if (type == kArmNNExecutionProvider) {
#ifdef USE_ARMNN
return onnxruntime::ArmNNProviderFactoryCreator::Create(
session_options.enable_cpu_mem_arena)
->CreateProvider();
#endif
} else if (type == kDmlExecutionProvider) {
#ifdef USE_DML
auto cit = provider_options_map.find(type);
return onnxruntime::DMLProviderFactoryCreator::CreateFromProviderOptions(
session_options.config_options, cit == provider_options_map.end() ? ProviderOptions{} : cit->second, true)
->CreateProvider();
#endif
} else if (type == kNnapiExecutionProvider) {
#if defined(USE_NNAPI)
#if !defined(__ANDROID__)
LOGS_DEFAULT(WARNING) << "NNAPI execution provider can only be used to generate ORT format model in this build.";
#endif
const auto partitioning_stop_ops_list = session_options.config_options.GetConfigEntry(
kOrtSessionOptionsConfigNnapiEpPartitioningStopOps);
return onnxruntime::NnapiProviderFactoryCreator::Create(0, partitioning_stop_ops_list)->CreateProvider();
#endif
} else if (type == kRknpuExecutionProvider) {
#ifdef USE_RKNPU
return onnxruntime::RknpuProviderFactoryCreator::Create()->CreateProvider();
#endif
} else if (type == kCoreMLExecutionProvider) {
#if defined(USE_COREML)
#if !defined(__APPLE__)
LOGS_DEFAULT(WARNING) << "CoreML execution provider can only be used to generate ORT format model in this build.";
#endif
uint32_t coreml_flags = 0;
const auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
const ProviderOptions& options = it->second;
auto flags = options.find("flags");
if (flags != options.end()) {
const auto& flags_str = flags->second;
if (flags_str.find("COREML_FLAG_USE_CPU_ONLY") != std::string::npos) {
coreml_flags |= COREMLFlags::COREML_FLAG_USE_CPU_ONLY;
} else if (flags_str.find("COREML_FLAG_USE_CPU_AND_GPU") != std::string::npos) {
coreml_flags |= COREMLFlags::COREML_FLAG_USE_CPU_AND_GPU;
}
if (flags_str.find("COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES") != std::string::npos) {
coreml_flags |= COREMLFlags::COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES;
}
if (flags_str.find("COREML_FLAG_CREATE_MLPROGRAM") != std::string::npos) {
coreml_flags |= COREMLFlags::COREML_FLAG_CREATE_MLPROGRAM;
}
}
}
return onnxruntime::CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider();
#endif
} else if (type == kXnnpackExecutionProvider) {
#if defined(USE_XNNPACK)
auto cit = provider_options_map.find(type);
return onnxruntime::XnnpackProviderFactoryCreator::Create(
cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options)
->CreateProvider();
#endif
} else if (type == kWebGpuExecutionProvider) {
#if defined(USE_WEBGPU)
return onnxruntime::WebGpuProviderFactoryCreator::Create(session_options.config_options)->CreateProvider();
#endif
} else if (type == kCannExecutionProvider) {
#ifdef USE_CANN
if (auto* cann_provider_info = TryGetProviderInfo_CANN()) {
const CANNExecutionProviderInfo info = GetCannExecutionProviderInfo(cann_provider_info,
provider_options_map);
return cann_provider_info->CreateExecutionProviderFactory(info)->CreateProvider();
} else {
ORT_THROW("create CANN ExecutionProvider fail");
}
#endif
} else if (type == kAzureExecutionProvider) {
#ifdef USE_AZURE
return onnxruntime::AzureProviderFactoryCreator::Create({})->CreateProvider();
#endif
} else if (type == kQnnExecutionProvider) {
#ifdef USE_QNN
auto cit = provider_options_map.find(type);
return onnxruntime::QNNProviderFactoryCreator::Create(
cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options)
->CreateProvider();
#endif
} else {
// check whether it is a dynamic load EP:
const auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
auto shared_lib_path_it = it->second.find(kExecutionProviderSharedLibraryPath);
if (shared_lib_path_it != it->second.end()) {
// this is an EP with dynamic loading
// construct the provider option
ProviderOptions provider_options;
std::string entry_symbol = kDefaultExecutionProviderEntry;
for (auto option : it->second) {
if (option.first == kExecutionProviderSharedLibraryEntry) {
entry_symbol = option.second;
} else if (option.first != kExecutionProviderSharedLibraryPath) {
provider_options.insert(option);
}
}
return LoadExecutionProvider(shared_lib_path_it->second, provider_options, entry_symbol);
}
}
// unknown provider
throw std::runtime_error("Unknown Provider Type: " + type);
}
return nullptr;
}
/*
* Register execution provider with options.
*/
static void RegisterExecutionProviders(InferenceSession* sess, const std::vector<std::string>& provider_types,
const ProviderOptionsMap& provider_options_map) {
ORT_UNUSED_PARAMETER(provider_options_map);
for (const std::string& type : provider_types) {
auto ep = CreateExecutionProviderInstance(sess->GetSessionOptions(), type, provider_options_map);
if (ep)
OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(ep)));
}
}
/**
* Generate a map for mapping execution provider to excution provider options.
*
* @param providers vector of excution providers. [ep1, ep2, ...]
* @param provider_options_vector vector of excution provider options. [option1, option2 ...]
* @param provider_options_map an unordered map for mapping excution provider to excution provider options.
* {'ep1' -> option1, 'ep2' -> option2 ...}
*
*/
static void GenerateProviderOptionsMap(const std::vector<std::string>& providers,
const ProviderOptionsVector& provider_options_vector,
ProviderOptionsMap& provider_options_map) {
if (provider_options_vector.empty() || providers.empty()) {
return;
}
std::size_t j = 0; // index for provider_options_vector
for (const std::string& type : providers) {
if (j < provider_options_vector.size() && !provider_options_vector[j].empty()) {
provider_options_map[type] = provider_options_vector[j];
}
j += 1;
}
}
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
static void RegisterCustomOpDomains(PyInferenceSession* sess, const PySessionOptions& so) {
if (!so.custom_op_domains_.empty()) {
// Register all custom op domains that will be needed for the session
std::vector<OrtCustomOpDomain*> custom_op_domains;
custom_op_domains.reserve(so.custom_op_domains_.size());
for (size_t i = 0; i < so.custom_op_domains_.size(); ++i) {
custom_op_domains.emplace_back(so.custom_op_domains_[i]);
}
OrtPybindThrowIfError(sess->GetSessionHandle()->AddCustomOpDomains(custom_op_domains));
}
}
#endif
void InitializeSession(InferenceSession* sess,
ExecutionProviderRegistrationFn ep_registration_fn,
const std::vector<std::string>& provider_types,
const ProviderOptionsVector& provider_options,
const std::unordered_set<std::string>& disabled_optimizer_names) {
ProviderOptionsMap provider_options_map;
GenerateProviderOptionsMap(provider_types, provider_options, provider_options_map);
ep_registration_fn(sess, provider_types, provider_options_map);
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
if (!disabled_optimizer_names.empty()) {
OrtPybindThrowIfError(sess->FilterEnabledOptimizers({disabled_optimizer_names.cbegin(), disabled_optimizer_names.cend()}));
}
#else
ORT_UNUSED_PARAMETER(disabled_optimizer_names);
#endif
OrtPybindThrowIfError(sess->Initialize());
}
bool CheckIfTensor(const std::vector<const NodeArg*>& def_list,
const std::string& name,
/*out*/ onnx::TypeProto& type_proto) {
auto ret_it = std::find_if(std::begin(def_list), std::end(def_list),
[&name](const NodeArg* node_arg) { return name == node_arg->Name(); });
if (ret_it == std::end(def_list)) {
throw std::runtime_error("Failed to find NodeArg with name: " + name + " in the def list");
}
const auto* temp = (*ret_it)->TypeAsProto();
if (!temp) {
throw std::runtime_error("Corresponding type_proto is null");
} else {
type_proto = *temp;
}
return type_proto.has_tensor_type();
}
#if defined(USE_OPENVINO) || \
defined(USE_CUDA) || \
defined(USE_ROCM)
static void LogDeprecationWarning(
const std::string& deprecated, const optional<std::string>& alternative = nullopt) {
LOGS_DEFAULT(WARNING) << "This is DEPRECATED and will be removed in the future: " << deprecated;
LOGS_DEFAULT_IF(alternative.has_value(), WARNING) << "As an alternative, use: " << *alternative;
}
#endif
void addGlobalMethods(py::module& m) {
m.def("get_default_session_options", &GetDefaultCPUSessionOptions, "Return a default session_options instance.");
m.def("get_session_initializer", &SessionObjectInitializer::Get, "Return a default session object initializer.");
m.def(
"get_device", []() -> std::string { return BACKEND_DEVICE; },
"Return the device used to compute the prediction (CPU, MKL, ...)");
m.def(
"set_seed", [](const int64_t seed) { utils::SetRandomSeed(seed); },
"Sets the seed used for random number generation in Onnxruntime.");
m.def(
"set_default_logger_severity", [](int severity) {
ORT_ENFORCE(severity >= 0 && severity <= 4,
"Invalid logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal");
auto env = GetEnv();
logging::LoggingManager* default_logging_manager = env->GetLoggingManager();
default_logging_manager->SetDefaultLoggerSeverity(static_cast<logging::Severity>(severity));
},
"Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal");
m.def(
"set_default_logger_verbosity", [](int vlog_level) {
auto env = GetEnv();
logging::LoggingManager* default_logging_manager = env->GetLoggingManager();
default_logging_manager->SetDefaultLoggerVerbosity(vlog_level);
},
"Sets the default logging verbosity level. To activate the verbose log, "
"you need to set the default logging severity to 0:Verbose level.");
m.def(
"get_all_providers", []() -> const std::vector<std::string>& { return GetAllExecutionProviderNames(); },
"Return list of Execution Providers that this version of Onnxruntime can support. "
"The order of elements represents the default priority order of Execution Providers "
"from highest to lowest.");
m.def(
"enable_telemetry_events", []() -> void { platform_env.GetTelemetryProvider().EnableTelemetryEvents(); },
"Enables platform-specific telemetry collection where applicable.");
m.def(
"disable_telemetry_events", []() -> void { platform_env.GetTelemetryProvider().DisableTelemetryEvents(); },
"Disables platform-specific telemetry collection.");
m.def(
"create_and_register_allocator", [](const OrtMemoryInfo& mem_info, const OrtArenaCfg* arena_cfg = nullptr) -> void {
auto env = GetEnv();
auto st = env->CreateAndRegisterAllocator(mem_info, arena_cfg);
if (!st.IsOK()) {
throw std::runtime_error("Error when creating and registering allocator: " + st.ErrorMessage());
}
});
m.def(
"create_and_register_allocator_v2", [](const std::string& provider_type, const OrtMemoryInfo& mem_info, const ProviderOptions& options, const OrtArenaCfg* arena_cfg = nullptr) -> void {
auto env = GetEnv();
auto st = env->CreateAndRegisterAllocatorV2(provider_type, mem_info, options, arena_cfg);
if (!st.IsOK()) {
throw std::runtime_error("Error when creating and registering allocator in create_and_register_allocator_v2: " + st.ErrorMessage());
}
});
#ifdef USE_OPENVINO
m.def(
"get_available_openvino_device_ids", []() -> std::vector<std::string> {
if (auto* info = GetProviderInfo_OpenVINO()) {
return info->GetAvailableDevices();
}
return {};
},
"Lists all OpenVINO device ids available.");
/*
* The following APIs to set config options are deprecated. Use Session.set_providers() instead.
*/
// TODO remove deprecated global config
m.def(
"set_openvino_device", [](const std::string& device_type) {
LogDeprecationWarning("set_openvino_device", "OpenVINO execution provider option \"device_type\"");
openvino_device_type = device_type;
},
"Set the preferred OpenVINO device type to be used. If left unset, "
"the device type selected during build time will be used.");
// TODO remove deprecated global config
m.def(
"get_openvino_device", []() -> std::string {
LogDeprecationWarning("get_openvino_device");
return openvino_device_type;
},
"Gets the dynamically selected OpenVINO device type for inference.");
#endif
#if defined(USE_CUDA) || defined(USE_ROCM)
/*
* The following set_* methods are deprecated.
*
* To achieve same result, please use the following python api:
* InferenceSession.set_providers(list_of_providers, list_of_provider_option_dicts)
*
*/
// TODO remove deprecated global config
m.def("set_cuda_device_id", [](const int id) {
LogDeprecationWarning("set_cuda_device_id", "CUDA/ROCM execution provider option \"device_id\"");
cuda_device_id = static_cast<OrtDevice::DeviceId>(id);
});
// TODO remove deprecated global config
m.def("set_cudnn_conv_algo_search", [](const OrtCudnnConvAlgoSearch algo) {
LogDeprecationWarning("set_cudnn_conv_algo_search", "CUDA execution provider option \"cudnn_conv_algo_search\"");
#ifdef USE_ROCM
ORT_UNUSED_PARAMETER(algo);
ORT_THROW("set_cudnn_conv_algo_search is not supported in ROCM");
#else
cudnn_conv_algo_search = algo;
#endif
});
// TODO remove deprecated global config
m.def("set_do_copy_in_default_stream", [](const bool use_single_stream) {
LogDeprecationWarning(
"set_do_copy_in_default_stream", "CUDA execution provider option \"do_copy_in_default_stream\"");
#ifdef USE_ROCM
ORT_UNUSED_PARAMETER(use_single_stream);
ORT_THROW("set_do_copy_in_default_stream is not supported in ROCM");
#else
do_copy_in_default_stream = use_single_stream;
#endif
});
// TODO remove deprecated global config
m.def("set_gpu_mem_limit", [](const int64_t limit) {
LogDeprecationWarning(
"set_gpu_mem_limit",
"CUDA execution provider option \"gpu_mem_limit\", ROCM execution provider option \"gpu_mem_limit\"");
gpu_mem_limit = gsl::narrow<size_t>(limit);
});
// TODO remove deprecated global config
m.def("set_arena_extend_strategy", [](const onnxruntime::ArenaExtendStrategy strategy) {
LogDeprecationWarning("set_arena_extend_strategy", "CUDA/ROCM execution provider option \"arena_extend_strategy\"");
arena_extend_strategy = strategy;
});
#endif
#ifdef USE_TENSORRT
m.def(
"register_tensorrt_plugins_as_custom_ops", [](PySessionOptions& so, const ProviderOptions& options) { RegisterTensorRTPluginsAsCustomOps(so, options); },
"Register TensorRT plugins as custom ops.");
#endif
#ifdef ENABLE_ATEN
m.def("register_aten_op_executor",
[](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void {
size_t is_tensor_argument_address_int, aten_op_executor_address_int;
ORT_THROW_IF_ERROR(
ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int));
ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int));
void* p_is_tensor_argument = reinterpret_cast<void*>(is_tensor_argument_address_int);
void* p_aten_op_executor = reinterpret_cast<void*>(aten_op_executor_address_int);
contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor);
});
#endif
}
void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn) {
py::enum_<GraphOptimizationLevel>(m, "GraphOptimizationLevel")
.value("ORT_DISABLE_ALL", GraphOptimizationLevel::ORT_DISABLE_ALL)
.value("ORT_ENABLE_BASIC", GraphOptimizationLevel::ORT_ENABLE_BASIC)
.value("ORT_ENABLE_EXTENDED", GraphOptimizationLevel::ORT_ENABLE_EXTENDED)
.value("ORT_ENABLE_ALL", GraphOptimizationLevel::ORT_ENABLE_ALL);
py::enum_<ExecutionMode>(m, "ExecutionMode")
.value("ORT_SEQUENTIAL", ExecutionMode::ORT_SEQUENTIAL)
.value("ORT_PARALLEL", ExecutionMode::ORT_PARALLEL);
py::enum_<ExecutionOrder>(m, "ExecutionOrder")
.value("DEFAULT", ExecutionOrder::DEFAULT)
.value("PRIORITY_BASED", ExecutionOrder::PRIORITY_BASED)
.value("MEMORY_EFFICIENT", ExecutionOrder::MEMORY_EFFICIENT);
py::enum_<OrtAllocatorType>(m, "OrtAllocatorType")
.value("INVALID", OrtInvalidAllocator)
.value("ORT_DEVICE_ALLOCATOR", OrtDeviceAllocator)
.value("ORT_ARENA_ALLOCATOR", OrtArenaAllocator);
py::enum_<OrtMemType>(m, "OrtMemType")
.value("CPU_INPUT", OrtMemTypeCPUInput)
.value("CPU_OUTPUT", OrtMemTypeCPUOutput)
.value("CPU", OrtMemTypeCPU)
.value("DEFAULT", OrtMemTypeDefault);
py::class_<OrtDevice> device(m, "OrtDevice", R"pbdoc(ONNXRuntime device informaion.)pbdoc");
device.def(py::init<OrtDevice::DeviceType, OrtDevice::MemoryType, OrtDevice::DeviceId>())
.def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc")
.def("device_type", &OrtDevice::Type, R"pbdoc(Device Type.)pbdoc")
.def_static("cpu", []() { return OrtDevice::CPU; })
.def_static("cuda", []() { return OrtDevice::GPU; })
.def_static("cann", []() { return OrtDevice::NPU; })
.def_static("fpga", []() { return OrtDevice::FPGA; })
.def_static("npu", []() { return OrtDevice::NPU; })
.def_static("dml", []() { return OrtDevice::DML; })
.def_static("webgpu", []() { return OrtDevice::GPU; })
.def_static("default_memory", []() { return OrtDevice::MemType::DEFAULT; });
py::class_<OrtArenaCfg> ort_arena_cfg_binding(m, "OrtArenaCfg");
// Note: Doesn't expose initial_growth_chunk_sizes_bytes/max_power_of_two_extend_bytes option.
// This constructor kept for backwards compatibility, key-value pair constructor overload exposes all options
// There is a global var: arena_extend_strategy, which means we can't use that var name here
// See docs/C_API.md for details on what the following parameters mean and how to choose these values
ort_arena_cfg_binding.def(py::init([](size_t max_mem, int arena_extend_strategy_local,
int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
auto ort_arena_cfg = std::make_unique<OrtArenaCfg>();
ort_arena_cfg->max_mem = max_mem;
ort_arena_cfg->arena_extend_strategy = arena_extend_strategy_local;
ort_arena_cfg->initial_chunk_size_bytes = initial_chunk_size_bytes;
ort_arena_cfg->max_dead_bytes_per_chunk = max_dead_bytes_per_chunk;
return ort_arena_cfg;
}))
.def(py::init([](const py::dict& feeds) {
auto ort_arena_cfg = std::make_unique<OrtArenaCfg>();
for (const auto kvp : feeds) {
std::string key = kvp.first.cast<std::string>();
if (key == "max_mem") {
ort_arena_cfg->max_mem = kvp.second.cast<size_t>();
} else if (key == "arena_extend_strategy") {
ort_arena_cfg->arena_extend_strategy = kvp.second.cast<int>();
} else if (key == "initial_chunk_size_bytes") {
ort_arena_cfg->initial_chunk_size_bytes = kvp.second.cast<int>();
} else if (key == "max_dead_bytes_per_chunk") {
ort_arena_cfg->max_dead_bytes_per_chunk = kvp.second.cast<int>();
} else if (key == "initial_growth_chunk_size_bytes") {
ort_arena_cfg->initial_growth_chunk_size_bytes = kvp.second.cast<int>();
} else if (key == "max_power_of_two_extend_bytes") {
ort_arena_cfg->max_power_of_two_extend_bytes = kvp.second.cast<int>();
} else {
ORT_THROW("Invalid OrtArenaCfg option: ", key);
}
}
return ort_arena_cfg;
}))
.def_readwrite("max_mem", &OrtArenaCfg::max_mem)
.def_readwrite("arena_extend_strategy", &OrtArenaCfg::arena_extend_strategy)
.def_readwrite("initial_chunk_size_bytes", &OrtArenaCfg::initial_chunk_size_bytes)
.def_readwrite("max_dead_bytes_per_chunk", &OrtArenaCfg::max_dead_bytes_per_chunk)
.def_readwrite("initial_growth_chunk_size_bytes", &OrtArenaCfg::initial_growth_chunk_size_bytes)
.def_readwrite("max_power_of_two_extend_bytes", &OrtArenaCfg::max_power_of_two_extend_bytes);
py::class_<OrtMemoryInfo> ort_memory_info_binding(m, "OrtMemoryInfo");
ort_memory_info_binding.def(py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
if (strcmp(name, onnxruntime::CPU) == 0) {
return std::make_unique<OrtMemoryInfo>(onnxruntime::CPU, type, OrtDevice(), id, mem_type);
} else if (strcmp(name, onnxruntime::CUDA) == 0) {
return std::make_unique<OrtMemoryInfo>(
onnxruntime::CUDA, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id)), id,
mem_type);
} else if (strcmp(name, onnxruntime::CUDA_PINNED) == 0) {
return std::make_unique<OrtMemoryInfo>(
onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast<OrtDevice::DeviceId>(id)),
id, mem_type);
} else {
throw std::runtime_error("Specified device is not supported.");
}
}));
py::class_<PySessionOptions>
sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc");
sess
.def(py::init())
.def_property(
"enable_cpu_mem_arena",
[](const PySessionOptions* options) -> bool { return options->value.enable_cpu_mem_arena; },
[](PySessionOptions* options, bool enable_cpu_mem_arena) -> void {
options->value.enable_cpu_mem_arena = enable_cpu_mem_arena;
},
R"pbdoc(Enables the memory arena on CPU. Arena may pre-allocate memory for future usage.
Set this option to false if you don't want it. Default is True.)pbdoc")
.def_property(
"enable_profiling",
[](const PySessionOptions* options) -> bool { return options->value.enable_profiling; },
[](PySessionOptions* options, bool enable_profiling) -> void {
options->value.enable_profiling = enable_profiling;
},
R"pbdoc(Enable profiling for this session. Default is false.)pbdoc")
.def_property(
"profile_file_prefix",
[](const PySessionOptions* options) -> std::basic_string<ORTCHAR_T> {
return options->value.profile_file_prefix;
},
[](PySessionOptions* options, std::basic_string<ORTCHAR_T> profile_file_prefix) -> void {
options->value.profile_file_prefix = std::move(profile_file_prefix);
},
R"pbdoc(The prefix of the profile file. The current time will be appended to the file name.)pbdoc")
.def_property(
"optimized_model_filepath",
[](const PySessionOptions* options) -> std::basic_string<ORTCHAR_T> {
return options->value.optimized_model_filepath;
},
[](PySessionOptions* options, std::basic_string<ORTCHAR_T> optimized_model_filepath) -> void {
options->value.optimized_model_filepath = std::move(optimized_model_filepath);
},
R"pbdoc(
File path to serialize optimized model to.
Optimized model is not serialized unless optimized_model_filepath is set.
Serialized model format will default to ONNX unless:
- add_session_config_entry is used to set 'session.save_model_format' to 'ORT', or
- there is no 'session.save_model_format' config entry and optimized_model_filepath ends in '.ort' (case insensitive)
)pbdoc")
.def_property(
"enable_cpu_mem_arena",
[](const PySessionOptions* options) -> bool { return options->value.enable_cpu_mem_arena; },
[](PySessionOptions* options, bool enable_cpu_mem_arena) -> void {
options->value.enable_cpu_mem_arena = enable_cpu_mem_arena;
},
R"pbdoc(Enable memory arena on CPU. Default is true.)pbdoc")
.def_property(
"enable_mem_pattern",
[](const PySessionOptions* options) -> bool { return options->value.enable_mem_pattern; },
[](PySessionOptions* options, bool enable_mem_pattern) -> void {
options->value.enable_mem_pattern = enable_mem_pattern;
},
R"pbdoc(Enable the memory pattern optimization. Default is true.)pbdoc")
.def_property(
"enable_mem_reuse",
[](const PySessionOptions* options) -> bool { return options->value.enable_mem_reuse; },
[](PySessionOptions* options, bool enable_mem_reuse) -> void {
options->value.enable_mem_reuse = enable_mem_reuse;
},
R"pbdoc(Enable the memory reuse optimization. Default is true.)pbdoc")
.def_property(
"logid",
[](const PySessionOptions* options) -> std::string {
return options->value.session_logid;
},
[](PySessionOptions* options, std::string logid) -> void {
options->value.session_logid = std::move(logid);
},
R"pbdoc(Logger id to use for session output.)pbdoc")
.def_property(
"log_severity_level",
[](const PySessionOptions* options) -> int { return options->value.session_log_severity_level; },
[](PySessionOptions* options, int log_severity_level) -> void {
options->value.session_log_severity_level = log_severity_level;
},
R"pbdoc(Log severity level. Applies to session load, initialization, etc.
0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.)pbdoc")
.def_property(
"log_verbosity_level",
[](const PySessionOptions* options) -> int { return options->value.session_log_verbosity_level; },
[](PySessionOptions* options, int log_verbosity_level) -> void {
options->value.session_log_verbosity_level = log_verbosity_level;
},
R"pbdoc(VLOG level if DEBUG build and session_log_severity_level is 0.
Applies to session load, initialization, etc. Default is 0.)pbdoc")
.def_property(
"intra_op_num_threads",
[](const PySessionOptions* options) -> int { return options->value.intra_op_param.thread_pool_size; },
[](PySessionOptions* options, int value) -> void { options->value.intra_op_param.thread_pool_size = value; },
R"pbdoc(Sets the number of threads used to parallelize the execution within nodes. Default is 0 to let onnxruntime choose.)pbdoc")
.def_property(
"inter_op_num_threads",
[](const PySessionOptions* options) -> int { return options->value.inter_op_param.thread_pool_size; },
[](PySessionOptions* options, int value) -> void { options->value.inter_op_param.thread_pool_size = value; },
R"pbdoc(Sets the number of threads used to parallelize the execution of the graph (across nodes). Default is 0 to let onnxruntime choose.)pbdoc")
.def_property(
"execution_mode",
[](const PySessionOptions* options) -> ExecutionMode { return options->value.execution_mode; },
[](PySessionOptions* options, ExecutionMode execution_mode) -> void {
options->value.execution_mode = execution_mode;
},
R"pbdoc(Sets the execution mode. Default is sequential.)pbdoc")
.def_property(
"execution_order",
[](const PySessionOptions* options) -> ExecutionOrder { return options->value.execution_order; },
[](PySessionOptions* options, ExecutionOrder execution_order) -> void {
options->value.execution_order = execution_order;
},
R"pbdoc(Sets the execution order. Default is basic topological order.)pbdoc")
.def_property(
"graph_optimization_level",
[](const PySessionOptions* options) -> GraphOptimizationLevel {
GraphOptimizationLevel retval = ORT_ENABLE_ALL;
switch (options->value.graph_optimization_level) {
case onnxruntime::TransformerLevel::Default:
retval = ORT_DISABLE_ALL;
break;
case onnxruntime::TransformerLevel::Level1:
retval = ORT_ENABLE_BASIC;
break;
case onnxruntime::TransformerLevel::Level2:
retval = ORT_ENABLE_EXTENDED;
break;
case onnxruntime::TransformerLevel::Level3:
retval = ORT_ENABLE_ALL;
break;
default:
retval = ORT_ENABLE_ALL;
LOGS_DEFAULT(WARNING) << "Got invalid graph optimization level; defaulting to ORT_ENABLE_ALL";
break;
}
return retval;
},
[](PySessionOptions* options, GraphOptimizationLevel level) -> void {
switch (level) {
case ORT_DISABLE_ALL:
options->value.graph_optimization_level = onnxruntime::TransformerLevel::Default;
break;
case ORT_ENABLE_BASIC:
options->value.graph_optimization_level = onnxruntime::TransformerLevel::Level1;
break;
case ORT_ENABLE_EXTENDED:
options->value.graph_optimization_level = onnxruntime::TransformerLevel::Level2;
break;
case ORT_ENABLE_ALL:
options->value.graph_optimization_level = onnxruntime::TransformerLevel::Level3;
break;
}
},
R"pbdoc(Graph optimization level for this session.)pbdoc")
.def_property(
"use_deterministic_compute",
[](const PySessionOptions* options) -> bool { return options->value.use_deterministic_compute; },
[](PySessionOptions* options, bool use_deterministic_compute) -> void {
options->value.use_deterministic_compute = use_deterministic_compute;
},
R"pbdoc(Whether to use deterministic compute. Default is false.)pbdoc")
.def(
"add_free_dimension_override_by_denotation",
[](PySessionOptions* options, const char* dim_name, int64_t dim_value)
-> void { options->value.free_dimension_overrides.push_back(
onnxruntime::FreeDimensionOverride{
dim_name,
onnxruntime::FreeDimensionOverrideType::Denotation,
dim_value}); },
R"pbdoc(Specify the dimension size for each denotation associated with an input's free dimension.)pbdoc")
.def(
"add_free_dimension_override_by_name",
[](PySessionOptions* options, const char* dim_name, int64_t dim_value)
-> void { options->value.free_dimension_overrides.push_back(
onnxruntime::FreeDimensionOverride{
dim_name,
onnxruntime::FreeDimensionOverrideType::Name,
dim_value}); },
R"pbdoc(Specify values of named dimensions within model inputs.)pbdoc")
.def(
"add_session_config_entry",
[](PySessionOptions* options, const char* config_key, const char* config_value) -> void {
// config_key and config_value will be copied
const Status status = options->value.config_options.AddConfigEntry(config_key, config_value);
if (!status.IsOK())
throw std::runtime_error(status.ErrorMessage());
},
R"pbdoc(Set a single session configuration entry as a pair of strings.)pbdoc")
.def(
"get_session_config_entry",
[](const PySessionOptions* options, const char* config_key) -> std::string {
const std::string key(config_key);
std::string value;
if (!options->value.config_options.TryGetConfigEntry(key, value))
throw std::runtime_error("SessionOptions does not have configuration with key: " + key);
return value;
},
R"pbdoc(Get a single session configuration value using the given configuration key.)pbdoc")
.def(
"register_custom_ops_library",
[](PySessionOptions* options, const char* library_name) -> void {
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
OrtPybindThrowIfError(options->RegisterCustomOpsLibrary(ToPathString(library_name)));
#else
ORT_UNUSED_PARAMETER(options);
ORT_UNUSED_PARAMETER(library_name);
ORT_THROW("Custom Ops are not supported in this build.");
#endif
},
R"pbdoc(Specify the path to the shared library containing the custom op kernels required to run a model.)pbdoc")
.def(
"add_initializer", [](PySessionOptions* options, const char* name, py::object& ml_value_pyobject) -> void {
ORT_ENFORCE(strcmp(Py_TYPE(ml_value_pyobject.ptr())->tp_name, PYTHON_ORTVALUE_OBJECT_NAME) == 0, "The provided Python object must be an OrtValue");
// The user needs to ensure that the python OrtValue being provided as an overriding initializer
// is not destructed as long as any session that uses the provided OrtValue initializer is still in scope
// This is no different than the native APIs
const OrtValue* ml_value = ml_value_pyobject.attr(PYTHON_ORTVALUE_NATIVE_OBJECT_ATTR).cast<OrtValue*>();
ORT_THROW_IF_ERROR(options->value.AddInitializer(name, ml_value));
})
.def("add_external_initializers", [](PySessionOptions* options, py::list& names, const py::list& ort_values) -> void {
#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_EXTERNAL_INITIALIZERS)
const auto init_num = ort_values.size();
ORT_ENFORCE(init_num == names.size(), "Expecting names and ort_values lists to have equal length");
InlinedVector<std::string> names_ptrs;
InlinedVector<OrtValue> values_ptrs;
names_ptrs.reserve(init_num);
values_ptrs.reserve(init_num);
for (size_t i = 0; i < init_num; ++i) {
names_ptrs.emplace_back(py::str(names[i]));
values_ptrs.emplace_back(*ort_values[i].attr(PYTHON_ORTVALUE_NATIVE_OBJECT_ATTR).cast<const OrtValue*>());
}
ORT_THROW_IF_ERROR(options->value.AddExternalInitializers(names_ptrs, values_ptrs));
#else
ORT_UNUSED_PARAMETER(options);
ORT_UNUSED_PARAMETER(names);
ORT_UNUSED_PARAMETER(ort_values);
ORT_THROW("External initializers are not supported in this build.");
#endif
});
py::class_<RunOptions>(m, "RunOptions", R"pbdoc(Configuration information for a single Run.)pbdoc")
.def(py::init())
.def_readwrite("log_severity_level", &RunOptions::run_log_severity_level,
R"pbdoc(Log severity level for a particular Run() invocation. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.)pbdoc")
.def_readwrite("log_verbosity_level", &RunOptions::run_log_verbosity_level,
R"pbdoc(VLOG level if DEBUG build and run_log_severity_level is 0.
Applies to a particular Run() invocation. Default is 0.)pbdoc")
.def_readwrite("logid", &RunOptions::run_tag,
"To identify logs generated by a particular Run() invocation.")
.def_readwrite("terminate", &RunOptions::terminate,
R"pbdoc(Set to True to terminate any currently executing calls that are using this
RunOptions instance. The individual calls will exit gracefully and return an error status.)pbdoc")
#ifdef ENABLE_TRAINING
.def_readwrite("training_mode", &RunOptions::training_mode,
R"pbdoc(Choose to run in training or inferencing mode)pbdoc")
#endif
.def_readwrite("only_execute_path_to_fetches", &RunOptions::only_execute_path_to_fetches,
R"pbdoc(Only execute the nodes needed by fetch list)pbdoc")
.def(
"add_run_config_entry",
[](RunOptions* options, const char* config_key, const char* config_value) -> void {
// config_key and config_value will be copied
const Status status = options->config_options.AddConfigEntry(config_key, config_value);
if (!status.IsOK())
throw std::runtime_error(status.ErrorMessage());
},
R"pbdoc(Set a single run configuration entry as a pair of strings.)pbdoc")
.def(
"get_run_config_entry",
[](const RunOptions* options, const char* config_key) -> std::string {
const std::string key(config_key);
std::string value;
if (!options->config_options.TryGetConfigEntry(key, value))
throw std::runtime_error("RunOptions does not have configuration with key: " + key);
return value;
},
R"pbdoc(Get a single run configuration value using the given configuration key.)pbdoc")
.def(
"add_active_adapter", [](RunOptions* options, lora::LoraAdapter* adapter) {
options->active_adapters.push_back(adapter);
},
R"pbdoc(Adds specified adapter as an active adapter)pbdoc");
py::class_<ModelMetadata>(m, "ModelMetadata", R"pbdoc(Pre-defined and custom metadata about the model.
It is usually used to identify the model used to run the prediction and
facilitate the comparison.)pbdoc")
.def_readwrite("producer_name", &ModelMetadata::producer_name, "producer name")
.def_readwrite("graph_name", &ModelMetadata::graph_name, "graph name")
.def_readwrite("domain", &ModelMetadata::domain, "ONNX domain")
.def_readwrite("description", &ModelMetadata::description, "description of the model")
.def_readwrite("graph_description", &ModelMetadata::graph_description, "description of the graph hosted in the model")
.def_readwrite("version", &ModelMetadata::version, "version of the model")
.def_readwrite("custom_metadata_map", &ModelMetadata::custom_metadata_map, "additional metadata");
py::class_<onnxruntime::NodeArg>(m, "NodeArg", R"pbdoc(Node argument definition, for both input and output,
including arg name, arg type (contains both type and shape).)pbdoc")
.def_property_readonly("name", &onnxruntime::NodeArg::Name, "node name")
.def_property_readonly(
"type", [](const onnxruntime::NodeArg& na) -> std::string {
return *(na.Type());
},
"node type")
.def("__str__", [](const onnxruntime::NodeArg& na) -> std::string {
std::ostringstream res;
res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape=";
auto shape = na.Shape();
std::vector<py::object> arr;
if (shape == nullptr || shape->dim_size() == 0) {
res << "[]";
} else {
res << "[";
for (int i = 0; i < shape->dim_size(); ++i) {
if (utils::HasDimValue(shape->dim(i))) {
res << shape->dim(i).dim_value();
} else if (utils::HasDimParam(shape->dim(i))) {
res << "'" << shape->dim(i).dim_param() << "'";
} else {
res << "None";
}
if (i < shape->dim_size() - 1) {
res << ", ";
}
}
res << "]";
}
res << ")";
return std::string(res.str()); }, "converts the node into a readable string")
.def_property_readonly("shape", [](const onnxruntime::NodeArg& na) -> std::vector<py::object> {
auto shape = na.Shape();
std::vector<py::object> arr;
if (shape == nullptr || shape->dim_size() == 0) {
return arr;
}
arr.resize(shape->dim_size());
for (int i = 0; i < shape->dim_size(); ++i) {
if (utils::HasDimValue(shape->dim(i))) {
arr[i] = py::cast(shape->dim(i).dim_value());
} else if (utils::HasDimParam(shape->dim(i))) {
arr[i] = py::cast(shape->dim(i).dim_param());
} else {
arr[i] = py::none();
}
}
return arr; }, "node shape (assuming the node holds a tensor)");
py::class_<SessionObjectInitializer> sessionObjectInitializer(m, "SessionObjectInitializer");
py::class_<PyInferenceSession>(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc")
// In Python3, a Python bytes object will be passed to C++ functions that accept std::string or char*
// without any conversion. So this init method can be used for model file path (string) and model content (bytes)
.def(py::init([](const PySessionOptions& so, const std::string arg, bool is_arg_file_name,
bool load_config_from_model = false) {
auto env = GetEnv();
std::unique_ptr<PyInferenceSession> sess;
// separate creation of the session from model loading unless we have to read the config from the model.
// in a minimal build we only support load via Load(...) and not at session creation time
if (load_config_from_model) {
#if !defined(ORT_MINIMAL_BUILD)
sess = std::make_unique<PyInferenceSession>(std::move(env), so, arg, is_arg_file_name);
RegisterCustomOpDomains(sess.get(), so);
OrtPybindThrowIfError(sess->GetSessionHandle()->Load());
#else
ORT_THROW("Loading configuration from an ONNX model is not supported in this build.");
#endif
} else {
sess = std::make_unique<PyInferenceSession>(std::move(env), so);
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
RegisterCustomOpDomains(sess.get(), so);
#endif
if (is_arg_file_name) {
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(arg));
} else {
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(arg.data(), narrow<int>(arg.size())));
}
}
return sess;
}))
.def(
"initialize_session",
[ep_registration_fn](PyInferenceSession* sess,
const std::vector<std::string>& provider_types = {},
const ProviderOptionsVector& provider_options = {},
const std::unordered_set<std::string>& disabled_optimizer_names = {}) {
InitializeSession(sess->GetSessionHandle(),
ep_registration_fn,
provider_types,
provider_options,
disabled_optimizer_names);
},
R"pbdoc(Load a model saved in ONNX or ORT format.)pbdoc")
.def("run",
[](PyInferenceSession* sess, const std::vector<std::string>& output_names,
const std::map<std::string, const py::object>& pyfeeds, RunOptions* run_options = nullptr)
-> py::list {
NameMLValMap feeds;
if (run_options != nullptr && !run_options->active_adapters.empty()) {
AppendLoraParametersAsInputs(*run_options, pyfeeds.size(), feeds);
} else {
feeds.reserve(pyfeeds.size());
}
for (const auto& feed : pyfeeds) {
// No need to process 'None's sent in by the user
// to feed Optional inputs in the graph.
// We just won't include anything in the feed and ORT
// will handle such implicit 'None's internally.
if (!feed.second.is(py::none())) {
OrtValue ml_value;
auto px = sess->GetSessionHandle()->GetModelInputs();
if (!px.first.IsOK() || !px.second) {
throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null");
}
CreateGenericMLValue(px.second, GetAllocator(), feed.first, feed.second, &ml_value);
ThrowIfPyErrOccured();
feeds.insert(std::make_pair(feed.first, std::move(ml_value)));
}
}
std::vector<OrtValue> fetches;
fetches.reserve(output_names.size());
common::Status status;
{
// release GIL to allow multiple python threads to invoke Run() in parallel.
py::gil_scoped_release release;
if (run_options != nullptr) {
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(*run_options, feeds, output_names, &fetches));
} else {
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(feeds, output_names, &fetches));
}
}
py::list result;
size_t pos = 0;
for (const auto& fet : fetches) {
if (fet.IsAllocated()) {
if (fet.IsTensor()) {
result.append(AddTensorAsPyObj(fet, nullptr, nullptr));
} else if (fet.IsSparseTensor()) {
result.append(GetPyObjectFromSparseTensor(pos, fet, nullptr));
} else {
result.append(AddNonTensorAsPyObj(fet, nullptr, nullptr));
}
} else { // Send back None because the corresponding OrtValue was empty
result.append(py::none());
}
++pos;
}
return result;
})
.def("run_async",
[](PyInferenceSession* sess,
const std::vector<std::string>& output_names,
const std::map<std::string, py::object>& pyfeeds,
PyCallback callback, py::object user_data = {},
RunOptions* run_options = nullptr)
-> void {
if (run_options != nullptr && !run_options->active_adapters.empty()) {
LOGS(*sess->GetSessionHandle()->GetLogger(), WARNING)
<< "run_async has active adapters specified, but won't have an effect";
}
std::unique_ptr<AsyncResource> async_resource = std::make_unique<AsyncResource>();
async_resource->callback = callback;
async_resource->user_data = user_data;
// prepare feeds
async_resource->ReserveFeeds(pyfeeds.size());
for (const auto& feed : pyfeeds) {
if (!feed.second.is(py::none())) {
OrtValue ml_value;
auto px = sess->GetSessionHandle()->GetModelInputs();
if (!px.first.IsOK() || !px.second) {
throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null");
}
CreateGenericMLValue(px.second, GetAllocator(), feed.first, feed.second, &ml_value);
ThrowIfPyErrOccured();
async_resource->feeds.push_back(ml_value);
async_resource->feeds_raw.push_back(&async_resource->feeds.back());
async_resource->feed_names.push_back(feed.first);
async_resource->feed_names_raw.push_back(async_resource->feed_names.back().c_str());
}
}
// prepare fetches
async_resource->ReserveFetches(output_names.size());
for (const auto& output_name : output_names) {
async_resource->fetch_names.push_back(output_name);
async_resource->fetch_names_raw.push_back(async_resource->fetch_names.back().c_str());
async_resource->fetches_raw.push_back({});
}
const RunOptions* run_async_option = run_options ? run_options : &async_resource->default_run_option;
common::Status status = sess->GetSessionHandle()->RunAsync(run_async_option,
gsl::span(async_resource->feed_names_raw.data(), async_resource->feed_names_raw.size()),
gsl::span(async_resource->feeds_raw.data(), async_resource->feeds_raw.size()),
gsl::span(async_resource->fetch_names_raw.data(), async_resource->fetch_names_raw.size()),
gsl::span(async_resource->fetches_raw.data(), async_resource->fetches_raw.size()),
AsyncCallback,
async_resource.get());
if (status.IsOK()) {
async_resource.release();
}
OrtPybindThrowIfError(status);
})
/// This method accepts a dictionary of feeds (name -> OrtValue) and the list of output_names
/// and returns a list of python objects representing OrtValues. Each name may represent either
/// a Tensor, SparseTensor or a TensorSequence.
.def("run_with_ort_values", [](PyInferenceSession* sess, const py::dict& feeds, const std::vector<std::string>& output_names, RunOptions* run_options = nullptr) -> std::vector<OrtValue> {
NameMLValMap ort_feeds;
if (run_options != nullptr && !run_options->active_adapters.empty()) {
AppendLoraParametersAsInputs(*run_options, feeds.size(), ort_feeds);
} else {
ort_feeds.reserve(feeds.size());
}
// item is always a copy since dict returns a value and not a ref
// and Apple XToolChain barks
for (const auto& item : feeds) {
auto name = item.first.cast<std::string>();
const OrtValue* ort_value = item.second.cast<const OrtValue*>();
ort_feeds.emplace(name, *ort_value);
}
std::vector<OrtValue> fetches;
fetches.reserve(output_names.size());
{
// release GIL to allow multiple python threads to invoke Run() in parallel.
py::gil_scoped_release release;
if (run_options != nullptr) {
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(*run_options, ort_feeds, output_names, &fetches));
} else {
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(ort_feeds, output_names, &fetches));
}
}
return fetches;
})
.def("run_with_ortvaluevector", [](PyInferenceSession* sess, RunOptions run_options, const std::vector<std::string>& feed_names, const std::vector<OrtValue>& feeds, const std::vector<std::string>& fetch_names, std::vector<OrtValue>& fetches, const std::vector<OrtDevice>& fetch_devices) -> void {
if (!run_options.active_adapters.empty()) {
LOGS(*sess->GetSessionHandle()->GetLogger(), WARNING)
<< "run_with_ortvaluevector has active adapters specified, but won't have an effect";
}
// release GIL to allow multiple python threads to invoke Run() in parallel.
py::gil_scoped_release release;
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(run_options, feed_names, feeds, fetch_names, &fetches, &fetch_devices));
})
.def("end_profiling", [](const PyInferenceSession* sess) -> std::string {
return sess->GetSessionHandle()->EndProfiling();
})
.def_property_readonly("get_profiling_start_time_ns", [](const PyInferenceSession* sess) -> uint64_t {
return sess->GetSessionHandle()->GetProfiling().GetStartTimeNs();
})
.def("get_providers", [](const PyInferenceSession* sess) -> const std::vector<std::string>& { return sess->GetSessionHandle()->GetRegisteredProviderTypes(); }, py::return_value_policy::reference_internal)
.def("get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& { return sess->GetSessionHandle()->GetAllProviderOptions(); }, py::return_value_policy::reference_internal)
.def_property_readonly("session_options", [](const PyInferenceSession* sess) -> PySessionOptions* {
auto session_options = std::make_unique<PySessionOptions>();
session_options->value = sess->GetSessionHandle()->GetSessionOptions();
return session_options.release(); }, py::return_value_policy::take_ownership)
.def_property_readonly("inputs_meta", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
auto res = sess->GetSessionHandle()->GetModelInputs();
OrtPybindThrowIfError(res.first);
return *(res.second); }, py::return_value_policy::reference_internal)
.def_property_readonly("outputs_meta", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
auto res = sess->GetSessionHandle()->GetModelOutputs();
OrtPybindThrowIfError(res.first);
return *(res.second); }, py::return_value_policy::reference_internal)
.def_property_readonly("overridable_initializers", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
auto res = sess->GetSessionHandle()->GetOverridableInitializers();
OrtPybindThrowIfError(res.first);
return *(res.second); }, py::return_value_policy::reference_internal)
.def_property_readonly("model_meta", [](const PyInferenceSession* sess) -> const onnxruntime::ModelMetadata& {
auto res = sess->GetSessionHandle()->GetModelMetadata();
OrtPybindThrowIfError(res.first);
return *(res.second); }, py::return_value_policy::reference_internal)
.def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void {
Status status;
if (run_options != nullptr && !run_options->active_adapters.empty()) {
LOGS(*sess->GetSessionHandle()->GetLogger(), WARNING)
<< "run_with_iobinding has active adapters specified, but won't have an effect";
}
// release GIL to allow multiple python threads to invoke Run() in parallel.
py::gil_scoped_release release;
if (!run_options)
status = sess->GetSessionHandle()->Run(*io_binding.Get());
else
status = sess->GetSessionHandle()->Run(*run_options, *io_binding.Get());
if (!status.IsOK())
throw std::runtime_error("Error in execution: " + status.ErrorMessage()); })
.def("get_tuning_results", [](PyInferenceSession* sess) -> py::list {
#if !defined(ORT_MINIMAL_BUILD)
auto results = sess->GetSessionHandle()->GetTuningResults();
py::list ret;
for (const auto& trs : results) {
py::dict py_trs;
py_trs["ep"] = trs.ep;
py_trs["results"] = trs.results;
py_trs["validators"] = trs.validators;
ret.append(std::move(py_trs));
}
return ret;
#else
ORT_UNUSED_PARAMETER(sess);
ORT_THROW("TunableOp and get_tuning_results are not supported in this build.");
#endif
})
.def("set_tuning_results", [](PyInferenceSession* sess, py::list results, bool error_on_invalid) -> void {
#if !defined(ORT_MINIMAL_BUILD)
std::vector<TuningResults> tuning_results;
for (auto handle : results) {
auto py_trs = handle.cast<py::dict>();
TuningResults trs;
trs.ep = py_trs["ep"].cast<py::str>();
for (const auto [py_op_sig, py_kernel_map] : py_trs["results"].cast<py::dict>()) {
KernelMap kernel_map;
for (const auto [py_params_sig, py_kernel_id] : py_kernel_map.cast<py::dict>()) {
kernel_map[py_params_sig.cast<py::str>()] = py_kernel_id.cast<py::int_>();
}
trs.results[py_op_sig.cast<py::str>()] = kernel_map;
}
for (const auto [k, v] : py_trs["validators"].cast<py::dict>()) {
trs.validators[k.cast<py::str>()] = v.cast<py::str>();
}
tuning_results.emplace_back(std::move(trs));
}
Status status = sess->GetSessionHandle()->SetTuningResults(tuning_results, error_on_invalid);
if (!status.IsOK()) {
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
}
#else
ORT_UNUSED_PARAMETER(sess);
ORT_UNUSED_PARAMETER(results);
ORT_UNUSED_PARAMETER(error_on_invalid);
ORT_THROW("TunableOp and set_tuning_results are not supported in this build.");
#endif
});
py::enum_<onnxruntime::ArenaExtendStrategy>(m, "ArenaExtendStrategy", py::arithmetic())
.value("kNextPowerOfTwo", onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo)
.value("kSameAsRequested", onnxruntime::ArenaExtendStrategy::kSameAsRequested)
.export_values();
}
bool CreateInferencePybindStateModule(py::module& m) {
m.doc() = "pybind11 stateful interface to ONNX runtime";
RegisterExceptions(m);
import_array1(false);
auto env = GetEnv();
addGlobalMethods(m);
addObjectMethods(m, RegisterExecutionProviders);
addOrtValueMethods(m);
addSparseTensorMethods(m);
addIoBindingMethods(m);
addAdapterFormatMethods(m);
#if !defined(__APPLE__) && !defined(ORT_MINIMAL_BUILD)
if (!InitProvidersSharedLibrary()) {
const logging::Logger& default_logger = logging::LoggingManager::DefaultLogger();
LOGS(default_logger, WARNING) << "Init provider bridge failed.";
}
#endif
addGlobalSchemaFunctions(m);
addOpSchemaSubmodule(m);
addOpKernelSubmodule(m);
return true;
}
// This function is only used by orttraining module
bool InitArray() {
import_array1(false);
return true;
}
namespace {
// This class provides a static shell for on-demand and thread-safe construction
// of Environment object for both Inference and Training python layers.
// Environment class contains objects such as default logger, that must be available
// for the entire duration of a program that makes use of onnxruntime library.
// Because Python is a garbage collected language and the order of destruction of objects
// is not guaranteed we design this class with the following important features.
// 1) we make this class a singleton that is a function local static. The function local statics
// are constructed when the function is called the very first time. This fact has several important
// properties.
// - First, it is constructed before it is first needed possibly by another static object
// and destroyed after that object is destroyed.
// - Second, it is constructed in a thread safe manner.
// - Last, this order of construction/destruction is enforced across the compilation units, as opposed
// to the static objects that are simply declared in order in a single unit, but their lifespan is
// unconnected to that of in other compilation units. This is achieved automatically by run-time
// by execution atexit() to build a chain.
// 2) We make Environment owned by a shared_ptr. This is done because python objects such as Inference and Training
// sessions depend on this global. We acquire a shared_ptr instance when those objects are instantiated
// and release it automatically when they are garbage collected. Although with this change all of the
// globals seem to have been destroyed after module is unloaded and GC runs before that, it is cheap and gives
// a piece of mind as there were situations when GC was still running in the past after Env was gone.
// TrainingEnv global also holds shared reference to this global.
// 3) We guard against singleton resurrection attempts to detect code runs that when it should
// not and make necessary adjustments.
// For all the related details and why it is needed see "Modern C++ design" by A. Alexandrescu Chapter 6.
class EnvInitializer {
public:
static std::shared_ptr<onnxruntime::Environment> SharedInstance() {
// Guard against attempts to resurrect the singleton
if (EnvInitializer::destroyed) {
ORT_THROW("Detected an attempt to resurrect destroyed Environment");
}
static EnvInitializer env_holder;
return env_holder.Get();
}
private:
EnvInitializer() {
std::unique_ptr<Environment> env_ptr;
Env::Default().GetTelemetryProvider().SetLanguageProjection(OrtLanguageProjection::ORT_PROJECTION_PYTHON);
OrtPybindThrowIfError(Environment::Create(std::make_unique<LoggingManager>(
std::make_unique<CLogSink>(),
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
&SessionObjectInitializer::default_logger_id),
env_ptr));
session_env_ = std::shared_ptr<Environment>(env_ptr.release());
destroyed = false;
}
~EnvInitializer() {
destroyed = true;
}
std::shared_ptr<Environment> Get() const {
return session_env_;
}
std::shared_ptr<Environment> session_env_;
static bool destroyed;
};
bool EnvInitializer::destroyed = false;
} // namespace
std::shared_ptr<onnxruntime::Environment> GetEnv() {
return EnvInitializer::SharedInstance();
}
} // namespace python
} // namespace onnxruntime