mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-15 20:50:42 +00:00
- Revert https://github.com/microsoft/onnxruntime/pull/22681 - But still implicitly exclude DDS ops for TRT 10. Will later provide better PR to add trt_op_types_to_exclude provider option.
2409 lines
117 KiB
C++
2409 lines
117 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
|
// Licensed under the MIT License.
|
|
|
|
#include "python/onnxruntime_pybind_exceptions.h"
|
|
#include "python/onnxruntime_pybind_mlvalue.h"
|
|
#include "python/onnxruntime_pybind_state_common.h"
|
|
|
|
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
|
#define PY_ARRAY_UNIQUE_SYMBOL onnxruntime_python_ARRAY_API
|
|
#include "python/numpy_helper.h"
|
|
#include "core/common/inlined_containers.h"
|
|
#include "core/common/logging/logging.h"
|
|
#include "core/common/logging/severity.h"
|
|
#include "core/common/narrow.h"
|
|
#include "core/common/optional.h"
|
|
#include "core/common/path_string.h"
|
|
#include "core/framework/arena_extend_strategy.h"
|
|
#include "core/framework/data_transfer_utils.h"
|
|
#include "core/framework/data_types_internal.h"
|
|
#include "core/framework/provider_options_utils.h"
|
|
#include "core/framework/random_seed.h"
|
|
#include "core/framework/sparse_tensor.h"
|
|
#include "core/framework/tensorprotoutils.h"
|
|
#include "core/framework/TensorSeq.h"
|
|
#include "core/graph/graph_viewer.h"
|
|
#include "core/platform/env.h"
|
|
#include "core/providers/get_execution_providers.h"
|
|
#include "core/providers/tensorrt/tensorrt_provider_options.h"
|
|
#include "core/session/IOBinding.h"
|
|
#include "core/session/abi_session_options_impl.h"
|
|
#include "core/session/onnxruntime_session_options_config_keys.h"
|
|
#include "core/session/provider_bridge_ort.h"
|
|
|
|
#include "core/session/lora_adapters.h"
|
|
|
|
#ifdef ENABLE_ATEN
|
|
#include "contrib_ops/cpu/aten_ops/aten_op_executor.h"
|
|
#endif
|
|
|
|
#ifdef USE_CUDA
|
|
#include <cuda.h> // for CUDA_VERSION
|
|
#include <cudnn.h> // for CUDNN_MAJOR
|
|
#endif
|
|
|
|
#if defined(USE_COREML)
|
|
#include "core/providers/coreml/coreml_provider_factory.h"
|
|
#endif
|
|
|
|
#include <pybind11/functional.h>
|
|
|
|
// Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct,
|
|
// GCC 4.x doesn't seem to define this and it breaks the pipelines based on CentOS as it uses
|
|
// GCC 4.x.
|
|
// (This static var is referenced in GetCudaToHostMemCpyFunction())
|
|
const OrtDevice::DeviceType OrtDevice::GPU;
|
|
|
|
#if defined(_MSC_VER)
|
|
#pragma warning(disable : 4267 4996 4503)
|
|
#endif // _MSC_VER
|
|
|
|
#include <iterator>
|
|
#include <algorithm>
|
|
|
|
namespace onnxruntime {
|
|
namespace python {
|
|
|
|
namespace py = pybind11;
|
|
using namespace onnxruntime;
|
|
using namespace onnxruntime::logging;
|
|
|
|
#if defined(_MSC_VER) && !defined(__clang__)
|
|
#pragma warning(push)
|
|
// "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
|
|
// TODO: we may delay-init this variable
|
|
#pragma warning(disable : 26426)
|
|
#endif
|
|
static Env& platform_env = Env::Default();
|
|
#if defined(_MSC_VER) && !defined(__clang__)
|
|
#pragma warning(push)
|
|
#endif
|
|
|
|
using PyCallback = std::function<void(std::vector<py::object>, py::object user_data, std::string)>;
|
|
|
|
struct AsyncResource {
|
|
std::vector<OrtValue> feeds;
|
|
std::vector<const OrtValue*> feeds_raw;
|
|
|
|
std::vector<std::string> feed_names;
|
|
std::vector<const char*> feed_names_raw;
|
|
|
|
std::vector<OrtValue*> fetches_raw; // will be released during destruction
|
|
|
|
std::vector<std::string> fetch_names;
|
|
std::vector<const char*> fetch_names_raw;
|
|
|
|
RunOptions default_run_option;
|
|
PyCallback callback;
|
|
py::object user_data;
|
|
|
|
void ReserveFeeds(size_t sz) {
|
|
feeds.reserve(sz);
|
|
feeds_raw.reserve(sz);
|
|
feed_names.reserve(sz);
|
|
feed_names_raw.reserve(sz);
|
|
}
|
|
|
|
void ReserveFetches(size_t sz) {
|
|
fetches_raw.reserve(sz);
|
|
fetch_names.reserve(sz);
|
|
fetch_names_raw.reserve(sz);
|
|
}
|
|
|
|
~AsyncResource() {
|
|
std::for_each(fetches_raw.begin(), fetches_raw.end(), [](const OrtValue* fetch) {
|
|
if (fetch) {
|
|
std::unique_ptr<const OrtValue> fetch_recycler(fetch);
|
|
}
|
|
});
|
|
fetches_raw.clear();
|
|
}
|
|
};
|
|
|
|
void AsyncCallback(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr ort_status) {
|
|
ORT_ENFORCE(user_data, "user data must not be NULL for callback in python");
|
|
|
|
auto invoke_callback = [&]() {
|
|
std::unique_ptr<AsyncResource> async_resource{reinterpret_cast<AsyncResource*>(user_data)};
|
|
Ort::Status status(ort_status);
|
|
|
|
// return on error
|
|
if (!status.IsOK()) {
|
|
async_resource->callback({}, async_resource->user_data, status.GetErrorMessage());
|
|
return;
|
|
}
|
|
|
|
std::vector<py::object> rfetch;
|
|
rfetch.reserve(num_outputs);
|
|
size_t pos = 0;
|
|
for (size_t ith = 0; ith < num_outputs; ++ith) {
|
|
const auto& fet = *outputs[ith];
|
|
if (fet.IsAllocated()) {
|
|
if (fet.IsTensor()) {
|
|
rfetch.push_back(AddTensorAsPyObj(fet, nullptr, nullptr));
|
|
} else if (fet.IsSparseTensor()) {
|
|
rfetch.push_back(GetPyObjectFromSparseTensor(pos, fet, nullptr));
|
|
} else {
|
|
rfetch.push_back(AddNonTensorAsPyObj(fet, nullptr, nullptr));
|
|
}
|
|
} else {
|
|
rfetch.push_back(py::none());
|
|
}
|
|
++pos;
|
|
}
|
|
async_resource->callback(rfetch, async_resource->user_data, "");
|
|
};
|
|
|
|
if (PyGILState_Check()) {
|
|
invoke_callback();
|
|
} else {
|
|
// acquire GIL to safely:
|
|
// 1) invoke python callback
|
|
// 2) create, manipulate, and destroy python objects
|
|
py::gil_scoped_acquire acquire;
|
|
invoke_callback();
|
|
}
|
|
}
|
|
|
|
void AppendLoraParametersAsInputs(const RunOptions& run_options,
|
|
size_t total_entries,
|
|
NameMLValMap& feeds) {
|
|
for (const auto* adapter : run_options.active_adapters) {
|
|
total_entries += adapter->GetParamNum();
|
|
}
|
|
feeds.reserve(total_entries + feeds.size());
|
|
|
|
// Append necessary inputs for active adapters
|
|
for (const auto* adapter : run_options.active_adapters) {
|
|
auto [begin, end] = adapter->GetParamIterators();
|
|
for (; begin != end; ++begin) {
|
|
const auto& [name, param] = *begin;
|
|
feeds.insert(std::make_pair(name, param.GetMapped()));
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
static py::object AddNonTensor(const OrtValue& val,
|
|
const DataTransferManager* /*data_transfer_manager*/,
|
|
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* /*mem_cpy_to_host_functions*/) {
|
|
return py::cast(val.Get<T>());
|
|
}
|
|
|
|
// This function is used to return strings from a string tensor to python
|
|
// as a numpy array of strings
|
|
// Strings are always on CPU and must always be copied to python memory
|
|
py::array StringTensorToNumpyArray(const Tensor& tensor) {
|
|
// Create the result and allocate memory with the right size
|
|
py::array result(py::dtype(NPY_OBJECT), tensor.Shape().GetDims());
|
|
const auto span = tensor.DataAsSpan<std::string>();
|
|
auto* mutable_data = reinterpret_cast<py::object*>(result.mutable_data());
|
|
for (size_t i = 0, lim = span.size(); i < lim; ++i) {
|
|
mutable_data[i] = py::cast(span[i]);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
pybind11::array PrimitiveTensorToNumpyOverOrtValue(const OrtValue& ort_value) {
|
|
const Tensor& tensor = ort_value.Get<Tensor>();
|
|
// The capsule destructor must be stateless
|
|
// We create a copy of OrtValue on the heap.
|
|
auto memory_release = [](void* data) {
|
|
auto* ort_value = reinterpret_cast<OrtValue*>(data);
|
|
delete ort_value;
|
|
};
|
|
|
|
const int numpy_type = OnnxRuntimeTensorToNumpyType(tensor.DataType());
|
|
auto ort_value_ptr = std::make_unique<OrtValue>(ort_value);
|
|
pybind11::capsule caps(ort_value_ptr.get(), memory_release);
|
|
ort_value_ptr.release();
|
|
|
|
// Not using array_t<T> because it may not handle MLFloat16 properly
|
|
pybind11::array result(py::dtype(numpy_type), tensor.Shape().GetDims(),
|
|
tensor.DataRaw(),
|
|
caps);
|
|
return result;
|
|
}
|
|
|
|
pybind11::array PrimitiveTensorToNumpyFromDevice(const OrtValue& ort_value, const DataTransferAlternative& dtm) {
|
|
const Tensor& tensor = ort_value.Get<Tensor>();
|
|
const int numpy_type = OnnxRuntimeTensorToNumpyType(tensor.DataType());
|
|
pybind11::array result(py::dtype(numpy_type), tensor.Shape().GetDims());
|
|
void* data = result.mutable_data();
|
|
|
|
if (std::holds_alternative<const DataTransferManager*>(dtm)) {
|
|
const DataTransferManager* data_transfer = std::get<const DataTransferManager*>(dtm);
|
|
static const OrtMemoryInfo cpu_alloc_info{onnxruntime::CPU, OrtDeviceAllocator};
|
|
const auto span = gsl::make_span<char>(reinterpret_cast<char*>(data), tensor.SizeInBytes());
|
|
ORT_THROW_IF_ERROR(CopyTensorDataToByteSpan(*data_transfer, tensor, cpu_alloc_info, span));
|
|
} else {
|
|
std::get<MemCpyFunc>(dtm)(data, tensor.DataRaw(), tensor.SizeInBytes());
|
|
}
|
|
return result;
|
|
}
|
|
|
|
// In all cases, we may not have access to a DataTransferManager, hence the user may specify functions that
|
|
// pretty much does what a DataTransferManager does - copy data from device(s) to the host
|
|
py::object GetPyObjFromTensor(const OrtValue& ort_value,
|
|
const DataTransferManager* data_transfer_manager,
|
|
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions) {
|
|
ORT_ENFORCE(ort_value.IsTensor(), "This function only supports tensors");
|
|
|
|
const auto& tensor = ort_value.Get<Tensor>();
|
|
if (tensor.IsDataTypeString()) {
|
|
ORT_ENFORCE(tensor.Location().device.Type() == OrtDevice::CPU, "Strings can only be on CPU");
|
|
// Create a numpy array of strings (python objects) by copy/converting them
|
|
py::array result = StringTensorToNumpyArray(tensor);
|
|
return py::cast<py::object>(result);
|
|
}
|
|
|
|
const auto device_type = tensor.Location().device.Type();
|
|
// Create an numpy array on top of the OrtValue memory, no copy
|
|
if (device_type == OrtDevice::CPU) {
|
|
py::array result = PrimitiveTensorToNumpyOverOrtValue(ort_value);
|
|
return py::cast<py::object>(result);
|
|
}
|
|
|
|
if (!data_transfer_manager && !mem_cpy_to_host_functions) {
|
|
throw std::runtime_error(
|
|
"GetPyObjFromTensor: Either data transfer manager or a "
|
|
"function to copy data to the host is needed to convert non-CPU tensor to numpy array");
|
|
}
|
|
|
|
py::array result;
|
|
if (data_transfer_manager != nullptr) {
|
|
result = PrimitiveTensorToNumpyFromDevice(ort_value, data_transfer_manager);
|
|
} else {
|
|
auto mem_cpy_to_host = mem_cpy_to_host_functions->find(device_type);
|
|
ORT_ENFORCE(mem_cpy_to_host != mem_cpy_to_host_functions->end(),
|
|
"Unable to locate a function that can copy data to the host from the device");
|
|
result = PrimitiveTensorToNumpyFromDevice(ort_value, mem_cpy_to_host->second);
|
|
}
|
|
return py::cast<py::object>(result);
|
|
}
|
|
|
|
const char* GetDeviceName(const OrtDevice& device) {
|
|
switch (device.Type()) {
|
|
case OrtDevice::CPU:
|
|
return CPU;
|
|
case OrtDevice::GPU:
|
|
return CUDA;
|
|
case OrtDevice::DML:
|
|
return DML;
|
|
case OrtDevice::FPGA:
|
|
return "FPGA";
|
|
case OrtDevice::NPU:
|
|
#ifdef USE_CANN
|
|
return CANN;
|
|
#else
|
|
return "NPU";
|
|
#endif
|
|
default:
|
|
ORT_THROW("Unknown device type: ", device.Type());
|
|
}
|
|
}
|
|
|
|
py::object GetPyObjectFromSparseTensor(size_t pos, const OrtValue& ort_value, const DataTransferManager* data_transfer_manager) {
|
|
#if !defined(DISABLE_SPARSE_TENSORS)
|
|
if (!ort_value.IsSparseTensor()) {
|
|
ORT_THROW("Must be a sparse tensor");
|
|
}
|
|
auto& logger = logging::LoggingManager::DefaultLogger();
|
|
const SparseTensor& src_sparse_tensor = ort_value.Get<SparseTensor>();
|
|
std::unique_ptr<PySparseTensor> py_sparse_tensor;
|
|
auto device_type = src_sparse_tensor.Location().device.Type();
|
|
if (device_type != OrtDevice::CPU) {
|
|
if (!data_transfer_manager) {
|
|
LOGS(logger, WARNING) << "Returned OrtValue with sparse tensor at position: " << pos << " is on GPU but no data_transfer_manager provided."
|
|
<< " Returned it will have its data on GPU, you can copy it using numpy_array_to_cpu()";
|
|
py_sparse_tensor = std::make_unique<PySparseTensor>(ort_value);
|
|
} else {
|
|
auto dst_sparse_tensor = std::make_unique<SparseTensor>(src_sparse_tensor.DataType(), src_sparse_tensor.DenseShape(), GetAllocator());
|
|
auto status = src_sparse_tensor.Copy(*data_transfer_manager, *dst_sparse_tensor);
|
|
OrtPybindThrowIfError(status);
|
|
py_sparse_tensor = std::make_unique<PySparseTensor>(std::move(dst_sparse_tensor));
|
|
}
|
|
} else {
|
|
py_sparse_tensor = std::make_unique<PySparseTensor>(ort_value);
|
|
}
|
|
|
|
py::object result = py::cast(py_sparse_tensor.get(), py::return_value_policy::take_ownership);
|
|
py_sparse_tensor.release();
|
|
return result;
|
|
#else
|
|
ORT_UNUSED_PARAMETER(pos);
|
|
ORT_UNUSED_PARAMETER(ort_value);
|
|
ORT_UNUSED_PARAMETER(data_transfer_manager);
|
|
ORT_THROW("SparseTensor support is disabled in this build.");
|
|
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
|
}
|
|
|
|
template <>
|
|
py::object AddNonTensor<TensorSeq>(const OrtValue& val,
|
|
const DataTransferManager* data_transfer_manager,
|
|
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions) {
|
|
const auto& seq_tensors = val.Get<TensorSeq>();
|
|
py::list py_list;
|
|
for (const auto& ort_value : seq_tensors) {
|
|
py::object obj = GetPyObjFromTensor(ort_value, data_transfer_manager, mem_cpy_to_host_functions);
|
|
py_list.append(std::move(obj));
|
|
}
|
|
// XToolChain kills the build
|
|
// local variable 'py_list' will be copied despite being returned by name [-Werror,-Wreturn-std-move]
|
|
// call 'std::move' explicitly to avoid copying
|
|
// We choose to cast it to object explicitly
|
|
return py::cast<py::object>(py_list);
|
|
}
|
|
|
|
py::object AddNonTensorAsPyObj(const OrtValue& val,
|
|
const DataTransferManager* data_transfer_manager,
|
|
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions) {
|
|
// Should be in sync with core/framework/datatypes.h
|
|
auto val_type = val.Type();
|
|
if (val_type->IsTensorSequenceType()) {
|
|
return AddNonTensor<TensorSeq>(val, data_transfer_manager, mem_cpy_to_host_functions);
|
|
} else {
|
|
#if !defined(DISABLE_ML_OPS)
|
|
utils::ContainerChecker c_checker(val_type);
|
|
if (c_checker.IsMap()) {
|
|
if (c_checker.IsMapOf<std::string, std::string>()) {
|
|
return AddNonTensor<MapStringToString>(val, data_transfer_manager, mem_cpy_to_host_functions);
|
|
} else if (c_checker.IsMapOf<std::string, int64_t>()) {
|
|
return AddNonTensor<MapStringToInt64>(val, data_transfer_manager, mem_cpy_to_host_functions);
|
|
} else if (c_checker.IsMapOf<std::string, float>()) {
|
|
return AddNonTensor<MapStringToFloat>(val, data_transfer_manager, mem_cpy_to_host_functions);
|
|
} else if (c_checker.IsMapOf<std::string, double>()) {
|
|
return AddNonTensor<MapStringToDouble>(val, data_transfer_manager, mem_cpy_to_host_functions);
|
|
} else if (c_checker.IsMapOf<int64_t, std::string>()) {
|
|
return AddNonTensor<MapInt64ToString>(val, data_transfer_manager, mem_cpy_to_host_functions);
|
|
} else if (c_checker.IsMapOf<int64_t, int64_t>()) {
|
|
return AddNonTensor<MapInt64ToInt64>(val, data_transfer_manager, mem_cpy_to_host_functions);
|
|
} else if (c_checker.IsMapOf<int64_t, float>()) {
|
|
return AddNonTensor<MapInt64ToFloat>(val, data_transfer_manager, mem_cpy_to_host_functions);
|
|
} else if (c_checker.IsMapOf<int64_t, double>()) {
|
|
return AddNonTensor<MapInt64ToDouble>(val, data_transfer_manager, mem_cpy_to_host_functions);
|
|
}
|
|
|
|
} else {
|
|
if (c_checker.IsSequenceOf<std::map<std::string, float>>()) {
|
|
return AddNonTensor<VectorMapStringToFloat>(val, data_transfer_manager, mem_cpy_to_host_functions);
|
|
} else if (c_checker.IsSequenceOf<std::map<int64_t, float>>()) {
|
|
return AddNonTensor<VectorMapInt64ToFloat>(val, data_transfer_manager, mem_cpy_to_host_functions);
|
|
}
|
|
}
|
|
#endif
|
|
}
|
|
ORT_THROW("Non-tensor type is not supported in this build: ", val_type);
|
|
}
|
|
|
|
py::object AddTensorAsPyObj(const OrtValue& val, const DataTransferManager* data_transfer_manager,
|
|
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions) {
|
|
return GetPyObjFromTensor(val, data_transfer_manager, mem_cpy_to_host_functions);
|
|
}
|
|
|
|
static std::unique_ptr<onnxruntime::IExecutionProvider> LoadExecutionProvider(
|
|
const std::string& ep_shared_lib_path,
|
|
const ProviderOptions& provider_options = {},
|
|
const std::string& entry_symbol_name = "GetProvider") {
|
|
void* handle;
|
|
const auto path_str = ToPathString(ep_shared_lib_path);
|
|
auto error = Env::Default().LoadDynamicLibrary(path_str, false, &handle);
|
|
if (!error.IsOK()) {
|
|
throw std::runtime_error(error.ErrorMessage());
|
|
}
|
|
|
|
Provider* (*PGetProvider)();
|
|
OrtPybindThrowIfError(Env::Default().GetSymbolFromLibrary(handle, entry_symbol_name, (void**)&PGetProvider));
|
|
|
|
Provider* provider = PGetProvider();
|
|
std::shared_ptr<IExecutionProviderFactory> ep_factory = provider->CreateExecutionProviderFactory(&provider_options);
|
|
return ep_factory->CreateProvider();
|
|
}
|
|
|
|
#ifdef USE_CUDA
|
|
const CUDAExecutionProviderInfo GetCudaExecutionProviderInfo(ProviderInfo_CUDA* cuda_provider_info,
|
|
const ProviderOptionsMap& provider_options_map) {
|
|
ORT_ENFORCE(cuda_provider_info);
|
|
const auto it = provider_options_map.find(kCudaExecutionProvider);
|
|
CUDAExecutionProviderInfo info;
|
|
if (it != provider_options_map.end())
|
|
cuda_provider_info->CUDAExecutionProviderInfo__FromProviderOptions(it->second, info);
|
|
else {
|
|
info.device_id = cuda_device_id;
|
|
info.gpu_mem_limit = gpu_mem_limit;
|
|
info.arena_extend_strategy = arena_extend_strategy;
|
|
info.cudnn_conv_algo_search = cudnn_conv_algo_search;
|
|
info.do_copy_in_default_stream = do_copy_in_default_stream;
|
|
info.external_allocator_info = external_allocator_info;
|
|
info.tunable_op = tunable_op;
|
|
}
|
|
return info;
|
|
}
|
|
#endif
|
|
|
|
#ifdef USE_CANN
|
|
const CANNExecutionProviderInfo GetCannExecutionProviderInfo(ProviderInfo_CANN* cann_provider_info,
|
|
const ProviderOptionsMap& provider_options_map) {
|
|
ORT_ENFORCE(cann_provider_info);
|
|
const auto it = provider_options_map.find(kCannExecutionProvider);
|
|
CANNExecutionProviderInfo info;
|
|
if (it != provider_options_map.end())
|
|
cann_provider_info->CANNExecutionProviderInfo__FromProviderOptions(it->second, info);
|
|
return info;
|
|
}
|
|
#endif
|
|
|
|
#ifdef USE_ROCM
|
|
const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM* rocm_provider_info,
|
|
const ProviderOptionsMap& provider_options_map) {
|
|
ORT_ENFORCE(rocm_provider_info);
|
|
const auto it = provider_options_map.find(kRocmExecutionProvider);
|
|
ROCMExecutionProviderInfo info;
|
|
if (it != provider_options_map.end())
|
|
rocm_provider_info->ROCMExecutionProviderInfo__FromProviderOptions(it->second, info);
|
|
else {
|
|
info.device_id = cuda_device_id;
|
|
info.gpu_mem_limit = gpu_mem_limit;
|
|
info.arena_extend_strategy = arena_extend_strategy;
|
|
info.miopen_conv_exhaustive_search = miopen_conv_exhaustive_search;
|
|
info.do_copy_in_default_stream = do_copy_in_default_stream;
|
|
info.external_allocator_info = external_allocator_info;
|
|
info.tunable_op = tunable_op;
|
|
}
|
|
return info;
|
|
}
|
|
#endif
|
|
|
|
#ifdef USE_TENSORRT
|
|
void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOptions& options) {
|
|
if (auto* tensorrt_provider_info = TryGetProviderInfo_TensorRT()) {
|
|
auto is_already_in_domains = [&](std::string& domain_name, std::vector<OrtCustomOpDomain*>& domains) {
|
|
for (auto ptr : domains) {
|
|
if (domain_name == ptr->domain_) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
};
|
|
|
|
std::string trt_extra_plugin_lib_paths = "";
|
|
const auto it = options.find("trt_extra_plugin_lib_paths");
|
|
if (it != options.end()) {
|
|
trt_extra_plugin_lib_paths = it->second;
|
|
}
|
|
std::vector<OrtCustomOpDomain*> custom_op_domains;
|
|
tensorrt_provider_info->GetTensorRTCustomOpDomainList(custom_op_domains, trt_extra_plugin_lib_paths);
|
|
for (auto ptr : custom_op_domains) {
|
|
if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) {
|
|
so.custom_op_domains_.push_back(ptr);
|
|
} else {
|
|
LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option.";
|
|
}
|
|
}
|
|
} else {
|
|
ORT_THROW("Please install TensorRT libraries as mentioned in the GPU requirements page, make sure they're in the PATH or LD_LIBRARY_PATH, and that your GPU is supported.");
|
|
}
|
|
}
|
|
#endif
|
|
|
|
std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
|
|
const SessionOptions& session_options,
|
|
const std::string& type,
|
|
const ProviderOptionsMap& provider_options_map) {
|
|
if (type == kCpuExecutionProvider) {
|
|
return onnxruntime::CPUProviderFactoryCreator::Create(
|
|
session_options.enable_cpu_mem_arena)
|
|
->CreateProvider();
|
|
} else if (type == kTensorrtExecutionProvider) {
|
|
#ifdef USE_TENSORRT
|
|
// If the environment variable 'ORT_TENSORRT_UNAVAILABLE' exists, then we do not load TensorRT. This is set by _ld_preload for the manylinux case
|
|
// as in that case, trying to load the library itself will result in a crash due to the way that auditwheel strips dependencies.
|
|
if (Env::Default().GetEnvironmentVar("ORT_TENSORRT_UNAVAILABLE").empty()) {
|
|
// provider_options_map is just a reference to the ProviderOptionsMap instance, so it can be released anytime from application.
|
|
// So we need these std::string variables defined here as they will be kept alive for the lifetime of TRT EP and we can still access them from OrtTensorRTProviderOptionsV2 instance.
|
|
// (The reason is string copy is involved, for example params.trt_engine_cache_path = cache_path.c_str() and those std::string variable is referenced by OrtTensorRTProviderOptionsV2 instance
|
|
// and TRT EP instance, so it won't be released.)
|
|
std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources,
|
|
trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile, ep_context_file_path,
|
|
onnx_model_folder_path;
|
|
auto it = provider_options_map.find(type);
|
|
if (it != provider_options_map.end()) {
|
|
OrtTensorRTProviderOptionsV2 params;
|
|
for (auto option : it->second) {
|
|
if (option.first == "device_id") {
|
|
if (!option.second.empty()) {
|
|
params.device_id = std::stoi(option.second);
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'device_id' should be a number i.e. '0'.\n");
|
|
}
|
|
} else if (option.first == "user_compute_stream") {
|
|
if (!option.second.empty()) {
|
|
auto stream = std::stoull(option.second, nullptr, 0);
|
|
params.user_compute_stream = reinterpret_cast<void*>(stream);
|
|
params.has_user_compute_stream = true;
|
|
} else {
|
|
params.has_user_compute_stream = false;
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'user_compute_stream' should be a string to define the compute stream for the inference to run on.\n");
|
|
}
|
|
} else if (option.first == "trt_max_partition_iterations") {
|
|
if (!option.second.empty()) {
|
|
params.trt_max_partition_iterations = std::stoi(option.second);
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_max_partition_iterations' should be a positive integer number i.e. '1000'.\n");
|
|
}
|
|
} else if (option.first == "trt_min_subgraph_size") {
|
|
if (!option.second.empty()) {
|
|
params.trt_min_subgraph_size = std::stoi(option.second);
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_min_subgraph_size' should be a positive integer number i.e. '1'.\n");
|
|
}
|
|
} else if (option.first == "trt_max_workspace_size") {
|
|
if (!option.second.empty()) {
|
|
params.trt_max_workspace_size = std::stoull(option.second);
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_max_workspace_size' should be a number in byte i.e. '1073741824'.\n");
|
|
}
|
|
} else if (option.first == "trt_fp16_enable") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_fp16_enable = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_fp16_enable = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_fp16_enable' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_int8_enable") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_int8_enable = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_int8_enable = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_int8_enable' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_int8_calibration_table_name") {
|
|
if (!option.second.empty()) {
|
|
calibration_table = option.second;
|
|
params.trt_int8_calibration_table_name = calibration_table.c_str();
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_int8_calibration_table_name' should be a file name i.e. 'cal_table'.\n");
|
|
}
|
|
} else if (option.first == "trt_int8_use_native_calibration_table") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_int8_use_native_calibration_table = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_int8_use_native_calibration_table = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_int8_use_native_calibration_table' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_dla_enable") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_dla_enable = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_dla_enable = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dla_enable' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_dla_core") {
|
|
if (!option.second.empty()) {
|
|
params.trt_dla_core = std::stoi(option.second);
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dla_core' should be a positive integer number i.e. '0'.\n");
|
|
}
|
|
} else if (option.first == "trt_dump_subgraphs") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_dump_subgraphs = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_dump_subgraphs = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dump_subgraphs' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_engine_cache_enable") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_engine_cache_enable = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_engine_cache_enable = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_enable' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_engine_cache_path") {
|
|
if (!option.second.empty()) {
|
|
cache_path = option.second;
|
|
params.trt_engine_cache_path = cache_path.c_str();
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_path' should be a path string i.e. 'engine_cache'.\n");
|
|
}
|
|
} else if (option.first == "trt_engine_cache_prefix") {
|
|
if (!option.second.empty()) {
|
|
cache_prefix = option.second;
|
|
params.trt_engine_cache_prefix = cache_prefix.c_str();
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_prefix' should be a string to customize engine cache prefix i.e. 'FRCNN' or 'yolov4'.\n");
|
|
}
|
|
} else if (option.first == "trt_weight_stripped_engine_enable") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_weight_stripped_engine_enable = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_weight_stripped_engine_enable = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_weight_stripped_engine_enable' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_onnx_model_folder_path") {
|
|
if (!option.second.empty()) {
|
|
onnx_model_folder_path = option.second;
|
|
params.trt_onnx_model_folder_path = onnx_model_folder_path.c_str();
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_onnx_model_folder_path' should be a path string i.e. 'engine_cache'.\n");
|
|
}
|
|
} else if (option.first == "trt_engine_decryption_enable") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_engine_decryption_enable = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_engine_decryption_enable = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_decryption_enable' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_engine_decryption_lib_path") {
|
|
if (!option.second.empty()) {
|
|
lib_path = option.second;
|
|
params.trt_engine_decryption_lib_path = lib_path.c_str();
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_decryption_lib_path' should be a path string i.e. 'decryption_lib'.\n");
|
|
}
|
|
} else if (option.first == "trt_force_sequential_engine_build") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_force_sequential_engine_build = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_force_sequential_engine_build = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_force_sequential_engine_build' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_context_memory_sharing_enable") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_context_memory_sharing_enable = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_context_memory_sharing_enable = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_context_memory_sharing_enable' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_layer_norm_fp32_fallback") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_layer_norm_fp32_fallback = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_layer_norm_fp32_fallback = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_layer_norm_fp32_fallback' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_timing_cache_enable") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_timing_cache_enable = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_timing_cache_enable = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_timing_cache_enable' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_timing_cache_path") {
|
|
if (!option.second.empty()) {
|
|
timing_cache_path = option.second;
|
|
params.trt_timing_cache_path = timing_cache_path.c_str();
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_timing_cache_path' should be a path string i.e. 'cache_folder/'.\n");
|
|
}
|
|
} else if (option.first == "trt_force_timing_cache") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_force_timing_cache = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_force_timing_cache = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_force_timing_cache' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_detailed_build_log") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_detailed_build_log = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_detailed_build_log = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_detailed_build_log' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_build_heuristics_enable") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_build_heuristics_enable = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_build_heuristics_enable = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_build_heuristics_enable' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_sparsity_enable") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_sparsity_enable = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_sparsity_enable = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_sparsity_enable' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_builder_optimization_level") {
|
|
if (!option.second.empty()) {
|
|
params.trt_builder_optimization_level = std::stoi(option.second);
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_builder_optimization_level' should be a number i.e. '0'.\n");
|
|
}
|
|
} else if (option.first == "trt_auxiliary_streams") {
|
|
if (!option.second.empty()) {
|
|
params.trt_auxiliary_streams = std::stoi(option.second);
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_auxiliary_streams' should be a number i.e. '0'.\n");
|
|
}
|
|
} else if (option.first == "trt_tactic_sources") {
|
|
if (!option.second.empty()) {
|
|
trt_tactic_sources = option.second;
|
|
params.trt_tactic_sources = trt_tactic_sources.c_str();
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_tactic_sources' should be a string. e.g. \"-CUDNN,+CUBLAS\" available keys: \"CUBLAS\"|\"CUBLAS_LT\"|\"CUDNN\"|\"EDGE_MASK_CONVOLUTIONS\".\n");
|
|
}
|
|
} else if (option.first == "trt_extra_plugin_lib_paths") {
|
|
if (!option.second.empty()) {
|
|
trt_extra_plugin_lib_paths = option.second;
|
|
params.trt_extra_plugin_lib_paths = trt_extra_plugin_lib_paths.c_str();
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_extra_plugin_lib_paths' should be a path string.\n");
|
|
}
|
|
} else if (option.first == "trt_profile_min_shapes") {
|
|
if (!option.second.empty()) {
|
|
min_profile = option.second;
|
|
params.trt_profile_min_shapes = min_profile.c_str();
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_min_shapes' should be a string of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'.\n");
|
|
}
|
|
} else if (option.first == "trt_profile_max_shapes") {
|
|
if (!option.second.empty()) {
|
|
max_profile = option.second;
|
|
params.trt_profile_max_shapes = max_profile.c_str();
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_max_shapes' should be a string of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'.\n");
|
|
}
|
|
} else if (option.first == "trt_profile_opt_shapes") {
|
|
if (!option.second.empty()) {
|
|
opt_profile = option.second;
|
|
params.trt_profile_opt_shapes = opt_profile.c_str();
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_opt_shapes' should be a string of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'.\n");
|
|
}
|
|
} else if (option.first == "trt_cuda_graph_enable") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_cuda_graph_enable = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_cuda_graph_enable = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_cuda_graph_enable' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_dump_ep_context_model") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_dump_ep_context_model = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_dump_ep_context_model = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dump_ep_context_model' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "trt_ep_context_file_path") {
|
|
if (!option.second.empty()) {
|
|
ep_context_file_path = option.second;
|
|
params.trt_ep_context_file_path = ep_context_file_path.c_str();
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_ep_context_file_path' should be a string.\n");
|
|
}
|
|
} else if (option.first == "trt_ep_context_embed_mode") {
|
|
if (!option.second.empty()) {
|
|
params.trt_ep_context_embed_mode = std::stoi(option.second);
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_ep_context_embed_mode' should be a positive integer number i.e. '1'.\n");
|
|
}
|
|
} else if (option.first == "trt_engine_hw_compatible") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.trt_engine_hw_compatible = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.trt_engine_hw_compatible = false;
|
|
} else {
|
|
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_hw_compatible' should be 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else {
|
|
ORT_THROW("Invalid TensorRT EP option: ", option.first);
|
|
}
|
|
}
|
|
if (std::shared_ptr<IExecutionProviderFactory> tensorrt_provider_factory = onnxruntime::TensorrtProviderFactoryCreator::Create(¶ms)) {
|
|
return tensorrt_provider_factory->CreateProvider();
|
|
}
|
|
} else {
|
|
if (std::shared_ptr<IExecutionProviderFactory> tensorrt_provider_factory = onnxruntime::TensorrtProviderFactoryCreator::Create(cuda_device_id)) {
|
|
return tensorrt_provider_factory->CreateProvider();
|
|
}
|
|
}
|
|
}
|
|
LOGS_DEFAULT(WARNING) << "Failed to create "
|
|
<< type
|
|
<< ". Please reference "
|
|
<< "https://onnxruntime.ai/docs/execution-providers/"
|
|
<< "TensorRT-ExecutionProvider.html#requirements to ensure all dependencies are met.";
|
|
#endif
|
|
} else if (type == kMIGraphXExecutionProvider) {
|
|
#ifdef USE_MIGRAPHX
|
|
std::string calibration_table;
|
|
std::string save_model_path;
|
|
std::string load_model_path;
|
|
auto it = provider_options_map.find(type);
|
|
if (it != provider_options_map.end()) {
|
|
OrtMIGraphXProviderOptions params{
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
nullptr,
|
|
1,
|
|
"./compiled_model.mxr",
|
|
1,
|
|
"./compiled_model.mxr",
|
|
1};
|
|
for (auto option : it->second) {
|
|
if (option.first == "device_id") {
|
|
if (!option.second.empty()) {
|
|
params.device_id = std::stoi(option.second);
|
|
} else {
|
|
ORT_THROW("[ERROR] [MIGraphX] The value for the key 'device_id' should be a number i.e. '0'.\n");
|
|
}
|
|
} else if (option.first == "migraphx_fp16_enable") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.migraphx_fp16_enable = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.migraphx_fp16_enable = false;
|
|
} else {
|
|
ORT_THROW(
|
|
"[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be"
|
|
" 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "migraphx_int8_enable") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.migraphx_int8_enable = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.migraphx_int8_enable = false;
|
|
} else {
|
|
ORT_THROW(
|
|
"[ERROR] [MIGraphX] The value for the key 'migx_int8_enable' should be"
|
|
" 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "migraphx_int8_calibration_table_name") {
|
|
if (!option.second.empty()) {
|
|
calibration_table = option.second;
|
|
params.migraphx_int8_calibration_table_name = calibration_table.c_str();
|
|
} else {
|
|
ORT_THROW(
|
|
"[ERROR] [MIGraphX] The value for the key 'migx_int8_calibration_table_name' should be a "
|
|
"file name i.e. 'cal_table'.\n");
|
|
}
|
|
} else if (option.first == "migraphx_use_native_calibration_table") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.migraphx_use_native_calibration_table = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.migraphx_use_native_calibration_table = false;
|
|
} else {
|
|
ORT_THROW(
|
|
"[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be"
|
|
" 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "migraphx_save_compiled_model") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.migraphx_fp16_enable = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.migraphx_fp16_enable = false;
|
|
} else {
|
|
ORT_THROW(
|
|
"[ERROR] [MIGraphX] The value for the key 'migx_save_compiled_model' should be"
|
|
" 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "migraphx_save_model_path") {
|
|
if (!option.second.empty()) {
|
|
save_model_path = option.second;
|
|
params.migraphx_save_model_path = save_model_path.c_str();
|
|
} else {
|
|
ORT_THROW(
|
|
"[ERROR] [MIGraphX] The value for the key 'migx_save_model_name' should be a "
|
|
"file name i.e. 'compiled_model.mxr'.\n");
|
|
}
|
|
} else if (option.first == "migraphx_load_compiled_model") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.migraphx_fp16_enable = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.migraphx_fp16_enable = false;
|
|
} else {
|
|
ORT_THROW(
|
|
"[ERROR] [MIGraphX] The value for the key 'migx_load_compiled_model' should be"
|
|
" 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else if (option.first == "migraphx_load_model_path") {
|
|
if (!option.second.empty()) {
|
|
load_model_path = option.second;
|
|
params.migraphx_load_model_path = load_model_path.c_str();
|
|
} else {
|
|
ORT_THROW(
|
|
"[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a "
|
|
"file name i.e. 'compiled_model.mxr'.\n");
|
|
}
|
|
} else if (option.first == "migraphx_exhaustive_tune") {
|
|
if (option.second == "True" || option.second == "true") {
|
|
params.migraphx_exhaustive_tune = true;
|
|
} else if (option.second == "False" || option.second == "false") {
|
|
params.migraphx_exhaustive_tune = false;
|
|
} else {
|
|
ORT_THROW(
|
|
"[ERROR] [MIGraphX] The value for the key 'migraphx_exhaustive_tune' should be"
|
|
" 'True' or 'False'. Default value is 'False'.\n");
|
|
}
|
|
} else {
|
|
ORT_THROW("Invalid MIGraphX EP option: ", option.first);
|
|
}
|
|
}
|
|
if (std::shared_ptr<IExecutionProviderFactory> migraphx_provider_factory =
|
|
onnxruntime::MIGraphXProviderFactoryCreator::Create(¶ms)) {
|
|
return migraphx_provider_factory->CreateProvider();
|
|
}
|
|
} else {
|
|
if (std::shared_ptr<IExecutionProviderFactory> migraphx_provider_factory =
|
|
onnxruntime::MIGraphXProviderFactoryCreator::Create(cuda_device_id)) {
|
|
return migraphx_provider_factory->CreateProvider();
|
|
}
|
|
}
|
|
#endif
|
|
} else if (type == kCudaExecutionProvider) {
|
|
#ifdef USE_CUDA
|
|
// If the environment variable 'CUDA_UNAVAILABLE' exists, then we do not load cuda.
|
|
// This is set by _ld_preload for the manylinux case as in that case,
|
|
// trying to load the library itself will result in a crash due to the way that auditwheel strips dependencies.
|
|
if (Env::Default().GetEnvironmentVar("ORT_CUDA_UNAVAILABLE").empty()) {
|
|
if (auto* cuda_provider_info = TryGetProviderInfo_CUDA()) {
|
|
const CUDAExecutionProviderInfo info = GetCudaExecutionProviderInfo(cuda_provider_info,
|
|
provider_options_map);
|
|
|
|
// This variable is never initialized because the APIs by which it should be initialized are deprecated,
|
|
// however they still exist are are in-use. Nevertheless, it is used to return CUDAAllocator,
|
|
// hence we must try to initialize it here if we can since FromProviderOptions might contain
|
|
// external CUDA allocator.
|
|
external_allocator_info = info.external_allocator_info;
|
|
return cuda_provider_info->CreateExecutionProviderFactory(info)->CreateProvider();
|
|
}
|
|
}
|
|
LOGS_DEFAULT(WARNING) << "Failed to create "
|
|
<< type
|
|
<< ". Require cuDNN " << CUDNN_MAJOR << ".* and "
|
|
<< "CUDA " << (CUDA_VERSION / 1000) << ".*"
|
|
#if defined(_MSC_VER)
|
|
<< ", and the latest MSVC runtime"
|
|
#endif
|
|
<< ". Please install all dependencies as mentioned in the GPU requirements page"
|
|
" (https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements), "
|
|
"make sure they're in the PATH, and that your GPU is supported.";
|
|
#endif
|
|
} else if (type == kRocmExecutionProvider) {
|
|
#ifdef USE_ROCM
|
|
if (auto* rocm_provider_info = TryGetProviderInfo_ROCM()) {
|
|
const ROCMExecutionProviderInfo info = GetRocmExecutionProviderInfo(rocm_provider_info,
|
|
provider_options_map);
|
|
|
|
// This variable is never initialized because the APIs by which is it should be initialized are deprecated,
|
|
// however they still exist and are in-use. Nevertheless, it is used to return ROCMAllocator, hence we must
|
|
// try to initialize it here if we can since FromProviderOptions might contain external ROCM allocator.
|
|
external_allocator_info = info.external_allocator_info;
|
|
return rocm_provider_info->CreateExecutionProviderFactory(info)->CreateProvider();
|
|
} else {
|
|
if (!Env::Default().GetEnvironmentVar("ROCM_PATH").empty()) {
|
|
ORT_THROW(
|
|
"ROCM_PATH is set but ROCM wasn't able to be loaded. Please install the correct version "
|
|
"of ROCM and MIOpen as mentioned in the GPU requirements page, make sure they're in the PATH, "
|
|
"and that your GPU is supported.");
|
|
}
|
|
}
|
|
#endif
|
|
} else if (type == kDnnlExecutionProvider) {
|
|
#ifdef USE_DNNL
|
|
// Generate dnnl_options
|
|
OrtDnnlProviderOptions dnnl_options;
|
|
// For Eigen and OpenMP
|
|
#if defined(DNNL_OPENMP)
|
|
int num_threads = 0;
|
|
auto it = provider_options_map.find(type);
|
|
if (it != provider_options_map.end()) {
|
|
for (auto option : it->second) {
|
|
if (option.first == "num_of_threads") {
|
|
num_threads = std::stoi(option.second);
|
|
if (num_threads < 0) {
|
|
ORT_THROW(
|
|
"[ERROR] [OneDNN] Invalid entry for the key 'num_of_threads',"
|
|
" set number of threads or use '0' for default\n");
|
|
// If the user doesnt define num_threads, auto detect threads later
|
|
}
|
|
} else {
|
|
ORT_THROW("Invalid OneDNN EP option: ", option.first);
|
|
}
|
|
}
|
|
}
|
|
dnnl_options.threadpool_args = static_cast<void*>(&num_threads);
|
|
#endif // !defined(DNNL_ORT_THREAD)
|
|
dnnl_options.use_arena = session_options.enable_cpu_mem_arena;
|
|
|
|
return onnxruntime::DnnlProviderFactoryCreator::Create(&dnnl_options)->CreateProvider();
|
|
#endif
|
|
} else if (type == kOpenVINOExecutionProvider) {
|
|
#ifdef USE_OPENVINO
|
|
ProviderOptions OV_provider_options_map;
|
|
auto it = provider_options_map.find(type);
|
|
if (it != provider_options_map.end()) {
|
|
for (auto option : it->second) {
|
|
if (option.first == "device_type") {
|
|
OV_provider_options_map[option.first] = option.second;
|
|
continue;
|
|
} else if (option.first == "precision") {
|
|
OV_provider_options_map[option.first] = option.second;
|
|
continue;
|
|
} else if (option.first == "enable_opencl_throttling") {
|
|
if (!(option.second == "True" || option.second == "true" ||
|
|
option.second == "False" || option.second == "false")) {
|
|
ORT_THROW("Invalid value passed for enable_opencl_throttling: ", option.second);
|
|
}
|
|
OV_provider_options_map[option.first] = option.second;
|
|
} else if (option.first == "disable_dynamic_shapes") {
|
|
if (!(option.second == "True" || option.second == "true" ||
|
|
option.second == "False" || option.second == "false")) {
|
|
ORT_THROW("Invalid value passed for disable_dynamic_shapes: ", option.second);
|
|
}
|
|
OV_provider_options_map[option.first] = option.second;
|
|
} else if (option.first == "enable_dynamic_shapes") {
|
|
LOGS_DEFAULT(WARNING) << " Deprecation notice - 'enable_dynamic_shapes' is Deprected. Upgrade the API to disable_dynamic_shapes parameter."
|
|
"Please refer https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements to ensure all dependencies are met.";
|
|
std::string value;
|
|
if (!(option.second == "True" || option.second == "true" ||
|
|
option.second == "False" || option.second == "false")) {
|
|
ORT_THROW("Invalid value passed for enable_dynamic_shapes: ", option.second);
|
|
}
|
|
if (option.second == "True" || option.second == "true") {
|
|
value = "false";
|
|
} else {
|
|
value = "true";
|
|
}
|
|
OV_provider_options_map["disable_dynamic_shapes"] = value;
|
|
} else if (option.first == "num_of_threads") {
|
|
OV_provider_options_map[option.first] = option.second;
|
|
continue;
|
|
} else if (option.first == "model_priority") {
|
|
OV_provider_options_map[option.first] = option.second;
|
|
continue;
|
|
} else if (option.first == "num_streams") {
|
|
OV_provider_options_map[option.first] = option.second;
|
|
continue;
|
|
} else if (option.first == "load_config") {
|
|
OV_provider_options_map[option.first] = option.second;
|
|
continue;
|
|
} else if (option.first == "cache_dir") {
|
|
OV_provider_options_map[option.first] = option.second;
|
|
continue;
|
|
} else if (option.first == "context") {
|
|
OV_provider_options_map[option.first] = option.second;
|
|
continue;
|
|
} else if (option.first == "enable_qdq_optimizer") {
|
|
OV_provider_options_map[option.first] = option.second;
|
|
continue;
|
|
} else {
|
|
ORT_THROW("Invalid OpenVINO EP option: ", option.first);
|
|
}
|
|
}
|
|
}
|
|
if (std::shared_ptr<IExecutionProviderFactory> openvino_provider_factory = onnxruntime::OpenVINOProviderFactoryCreator::Create(
|
|
&OV_provider_options_map, &session_options)) {
|
|
auto p = openvino_provider_factory->CreateProvider();
|
|
// Reset global variables config to avoid it being accidentally passed on to the next session
|
|
openvino_device_type.clear();
|
|
return p;
|
|
} else {
|
|
if (!Env::Default().GetEnvironmentVar("INTEL_OPENVINO_DIR").empty()) {
|
|
ORT_THROW("INTEL_OPENVINO_DIR is set but OpenVINO library wasn't able to be loaded. Please install a supported version of OpenVINO as mentioned in the requirements page (https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements), ensure dependency libraries are in the PATH and your hardware is supported.");
|
|
} else {
|
|
LOGS_DEFAULT(WARNING) << "Failed to create " << type << ". Please refer https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements to ensure all dependencies are met.";
|
|
}
|
|
}
|
|
#endif
|
|
} else if (type == kTvmExecutionProvider) {
|
|
#if USE_TVM
|
|
onnxruntime::tvm::TvmEPOptions info{};
|
|
const auto it = provider_options_map.find(type);
|
|
if (it != provider_options_map.end()) {
|
|
info = onnxruntime::tvm::TvmEPOptionsHelper::FromProviderOptions(it->second);
|
|
}
|
|
|
|
return onnxruntime::TVMProviderFactoryCreator::Create(info)->CreateProvider();
|
|
#endif
|
|
} else if (type == kVitisAIExecutionProvider) {
|
|
#ifdef USE_VITISAI
|
|
ProviderOptions info{};
|
|
const auto it = provider_options_map.find(type);
|
|
if (it != provider_options_map.end()) {
|
|
info = it->second;
|
|
}
|
|
info["session_options"] = std::to_string((uintptr_t)(void*)&session_options);
|
|
return onnxruntime::VitisAIProviderFactoryCreator::Create(info)->CreateProvider();
|
|
#endif
|
|
} else if (type == kAclExecutionProvider) {
|
|
#ifdef USE_ACL
|
|
bool enable_fast_math = false;
|
|
auto it = provider_options_map.find(type);
|
|
if (it != provider_options_map.end()) {
|
|
for (auto option : it->second) {
|
|
if (option.first == "enable_fast_math") {
|
|
std::set<std::string> supported_values = {"true", "True", "false", "False"};
|
|
if (supported_values.find(option.second) != supported_values.end()) {
|
|
enable_fast_math = (option.second == "true") || (option.second == "True");
|
|
} else {
|
|
ORT_THROW(
|
|
"Invalid value for enable_fast_math. "
|
|
"Select from 'true' or 'false'\n");
|
|
}
|
|
} else {
|
|
ORT_THROW("Unrecognized option: ", option.first);
|
|
}
|
|
}
|
|
}
|
|
return onnxruntime::ACLProviderFactoryCreator::Create(enable_fast_math)
|
|
->CreateProvider();
|
|
#endif
|
|
} else if (type == kArmNNExecutionProvider) {
|
|
#ifdef USE_ARMNN
|
|
return onnxruntime::ArmNNProviderFactoryCreator::Create(
|
|
session_options.enable_cpu_mem_arena)
|
|
->CreateProvider();
|
|
#endif
|
|
} else if (type == kDmlExecutionProvider) {
|
|
#ifdef USE_DML
|
|
auto cit = provider_options_map.find(type);
|
|
return onnxruntime::DMLProviderFactoryCreator::CreateFromProviderOptions(
|
|
session_options.config_options, cit == provider_options_map.end() ? ProviderOptions{} : cit->second, true)
|
|
->CreateProvider();
|
|
#endif
|
|
} else if (type == kNnapiExecutionProvider) {
|
|
#if defined(USE_NNAPI)
|
|
#if !defined(__ANDROID__)
|
|
LOGS_DEFAULT(WARNING) << "NNAPI execution provider can only be used to generate ORT format model in this build.";
|
|
#endif
|
|
const auto partitioning_stop_ops_list = session_options.config_options.GetConfigEntry(
|
|
kOrtSessionOptionsConfigNnapiEpPartitioningStopOps);
|
|
return onnxruntime::NnapiProviderFactoryCreator::Create(0, partitioning_stop_ops_list)->CreateProvider();
|
|
#endif
|
|
} else if (type == kRknpuExecutionProvider) {
|
|
#ifdef USE_RKNPU
|
|
return onnxruntime::RknpuProviderFactoryCreator::Create()->CreateProvider();
|
|
#endif
|
|
} else if (type == kCoreMLExecutionProvider) {
|
|
#if defined(USE_COREML)
|
|
#if !defined(__APPLE__)
|
|
LOGS_DEFAULT(WARNING) << "CoreML execution provider can only be used to generate ORT format model in this build.";
|
|
#endif
|
|
uint32_t coreml_flags = 0;
|
|
|
|
const auto it = provider_options_map.find(type);
|
|
if (it != provider_options_map.end()) {
|
|
const ProviderOptions& options = it->second;
|
|
auto flags = options.find("flags");
|
|
if (flags != options.end()) {
|
|
const auto& flags_str = flags->second;
|
|
|
|
if (flags_str.find("COREML_FLAG_USE_CPU_ONLY") != std::string::npos) {
|
|
coreml_flags |= COREMLFlags::COREML_FLAG_USE_CPU_ONLY;
|
|
} else if (flags_str.find("COREML_FLAG_USE_CPU_AND_GPU") != std::string::npos) {
|
|
coreml_flags |= COREMLFlags::COREML_FLAG_USE_CPU_AND_GPU;
|
|
}
|
|
|
|
if (flags_str.find("COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES") != std::string::npos) {
|
|
coreml_flags |= COREMLFlags::COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES;
|
|
}
|
|
|
|
if (flags_str.find("COREML_FLAG_CREATE_MLPROGRAM") != std::string::npos) {
|
|
coreml_flags |= COREMLFlags::COREML_FLAG_CREATE_MLPROGRAM;
|
|
}
|
|
}
|
|
}
|
|
|
|
return onnxruntime::CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider();
|
|
#endif
|
|
} else if (type == kXnnpackExecutionProvider) {
|
|
#if defined(USE_XNNPACK)
|
|
auto cit = provider_options_map.find(type);
|
|
return onnxruntime::XnnpackProviderFactoryCreator::Create(
|
|
cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options)
|
|
->CreateProvider();
|
|
#endif
|
|
} else if (type == kWebGpuExecutionProvider) {
|
|
#if defined(USE_WEBGPU)
|
|
return onnxruntime::WebGpuProviderFactoryCreator::Create(session_options.config_options)->CreateProvider();
|
|
#endif
|
|
} else if (type == kCannExecutionProvider) {
|
|
#ifdef USE_CANN
|
|
if (auto* cann_provider_info = TryGetProviderInfo_CANN()) {
|
|
const CANNExecutionProviderInfo info = GetCannExecutionProviderInfo(cann_provider_info,
|
|
provider_options_map);
|
|
return cann_provider_info->CreateExecutionProviderFactory(info)->CreateProvider();
|
|
} else {
|
|
ORT_THROW("create CANN ExecutionProvider fail");
|
|
}
|
|
#endif
|
|
} else if (type == kAzureExecutionProvider) {
|
|
#ifdef USE_AZURE
|
|
return onnxruntime::AzureProviderFactoryCreator::Create({})->CreateProvider();
|
|
#endif
|
|
} else if (type == kQnnExecutionProvider) {
|
|
#ifdef USE_QNN
|
|
auto cit = provider_options_map.find(type);
|
|
return onnxruntime::QNNProviderFactoryCreator::Create(
|
|
cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options)
|
|
->CreateProvider();
|
|
#endif
|
|
} else {
|
|
// check whether it is a dynamic load EP:
|
|
const auto it = provider_options_map.find(type);
|
|
if (it != provider_options_map.end()) {
|
|
auto shared_lib_path_it = it->second.find(kExecutionProviderSharedLibraryPath);
|
|
if (shared_lib_path_it != it->second.end()) {
|
|
// this is an EP with dynamic loading
|
|
// construct the provider option
|
|
ProviderOptions provider_options;
|
|
std::string entry_symbol = kDefaultExecutionProviderEntry;
|
|
for (auto option : it->second) {
|
|
if (option.first == kExecutionProviderSharedLibraryEntry) {
|
|
entry_symbol = option.second;
|
|
} else if (option.first != kExecutionProviderSharedLibraryPath) {
|
|
provider_options.insert(option);
|
|
}
|
|
}
|
|
return LoadExecutionProvider(shared_lib_path_it->second, provider_options, entry_symbol);
|
|
}
|
|
}
|
|
// unknown provider
|
|
throw std::runtime_error("Unknown Provider Type: " + type);
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
/*
|
|
* Register execution provider with options.
|
|
*/
|
|
static void RegisterExecutionProviders(InferenceSession* sess, const std::vector<std::string>& provider_types,
|
|
const ProviderOptionsMap& provider_options_map) {
|
|
ORT_UNUSED_PARAMETER(provider_options_map);
|
|
|
|
for (const std::string& type : provider_types) {
|
|
auto ep = CreateExecutionProviderInstance(sess->GetSessionOptions(), type, provider_options_map);
|
|
if (ep)
|
|
OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(ep)));
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Generate a map for mapping execution provider to excution provider options.
|
|
*
|
|
* @param providers vector of excution providers. [ep1, ep2, ...]
|
|
* @param provider_options_vector vector of excution provider options. [option1, option2 ...]
|
|
* @param provider_options_map an unordered map for mapping excution provider to excution provider options.
|
|
* {'ep1' -> option1, 'ep2' -> option2 ...}
|
|
*
|
|
*/
|
|
static void GenerateProviderOptionsMap(const std::vector<std::string>& providers,
|
|
const ProviderOptionsVector& provider_options_vector,
|
|
ProviderOptionsMap& provider_options_map) {
|
|
if (provider_options_vector.empty() || providers.empty()) {
|
|
return;
|
|
}
|
|
|
|
std::size_t j = 0; // index for provider_options_vector
|
|
|
|
for (const std::string& type : providers) {
|
|
if (j < provider_options_vector.size() && !provider_options_vector[j].empty()) {
|
|
provider_options_map[type] = provider_options_vector[j];
|
|
}
|
|
|
|
j += 1;
|
|
}
|
|
}
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
|
|
static void RegisterCustomOpDomains(PyInferenceSession* sess, const PySessionOptions& so) {
|
|
if (!so.custom_op_domains_.empty()) {
|
|
// Register all custom op domains that will be needed for the session
|
|
std::vector<OrtCustomOpDomain*> custom_op_domains;
|
|
custom_op_domains.reserve(so.custom_op_domains_.size());
|
|
for (size_t i = 0; i < so.custom_op_domains_.size(); ++i) {
|
|
custom_op_domains.emplace_back(so.custom_op_domains_[i]);
|
|
}
|
|
OrtPybindThrowIfError(sess->GetSessionHandle()->AddCustomOpDomains(custom_op_domains));
|
|
}
|
|
}
|
|
#endif
|
|
|
|
void InitializeSession(InferenceSession* sess,
|
|
ExecutionProviderRegistrationFn ep_registration_fn,
|
|
const std::vector<std::string>& provider_types,
|
|
const ProviderOptionsVector& provider_options,
|
|
const std::unordered_set<std::string>& disabled_optimizer_names) {
|
|
ProviderOptionsMap provider_options_map;
|
|
GenerateProviderOptionsMap(provider_types, provider_options, provider_options_map);
|
|
|
|
ep_registration_fn(sess, provider_types, provider_options_map);
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
|
if (!disabled_optimizer_names.empty()) {
|
|
OrtPybindThrowIfError(sess->FilterEnabledOptimizers({disabled_optimizer_names.cbegin(), disabled_optimizer_names.cend()}));
|
|
}
|
|
#else
|
|
ORT_UNUSED_PARAMETER(disabled_optimizer_names);
|
|
#endif
|
|
|
|
OrtPybindThrowIfError(sess->Initialize());
|
|
}
|
|
|
|
bool CheckIfTensor(const std::vector<const NodeArg*>& def_list,
|
|
const std::string& name,
|
|
/*out*/ onnx::TypeProto& type_proto) {
|
|
auto ret_it = std::find_if(std::begin(def_list), std::end(def_list),
|
|
[&name](const NodeArg* node_arg) { return name == node_arg->Name(); });
|
|
if (ret_it == std::end(def_list)) {
|
|
throw std::runtime_error("Failed to find NodeArg with name: " + name + " in the def list");
|
|
}
|
|
|
|
const auto* temp = (*ret_it)->TypeAsProto();
|
|
if (!temp) {
|
|
throw std::runtime_error("Corresponding type_proto is null");
|
|
} else {
|
|
type_proto = *temp;
|
|
}
|
|
|
|
return type_proto.has_tensor_type();
|
|
}
|
|
|
|
#if defined(USE_OPENVINO) || \
|
|
defined(USE_CUDA) || \
|
|
defined(USE_ROCM)
|
|
static void LogDeprecationWarning(
|
|
const std::string& deprecated, const optional<std::string>& alternative = nullopt) {
|
|
LOGS_DEFAULT(WARNING) << "This is DEPRECATED and will be removed in the future: " << deprecated;
|
|
LOGS_DEFAULT_IF(alternative.has_value(), WARNING) << "As an alternative, use: " << *alternative;
|
|
}
|
|
#endif
|
|
|
|
void addGlobalMethods(py::module& m) {
|
|
m.def("get_default_session_options", &GetDefaultCPUSessionOptions, "Return a default session_options instance.");
|
|
m.def("get_session_initializer", &SessionObjectInitializer::Get, "Return a default session object initializer.");
|
|
m.def(
|
|
"get_device", []() -> std::string { return BACKEND_DEVICE; },
|
|
"Return the device used to compute the prediction (CPU, MKL, ...)");
|
|
m.def(
|
|
"set_seed", [](const int64_t seed) { utils::SetRandomSeed(seed); },
|
|
"Sets the seed used for random number generation in Onnxruntime.");
|
|
m.def(
|
|
"set_default_logger_severity", [](int severity) {
|
|
ORT_ENFORCE(severity >= 0 && severity <= 4,
|
|
"Invalid logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal");
|
|
auto env = GetEnv();
|
|
logging::LoggingManager* default_logging_manager = env->GetLoggingManager();
|
|
default_logging_manager->SetDefaultLoggerSeverity(static_cast<logging::Severity>(severity));
|
|
},
|
|
"Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal");
|
|
m.def(
|
|
"set_default_logger_verbosity", [](int vlog_level) {
|
|
auto env = GetEnv();
|
|
logging::LoggingManager* default_logging_manager = env->GetLoggingManager();
|
|
default_logging_manager->SetDefaultLoggerVerbosity(vlog_level);
|
|
},
|
|
"Sets the default logging verbosity level. To activate the verbose log, "
|
|
"you need to set the default logging severity to 0:Verbose level.");
|
|
m.def(
|
|
"get_all_providers", []() -> const std::vector<std::string>& { return GetAllExecutionProviderNames(); },
|
|
"Return list of Execution Providers that this version of Onnxruntime can support. "
|
|
"The order of elements represents the default priority order of Execution Providers "
|
|
"from highest to lowest.");
|
|
m.def(
|
|
"enable_telemetry_events", []() -> void { platform_env.GetTelemetryProvider().EnableTelemetryEvents(); },
|
|
"Enables platform-specific telemetry collection where applicable.");
|
|
m.def(
|
|
"disable_telemetry_events", []() -> void { platform_env.GetTelemetryProvider().DisableTelemetryEvents(); },
|
|
"Disables platform-specific telemetry collection.");
|
|
m.def(
|
|
"create_and_register_allocator", [](const OrtMemoryInfo& mem_info, const OrtArenaCfg* arena_cfg = nullptr) -> void {
|
|
auto env = GetEnv();
|
|
auto st = env->CreateAndRegisterAllocator(mem_info, arena_cfg);
|
|
if (!st.IsOK()) {
|
|
throw std::runtime_error("Error when creating and registering allocator: " + st.ErrorMessage());
|
|
}
|
|
});
|
|
m.def(
|
|
"create_and_register_allocator_v2", [](const std::string& provider_type, const OrtMemoryInfo& mem_info, const ProviderOptions& options, const OrtArenaCfg* arena_cfg = nullptr) -> void {
|
|
auto env = GetEnv();
|
|
auto st = env->CreateAndRegisterAllocatorV2(provider_type, mem_info, options, arena_cfg);
|
|
if (!st.IsOK()) {
|
|
throw std::runtime_error("Error when creating and registering allocator in create_and_register_allocator_v2: " + st.ErrorMessage());
|
|
}
|
|
});
|
|
|
|
#ifdef USE_OPENVINO
|
|
m.def(
|
|
"get_available_openvino_device_ids", []() -> std::vector<std::string> {
|
|
if (auto* info = GetProviderInfo_OpenVINO()) {
|
|
return info->GetAvailableDevices();
|
|
}
|
|
return {};
|
|
},
|
|
"Lists all OpenVINO device ids available.");
|
|
/*
|
|
* The following APIs to set config options are deprecated. Use Session.set_providers() instead.
|
|
*/
|
|
// TODO remove deprecated global config
|
|
m.def(
|
|
"set_openvino_device", [](const std::string& device_type) {
|
|
LogDeprecationWarning("set_openvino_device", "OpenVINO execution provider option \"device_type\"");
|
|
openvino_device_type = device_type;
|
|
},
|
|
"Set the preferred OpenVINO device type to be used. If left unset, "
|
|
"the device type selected during build time will be used.");
|
|
// TODO remove deprecated global config
|
|
m.def(
|
|
"get_openvino_device", []() -> std::string {
|
|
LogDeprecationWarning("get_openvino_device");
|
|
return openvino_device_type;
|
|
},
|
|
"Gets the dynamically selected OpenVINO device type for inference.");
|
|
#endif
|
|
|
|
#if defined(USE_CUDA) || defined(USE_ROCM)
|
|
/*
|
|
* The following set_* methods are deprecated.
|
|
*
|
|
* To achieve same result, please use the following python api:
|
|
* InferenceSession.set_providers(list_of_providers, list_of_provider_option_dicts)
|
|
*
|
|
*/
|
|
// TODO remove deprecated global config
|
|
m.def("set_cuda_device_id", [](const int id) {
|
|
LogDeprecationWarning("set_cuda_device_id", "CUDA/ROCM execution provider option \"device_id\"");
|
|
cuda_device_id = static_cast<OrtDevice::DeviceId>(id);
|
|
});
|
|
// TODO remove deprecated global config
|
|
m.def("set_cudnn_conv_algo_search", [](const OrtCudnnConvAlgoSearch algo) {
|
|
LogDeprecationWarning("set_cudnn_conv_algo_search", "CUDA execution provider option \"cudnn_conv_algo_search\"");
|
|
#ifdef USE_ROCM
|
|
ORT_UNUSED_PARAMETER(algo);
|
|
ORT_THROW("set_cudnn_conv_algo_search is not supported in ROCM");
|
|
#else
|
|
cudnn_conv_algo_search = algo;
|
|
#endif
|
|
});
|
|
// TODO remove deprecated global config
|
|
m.def("set_do_copy_in_default_stream", [](const bool use_single_stream) {
|
|
LogDeprecationWarning(
|
|
"set_do_copy_in_default_stream", "CUDA execution provider option \"do_copy_in_default_stream\"");
|
|
#ifdef USE_ROCM
|
|
ORT_UNUSED_PARAMETER(use_single_stream);
|
|
ORT_THROW("set_do_copy_in_default_stream is not supported in ROCM");
|
|
#else
|
|
do_copy_in_default_stream = use_single_stream;
|
|
#endif
|
|
});
|
|
// TODO remove deprecated global config
|
|
m.def("set_gpu_mem_limit", [](const int64_t limit) {
|
|
LogDeprecationWarning(
|
|
"set_gpu_mem_limit",
|
|
"CUDA execution provider option \"gpu_mem_limit\", ROCM execution provider option \"gpu_mem_limit\"");
|
|
gpu_mem_limit = gsl::narrow<size_t>(limit);
|
|
});
|
|
// TODO remove deprecated global config
|
|
m.def("set_arena_extend_strategy", [](const onnxruntime::ArenaExtendStrategy strategy) {
|
|
LogDeprecationWarning("set_arena_extend_strategy", "CUDA/ROCM execution provider option \"arena_extend_strategy\"");
|
|
arena_extend_strategy = strategy;
|
|
});
|
|
#endif
|
|
|
|
#ifdef USE_TENSORRT
|
|
m.def(
|
|
"register_tensorrt_plugins_as_custom_ops", [](PySessionOptions& so, const ProviderOptions& options) { RegisterTensorRTPluginsAsCustomOps(so, options); },
|
|
"Register TensorRT plugins as custom ops.");
|
|
#endif
|
|
|
|
#ifdef ENABLE_ATEN
|
|
m.def("register_aten_op_executor",
|
|
[](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void {
|
|
size_t is_tensor_argument_address_int, aten_op_executor_address_int;
|
|
ORT_THROW_IF_ERROR(
|
|
ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int));
|
|
ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int));
|
|
void* p_is_tensor_argument = reinterpret_cast<void*>(is_tensor_argument_address_int);
|
|
void* p_aten_op_executor = reinterpret_cast<void*>(aten_op_executor_address_int);
|
|
contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor);
|
|
});
|
|
#endif
|
|
}
|
|
|
|
void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn) {
|
|
py::enum_<GraphOptimizationLevel>(m, "GraphOptimizationLevel")
|
|
.value("ORT_DISABLE_ALL", GraphOptimizationLevel::ORT_DISABLE_ALL)
|
|
.value("ORT_ENABLE_BASIC", GraphOptimizationLevel::ORT_ENABLE_BASIC)
|
|
.value("ORT_ENABLE_EXTENDED", GraphOptimizationLevel::ORT_ENABLE_EXTENDED)
|
|
.value("ORT_ENABLE_ALL", GraphOptimizationLevel::ORT_ENABLE_ALL);
|
|
|
|
py::enum_<ExecutionMode>(m, "ExecutionMode")
|
|
.value("ORT_SEQUENTIAL", ExecutionMode::ORT_SEQUENTIAL)
|
|
.value("ORT_PARALLEL", ExecutionMode::ORT_PARALLEL);
|
|
|
|
py::enum_<ExecutionOrder>(m, "ExecutionOrder")
|
|
.value("DEFAULT", ExecutionOrder::DEFAULT)
|
|
.value("PRIORITY_BASED", ExecutionOrder::PRIORITY_BASED)
|
|
.value("MEMORY_EFFICIENT", ExecutionOrder::MEMORY_EFFICIENT);
|
|
|
|
py::enum_<OrtAllocatorType>(m, "OrtAllocatorType")
|
|
.value("INVALID", OrtInvalidAllocator)
|
|
.value("ORT_DEVICE_ALLOCATOR", OrtDeviceAllocator)
|
|
.value("ORT_ARENA_ALLOCATOR", OrtArenaAllocator);
|
|
|
|
py::enum_<OrtMemType>(m, "OrtMemType")
|
|
.value("CPU_INPUT", OrtMemTypeCPUInput)
|
|
.value("CPU_OUTPUT", OrtMemTypeCPUOutput)
|
|
.value("CPU", OrtMemTypeCPU)
|
|
.value("DEFAULT", OrtMemTypeDefault);
|
|
|
|
py::class_<OrtDevice> device(m, "OrtDevice", R"pbdoc(ONNXRuntime device informaion.)pbdoc");
|
|
device.def(py::init<OrtDevice::DeviceType, OrtDevice::MemoryType, OrtDevice::DeviceId>())
|
|
.def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc")
|
|
.def("device_type", &OrtDevice::Type, R"pbdoc(Device Type.)pbdoc")
|
|
.def_static("cpu", []() { return OrtDevice::CPU; })
|
|
.def_static("cuda", []() { return OrtDevice::GPU; })
|
|
.def_static("cann", []() { return OrtDevice::NPU; })
|
|
.def_static("fpga", []() { return OrtDevice::FPGA; })
|
|
.def_static("npu", []() { return OrtDevice::NPU; })
|
|
.def_static("dml", []() { return OrtDevice::DML; })
|
|
.def_static("webgpu", []() { return OrtDevice::GPU; })
|
|
.def_static("default_memory", []() { return OrtDevice::MemType::DEFAULT; });
|
|
|
|
py::class_<OrtArenaCfg> ort_arena_cfg_binding(m, "OrtArenaCfg");
|
|
// Note: Doesn't expose initial_growth_chunk_sizes_bytes/max_power_of_two_extend_bytes option.
|
|
// This constructor kept for backwards compatibility, key-value pair constructor overload exposes all options
|
|
// There is a global var: arena_extend_strategy, which means we can't use that var name here
|
|
// See docs/C_API.md for details on what the following parameters mean and how to choose these values
|
|
ort_arena_cfg_binding.def(py::init([](size_t max_mem, int arena_extend_strategy_local,
|
|
int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
|
|
auto ort_arena_cfg = std::make_unique<OrtArenaCfg>();
|
|
ort_arena_cfg->max_mem = max_mem;
|
|
ort_arena_cfg->arena_extend_strategy = arena_extend_strategy_local;
|
|
ort_arena_cfg->initial_chunk_size_bytes = initial_chunk_size_bytes;
|
|
ort_arena_cfg->max_dead_bytes_per_chunk = max_dead_bytes_per_chunk;
|
|
return ort_arena_cfg;
|
|
}))
|
|
.def(py::init([](const py::dict& feeds) {
|
|
auto ort_arena_cfg = std::make_unique<OrtArenaCfg>();
|
|
for (const auto kvp : feeds) {
|
|
std::string key = kvp.first.cast<std::string>();
|
|
if (key == "max_mem") {
|
|
ort_arena_cfg->max_mem = kvp.second.cast<size_t>();
|
|
} else if (key == "arena_extend_strategy") {
|
|
ort_arena_cfg->arena_extend_strategy = kvp.second.cast<int>();
|
|
} else if (key == "initial_chunk_size_bytes") {
|
|
ort_arena_cfg->initial_chunk_size_bytes = kvp.second.cast<int>();
|
|
} else if (key == "max_dead_bytes_per_chunk") {
|
|
ort_arena_cfg->max_dead_bytes_per_chunk = kvp.second.cast<int>();
|
|
} else if (key == "initial_growth_chunk_size_bytes") {
|
|
ort_arena_cfg->initial_growth_chunk_size_bytes = kvp.second.cast<int>();
|
|
} else if (key == "max_power_of_two_extend_bytes") {
|
|
ort_arena_cfg->max_power_of_two_extend_bytes = kvp.second.cast<int>();
|
|
} else {
|
|
ORT_THROW("Invalid OrtArenaCfg option: ", key);
|
|
}
|
|
}
|
|
return ort_arena_cfg;
|
|
}))
|
|
.def_readwrite("max_mem", &OrtArenaCfg::max_mem)
|
|
.def_readwrite("arena_extend_strategy", &OrtArenaCfg::arena_extend_strategy)
|
|
.def_readwrite("initial_chunk_size_bytes", &OrtArenaCfg::initial_chunk_size_bytes)
|
|
.def_readwrite("max_dead_bytes_per_chunk", &OrtArenaCfg::max_dead_bytes_per_chunk)
|
|
.def_readwrite("initial_growth_chunk_size_bytes", &OrtArenaCfg::initial_growth_chunk_size_bytes)
|
|
.def_readwrite("max_power_of_two_extend_bytes", &OrtArenaCfg::max_power_of_two_extend_bytes);
|
|
|
|
py::class_<OrtMemoryInfo> ort_memory_info_binding(m, "OrtMemoryInfo");
|
|
ort_memory_info_binding.def(py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
|
|
if (strcmp(name, onnxruntime::CPU) == 0) {
|
|
return std::make_unique<OrtMemoryInfo>(onnxruntime::CPU, type, OrtDevice(), id, mem_type);
|
|
} else if (strcmp(name, onnxruntime::CUDA) == 0) {
|
|
return std::make_unique<OrtMemoryInfo>(
|
|
onnxruntime::CUDA, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id)), id,
|
|
mem_type);
|
|
} else if (strcmp(name, onnxruntime::CUDA_PINNED) == 0) {
|
|
return std::make_unique<OrtMemoryInfo>(
|
|
onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast<OrtDevice::DeviceId>(id)),
|
|
id, mem_type);
|
|
} else {
|
|
throw std::runtime_error("Specified device is not supported.");
|
|
}
|
|
}));
|
|
|
|
py::class_<PySessionOptions>
|
|
sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc");
|
|
sess
|
|
.def(py::init())
|
|
.def_property(
|
|
"enable_cpu_mem_arena",
|
|
[](const PySessionOptions* options) -> bool { return options->value.enable_cpu_mem_arena; },
|
|
[](PySessionOptions* options, bool enable_cpu_mem_arena) -> void {
|
|
options->value.enable_cpu_mem_arena = enable_cpu_mem_arena;
|
|
},
|
|
R"pbdoc(Enables the memory arena on CPU. Arena may pre-allocate memory for future usage.
|
|
Set this option to false if you don't want it. Default is True.)pbdoc")
|
|
.def_property(
|
|
"enable_profiling",
|
|
[](const PySessionOptions* options) -> bool { return options->value.enable_profiling; },
|
|
[](PySessionOptions* options, bool enable_profiling) -> void {
|
|
options->value.enable_profiling = enable_profiling;
|
|
},
|
|
R"pbdoc(Enable profiling for this session. Default is false.)pbdoc")
|
|
.def_property(
|
|
"profile_file_prefix",
|
|
[](const PySessionOptions* options) -> std::basic_string<ORTCHAR_T> {
|
|
return options->value.profile_file_prefix;
|
|
},
|
|
[](PySessionOptions* options, std::basic_string<ORTCHAR_T> profile_file_prefix) -> void {
|
|
options->value.profile_file_prefix = std::move(profile_file_prefix);
|
|
},
|
|
R"pbdoc(The prefix of the profile file. The current time will be appended to the file name.)pbdoc")
|
|
.def_property(
|
|
"optimized_model_filepath",
|
|
[](const PySessionOptions* options) -> std::basic_string<ORTCHAR_T> {
|
|
return options->value.optimized_model_filepath;
|
|
},
|
|
[](PySessionOptions* options, std::basic_string<ORTCHAR_T> optimized_model_filepath) -> void {
|
|
options->value.optimized_model_filepath = std::move(optimized_model_filepath);
|
|
},
|
|
R"pbdoc(
|
|
File path to serialize optimized model to.
|
|
Optimized model is not serialized unless optimized_model_filepath is set.
|
|
Serialized model format will default to ONNX unless:
|
|
- add_session_config_entry is used to set 'session.save_model_format' to 'ORT', or
|
|
- there is no 'session.save_model_format' config entry and optimized_model_filepath ends in '.ort' (case insensitive)
|
|
|
|
)pbdoc")
|
|
.def_property(
|
|
"enable_cpu_mem_arena",
|
|
[](const PySessionOptions* options) -> bool { return options->value.enable_cpu_mem_arena; },
|
|
[](PySessionOptions* options, bool enable_cpu_mem_arena) -> void {
|
|
options->value.enable_cpu_mem_arena = enable_cpu_mem_arena;
|
|
},
|
|
R"pbdoc(Enable memory arena on CPU. Default is true.)pbdoc")
|
|
.def_property(
|
|
"enable_mem_pattern",
|
|
[](const PySessionOptions* options) -> bool { return options->value.enable_mem_pattern; },
|
|
[](PySessionOptions* options, bool enable_mem_pattern) -> void {
|
|
options->value.enable_mem_pattern = enable_mem_pattern;
|
|
},
|
|
R"pbdoc(Enable the memory pattern optimization. Default is true.)pbdoc")
|
|
.def_property(
|
|
"enable_mem_reuse",
|
|
[](const PySessionOptions* options) -> bool { return options->value.enable_mem_reuse; },
|
|
[](PySessionOptions* options, bool enable_mem_reuse) -> void {
|
|
options->value.enable_mem_reuse = enable_mem_reuse;
|
|
},
|
|
R"pbdoc(Enable the memory reuse optimization. Default is true.)pbdoc")
|
|
.def_property(
|
|
"logid",
|
|
[](const PySessionOptions* options) -> std::string {
|
|
return options->value.session_logid;
|
|
},
|
|
[](PySessionOptions* options, std::string logid) -> void {
|
|
options->value.session_logid = std::move(logid);
|
|
},
|
|
R"pbdoc(Logger id to use for session output.)pbdoc")
|
|
.def_property(
|
|
"log_severity_level",
|
|
[](const PySessionOptions* options) -> int { return options->value.session_log_severity_level; },
|
|
[](PySessionOptions* options, int log_severity_level) -> void {
|
|
options->value.session_log_severity_level = log_severity_level;
|
|
},
|
|
R"pbdoc(Log severity level. Applies to session load, initialization, etc.
|
|
0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.)pbdoc")
|
|
.def_property(
|
|
"log_verbosity_level",
|
|
[](const PySessionOptions* options) -> int { return options->value.session_log_verbosity_level; },
|
|
[](PySessionOptions* options, int log_verbosity_level) -> void {
|
|
options->value.session_log_verbosity_level = log_verbosity_level;
|
|
},
|
|
R"pbdoc(VLOG level if DEBUG build and session_log_severity_level is 0.
|
|
Applies to session load, initialization, etc. Default is 0.)pbdoc")
|
|
.def_property(
|
|
"intra_op_num_threads",
|
|
[](const PySessionOptions* options) -> int { return options->value.intra_op_param.thread_pool_size; },
|
|
[](PySessionOptions* options, int value) -> void { options->value.intra_op_param.thread_pool_size = value; },
|
|
R"pbdoc(Sets the number of threads used to parallelize the execution within nodes. Default is 0 to let onnxruntime choose.)pbdoc")
|
|
.def_property(
|
|
"inter_op_num_threads",
|
|
[](const PySessionOptions* options) -> int { return options->value.inter_op_param.thread_pool_size; },
|
|
[](PySessionOptions* options, int value) -> void { options->value.inter_op_param.thread_pool_size = value; },
|
|
R"pbdoc(Sets the number of threads used to parallelize the execution of the graph (across nodes). Default is 0 to let onnxruntime choose.)pbdoc")
|
|
.def_property(
|
|
"execution_mode",
|
|
[](const PySessionOptions* options) -> ExecutionMode { return options->value.execution_mode; },
|
|
[](PySessionOptions* options, ExecutionMode execution_mode) -> void {
|
|
options->value.execution_mode = execution_mode;
|
|
},
|
|
R"pbdoc(Sets the execution mode. Default is sequential.)pbdoc")
|
|
.def_property(
|
|
"execution_order",
|
|
[](const PySessionOptions* options) -> ExecutionOrder { return options->value.execution_order; },
|
|
[](PySessionOptions* options, ExecutionOrder execution_order) -> void {
|
|
options->value.execution_order = execution_order;
|
|
},
|
|
R"pbdoc(Sets the execution order. Default is basic topological order.)pbdoc")
|
|
.def_property(
|
|
"graph_optimization_level",
|
|
[](const PySessionOptions* options) -> GraphOptimizationLevel {
|
|
GraphOptimizationLevel retval = ORT_ENABLE_ALL;
|
|
switch (options->value.graph_optimization_level) {
|
|
case onnxruntime::TransformerLevel::Default:
|
|
retval = ORT_DISABLE_ALL;
|
|
break;
|
|
case onnxruntime::TransformerLevel::Level1:
|
|
retval = ORT_ENABLE_BASIC;
|
|
break;
|
|
case onnxruntime::TransformerLevel::Level2:
|
|
retval = ORT_ENABLE_EXTENDED;
|
|
break;
|
|
case onnxruntime::TransformerLevel::Level3:
|
|
retval = ORT_ENABLE_ALL;
|
|
break;
|
|
default:
|
|
retval = ORT_ENABLE_ALL;
|
|
LOGS_DEFAULT(WARNING) << "Got invalid graph optimization level; defaulting to ORT_ENABLE_ALL";
|
|
break;
|
|
}
|
|
return retval;
|
|
},
|
|
|
|
[](PySessionOptions* options, GraphOptimizationLevel level) -> void {
|
|
switch (level) {
|
|
case ORT_DISABLE_ALL:
|
|
options->value.graph_optimization_level = onnxruntime::TransformerLevel::Default;
|
|
break;
|
|
case ORT_ENABLE_BASIC:
|
|
options->value.graph_optimization_level = onnxruntime::TransformerLevel::Level1;
|
|
break;
|
|
case ORT_ENABLE_EXTENDED:
|
|
options->value.graph_optimization_level = onnxruntime::TransformerLevel::Level2;
|
|
break;
|
|
case ORT_ENABLE_ALL:
|
|
options->value.graph_optimization_level = onnxruntime::TransformerLevel::Level3;
|
|
break;
|
|
}
|
|
},
|
|
R"pbdoc(Graph optimization level for this session.)pbdoc")
|
|
.def_property(
|
|
"use_deterministic_compute",
|
|
[](const PySessionOptions* options) -> bool { return options->value.use_deterministic_compute; },
|
|
[](PySessionOptions* options, bool use_deterministic_compute) -> void {
|
|
options->value.use_deterministic_compute = use_deterministic_compute;
|
|
},
|
|
R"pbdoc(Whether to use deterministic compute. Default is false.)pbdoc")
|
|
.def(
|
|
"add_free_dimension_override_by_denotation",
|
|
[](PySessionOptions* options, const char* dim_name, int64_t dim_value)
|
|
-> void { options->value.free_dimension_overrides.push_back(
|
|
onnxruntime::FreeDimensionOverride{
|
|
dim_name,
|
|
onnxruntime::FreeDimensionOverrideType::Denotation,
|
|
dim_value}); },
|
|
R"pbdoc(Specify the dimension size for each denotation associated with an input's free dimension.)pbdoc")
|
|
.def(
|
|
"add_free_dimension_override_by_name",
|
|
[](PySessionOptions* options, const char* dim_name, int64_t dim_value)
|
|
-> void { options->value.free_dimension_overrides.push_back(
|
|
onnxruntime::FreeDimensionOverride{
|
|
dim_name,
|
|
onnxruntime::FreeDimensionOverrideType::Name,
|
|
dim_value}); },
|
|
R"pbdoc(Specify values of named dimensions within model inputs.)pbdoc")
|
|
.def(
|
|
"add_session_config_entry",
|
|
[](PySessionOptions* options, const char* config_key, const char* config_value) -> void {
|
|
// config_key and config_value will be copied
|
|
const Status status = options->value.config_options.AddConfigEntry(config_key, config_value);
|
|
if (!status.IsOK())
|
|
throw std::runtime_error(status.ErrorMessage());
|
|
},
|
|
R"pbdoc(Set a single session configuration entry as a pair of strings.)pbdoc")
|
|
.def(
|
|
"get_session_config_entry",
|
|
[](const PySessionOptions* options, const char* config_key) -> std::string {
|
|
const std::string key(config_key);
|
|
std::string value;
|
|
if (!options->value.config_options.TryGetConfigEntry(key, value))
|
|
throw std::runtime_error("SessionOptions does not have configuration with key: " + key);
|
|
|
|
return value;
|
|
},
|
|
R"pbdoc(Get a single session configuration value using the given configuration key.)pbdoc")
|
|
.def(
|
|
"register_custom_ops_library",
|
|
[](PySessionOptions* options, const char* library_name) -> void {
|
|
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
|
|
OrtPybindThrowIfError(options->RegisterCustomOpsLibrary(ToPathString(library_name)));
|
|
#else
|
|
ORT_UNUSED_PARAMETER(options);
|
|
ORT_UNUSED_PARAMETER(library_name);
|
|
ORT_THROW("Custom Ops are not supported in this build.");
|
|
#endif
|
|
},
|
|
R"pbdoc(Specify the path to the shared library containing the custom op kernels required to run a model.)pbdoc")
|
|
.def(
|
|
"add_initializer", [](PySessionOptions* options, const char* name, py::object& ml_value_pyobject) -> void {
|
|
ORT_ENFORCE(strcmp(Py_TYPE(ml_value_pyobject.ptr())->tp_name, PYTHON_ORTVALUE_OBJECT_NAME) == 0, "The provided Python object must be an OrtValue");
|
|
// The user needs to ensure that the python OrtValue being provided as an overriding initializer
|
|
// is not destructed as long as any session that uses the provided OrtValue initializer is still in scope
|
|
// This is no different than the native APIs
|
|
const OrtValue* ml_value = ml_value_pyobject.attr(PYTHON_ORTVALUE_NATIVE_OBJECT_ATTR).cast<OrtValue*>();
|
|
ORT_THROW_IF_ERROR(options->value.AddInitializer(name, ml_value));
|
|
})
|
|
.def("add_external_initializers", [](PySessionOptions* options, py::list& names, const py::list& ort_values) -> void {
|
|
#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_EXTERNAL_INITIALIZERS)
|
|
const auto init_num = ort_values.size();
|
|
ORT_ENFORCE(init_num == names.size(), "Expecting names and ort_values lists to have equal length");
|
|
InlinedVector<std::string> names_ptrs;
|
|
InlinedVector<OrtValue> values_ptrs;
|
|
names_ptrs.reserve(init_num);
|
|
values_ptrs.reserve(init_num);
|
|
for (size_t i = 0; i < init_num; ++i) {
|
|
names_ptrs.emplace_back(py::str(names[i]));
|
|
values_ptrs.emplace_back(*ort_values[i].attr(PYTHON_ORTVALUE_NATIVE_OBJECT_ATTR).cast<const OrtValue*>());
|
|
}
|
|
ORT_THROW_IF_ERROR(options->value.AddExternalInitializers(names_ptrs, values_ptrs));
|
|
#else
|
|
ORT_UNUSED_PARAMETER(options);
|
|
ORT_UNUSED_PARAMETER(names);
|
|
ORT_UNUSED_PARAMETER(ort_values);
|
|
ORT_THROW("External initializers are not supported in this build.");
|
|
#endif
|
|
});
|
|
|
|
py::class_<RunOptions>(m, "RunOptions", R"pbdoc(Configuration information for a single Run.)pbdoc")
|
|
.def(py::init())
|
|
.def_readwrite("log_severity_level", &RunOptions::run_log_severity_level,
|
|
R"pbdoc(Log severity level for a particular Run() invocation. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.)pbdoc")
|
|
.def_readwrite("log_verbosity_level", &RunOptions::run_log_verbosity_level,
|
|
R"pbdoc(VLOG level if DEBUG build and run_log_severity_level is 0.
|
|
Applies to a particular Run() invocation. Default is 0.)pbdoc")
|
|
.def_readwrite("logid", &RunOptions::run_tag,
|
|
"To identify logs generated by a particular Run() invocation.")
|
|
.def_readwrite("terminate", &RunOptions::terminate,
|
|
R"pbdoc(Set to True to terminate any currently executing calls that are using this
|
|
RunOptions instance. The individual calls will exit gracefully and return an error status.)pbdoc")
|
|
#ifdef ENABLE_TRAINING
|
|
.def_readwrite("training_mode", &RunOptions::training_mode,
|
|
R"pbdoc(Choose to run in training or inferencing mode)pbdoc")
|
|
#endif
|
|
.def_readwrite("only_execute_path_to_fetches", &RunOptions::only_execute_path_to_fetches,
|
|
R"pbdoc(Only execute the nodes needed by fetch list)pbdoc")
|
|
.def(
|
|
"add_run_config_entry",
|
|
[](RunOptions* options, const char* config_key, const char* config_value) -> void {
|
|
// config_key and config_value will be copied
|
|
const Status status = options->config_options.AddConfigEntry(config_key, config_value);
|
|
if (!status.IsOK())
|
|
throw std::runtime_error(status.ErrorMessage());
|
|
},
|
|
R"pbdoc(Set a single run configuration entry as a pair of strings.)pbdoc")
|
|
.def(
|
|
"get_run_config_entry",
|
|
[](const RunOptions* options, const char* config_key) -> std::string {
|
|
const std::string key(config_key);
|
|
std::string value;
|
|
if (!options->config_options.TryGetConfigEntry(key, value))
|
|
throw std::runtime_error("RunOptions does not have configuration with key: " + key);
|
|
|
|
return value;
|
|
},
|
|
R"pbdoc(Get a single run configuration value using the given configuration key.)pbdoc")
|
|
.def(
|
|
"add_active_adapter", [](RunOptions* options, lora::LoraAdapter* adapter) {
|
|
options->active_adapters.push_back(adapter);
|
|
},
|
|
R"pbdoc(Adds specified adapter as an active adapter)pbdoc");
|
|
|
|
py::class_<ModelMetadata>(m, "ModelMetadata", R"pbdoc(Pre-defined and custom metadata about the model.
|
|
It is usually used to identify the model used to run the prediction and
|
|
facilitate the comparison.)pbdoc")
|
|
.def_readwrite("producer_name", &ModelMetadata::producer_name, "producer name")
|
|
.def_readwrite("graph_name", &ModelMetadata::graph_name, "graph name")
|
|
.def_readwrite("domain", &ModelMetadata::domain, "ONNX domain")
|
|
.def_readwrite("description", &ModelMetadata::description, "description of the model")
|
|
.def_readwrite("graph_description", &ModelMetadata::graph_description, "description of the graph hosted in the model")
|
|
.def_readwrite("version", &ModelMetadata::version, "version of the model")
|
|
.def_readwrite("custom_metadata_map", &ModelMetadata::custom_metadata_map, "additional metadata");
|
|
|
|
py::class_<onnxruntime::NodeArg>(m, "NodeArg", R"pbdoc(Node argument definition, for both input and output,
|
|
including arg name, arg type (contains both type and shape).)pbdoc")
|
|
.def_property_readonly("name", &onnxruntime::NodeArg::Name, "node name")
|
|
.def_property_readonly(
|
|
"type", [](const onnxruntime::NodeArg& na) -> std::string {
|
|
return *(na.Type());
|
|
},
|
|
"node type")
|
|
.def("__str__", [](const onnxruntime::NodeArg& na) -> std::string {
|
|
std::ostringstream res;
|
|
res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape=";
|
|
auto shape = na.Shape();
|
|
std::vector<py::object> arr;
|
|
if (shape == nullptr || shape->dim_size() == 0) {
|
|
res << "[]";
|
|
} else {
|
|
res << "[";
|
|
for (int i = 0; i < shape->dim_size(); ++i) {
|
|
if (utils::HasDimValue(shape->dim(i))) {
|
|
res << shape->dim(i).dim_value();
|
|
} else if (utils::HasDimParam(shape->dim(i))) {
|
|
res << "'" << shape->dim(i).dim_param() << "'";
|
|
} else {
|
|
res << "None";
|
|
}
|
|
|
|
if (i < shape->dim_size() - 1) {
|
|
res << ", ";
|
|
}
|
|
}
|
|
res << "]";
|
|
}
|
|
res << ")";
|
|
|
|
return std::string(res.str()); }, "converts the node into a readable string")
|
|
.def_property_readonly("shape", [](const onnxruntime::NodeArg& na) -> std::vector<py::object> {
|
|
auto shape = na.Shape();
|
|
std::vector<py::object> arr;
|
|
if (shape == nullptr || shape->dim_size() == 0) {
|
|
return arr;
|
|
}
|
|
|
|
arr.resize(shape->dim_size());
|
|
for (int i = 0; i < shape->dim_size(); ++i) {
|
|
if (utils::HasDimValue(shape->dim(i))) {
|
|
arr[i] = py::cast(shape->dim(i).dim_value());
|
|
} else if (utils::HasDimParam(shape->dim(i))) {
|
|
arr[i] = py::cast(shape->dim(i).dim_param());
|
|
} else {
|
|
arr[i] = py::none();
|
|
}
|
|
}
|
|
return arr; }, "node shape (assuming the node holds a tensor)");
|
|
|
|
py::class_<SessionObjectInitializer> sessionObjectInitializer(m, "SessionObjectInitializer");
|
|
py::class_<PyInferenceSession>(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc")
|
|
// In Python3, a Python bytes object will be passed to C++ functions that accept std::string or char*
|
|
// without any conversion. So this init method can be used for model file path (string) and model content (bytes)
|
|
.def(py::init([](const PySessionOptions& so, const std::string arg, bool is_arg_file_name,
|
|
bool load_config_from_model = false) {
|
|
auto env = GetEnv();
|
|
std::unique_ptr<PyInferenceSession> sess;
|
|
|
|
// separate creation of the session from model loading unless we have to read the config from the model.
|
|
// in a minimal build we only support load via Load(...) and not at session creation time
|
|
if (load_config_from_model) {
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
sess = std::make_unique<PyInferenceSession>(std::move(env), so, arg, is_arg_file_name);
|
|
|
|
RegisterCustomOpDomains(sess.get(), so);
|
|
|
|
OrtPybindThrowIfError(sess->GetSessionHandle()->Load());
|
|
#else
|
|
ORT_THROW("Loading configuration from an ONNX model is not supported in this build.");
|
|
#endif
|
|
} else {
|
|
sess = std::make_unique<PyInferenceSession>(std::move(env), so);
|
|
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
|
|
RegisterCustomOpDomains(sess.get(), so);
|
|
#endif
|
|
|
|
if (is_arg_file_name) {
|
|
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(arg));
|
|
} else {
|
|
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(arg.data(), narrow<int>(arg.size())));
|
|
}
|
|
}
|
|
|
|
return sess;
|
|
}))
|
|
.def(
|
|
"initialize_session",
|
|
[ep_registration_fn](PyInferenceSession* sess,
|
|
const std::vector<std::string>& provider_types = {},
|
|
const ProviderOptionsVector& provider_options = {},
|
|
const std::unordered_set<std::string>& disabled_optimizer_names = {}) {
|
|
InitializeSession(sess->GetSessionHandle(),
|
|
ep_registration_fn,
|
|
provider_types,
|
|
provider_options,
|
|
disabled_optimizer_names);
|
|
},
|
|
R"pbdoc(Load a model saved in ONNX or ORT format.)pbdoc")
|
|
.def("run",
|
|
[](PyInferenceSession* sess, const std::vector<std::string>& output_names,
|
|
const std::map<std::string, const py::object>& pyfeeds, RunOptions* run_options = nullptr)
|
|
-> py::list {
|
|
NameMLValMap feeds;
|
|
if (run_options != nullptr && !run_options->active_adapters.empty()) {
|
|
AppendLoraParametersAsInputs(*run_options, pyfeeds.size(), feeds);
|
|
} else {
|
|
feeds.reserve(pyfeeds.size());
|
|
}
|
|
|
|
for (const auto& feed : pyfeeds) {
|
|
// No need to process 'None's sent in by the user
|
|
// to feed Optional inputs in the graph.
|
|
// We just won't include anything in the feed and ORT
|
|
// will handle such implicit 'None's internally.
|
|
if (!feed.second.is(py::none())) {
|
|
OrtValue ml_value;
|
|
auto px = sess->GetSessionHandle()->GetModelInputs();
|
|
if (!px.first.IsOK() || !px.second) {
|
|
throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null");
|
|
}
|
|
CreateGenericMLValue(px.second, GetAllocator(), feed.first, feed.second, &ml_value);
|
|
ThrowIfPyErrOccured();
|
|
feeds.insert(std::make_pair(feed.first, std::move(ml_value)));
|
|
}
|
|
}
|
|
|
|
std::vector<OrtValue> fetches;
|
|
fetches.reserve(output_names.size());
|
|
common::Status status;
|
|
|
|
{
|
|
// release GIL to allow multiple python threads to invoke Run() in parallel.
|
|
py::gil_scoped_release release;
|
|
if (run_options != nullptr) {
|
|
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(*run_options, feeds, output_names, &fetches));
|
|
} else {
|
|
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(feeds, output_names, &fetches));
|
|
}
|
|
}
|
|
|
|
py::list result;
|
|
size_t pos = 0;
|
|
for (const auto& fet : fetches) {
|
|
if (fet.IsAllocated()) {
|
|
if (fet.IsTensor()) {
|
|
result.append(AddTensorAsPyObj(fet, nullptr, nullptr));
|
|
} else if (fet.IsSparseTensor()) {
|
|
result.append(GetPyObjectFromSparseTensor(pos, fet, nullptr));
|
|
} else {
|
|
result.append(AddNonTensorAsPyObj(fet, nullptr, nullptr));
|
|
}
|
|
} else { // Send back None because the corresponding OrtValue was empty
|
|
result.append(py::none());
|
|
}
|
|
++pos;
|
|
}
|
|
return result;
|
|
})
|
|
.def("run_async",
|
|
[](PyInferenceSession* sess,
|
|
const std::vector<std::string>& output_names,
|
|
const std::map<std::string, py::object>& pyfeeds,
|
|
PyCallback callback, py::object user_data = {},
|
|
RunOptions* run_options = nullptr)
|
|
-> void {
|
|
if (run_options != nullptr && !run_options->active_adapters.empty()) {
|
|
LOGS(*sess->GetSessionHandle()->GetLogger(), WARNING)
|
|
<< "run_async has active adapters specified, but won't have an effect";
|
|
}
|
|
|
|
std::unique_ptr<AsyncResource> async_resource = std::make_unique<AsyncResource>();
|
|
async_resource->callback = callback;
|
|
async_resource->user_data = user_data;
|
|
// prepare feeds
|
|
async_resource->ReserveFeeds(pyfeeds.size());
|
|
for (const auto& feed : pyfeeds) {
|
|
if (!feed.second.is(py::none())) {
|
|
OrtValue ml_value;
|
|
auto px = sess->GetSessionHandle()->GetModelInputs();
|
|
if (!px.first.IsOK() || !px.second) {
|
|
throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null");
|
|
}
|
|
CreateGenericMLValue(px.second, GetAllocator(), feed.first, feed.second, &ml_value);
|
|
ThrowIfPyErrOccured();
|
|
async_resource->feeds.push_back(ml_value);
|
|
async_resource->feeds_raw.push_back(&async_resource->feeds.back());
|
|
async_resource->feed_names.push_back(feed.first);
|
|
async_resource->feed_names_raw.push_back(async_resource->feed_names.back().c_str());
|
|
}
|
|
}
|
|
// prepare fetches
|
|
async_resource->ReserveFetches(output_names.size());
|
|
for (const auto& output_name : output_names) {
|
|
async_resource->fetch_names.push_back(output_name);
|
|
async_resource->fetch_names_raw.push_back(async_resource->fetch_names.back().c_str());
|
|
async_resource->fetches_raw.push_back({});
|
|
}
|
|
const RunOptions* run_async_option = run_options ? run_options : &async_resource->default_run_option;
|
|
common::Status status = sess->GetSessionHandle()->RunAsync(run_async_option,
|
|
gsl::span(async_resource->feed_names_raw.data(), async_resource->feed_names_raw.size()),
|
|
gsl::span(async_resource->feeds_raw.data(), async_resource->feeds_raw.size()),
|
|
gsl::span(async_resource->fetch_names_raw.data(), async_resource->fetch_names_raw.size()),
|
|
gsl::span(async_resource->fetches_raw.data(), async_resource->fetches_raw.size()),
|
|
AsyncCallback,
|
|
async_resource.get());
|
|
if (status.IsOK()) {
|
|
async_resource.release();
|
|
}
|
|
OrtPybindThrowIfError(status);
|
|
})
|
|
/// This method accepts a dictionary of feeds (name -> OrtValue) and the list of output_names
|
|
/// and returns a list of python objects representing OrtValues. Each name may represent either
|
|
/// a Tensor, SparseTensor or a TensorSequence.
|
|
.def("run_with_ort_values", [](PyInferenceSession* sess, const py::dict& feeds, const std::vector<std::string>& output_names, RunOptions* run_options = nullptr) -> std::vector<OrtValue> {
|
|
NameMLValMap ort_feeds;
|
|
if (run_options != nullptr && !run_options->active_adapters.empty()) {
|
|
AppendLoraParametersAsInputs(*run_options, feeds.size(), ort_feeds);
|
|
} else {
|
|
ort_feeds.reserve(feeds.size());
|
|
}
|
|
|
|
// item is always a copy since dict returns a value and not a ref
|
|
// and Apple XToolChain barks
|
|
for (const auto& item : feeds) {
|
|
auto name = item.first.cast<std::string>();
|
|
const OrtValue* ort_value = item.second.cast<const OrtValue*>();
|
|
ort_feeds.emplace(name, *ort_value);
|
|
}
|
|
|
|
std::vector<OrtValue> fetches;
|
|
fetches.reserve(output_names.size());
|
|
{
|
|
// release GIL to allow multiple python threads to invoke Run() in parallel.
|
|
py::gil_scoped_release release;
|
|
if (run_options != nullptr) {
|
|
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(*run_options, ort_feeds, output_names, &fetches));
|
|
} else {
|
|
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(ort_feeds, output_names, &fetches));
|
|
}
|
|
}
|
|
return fetches;
|
|
})
|
|
.def("run_with_ortvaluevector", [](PyInferenceSession* sess, RunOptions run_options, const std::vector<std::string>& feed_names, const std::vector<OrtValue>& feeds, const std::vector<std::string>& fetch_names, std::vector<OrtValue>& fetches, const std::vector<OrtDevice>& fetch_devices) -> void {
|
|
if (!run_options.active_adapters.empty()) {
|
|
LOGS(*sess->GetSessionHandle()->GetLogger(), WARNING)
|
|
<< "run_with_ortvaluevector has active adapters specified, but won't have an effect";
|
|
}
|
|
|
|
// release GIL to allow multiple python threads to invoke Run() in parallel.
|
|
py::gil_scoped_release release;
|
|
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(run_options, feed_names, feeds, fetch_names, &fetches, &fetch_devices));
|
|
})
|
|
.def("end_profiling", [](const PyInferenceSession* sess) -> std::string {
|
|
return sess->GetSessionHandle()->EndProfiling();
|
|
})
|
|
.def_property_readonly("get_profiling_start_time_ns", [](const PyInferenceSession* sess) -> uint64_t {
|
|
return sess->GetSessionHandle()->GetProfiling().GetStartTimeNs();
|
|
})
|
|
.def("get_providers", [](const PyInferenceSession* sess) -> const std::vector<std::string>& { return sess->GetSessionHandle()->GetRegisteredProviderTypes(); }, py::return_value_policy::reference_internal)
|
|
.def("get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& { return sess->GetSessionHandle()->GetAllProviderOptions(); }, py::return_value_policy::reference_internal)
|
|
.def_property_readonly("session_options", [](const PyInferenceSession* sess) -> PySessionOptions* {
|
|
auto session_options = std::make_unique<PySessionOptions>();
|
|
session_options->value = sess->GetSessionHandle()->GetSessionOptions();
|
|
return session_options.release(); }, py::return_value_policy::take_ownership)
|
|
.def_property_readonly("inputs_meta", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
|
|
auto res = sess->GetSessionHandle()->GetModelInputs();
|
|
OrtPybindThrowIfError(res.first);
|
|
return *(res.second); }, py::return_value_policy::reference_internal)
|
|
.def_property_readonly("outputs_meta", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
|
|
auto res = sess->GetSessionHandle()->GetModelOutputs();
|
|
OrtPybindThrowIfError(res.first);
|
|
return *(res.second); }, py::return_value_policy::reference_internal)
|
|
.def_property_readonly("overridable_initializers", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
|
|
auto res = sess->GetSessionHandle()->GetOverridableInitializers();
|
|
OrtPybindThrowIfError(res.first);
|
|
return *(res.second); }, py::return_value_policy::reference_internal)
|
|
.def_property_readonly("model_meta", [](const PyInferenceSession* sess) -> const onnxruntime::ModelMetadata& {
|
|
auto res = sess->GetSessionHandle()->GetModelMetadata();
|
|
OrtPybindThrowIfError(res.first);
|
|
return *(res.second); }, py::return_value_policy::reference_internal)
|
|
.def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void {
|
|
|
|
Status status;
|
|
|
|
if (run_options != nullptr && !run_options->active_adapters.empty()) {
|
|
LOGS(*sess->GetSessionHandle()->GetLogger(), WARNING)
|
|
<< "run_with_iobinding has active adapters specified, but won't have an effect";
|
|
}
|
|
|
|
// release GIL to allow multiple python threads to invoke Run() in parallel.
|
|
py::gil_scoped_release release;
|
|
if (!run_options)
|
|
status = sess->GetSessionHandle()->Run(*io_binding.Get());
|
|
else
|
|
status = sess->GetSessionHandle()->Run(*run_options, *io_binding.Get());
|
|
if (!status.IsOK())
|
|
throw std::runtime_error("Error in execution: " + status.ErrorMessage()); })
|
|
.def("get_tuning_results", [](PyInferenceSession* sess) -> py::list {
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
auto results = sess->GetSessionHandle()->GetTuningResults();
|
|
py::list ret;
|
|
for (const auto& trs : results) {
|
|
py::dict py_trs;
|
|
py_trs["ep"] = trs.ep;
|
|
py_trs["results"] = trs.results;
|
|
py_trs["validators"] = trs.validators;
|
|
ret.append(std::move(py_trs));
|
|
}
|
|
|
|
return ret;
|
|
#else
|
|
ORT_UNUSED_PARAMETER(sess);
|
|
ORT_THROW("TunableOp and get_tuning_results are not supported in this build.");
|
|
#endif
|
|
})
|
|
.def("set_tuning_results", [](PyInferenceSession* sess, py::list results, bool error_on_invalid) -> void {
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
std::vector<TuningResults> tuning_results;
|
|
for (auto handle : results) {
|
|
auto py_trs = handle.cast<py::dict>();
|
|
TuningResults trs;
|
|
trs.ep = py_trs["ep"].cast<py::str>();
|
|
|
|
for (const auto [py_op_sig, py_kernel_map] : py_trs["results"].cast<py::dict>()) {
|
|
KernelMap kernel_map;
|
|
for (const auto [py_params_sig, py_kernel_id] : py_kernel_map.cast<py::dict>()) {
|
|
kernel_map[py_params_sig.cast<py::str>()] = py_kernel_id.cast<py::int_>();
|
|
}
|
|
trs.results[py_op_sig.cast<py::str>()] = kernel_map;
|
|
}
|
|
|
|
for (const auto [k, v] : py_trs["validators"].cast<py::dict>()) {
|
|
trs.validators[k.cast<py::str>()] = v.cast<py::str>();
|
|
}
|
|
|
|
tuning_results.emplace_back(std::move(trs));
|
|
}
|
|
|
|
Status status = sess->GetSessionHandle()->SetTuningResults(tuning_results, error_on_invalid);
|
|
if (!status.IsOK()) {
|
|
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
|
|
}
|
|
#else
|
|
ORT_UNUSED_PARAMETER(sess);
|
|
ORT_UNUSED_PARAMETER(results);
|
|
ORT_UNUSED_PARAMETER(error_on_invalid);
|
|
ORT_THROW("TunableOp and set_tuning_results are not supported in this build.");
|
|
#endif
|
|
});
|
|
|
|
py::enum_<onnxruntime::ArenaExtendStrategy>(m, "ArenaExtendStrategy", py::arithmetic())
|
|
.value("kNextPowerOfTwo", onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo)
|
|
.value("kSameAsRequested", onnxruntime::ArenaExtendStrategy::kSameAsRequested)
|
|
.export_values();
|
|
}
|
|
|
|
bool CreateInferencePybindStateModule(py::module& m) {
|
|
m.doc() = "pybind11 stateful interface to ONNX runtime";
|
|
RegisterExceptions(m);
|
|
|
|
import_array1(false);
|
|
|
|
auto env = GetEnv();
|
|
|
|
addGlobalMethods(m);
|
|
addObjectMethods(m, RegisterExecutionProviders);
|
|
addOrtValueMethods(m);
|
|
addSparseTensorMethods(m);
|
|
addIoBindingMethods(m);
|
|
addAdapterFormatMethods(m);
|
|
|
|
#if !defined(__APPLE__) && !defined(ORT_MINIMAL_BUILD)
|
|
if (!InitProvidersSharedLibrary()) {
|
|
const logging::Logger& default_logger = logging::LoggingManager::DefaultLogger();
|
|
LOGS(default_logger, WARNING) << "Init provider bridge failed.";
|
|
}
|
|
#endif
|
|
|
|
addGlobalSchemaFunctions(m);
|
|
addOpSchemaSubmodule(m);
|
|
addOpKernelSubmodule(m);
|
|
return true;
|
|
}
|
|
|
|
// This function is only used by orttraining module
|
|
bool InitArray() {
|
|
import_array1(false);
|
|
return true;
|
|
}
|
|
|
|
namespace {
|
|
// This class provides a static shell for on-demand and thread-safe construction
|
|
// of Environment object for both Inference and Training python layers.
|
|
// Environment class contains objects such as default logger, that must be available
|
|
// for the entire duration of a program that makes use of onnxruntime library.
|
|
// Because Python is a garbage collected language and the order of destruction of objects
|
|
// is not guaranteed we design this class with the following important features.
|
|
|
|
// 1) we make this class a singleton that is a function local static. The function local statics
|
|
// are constructed when the function is called the very first time. This fact has several important
|
|
// properties.
|
|
// - First, it is constructed before it is first needed possibly by another static object
|
|
// and destroyed after that object is destroyed.
|
|
// - Second, it is constructed in a thread safe manner.
|
|
// - Last, this order of construction/destruction is enforced across the compilation units, as opposed
|
|
// to the static objects that are simply declared in order in a single unit, but their lifespan is
|
|
// unconnected to that of in other compilation units. This is achieved automatically by run-time
|
|
// by execution atexit() to build a chain.
|
|
// 2) We make Environment owned by a shared_ptr. This is done because python objects such as Inference and Training
|
|
// sessions depend on this global. We acquire a shared_ptr instance when those objects are instantiated
|
|
// and release it automatically when they are garbage collected. Although with this change all of the
|
|
// globals seem to have been destroyed after module is unloaded and GC runs before that, it is cheap and gives
|
|
// a piece of mind as there were situations when GC was still running in the past after Env was gone.
|
|
// TrainingEnv global also holds shared reference to this global.
|
|
// 3) We guard against singleton resurrection attempts to detect code runs that when it should
|
|
// not and make necessary adjustments.
|
|
// For all the related details and why it is needed see "Modern C++ design" by A. Alexandrescu Chapter 6.
|
|
class EnvInitializer {
|
|
public:
|
|
static std::shared_ptr<onnxruntime::Environment> SharedInstance() {
|
|
// Guard against attempts to resurrect the singleton
|
|
if (EnvInitializer::destroyed) {
|
|
ORT_THROW("Detected an attempt to resurrect destroyed Environment");
|
|
}
|
|
static EnvInitializer env_holder;
|
|
return env_holder.Get();
|
|
}
|
|
|
|
private:
|
|
EnvInitializer() {
|
|
std::unique_ptr<Environment> env_ptr;
|
|
Env::Default().GetTelemetryProvider().SetLanguageProjection(OrtLanguageProjection::ORT_PROJECTION_PYTHON);
|
|
OrtPybindThrowIfError(Environment::Create(std::make_unique<LoggingManager>(
|
|
std::make_unique<CLogSink>(),
|
|
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
|
|
&SessionObjectInitializer::default_logger_id),
|
|
env_ptr));
|
|
session_env_ = std::shared_ptr<Environment>(env_ptr.release());
|
|
destroyed = false;
|
|
}
|
|
|
|
~EnvInitializer() {
|
|
destroyed = true;
|
|
}
|
|
|
|
std::shared_ptr<Environment> Get() const {
|
|
return session_env_;
|
|
}
|
|
|
|
std::shared_ptr<Environment> session_env_;
|
|
|
|
static bool destroyed;
|
|
};
|
|
|
|
bool EnvInitializer::destroyed = false;
|
|
} // namespace
|
|
|
|
std::shared_ptr<onnxruntime::Environment> GetEnv() {
|
|
return EnvInitializer::SharedInstance();
|
|
}
|
|
|
|
} // namespace python
|
|
} // namespace onnxruntime
|