mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
Deprecate Python global configuration functions [Part 2] (#6171)
Update Python API to allow more flexibility for setting providers and provider options. The providers argument (InferenceSession/TrainingSession constructors, InferenceSession.set_providers()) now also accepts a tuple of (name, options dict). Fix get_available_providers() API (and the corresponding function in the C API) to return the providers in default priority order. Now it can be used as a starting point for the providers argument and maintain the default priority order. Convert some usages of the deprecated global configuration functions to use EP-specific options instead. Update some EP-specific option parsing to fail on unknown options. Other clean up.
This commit is contained in:
parent
bbc9ed908a
commit
d761571afc
32 changed files with 937 additions and 370 deletions
|
|
@ -180,7 +180,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
string[] providers = ortEnvInstance.GetAvailableProviders();
|
||||
|
||||
Assert.True(providers.Length > 0);
|
||||
Assert.Equal("CPUExecutionProvider", providers[0]);
|
||||
Assert.Equal("CPUExecutionProvider", providers[providers.Length - 1]);
|
||||
|
||||
# if USE_CUDA
|
||||
Assert.True(Array.Exists(providers, provider => provider == "CUDAExecutionProvider"););
|
||||
|
|
|
|||
|
|
@ -165,8 +165,8 @@
|
|||
"code_to_run = '''\n",
|
||||
"import onnxruntime\n",
|
||||
"s = 'codegen_dump_lower:verbose'\n",
|
||||
"onnxruntime.capi._pybind_state.set_nuphar_settings(s)\n",
|
||||
"sess = onnxruntime.InferenceSession('simple.onnx')\n",
|
||||
"providers = [('NupharExecutionProvider', {'nuphar_settings': s}), 'CPUExecutionProvider']\n",
|
||||
"sess = onnxruntime.InferenceSession('simple.onnx', providers=providers)\n",
|
||||
"'''\n",
|
||||
"\n",
|
||||
"log_file = 'simple_lower.log' \n",
|
||||
|
|
@ -992,8 +992,8 @@
|
|||
"bidaf_cache_dir = os.path.join(bidaf_dir, 'cache')\n",
|
||||
"create_cache_dir(bidaf_cache_dir)\n",
|
||||
"settings = 'nuphar_cache_path:{}'.format(bidaf_cache_dir)\n",
|
||||
"onnxruntime.capi._pybind_state.set_nuphar_settings(settings)\n",
|
||||
"sess = onnxruntime.InferenceSession(bidaf_converted)"
|
||||
"providers = [('NupharExecutionProvider', {'nuphar_settings': settings}), 'CPUExecutionProvider']\n",
|
||||
"sess = onnxruntime.InferenceSession(bidaf_converted, providers=providers)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -1049,10 +1049,9 @@
|
|||
],
|
||||
"source": [
|
||||
"start_aot = timer()\n",
|
||||
"# NOTE: Nuphar settings string is not sticky. It needs to be reset before creating InferenceSession\n",
|
||||
"settings = 'nuphar_cache_path:{}'.format(bidaf_cache_dir)\n",
|
||||
"onnxruntime.capi._pybind_state.set_nuphar_settings(settings)\n",
|
||||
"sess = onnxruntime.InferenceSession(bidaf_converted)\n",
|
||||
"providers = [('NupharExecutionProvider', {'nuphar_settings': settings}), 'CPUExecutionProvider']\n",
|
||||
"sess = onnxruntime.InferenceSession(bidaf_converted, providers=providers)\n",
|
||||
"end_aot = timer()\n",
|
||||
"print_speedup('AOT', end_jit - start_jit, end_aot - start_aot)"
|
||||
]
|
||||
|
|
@ -1096,8 +1095,8 @@
|
|||
"settings = 'nuphar_cache_path:{}'.format(cache_dir)\n",
|
||||
"for isa in ['avx512', 'avx2', 'avx']:\n",
|
||||
" settings_with_isa = settings + ', nuphar_codegen_target:' + isa\n",
|
||||
" onnxruntime.capi._pybind_state.set_nuphar_settings(settings_with_isa)\n",
|
||||
" sess = onnxruntime.InferenceSession(model_name)\n",
|
||||
" providers = [('NupharExecutionProvider', {'nuphar_settings': settings_with_isa}), 'CPUExecutionProvider']\n",
|
||||
" sess = onnxruntime.InferenceSession(model_name, providers=providers)\n",
|
||||
" cache_versioned_dir = os.path.join(cache_dir, os.listdir(cache_dir)[0])\n",
|
||||
"\n",
|
||||
"# link object files to AOT dll\n",
|
||||
|
|
@ -1106,15 +1105,15 @@
|
|||
"# now load the model with AOT dll\n",
|
||||
"# NOTE: when nuphar_codegen_target is not set, it defaults to current CPU ISA\n",
|
||||
"settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_model_checksum:{}, nuphar_cache_force_no_jit:on'.format(cache_dir, multi_isa_so, model_checksum)\n",
|
||||
"onnxruntime.capi._pybind_state.set_nuphar_settings(settings)\n",
|
||||
"sess = onnxruntime.InferenceSession(model_name)\n",
|
||||
"providers = [('NupharExecutionProvider', {'nuphar_settings': settings}), 'CPUExecutionProvider']\n",
|
||||
"sess = onnxruntime.InferenceSession(model_name, providers=providers)\n",
|
||||
"\n",
|
||||
"# force to a different ISA which is a subset of current CPU\n",
|
||||
"# NOTE: if an incompatible ISA is used, exception on invalid instructions would be thrown\n",
|
||||
"for valid_isa in ['avx2', 'avx']:\n",
|
||||
" settings_with_isa = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_model_checksum:{}, nuphar_codegen_target:{}, nuphar_cache_force_no_jit:on'.format(cache_dir, multi_isa_so, model_checksum, valid_isa)\n",
|
||||
" onnxruntime.capi._pybind_state.set_nuphar_settings(settings_with_isa)\n",
|
||||
" sess = onnxruntime.InferenceSession(model_name)\n",
|
||||
" providers = [('NupharExecutionProvider', {'nuphar_settings': settings_with_isa}), 'CPUExecutionProvider']\n",
|
||||
" sess = onnxruntime.InferenceSession(model_name, providers=providers)\n",
|
||||
"\n",
|
||||
" start_nuphar = timer()\n",
|
||||
" for i in range(repeats):\n",
|
||||
|
|
@ -1160,8 +1159,8 @@
|
|||
"# use NUPHAR_PARALLEL_MIN_WORKLOADS=0 to turn off parallel schedule, using settings string\n",
|
||||
"# it can be set from environment variable too: os.environ['NUPHAR_PARALLEL_MIN_WORKLOADS'] = '0'\n",
|
||||
"settings = 'nuphar_parallel_min_workloads:0'\n",
|
||||
"onnxruntime.capi._pybind_state.set_nuphar_settings(settings)\n",
|
||||
"sess = onnxruntime.InferenceSession(bidaf_converted)\n",
|
||||
"providers = [('NupharExecutionProvider', {'nuphar_settings': settings}), 'CPUExecutionProvider']\n",
|
||||
"sess = onnxruntime.InferenceSession(bidaf_converted, providers=providers)\n",
|
||||
"\n",
|
||||
"start = timer()\n",
|
||||
"for i in range(bidaf_repeats):\n",
|
||||
|
|
@ -1200,4 +1199,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
|
|
@ -33,9 +33,9 @@
|
|||
|
||||
#include "core/common/code_location.h"
|
||||
#include "core/common/exceptions.h"
|
||||
#include "core/common/make_string.h"
|
||||
#include "core/common/make_unique.h"
|
||||
#include "core/common/status.h"
|
||||
#include "core/common/string_utils.h"
|
||||
|
||||
#ifdef USE_MIMALLOC_ARENA_ALLOCATOR
|
||||
#include <mimalloc.h>
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@
|
|||
|
||||
#include <locale>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -74,41 +73,4 @@ inline std::string MakeString(const char* cstr) {
|
|||
return cstr;
|
||||
}
|
||||
|
||||
/**
|
||||
* Tries to parse a value from an entire string.
|
||||
*/
|
||||
template <typename T>
|
||||
bool TryParse(const std::string& str, T& value) {
|
||||
if (std::is_integral<T>::value && std::is_unsigned<T>::value) {
|
||||
// if T is unsigned integral type, reject negative values which will wrap
|
||||
if (!str.empty() && str[0] == '-') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// don't allow leading whitespace
|
||||
if (!str.empty() && std::isspace(str[0], std::locale::classic())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::istringstream is{str};
|
||||
is.imbue(std::locale::classic());
|
||||
T parsed_value{};
|
||||
|
||||
const bool parse_successful =
|
||||
is >> parsed_value &&
|
||||
is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters
|
||||
if (!parse_successful) {
|
||||
return false;
|
||||
}
|
||||
|
||||
value = std::move(parsed_value);
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool TryParse(const std::string& str, std::string& value) {
|
||||
value = str;
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
84
include/onnxruntime/core/common/parse_string.h
Normal file
84
include/onnxruntime/core/common/parse_string.h
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <locale>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
|
||||
#include "core/common/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
* Tries to parse a value from an entire string.
|
||||
*/
|
||||
template <typename T>
|
||||
bool TryParseString(const std::string& str, T& value) {
|
||||
if (std::is_integral<T>::value && std::is_unsigned<T>::value) {
|
||||
// if T is unsigned integral type, reject negative values which will wrap
|
||||
if (!str.empty() && str[0] == '-') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// don't allow leading whitespace
|
||||
if (!str.empty() && std::isspace(str[0], std::locale::classic())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::istringstream is{str};
|
||||
is.imbue(std::locale::classic());
|
||||
T parsed_value{};
|
||||
|
||||
const bool parse_successful =
|
||||
is >> parsed_value &&
|
||||
is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters
|
||||
if (!parse_successful) {
|
||||
return false;
|
||||
}
|
||||
|
||||
value = std::move(parsed_value);
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool TryParseString(const std::string& str, std::string& value) {
|
||||
value = str;
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool TryParseString(const std::string& str, bool& value) {
|
||||
if (str == "0" || str == "False" || str == "false") {
|
||||
value = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (str == "1" || str == "True" || str == "true") {
|
||||
value = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses a value from an entire string.
|
||||
*/
|
||||
template <typename T>
|
||||
Status ParseString(const std::string& s, T& value) {
|
||||
ORT_RETURN_IF_NOT(TryParseString(s, value), "Failed to parse value: \"", value, "\"");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses a value from an entire string.
|
||||
*/
|
||||
template <typename T>
|
||||
T ParseString(const std::string& s) {
|
||||
T value{};
|
||||
ORT_THROW_IF_ERROR(ParseString(s, value));
|
||||
return value;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -3,37 +3,163 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/parse_string.h"
|
||||
#include "core/framework/provider_options.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
template <typename TEnum>
|
||||
using EnumNameMapping = std::vector<std::pair<TEnum, std::string>>;
|
||||
|
||||
/**
|
||||
* Reads the named provider option.
|
||||
* Returns true if the option is present, false otherwise.
|
||||
* Given a mapping and an enumeration value, gets the corresponding name.
|
||||
*/
|
||||
template <typename T>
|
||||
bool ReadProviderOption(const ProviderOptions& options, const std::string& key, T& value) {
|
||||
auto it = options.find(key);
|
||||
if (it != options.end()) {
|
||||
ORT_ENFORCE(
|
||||
TryParse(it->second, value),
|
||||
"Failed to parse provider option \"", key, "\" with value \"", it->second, "\".");
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
template <typename TEnum>
|
||||
Status EnumToName(const EnumNameMapping<TEnum>& mapping, TEnum value, std::string& name) {
|
||||
const auto it = std::find_if(
|
||||
mapping.begin(), mapping.end(),
|
||||
[&value](const std::pair<TEnum, std::string>& entry) {
|
||||
return entry.first == value;
|
||||
});
|
||||
ORT_RETURN_IF(
|
||||
it == mapping.end(),
|
||||
"Failed to map enum value to name: ", static_cast<typename std::underlying_type<TEnum>::type>(value));
|
||||
name = it->second;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename TEnum>
|
||||
std::string EnumToName(const EnumNameMapping<TEnum>& mapping, TEnum value) {
|
||||
std::string name;
|
||||
ORT_THROW_IF_ERROR(EnumToName(mapping, value, name));
|
||||
return name;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads the named provider option.
|
||||
* Returns the value if the option is present or the specified default value otherwise.
|
||||
* Given a mapping and a name, gets the corresponding enumeration value.
|
||||
*/
|
||||
template <typename T>
|
||||
T ReadProviderOptionOrDefault(
|
||||
const ProviderOptions& options, const std::string& key, const T& default_value) {
|
||||
T value{};
|
||||
if (ReadProviderOption(options, key, value)) {
|
||||
return value;
|
||||
}
|
||||
return default_value;
|
||||
template <typename TEnum>
|
||||
Status NameToEnum(
|
||||
const EnumNameMapping<TEnum>& mapping, const std::string& name, TEnum& value) {
|
||||
const auto it = std::find_if(
|
||||
mapping.begin(), mapping.end(),
|
||||
[&name](const std::pair<TEnum, std::string>& entry) {
|
||||
return entry.second == name;
|
||||
});
|
||||
ORT_RETURN_IF(
|
||||
it == mapping.end(),
|
||||
"Failed to map enum name to value: ", name);
|
||||
value = it->first;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename TEnum>
|
||||
TEnum NameToEnum(const EnumNameMapping<TEnum>& mapping, const std::string& name) {
|
||||
TEnum value;
|
||||
ORT_THROW_IF_ERROR(NameToEnum(mapping, name, value));
|
||||
return value;
|
||||
}
|
||||
|
||||
class ProviderOptionsParser {
|
||||
public:
|
||||
/**
|
||||
* Adds a parser for a particular provider option value.
|
||||
*
|
||||
* @param name The provider option name.
|
||||
* @param value_parser An object that parses the option value.
|
||||
* It should be callable with the following signature and return
|
||||
* whether the parsing was successful:
|
||||
* Status value_parser(const std::string&)
|
||||
*
|
||||
* @return The current ProviderOptionsParser instance.
|
||||
*/
|
||||
template <typename ValueParserType>
|
||||
ProviderOptionsParser& AddValueParser(
|
||||
const std::string& name, ValueParserType value_parser) {
|
||||
ORT_ENFORCE(
|
||||
value_parsers_.emplace(name, ValueParser{value_parser}).second,
|
||||
"Provider option \"", name, "\" already has a value parser.");
|
||||
return *this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a parser for a particular provider option value which converts a
|
||||
* value to the right type and assigns it to the given reference.
|
||||
*
|
||||
* IMPORTANT: This function stores a reference to the destination variable.
|
||||
* The caller must ensure that the reference is valid when Parse() is called!
|
||||
*
|
||||
* @param name The provider option name.
|
||||
* @param dest The destination variable reference.
|
||||
*
|
||||
* @return The current ProviderOptionsParser instance.
|
||||
*/
|
||||
template <typename ValueType>
|
||||
ProviderOptionsParser& AddAssignmentToReference(
|
||||
const std::string& name, ValueType& dest) {
|
||||
return AddValueParser(
|
||||
name,
|
||||
[&dest](const std::string& value_str) -> Status {
|
||||
return ParseString(value_str, dest);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a parser for a particular provider option value which maps an
|
||||
* enumeration name to a value and assigns it to the given reference.
|
||||
*
|
||||
* IMPORTANT: This function stores references to the mapping and destination
|
||||
* variables. The caller must ensure that the references are valid when
|
||||
* Parse() is called!
|
||||
*
|
||||
* @param name The provider option name.
|
||||
* @param mapping The enumeration value to name mapping.
|
||||
* @param dest The destination variable reference.
|
||||
*
|
||||
* @return The current ProviderOptionsParser instance.
|
||||
*/
|
||||
template <typename EnumType>
|
||||
ProviderOptionsParser& AddAssignmentToEnumReference(
|
||||
const std::string& name, const EnumNameMapping<EnumType>& mapping, EnumType& dest) {
|
||||
return AddValueParser(
|
||||
name,
|
||||
[&mapping, &dest](const std::string& value_str) -> Status {
|
||||
return NameToEnum(mapping, value_str, dest);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses the given provider options.
|
||||
*/
|
||||
Status Parse(const ProviderOptions& options) const {
|
||||
for (const auto& option : options) {
|
||||
const auto& name = option.first;
|
||||
const auto& value_str = option.second;
|
||||
|
||||
const auto value_parser_it = value_parsers_.find(name);
|
||||
ORT_RETURN_IF(
|
||||
value_parser_it == value_parsers_.end(),
|
||||
"Unknown provider option: \"", name, "\".");
|
||||
|
||||
const auto parse_status = value_parser_it->second(value_str);
|
||||
ORT_RETURN_IF_NOT(
|
||||
parse_status.IsOK(),
|
||||
"Failed to parse provider option \"", name, "\": ", parse_status.ErrorMessage());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
using ValueParser = std::function<Status(const std::string&)>;
|
||||
std::unordered_map<std::string, ValueParser> value_parsers_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -3,11 +3,8 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "core/common/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
constexpr const char* kNoOp = "NoOp";
|
||||
constexpr const char* kConstant = "Constant";
|
||||
constexpr const char* kFunctionOp = "_kFunctionOp";
|
||||
|
|
@ -22,10 +19,10 @@ constexpr const char* kMSDmlDomain = "com.microsoft.dml";
|
|||
constexpr const char* kNGraphDomain = "com.intel.ai";
|
||||
constexpr const char* kMIGraphXDomain = "";
|
||||
constexpr const char* kVitisAIDomain = "com.xilinx";
|
||||
|
||||
constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider";
|
||||
constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider";
|
||||
constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider";
|
||||
constexpr const char* kNGraphExecutionProvider = "NGRAPHExecutionProvider";
|
||||
constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider";
|
||||
constexpr const char* kNupharExecutionProvider = "NupharExecutionProvider";
|
||||
constexpr const char* kVitisAIExecutionProvider = "VitisAIExecutionProvider";
|
||||
|
|
@ -38,50 +35,4 @@ constexpr const char* kAclExecutionProvider = "ACLExecutionProvider";
|
|||
constexpr const char* kArmNNExecutionProvider = "ArmNNExecutionProvider";
|
||||
constexpr const char* kRocmExecutionProvider = "ROCMExecutionProvider";
|
||||
|
||||
constexpr const char* providers_available[] = {
|
||||
kCpuExecutionProvider,
|
||||
#ifdef USE_CUDA
|
||||
kCudaExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_DNNL
|
||||
kDnnlExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_NGRAPH
|
||||
kNGraphExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_OPENVINO
|
||||
kOpenVINOExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_NUPHAR
|
||||
kNupharExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_VITISAI
|
||||
kVitisAIExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_TENSORRT
|
||||
kTensorrtExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_NNAPI
|
||||
kNnapiExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_RKNPU
|
||||
kRknpuExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_DML
|
||||
kDmlExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_MIGRAPHX
|
||||
kMIGraphXExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_ACL
|
||||
kAclExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_ARMNN
|
||||
kArmNNExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_ROCM
|
||||
kRocmExecutionProvider,
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -26,11 +26,6 @@ from onnxruntime.capi._pybind_state import get_all_providers, get_available_prov
|
|||
NodeArg, ModelMetadata, GraphOptimizationLevel, ExecutionMode, ExecutionOrder, OrtDevice, SessionIOBinding, \
|
||||
OrtAllocatorType, OrtMemType, OrtArenaCfg, OrtMemoryInfo, create_and_register_allocator
|
||||
|
||||
try:
|
||||
from onnxruntime.capi._pybind_state import set_cuda_mem_limit, set_cuda_device_id
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from onnxruntime.capi.onnxruntime_inference_collection import InferenceSession, IOBinding, OrtValue
|
||||
from onnxruntime.capi import onnxruntime_validation
|
||||
|
||||
|
|
|
|||
|
|
@ -3,8 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <cstdint>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -13,33 +12,4 @@ enum class ArenaExtendStrategy : int32_t {
|
|||
kSameAsRequested,
|
||||
};
|
||||
|
||||
inline std::istream& operator>>(std::istream& is, ArenaExtendStrategy& value) {
|
||||
std::string value_str;
|
||||
if (is >> value_str) {
|
||||
if (value_str == "kNextPowerOfTwo") {
|
||||
value = ArenaExtendStrategy::kNextPowerOfTwo;
|
||||
} else if (value_str == "kSameAsRequested") {
|
||||
value = ArenaExtendStrategy::kSameAsRequested;
|
||||
} else {
|
||||
is.setstate(std::ios_base::failbit);
|
||||
}
|
||||
}
|
||||
return is;
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ArenaExtendStrategy value) {
|
||||
switch (value) {
|
||||
case ArenaExtendStrategy::kNextPowerOfTwo:
|
||||
os << "kNextPowerOfTwo";
|
||||
break;
|
||||
case ArenaExtendStrategy::kSameAsRequested:
|
||||
os << "kSameAsRequested";
|
||||
break;
|
||||
default:
|
||||
os << "unknown";
|
||||
break;
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/optional.h"
|
||||
#include "core/common/parse_string.h"
|
||||
#include "core/platform/env.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -20,7 +21,7 @@ optional<T> ParseEnvironmentVariable(const std::string& name) {
|
|||
|
||||
T parsed_value;
|
||||
ORT_ENFORCE(
|
||||
TryParse(value_str, parsed_value),
|
||||
TryParseString(value_str, parsed_value),
|
||||
"Failed to parse environment variable - name: \"", name, "\", value: \"", value_str, "\"");
|
||||
|
||||
return parsed_value;
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
|
||||
int GetDeviceId() const override { return info_.device_id; }
|
||||
const cudaDeviceProp& GetDeviceProp() const { return device_prop_; };
|
||||
int GetCudnnConvAlgo() const { return info_.cudnn_conv_algo; }
|
||||
int GetCudnnConvAlgo() const { return info_.cudnn_conv_algo_search; }
|
||||
|
||||
ProviderOptions GetProviderOptions() const override {
|
||||
return CUDAExecutionProviderInfo::ToProviderOptions(info_);
|
||||
|
|
|
|||
|
|
@ -3,58 +3,76 @@
|
|||
|
||||
#include "core/providers/cuda/cuda_execution_provider_info.h"
|
||||
|
||||
#include "core/common/string_utils.h"
|
||||
#include "core/common/make_string.h"
|
||||
#include "core/common/parse_string.h"
|
||||
#include "core/framework/provider_options_utils.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
namespace provider_option_names {
|
||||
constexpr const char* kDeviceId = "device_id";
|
||||
constexpr const char* kMemLimit = "cuda_mem_limit";
|
||||
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
|
||||
constexpr const char* kCudnnConvAlgo = "cudnn_conv_algo";
|
||||
constexpr const char* kCudnnConvAlgoSearch = "cudnn_conv_algo_search";
|
||||
constexpr const char* kDoCopyInDefaultStream = "do_copy_in_default_stream";
|
||||
} // namespace provider_option_names
|
||||
} // namespace cuda
|
||||
|
||||
namespace {
|
||||
const EnumNameMapping<OrtCudnnConvAlgoSearch> ort_cudnn_conv_algo_search_mapping{
|
||||
{OrtCudnnConvAlgoSearch::EXHAUSTIVE, "EXHAUSTIVE"},
|
||||
{OrtCudnnConvAlgoSearch::HEURISTIC, "HEURISTIC"},
|
||||
{OrtCudnnConvAlgoSearch::DEFAULT, "DEFAULT"},
|
||||
};
|
||||
|
||||
const EnumNameMapping<ArenaExtendStrategy> arena_extend_strategy_mapping{
|
||||
{ArenaExtendStrategy::kNextPowerOfTwo, "kNextPowerOfTwo"},
|
||||
{ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"},
|
||||
};
|
||||
} // namespace
|
||||
|
||||
CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) {
|
||||
CUDAExecutionProviderInfo info{};
|
||||
|
||||
if (ReadProviderOption(options, provider_option_names::kDeviceId, info.device_id)) {
|
||||
int num_devices = 0;
|
||||
CUDA_CALL_THROW(cudaGetDeviceCount(&num_devices));
|
||||
ORT_ENFORCE(
|
||||
0 <= info.device_id && info.device_id < num_devices,
|
||||
"Invalid ", provider_option_names::kDeviceId, " value: ", info.device_id,
|
||||
", must be between 0 (inclusive) and ", num_devices, " (exclusive).");
|
||||
}
|
||||
ReadProviderOption(options, provider_option_names::kMemLimit, info.cuda_mem_limit);
|
||||
ReadProviderOption(options, provider_option_names::kArenaExtendStrategy, info.arena_extend_strategy);
|
||||
{
|
||||
int cudnn_conv_algo_val;
|
||||
if (ReadProviderOption(options, provider_option_names::kCudnnConvAlgo, cudnn_conv_algo_val)) {
|
||||
switch (cudnn_conv_algo_val) {
|
||||
case OrtCudnnConvAlgoSearch::EXHAUSTIVE:
|
||||
case OrtCudnnConvAlgoSearch::HEURISTIC:
|
||||
case OrtCudnnConvAlgoSearch::DEFAULT:
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("Invalid ", provider_option_names::kCudnnConvAlgo, " value: ", cudnn_conv_algo_val);
|
||||
}
|
||||
info.cudnn_conv_algo = static_cast<OrtCudnnConvAlgoSearch>(cudnn_conv_algo_val);
|
||||
}
|
||||
}
|
||||
ReadProviderOption(options, provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream);
|
||||
ORT_THROW_IF_ERROR(
|
||||
ProviderOptionsParser{}
|
||||
.AddValueParser(
|
||||
cuda::provider_option_names::kDeviceId,
|
||||
[&info](const std::string& value_str) -> Status {
|
||||
ORT_RETURN_IF_ERROR(ParseString(value_str, info.device_id));
|
||||
int num_devices{};
|
||||
ORT_RETURN_IF_NOT(
|
||||
CUDA_CALL(cudaGetDeviceCount(&num_devices)),
|
||||
"cudaGetDeviceCount() failed.");
|
||||
ORT_RETURN_IF_NOT(
|
||||
0 <= info.device_id && info.device_id < num_devices,
|
||||
"Invalid device ID: ", info.device_id,
|
||||
", must be between 0 (inclusive) and ", num_devices, " (exclusive).");
|
||||
return Status::OK();
|
||||
})
|
||||
.AddAssignmentToReference(cuda::provider_option_names::kMemLimit, info.cuda_mem_limit)
|
||||
.AddAssignmentToEnumReference(
|
||||
cuda::provider_option_names::kArenaExtendStrategy,
|
||||
arena_extend_strategy_mapping, info.arena_extend_strategy)
|
||||
.AddAssignmentToEnumReference(
|
||||
cuda::provider_option_names::kCudnnConvAlgoSearch,
|
||||
ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search)
|
||||
.AddAssignmentToReference(cuda::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream)
|
||||
.Parse(options));
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecutionProviderInfo& info) {
|
||||
const ProviderOptions options{
|
||||
{provider_option_names::kDeviceId, MakeString(info.device_id)},
|
||||
{provider_option_names::kMemLimit, MakeString(info.cuda_mem_limit)},
|
||||
{provider_option_names::kArenaExtendStrategy, MakeString(info.arena_extend_strategy)},
|
||||
{provider_option_names::kCudnnConvAlgo, MakeString(static_cast<int>(info.cudnn_conv_algo))},
|
||||
{provider_option_names::kDoCopyInDefaultStream, MakeString(info.do_copy_in_default_stream)},
|
||||
{cuda::provider_option_names::kDeviceId, MakeString(info.device_id)},
|
||||
{cuda::provider_option_names::kMemLimit, MakeString(info.cuda_mem_limit)},
|
||||
{cuda::provider_option_names::kArenaExtendStrategy,
|
||||
EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
|
||||
{cuda::provider_option_names::kCudnnConvAlgoSearch,
|
||||
EnumToName(ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search)},
|
||||
{cuda::provider_option_names::kDoCopyInDefaultStream, MakeString(info.do_copy_in_default_stream)},
|
||||
};
|
||||
|
||||
return options;
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ struct CUDAExecutionProviderInfo {
|
|||
OrtDevice::DeviceId device_id{0};
|
||||
size_t cuda_mem_limit{std::numeric_limits<size_t>::max()};
|
||||
ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo};
|
||||
OrtCudnnConvAlgoSearch cudnn_conv_algo{OrtCudnnConvAlgoSearch::EXHAUSTIVE};
|
||||
OrtCudnnConvAlgoSearch cudnn_conv_algo_search{OrtCudnnConvAlgoSearch::EXHAUSTIVE};
|
||||
bool do_copy_in_default_stream{true};
|
||||
|
||||
static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_CUDA,
|
|||
info.device_id = gsl::narrow<OrtDevice::DeviceId>(cuda_options->device_id);
|
||||
info.cuda_mem_limit = cuda_options->cuda_mem_limit;
|
||||
info.arena_extend_strategy = static_cast<onnxruntime::ArenaExtendStrategy>(cuda_options->arena_extend_strategy);
|
||||
info.cudnn_conv_algo = cuda_options->cudnn_conv_algo_search;
|
||||
info.cudnn_conv_algo_search = cuda_options->cudnn_conv_algo_search;
|
||||
info.do_copy_in_default_stream = cuda_options->do_copy_in_default_stream;
|
||||
|
||||
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_CUDA(info));
|
||||
|
|
|
|||
154
onnxruntime/core/providers/get_execution_providers.cc
Normal file
154
onnxruntime/core/providers/get_execution_providers.cc
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/get_execution_providers.h"
|
||||
|
||||
#include "core/graph/constants.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace {
|
||||
struct ProviderInfo {
|
||||
const char* name;
|
||||
bool available;
|
||||
};
|
||||
|
||||
// all providers ordered by default priority from highest to lowest
|
||||
// kCpuExecutionProvider should always be last
|
||||
constexpr ProviderInfo kProvidersInPriorityOrder[] =
|
||||
{
|
||||
{
|
||||
kTensorrtExecutionProvider,
|
||||
#ifdef USE_TENSORRT
|
||||
true,
|
||||
#else
|
||||
false,
|
||||
#endif
|
||||
},
|
||||
{
|
||||
kCudaExecutionProvider,
|
||||
#ifdef USE_CUDA
|
||||
true,
|
||||
#else
|
||||
false,
|
||||
#endif
|
||||
},
|
||||
{
|
||||
kMIGraphXExecutionProvider,
|
||||
#ifdef USE_MIGRAPHX
|
||||
true,
|
||||
#else
|
||||
false,
|
||||
#endif
|
||||
},
|
||||
{
|
||||
kRocmExecutionProvider,
|
||||
#ifdef USE_ROCM
|
||||
true,
|
||||
#else
|
||||
false,
|
||||
#endif
|
||||
},
|
||||
{
|
||||
kOpenVINOExecutionProvider,
|
||||
#ifdef USE_OPENVINO
|
||||
true,
|
||||
#else
|
||||
false,
|
||||
#endif
|
||||
},
|
||||
{
|
||||
kDnnlExecutionProvider,
|
||||
#ifdef USE_DNNL
|
||||
true,
|
||||
#else
|
||||
false,
|
||||
#endif
|
||||
},
|
||||
{
|
||||
kNupharExecutionProvider,
|
||||
#ifdef USE_NUPHAR
|
||||
true,
|
||||
#else
|
||||
false,
|
||||
#endif
|
||||
},
|
||||
{
|
||||
kVitisAIExecutionProvider,
|
||||
#ifdef USE_VITISAI
|
||||
true,
|
||||
#else
|
||||
false,
|
||||
#endif
|
||||
},
|
||||
{
|
||||
kNnapiExecutionProvider,
|
||||
#ifdef USE_NNAPI
|
||||
true,
|
||||
#else
|
||||
false,
|
||||
#endif
|
||||
},
|
||||
{
|
||||
kArmNNExecutionProvider,
|
||||
#ifdef USE_ARMNN
|
||||
true,
|
||||
#else
|
||||
false,
|
||||
#endif
|
||||
},
|
||||
{
|
||||
kAclExecutionProvider,
|
||||
#ifdef USE_ACL
|
||||
true,
|
||||
#else
|
||||
false,
|
||||
#endif
|
||||
},
|
||||
{
|
||||
kDmlExecutionProvider,
|
||||
#ifdef USE_DML
|
||||
true,
|
||||
#else
|
||||
false,
|
||||
#endif
|
||||
},
|
||||
{
|
||||
kRknpuExecutionProvider,
|
||||
#ifdef USE_RKNPU
|
||||
true,
|
||||
#else
|
||||
false,
|
||||
#endif
|
||||
},
|
||||
{kCpuExecutionProvider, true}, // kCpuExecutionProvider is always last
|
||||
};
|
||||
} // namespace
|
||||
|
||||
const std::vector<std::string>& GetAllExecutionProviderNames() {
|
||||
static const auto all_execution_providers = []() {
|
||||
std::vector<std::string> result{};
|
||||
for (const auto& provider : kProvidersInPriorityOrder) {
|
||||
result.push_back(provider.name);
|
||||
}
|
||||
return result;
|
||||
}();
|
||||
|
||||
return all_execution_providers;
|
||||
}
|
||||
|
||||
const std::vector<std::string>& GetAvailableExecutionProviderNames() {
|
||||
static const auto available_execution_providers = []() {
|
||||
std::vector<std::string> result{};
|
||||
for (const auto& provider : kProvidersInPriorityOrder) {
|
||||
if (provider.available) {
|
||||
result.push_back(provider.name);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}();
|
||||
|
||||
return available_execution_providers;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
23
onnxruntime/core/providers/get_execution_providers.h
Normal file
23
onnxruntime/core/providers/get_execution_providers.h
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
* Gets the names of all execution providers, in order of decreasing default
|
||||
* priority.
|
||||
*/
|
||||
const std::vector<std::string>& GetAllExecutionProviderNames();
|
||||
|
||||
/**
|
||||
* Gets the names of execution providers available in this build, in order of
|
||||
* decreasing default priority.
|
||||
*/
|
||||
const std::vector<std::string>& GetAvailableExecutionProviderNames();
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -3,32 +3,47 @@
|
|||
|
||||
#include "core/providers/rocm/rocm_execution_provider_info.h"
|
||||
|
||||
#include "core/common/string_utils.h"
|
||||
#include "core/common/make_string.h"
|
||||
#include "core/framework/provider_options_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
namespace provider_option_names {
|
||||
constexpr const char* kDeviceId = "device_id";
|
||||
constexpr const char* kMemLimit = "hip_mem_limit";
|
||||
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
|
||||
} // namespace provider_option_names
|
||||
} // namespace rocm
|
||||
|
||||
namespace {
|
||||
const EnumNameMapping<ArenaExtendStrategy> arena_extend_strategy_mapping{
|
||||
{ArenaExtendStrategy::kNextPowerOfTwo, "kNextPowerOfTwo"},
|
||||
{ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"},
|
||||
};
|
||||
} // namespace
|
||||
|
||||
ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) {
|
||||
ROCMExecutionProviderInfo info{};
|
||||
|
||||
// TODO validate info.device_id
|
||||
ReadProviderOption(options, provider_option_names::kDeviceId, info.device_id);
|
||||
ReadProviderOption(options, provider_option_names::kMemLimit, info.hip_mem_limit);
|
||||
ReadProviderOption(options, provider_option_names::kArenaExtendStrategy, info.arena_extend_strategy);
|
||||
ORT_THROW_IF_ERROR(
|
||||
ProviderOptionsParser{}
|
||||
// TODO validate info.device_id
|
||||
.AddAssignmentToReference(rocm::provider_option_names::kDeviceId, info.device_id)
|
||||
.AddAssignmentToReference(rocm::provider_option_names::kMemLimit, info.hip_mem_limit)
|
||||
.AddAssignmentToEnumReference(
|
||||
rocm::provider_option_names::kArenaExtendStrategy,
|
||||
arena_extend_strategy_mapping, info.arena_extend_strategy)
|
||||
.Parse(options));
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecutionProviderInfo& info) {
|
||||
const ProviderOptions options{
|
||||
{provider_option_names::kDeviceId, MakeString(info.device_id)},
|
||||
{provider_option_names::kMemLimit, MakeString(info.hip_mem_limit)},
|
||||
{provider_option_names::kArenaExtendStrategy, MakeString(info.arena_extend_strategy)},
|
||||
{rocm::provider_option_names::kDeviceId, MakeString(info.device_id)},
|
||||
{rocm::provider_option_names::kMemLimit, MakeString(info.hip_mem_limit)},
|
||||
{rocm::provider_option_names::kArenaExtendStrategy,
|
||||
EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
|
||||
};
|
||||
|
||||
return options;
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/framework/ml_value.h"
|
||||
#include "core/providers/get_execution_providers.h"
|
||||
#include "core/session/environment.h"
|
||||
#include "core/framework/callback.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
|
|
@ -1713,19 +1714,20 @@ ORT_API_STATUS_IMPL(OrtApis::GetAvailableProviders, _Outptr_ char*** out_ptr,
|
|||
_In_ int* providers_length) {
|
||||
API_IMPL_BEGIN
|
||||
const size_t MAX_LEN = 30;
|
||||
int available_count = (int)(sizeof(providers_available) / sizeof(char*));
|
||||
char** out = (char**)malloc(available_count * sizeof(char*));
|
||||
const auto& available_providers = GetAvailableExecutionProviderNames();
|
||||
const int available_count = gsl::narrow<int>(available_providers.size());
|
||||
char** const out = (char**)malloc(available_count * sizeof(char*));
|
||||
if (out) {
|
||||
for (int i = 0; i < available_count; i++) {
|
||||
out[i] = (char*)malloc((MAX_LEN + 1) * sizeof(char));
|
||||
if (out[i]) {
|
||||
#ifdef _MSC_VER
|
||||
strncpy_s(out[i], MAX_LEN, providers_available[i], MAX_LEN);
|
||||
strncpy_s(out[i], MAX_LEN, available_providers[i].c_str(), MAX_LEN);
|
||||
out[i][MAX_LEN] = '\0';
|
||||
#elif defined(__APPLE__)
|
||||
strlcpy(out[i], providers_available[i], MAX_LEN);
|
||||
strlcpy(out[i], available_providers[i].c_str(), MAX_LEN);
|
||||
#else
|
||||
strncpy(out[i], providers_available[i], MAX_LEN);
|
||||
strncpy(out[i], available_providers[i].c_str(), MAX_LEN);
|
||||
out[i][MAX_LEN] = '\0';
|
||||
#endif
|
||||
}
|
||||
|
|
@ -1734,7 +1736,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetAvailableProviders, _Outptr_ char*** out_ptr,
|
|||
*providers_length = available_count;
|
||||
*out_ptr = out;
|
||||
API_IMPL_END
|
||||
return NULL;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::ReleaseAvailableProviders, _In_ char** ptr,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,10 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import collections
|
||||
import collections.abc
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
|
||||
|
|
@ -17,6 +20,78 @@ def get_ort_device_type(device):
|
|||
raise Exception('Unsupported device type: ' + device)
|
||||
|
||||
|
||||
def check_and_normalize_provider_args(providers, provider_options, available_provider_names):
|
||||
"""
|
||||
Validates the 'providers' and 'provider_options' arguments and returns a
|
||||
normalized version.
|
||||
|
||||
:param providers: Optional sequence of providers in order of decreasing
|
||||
precedence. Values can either be provider names or tuples of
|
||||
(provider name, options dict).
|
||||
:param provider_options: Optional sequence of options dicts corresponding
|
||||
to the providers listed in 'providers'.
|
||||
:param available_provider_names: The available provider names.
|
||||
|
||||
:return: Tuple of (normalized 'providers' sequence, normalized
|
||||
'provider_options' sequence).
|
||||
|
||||
'providers' can contain either names or names and options. When any options
|
||||
are given in 'providers', 'provider_options' should not be used.
|
||||
|
||||
The normalized result is a tuple of:
|
||||
1. Sequence of provider names in the same order as 'providers'.
|
||||
2. Sequence of corresponding provider options dicts with string keys and
|
||||
values. Unspecified provider options yield empty dicts.
|
||||
"""
|
||||
if providers is None:
|
||||
return [], []
|
||||
|
||||
provider_name_to_options = collections.OrderedDict()
|
||||
|
||||
def set_provider_options(name, options):
|
||||
if name not in available_provider_names:
|
||||
raise ValueError("Specified provider '{}' is unavailable. Available providers: '{}'".format(
|
||||
name, ", ".join(available_provider_names)))
|
||||
|
||||
if name in provider_name_to_options:
|
||||
warnings.warn("Duplicate provider '{}' encountered, ignoring.".format(name))
|
||||
return
|
||||
|
||||
normalized_options = {str(key): str(value) for key, value in options.items()}
|
||||
provider_name_to_options[name] = normalized_options
|
||||
|
||||
if not isinstance(providers, collections.abc.Sequence):
|
||||
raise ValueError("'providers' should be a sequence.")
|
||||
|
||||
if provider_options is not None:
|
||||
if not isinstance(provider_options, collections.abc.Sequence):
|
||||
raise ValueError("'provider_options' should be a sequence.")
|
||||
|
||||
if len(providers) != len(provider_options):
|
||||
raise ValueError("'providers' and 'provider_options' should be the same length if both are given.")
|
||||
|
||||
if not all([isinstance(provider, str) for provider in providers]):
|
||||
raise ValueError("Only string values for 'providers' are supported if 'provider_options' is given.")
|
||||
|
||||
if not all([isinstance(options_for_provider, dict) for options_for_provider in provider_options]):
|
||||
raise ValueError("'provider_options' values must be dicts.")
|
||||
|
||||
for name, options in zip(providers, provider_options):
|
||||
set_provider_options(name, options)
|
||||
|
||||
else:
|
||||
for provider in providers:
|
||||
if isinstance(provider, str):
|
||||
set_provider_options(provider, dict())
|
||||
elif isinstance(provider, tuple) and len(provider) == 2 and \
|
||||
isinstance(provider[0], str) and isinstance(provider[1], dict):
|
||||
set_provider_options(provider[0], provider[1])
|
||||
else:
|
||||
raise ValueError("'providers' values must be either strings or (string, dict) tuples.")
|
||||
|
||||
return list(provider_name_to_options.keys()), list(provider_name_to_options.values())
|
||||
|
||||
|
||||
class Session:
|
||||
"""
|
||||
This is the main class used to run a model.
|
||||
|
|
@ -55,34 +130,23 @@ class Session:
|
|||
"Return registered execution providers' configurations."
|
||||
return self._provider_options
|
||||
|
||||
def set_providers(self, providers, provider_options=None):
|
||||
def set_providers(self, providers=None, provider_options=None):
|
||||
"""
|
||||
Register the input list of execution providers. The underlying session is re-created.
|
||||
|
||||
:param providers: list of execution providers
|
||||
:param provider_options: list of provider options dict for each provider, in the same order as 'providers'
|
||||
:param providers: Optional sequence of providers in order of decreasing
|
||||
precedence. Values can either be provider names or tuples of
|
||||
(provider name, options dict). If not provided, then all available
|
||||
providers are used with the default precedence.
|
||||
:param provider_options: Optional sequence of options dicts corresponding
|
||||
to the providers listed in 'providers'.
|
||||
|
||||
The list of providers is ordered by Priority. For example ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
'providers' can contain either names or names and options. When any options
|
||||
are given in 'providers', 'provider_options' should not be used.
|
||||
|
||||
The list of providers is ordered by precedence. For example ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider.
|
||||
"""
|
||||
if not set(providers).issubset(C.get_available_providers()):
|
||||
raise ValueError("{} does not contain a subset of available providers {}".format(
|
||||
providers, C.get_available_providers()))
|
||||
|
||||
if provider_options:
|
||||
if not isinstance(providers, list) or not isinstance(provider_options, list):
|
||||
raise ValueError("Inputs must be two python lists.")
|
||||
|
||||
if len(providers) != len(provider_options):
|
||||
raise ValueError("Two input lists must have same length.")
|
||||
|
||||
for option in provider_options:
|
||||
if not isinstance(option, dict):
|
||||
raise ValueError("Provider options must be list of python dict.")
|
||||
|
||||
for key, val in option.items():
|
||||
option[key] = str(val)
|
||||
|
||||
# recreate the underlying C.InferenceSession
|
||||
self._reset_session(providers, provider_options)
|
||||
|
||||
|
|
@ -173,8 +237,12 @@ class InferenceSession(Session):
|
|||
"""
|
||||
:param path_or_bytes: filename or serialized ONNX or ORT format model in a byte string
|
||||
:param sess_options: session options
|
||||
:param providers: list of providers to use for session. If empty, will use all available providers.
|
||||
:param provider_options: list of provider options dict for each provider, in the same order as 'providers'
|
||||
:param providers: Optional sequence of providers in order of decreasing
|
||||
precedence. Values can either be provider names or tuples of
|
||||
(provider name, options dict). If not provided, then all available
|
||||
providers are used with the default precedence.
|
||||
:param provider_options: Optional sequence of options dicts corresponding
|
||||
to the providers listed in 'providers'.
|
||||
|
||||
The model type will be inferred unless explicitly set in the SessionOptions.
|
||||
To explicitly set:
|
||||
|
|
@ -184,6 +252,12 @@ class InferenceSession(Session):
|
|||
|
||||
A file extension of '.ort' will be inferred as an ORT format model.
|
||||
All other filenames are assumed to be ONNX format models.
|
||||
|
||||
'providers' can contain either names or names and options. When any options
|
||||
are given in 'providers', 'provider_options' should not be used.
|
||||
|
||||
The list of providers is ordered by precedence. For example ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider.
|
||||
"""
|
||||
|
||||
Session.__init__(self)
|
||||
|
|
@ -208,15 +282,22 @@ class InferenceSession(Session):
|
|||
if self._enable_fallback:
|
||||
print("EP Error using {}".format(self._providers))
|
||||
print("Falling back to {} and retrying.".format(self._fallback_providers))
|
||||
self._create_inference_session(self._fallback_providers)
|
||||
self._create_inference_session(self._fallback_providers, None)
|
||||
# Fallback only once.
|
||||
self.disable_fallback()
|
||||
else:
|
||||
raise
|
||||
|
||||
def _create_inference_session(self, providers, provider_options):
|
||||
available_providers = C.get_available_providers()
|
||||
|
||||
# validate providers and provider_options before other initialization
|
||||
providers, provider_options = check_and_normalize_provider_args(providers,
|
||||
provider_options,
|
||||
available_providers)
|
||||
|
||||
# Tensorrt can fall back to CUDA. All others fall back to CPU.
|
||||
if 'TensorrtExecutionProvider' in C.get_available_providers():
|
||||
if 'TensorrtExecutionProvider' in available_providers:
|
||||
self._fallback_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
else:
|
||||
self._fallback_providers = ['CPUExecutionProvider']
|
||||
|
|
@ -228,7 +309,7 @@ class InferenceSession(Session):
|
|||
sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model)
|
||||
|
||||
# initialize the C++ InferenceSession
|
||||
sess.initialize_session(providers or [], provider_options or [])
|
||||
sess.initialize_session(providers, provider_options)
|
||||
|
||||
self._sess = sess
|
||||
self._sess_options = self._sess.session_options
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
#include "core/framework/arena_extend_strategy.h"
|
||||
#include "core/framework/data_transfer_utils.h"
|
||||
#include "core/framework/data_types_internal.h"
|
||||
#include "core/providers/get_execution_providers.h"
|
||||
#include "core/framework/kernel_registry.h"
|
||||
#include "core/framework/provider_options_utils.h"
|
||||
#include "core/framework/random_seed.h"
|
||||
|
|
@ -200,6 +201,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ACL(in
|
|||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ArmNN(int use_arena);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(int device_id);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi(uint32_t flags);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Rknpu();
|
||||
} // namespace onnxruntime
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
|
|
@ -425,27 +427,6 @@ static inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime
|
|||
OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(p)));
|
||||
}
|
||||
|
||||
// ordered by default priority from highest to lowest. kCpuExecutionProvider should always be last.
|
||||
static const std::vector<std::string>& GetAllProviders() {
|
||||
static std::vector<std::string> all_providers = {kTensorrtExecutionProvider, kCudaExecutionProvider,
|
||||
kMIGraphXExecutionProvider, kRocmExecutionProvider,
|
||||
kOpenVINOExecutionProvider, kDnnlExecutionProvider,
|
||||
kNupharExecutionProvider, kVitisAIExecutionProvider,
|
||||
kNnapiExecutionProvider,
|
||||
kArmNNExecutionProvider, kAclExecutionProvider,
|
||||
kDmlExecutionProvider, kCpuExecutionProvider};
|
||||
return all_providers;
|
||||
}
|
||||
|
||||
static const std::vector<std::string>& GetAvailableProviders() {
|
||||
auto InitializeProviders = []() {
|
||||
std::vector<std::string> available_providers(std::begin(providers_available), std::end(providers_available));
|
||||
return available_providers;
|
||||
};
|
||||
static std::vector<std::string> available_providers = InitializeProviders();
|
||||
return available_providers;
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
static bool IsCudaDeviceIdValid(const onnxruntime::logging::Logger& logger, int id) {
|
||||
|
|
@ -537,7 +518,7 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
|
|||
info.device_id = cuda_device_id;
|
||||
info.cuda_mem_limit = cuda_mem_limit;
|
||||
info.arena_extend_strategy = arena_extend_strategy;
|
||||
info.cudnn_conv_algo = cudnn_conv_algo_search;
|
||||
info.cudnn_conv_algo_search = cudnn_conv_algo_search;
|
||||
info.do_copy_in_default_stream = do_copy_in_default_stream;
|
||||
return info;
|
||||
}();
|
||||
|
|
@ -605,7 +586,10 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
|
|||
#if USE_NUPHAR
|
||||
const auto it = provider_options_map.find(type);
|
||||
if (it != provider_options_map.end()) {
|
||||
ReadProviderOption(it->second, "nuphar_settings", nuphar_settings);
|
||||
ORT_THROW_IF_ERROR(
|
||||
ProviderOptionsParser{}
|
||||
.AddAssignmentToReference("nuphar_settings", nuphar_settings)
|
||||
.Parse(it->second));
|
||||
}
|
||||
|
||||
RegisterExecutionProvider(
|
||||
|
|
@ -638,6 +622,10 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
|
|||
LOGS_DEFAULT(WARNING) << "NNAPI execution provider can only be used to generate ORT format model in this build.";
|
||||
#endif
|
||||
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_Nnapi(0));
|
||||
#endif
|
||||
} else if (type == kRknpuExecutionProvider) {
|
||||
#ifdef USE_RKNPU
|
||||
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_Rknpu());
|
||||
#endif
|
||||
} else {
|
||||
// unknown provider
|
||||
|
|
@ -697,7 +685,7 @@ void InitializeSession(InferenceSession* sess, const std::vector<std::string>& p
|
|||
|
||||
if (provider_types.empty()) {
|
||||
// use default registration priority.
|
||||
RegisterExecutionProviders(sess, GetAllProviders(), provider_options_map);
|
||||
RegisterExecutionProviders(sess, GetAllExecutionProviderNames(), provider_options_map);
|
||||
} else {
|
||||
RegisterExecutionProviders(sess, provider_types, provider_options_map);
|
||||
}
|
||||
|
|
@ -752,13 +740,15 @@ void addGlobalMethods(py::module& m, Environment& env) {
|
|||
},
|
||||
"Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal");
|
||||
m.def(
|
||||
"get_all_providers", []() -> const std::vector<std::string>& { return GetAllProviders(); },
|
||||
"get_all_providers", []() -> const std::vector<std::string>& { return GetAllExecutionProviderNames(); },
|
||||
"Return list of Execution Providers that this version of Onnxruntime can support. "
|
||||
"The order of elements represents the default priority order of Execution Providers"
|
||||
" from highest to lowest.");
|
||||
"The order of elements represents the default priority order of Execution Providers "
|
||||
"from highest to lowest.");
|
||||
m.def(
|
||||
"get_available_providers", []() -> const std::vector<std::string>& { return GetAvailableProviders(); },
|
||||
"Return list of available Execution Providers available in this installed version of Onnxruntime.");
|
||||
"get_available_providers", []() -> const std::vector<std::string>& { return GetAvailableExecutionProviderNames(); },
|
||||
"Return list of available Execution Providers available in this installed version of Onnxruntime. "
|
||||
"The order of elements represents the default priority order of Execution Providers "
|
||||
"from highest to lowest.");
|
||||
m.def(
|
||||
"enable_telemetry_events", []() -> void { platform_env.GetTelemetryProvider().EnableTelemetryEvents(); },
|
||||
"Enables platform-specific telemetry collection where applicable.");
|
||||
|
|
@ -833,7 +823,7 @@ void addGlobalMethods(py::module& m, Environment& env) {
|
|||
info.device_id = cuda_device_id;
|
||||
info.cuda_mem_limit = cuda_mem_limit;
|
||||
info.arena_extend_strategy = arena_extend_strategy;
|
||||
info.cudnn_conv_algo = cudnn_conv_algo_search;
|
||||
info.cudnn_conv_algo_search = cudnn_conv_algo_search;
|
||||
info.do_copy_in_default_stream = do_copy_in_default_stream;
|
||||
return info;
|
||||
}()),
|
||||
|
|
@ -874,6 +864,9 @@ void addGlobalMethods(py::module& m, Environment& env) {
|
|||
#endif
|
||||
#ifdef USE_NNAPI
|
||||
onnxruntime::CreateExecutionProviderFactory_NNAPI(0),
|
||||
#endif
|
||||
#ifdef USE_RKNPU
|
||||
onnxruntime::CreateExecutionProviderFactory_Rknpu(),
|
||||
#endif
|
||||
};
|
||||
|
||||
|
|
@ -905,7 +898,7 @@ void addGlobalMethods(py::module& m, Environment& env) {
|
|||
});
|
||||
// TODO remove deprecated global config
|
||||
m.def("set_cudnn_conv_algo_search", [](const OrtCudnnConvAlgoSearch algo) {
|
||||
LogDeprecationWarning("set_cudnn_conv_algo_search", "CUDA execution provider option \"cudnn_conv_algo\"");
|
||||
LogDeprecationWarning("set_cudnn_conv_algo_search", "CUDA execution provider option \"cudnn_conv_algo_search\"");
|
||||
#ifdef USE_ROCM
|
||||
ORT_UNUSED_PARAMETER(algo);
|
||||
ORT_THROW("set_cudnn_conv_algo_search is not supported in ROCM");
|
||||
|
|
@ -1084,11 +1077,6 @@ void addObjectMethods(py::module& m, Environment& env) {
|
|||
.value("CPU", OrtMemType::OrtMemTypeCPU)
|
||||
.value("DEFAULT", OrtMemType::OrtMemTypeDefault);
|
||||
|
||||
py::enum_<OrtCudnnConvAlgoSearch>(m, "OrtCudnnConvAlgoSearch")
|
||||
.value("EXHAUSTIVE", OrtCudnnConvAlgoSearch::EXHAUSTIVE)
|
||||
.value("HEURISTIC", OrtCudnnConvAlgoSearch::HEURISTIC)
|
||||
.value("DEFAULT", OrtCudnnConvAlgoSearch::DEFAULT);
|
||||
|
||||
py::class_<OrtDevice> device(m, "OrtDevice", R"pbdoc(ONNXRuntime device informaion.)pbdoc");
|
||||
device.def(py::init<OrtDevice::DeviceType, OrtDevice::MemoryType, OrtDevice::DeviceId>())
|
||||
.def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc")
|
||||
|
|
|
|||
|
|
@ -11,9 +11,6 @@
|
|||
namespace onnxruntime {
|
||||
namespace python {
|
||||
|
||||
using namespace onnxruntime;
|
||||
using namespace onnxruntime::logging;
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
struct CustomOpLibrary {
|
||||
CustomOpLibrary(const char* library_path, OrtSessionOptions& ort_so);
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/string_utils.h"
|
||||
#include "core/common/make_string.h"
|
||||
#include "core/common/parse_string.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
|
|
@ -12,29 +13,26 @@ namespace {
|
|||
template <typename T>
|
||||
void TestSuccessfulParse(const std::string& input, const T& expected_value) {
|
||||
T value;
|
||||
ASSERT_TRUE(TryParse(input, value));
|
||||
ASSERT_TRUE(TryParseString(input, value));
|
||||
EXPECT_EQ(value, expected_value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TestFailedParse(const std::string& input) {
|
||||
T value;
|
||||
EXPECT_FALSE(TryParse(input, value));
|
||||
EXPECT_FALSE(TryParseString(input, value));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(StringUtilsTest, TryParse) {
|
||||
TEST(StringUtilsTest, TryParseString) {
|
||||
TestSuccessfulParse("-1", -1);
|
||||
TestSuccessfulParse("42", 42u);
|
||||
TestSuccessfulParse("2.5", 2.5f);
|
||||
TestSuccessfulParse("1", true);
|
||||
TestSuccessfulParse("0", false);
|
||||
|
||||
// out of range
|
||||
TestFailedParse<int16_t>("32768");
|
||||
TestFailedParse<uint32_t>("-1");
|
||||
TestFailedParse<float>("1e100");
|
||||
TestFailedParse<bool>("2");
|
||||
// invalid representation
|
||||
TestFailedParse<int32_t>("1.2");
|
||||
TestFailedParse<int32_t>("one");
|
||||
|
|
@ -43,12 +41,21 @@ TEST(StringUtilsTest, TryParse) {
|
|||
TestFailedParse<int32_t>("1 ");
|
||||
}
|
||||
|
||||
TEST(StringUtilsTest, TryParseString) {
|
||||
TEST(StringUtilsTest, TryParseStringAsString) {
|
||||
// when parsing a string as a string, allow leading and trailing whitespace
|
||||
const std::string s = " this is a string! ";
|
||||
TestSuccessfulParse(s, s);
|
||||
}
|
||||
|
||||
TEST(StringUtilsTest, TryParseStringAsBool) {
|
||||
TestSuccessfulParse("True", true);
|
||||
TestSuccessfulParse("1", true);
|
||||
TestSuccessfulParse("False", false);
|
||||
TestSuccessfulParse("0", false);
|
||||
|
||||
TestFailedParse<bool>("2");
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct S {
|
||||
int i{};
|
||||
|
|
@ -73,12 +80,12 @@ std::istream& operator>>(std::istream& is, S& s) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
TEST(StringUtilsTest, MakeStringAndTryParseCustomType) {
|
||||
TEST(StringUtilsTest, MakeStringAndTryParseStringWithCustomType) {
|
||||
S s;
|
||||
s.i = 42;
|
||||
const auto str = MakeString(s);
|
||||
S parsed_s;
|
||||
ASSERT_TRUE(TryParse(str, parsed_s));
|
||||
ASSERT_TRUE(TryParseString(str, parsed_s));
|
||||
ASSERT_EQ(parsed_s, s);
|
||||
}
|
||||
|
||||
|
|
|
|||
59
onnxruntime/test/framework/provider_options_utils_test.cc
Normal file
59
onnxruntime/test/framework/provider_options_utils_test.cc
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/provider_options_utils.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "asserts.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
namespace {
|
||||
enum class TestEnum {
|
||||
A,
|
||||
Unmapped,
|
||||
};
|
||||
|
||||
const EnumNameMapping<TestEnum> test_enum_mapping{
|
||||
{TestEnum::A, "A"},
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST(ProviderOptionsUtilsTest, ProviderOptionsParser) {
|
||||
int i;
|
||||
bool b;
|
||||
TestEnum e;
|
||||
ProviderOptionsParser parser{};
|
||||
parser.AddAssignmentToReference("int", i);
|
||||
parser.AddAssignmentToReference("bool", b);
|
||||
parser.AddAssignmentToEnumReference("enum", test_enum_mapping, e);
|
||||
|
||||
// adding same option again should throw
|
||||
ASSERT_THROW(parser.AddAssignmentToReference("int", i), OnnxRuntimeException);
|
||||
|
||||
ASSERT_STATUS_OK(parser.Parse({{"int", "3"}, {"bool", "true"}, {"enum", "A"}}));
|
||||
EXPECT_EQ(i, 3);
|
||||
EXPECT_EQ(b, true);
|
||||
EXPECT_EQ(e, TestEnum::A);
|
||||
|
||||
ASSERT_FALSE(parser.Parse({{"unknown option", "some value"}}).IsOK());
|
||||
}
|
||||
|
||||
TEST(ProviderOptionsUtilsTest, EnumToName) {
|
||||
std::string name;
|
||||
ASSERT_STATUS_OK(EnumToName(test_enum_mapping, TestEnum::A, name));
|
||||
EXPECT_EQ(name, "A");
|
||||
ASSERT_FALSE(EnumToName(test_enum_mapping, TestEnum::Unmapped, name).IsOK());
|
||||
}
|
||||
|
||||
TEST(ProviderOptionsUtilsTest, NameToEnum) {
|
||||
TestEnum value;
|
||||
ASSERT_STATUS_OK(NameToEnum(test_enum_mapping, "A", value));
|
||||
EXPECT_EQ(value, TestEnum::A);
|
||||
ASSERT_FALSE(NameToEnum(test_enum_mapping, "invalid", value).IsOK());
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -11,7 +11,7 @@
|
|||
#include "mem_buffer.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/common/status.h"
|
||||
#include "core/common/string_utils.h"
|
||||
#include "core/common/make_string.h"
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/framework/endian.h"
|
||||
#include "core/framework/allocator.h"
|
||||
|
|
|
|||
52
onnxruntime/test/providers/get_execution_providers_test.cc
Normal file
52
onnxruntime/test/providers/get_execution_providers_test.cc
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/get_execution_providers.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "core/graph/constants.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(GetExecutionProvidersTest, CpuEpAlwaysLast) {
|
||||
const auto check = [](const std::vector<std::string>& providers) {
|
||||
ASSERT_FALSE(providers.empty());
|
||||
EXPECT_EQ(providers.back(), kCpuExecutionProvider);
|
||||
};
|
||||
|
||||
check(GetAllExecutionProviderNames());
|
||||
check(GetAvailableExecutionProviderNames());
|
||||
}
|
||||
|
||||
TEST(GetExecutionProvidersTest, ConsistentOrdering) {
|
||||
const auto& all = GetAllExecutionProviderNames();
|
||||
const auto& available = GetAvailableExecutionProviderNames();
|
||||
std::vector<std::string> available_from_all{};
|
||||
std::copy_if(
|
||||
all.begin(), all.end(),
|
||||
std::back_inserter(available_from_all),
|
||||
[&available](const std::string& value) {
|
||||
return std::find(available.begin(), available.end(), value) != available.end();
|
||||
});
|
||||
|
||||
EXPECT_EQ(available, available_from_all);
|
||||
}
|
||||
|
||||
TEST(GetExecutionProvidersTest, NoDuplicates) {
|
||||
const auto check = [](const std::vector<std::string>& providers) {
|
||||
const std::unordered_set<std::string> providers_set(providers.begin(), providers.end());
|
||||
EXPECT_EQ(providers.size(), providers_set.size());
|
||||
};
|
||||
|
||||
check(GetAllExecutionProviderNames());
|
||||
check(GetAvailableExecutionProviderNames());
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -81,13 +81,13 @@ class TestInferenceSession(unittest.TestCase):
|
|||
|
||||
def runBaseTest2():
|
||||
sess = onnxrt.InferenceSession(get_name("mul_1.onnx"))
|
||||
self.assertTrue('CUDAExecutionProvider' in sess.get_providers())
|
||||
self.assertIn('CUDAExecutionProvider', sess.get_providers())
|
||||
|
||||
# test get/set of "cuda_mem_limit" configuration.
|
||||
options = sess.get_provider_options()
|
||||
self.assertTrue('CUDAExecutionProvider' in options)
|
||||
self.assertIn('CUDAExecutionProvider', options)
|
||||
option = options['CUDAExecutionProvider']
|
||||
self.assertTrue('cuda_mem_limit' in option)
|
||||
self.assertIn('cuda_mem_limit', option)
|
||||
ori_mem_limit = option['cuda_mem_limit']
|
||||
new_mem_limit = int(ori_mem_limit) // 2
|
||||
option['cuda_mem_limit'] = new_mem_limit
|
||||
|
|
@ -100,16 +100,27 @@ class TestInferenceSession(unittest.TestCase):
|
|||
options = sess.get_provider_options()
|
||||
self.assertEqual(options['CUDAExecutionProvider']['cuda_mem_limit'], ori_mem_limit)
|
||||
|
||||
# test get/set of "arena_extend_strategy" configuration.
|
||||
options = sess.get_provider_options()
|
||||
self.assertTrue('CUDAExecutionProvider' in options)
|
||||
option = options['CUDAExecutionProvider']
|
||||
self.assertTrue('arena_extend_strategy' in option)
|
||||
for strategy in ['kNextPowerOfTwo', 'kSameAsRequested']:
|
||||
option['arena_extend_strategy'] = strategy
|
||||
sess.set_providers(['CUDAExecutionProvider'], [option])
|
||||
options = sess.get_provider_options()
|
||||
self.assertEqual(options['CUDAExecutionProvider']['arena_extend_strategy'], strategy)
|
||||
def test_get_and_set_option_with_values(option_name, option_values):
|
||||
provider_options = sess.get_provider_options()
|
||||
self.assertIn('CUDAExecutionProvider', provider_options)
|
||||
cuda_options = options['CUDAExecutionProvider']
|
||||
self.assertIn(option_name, cuda_options)
|
||||
for option_value in option_values:
|
||||
cuda_options[option_name] = option_value
|
||||
sess.set_providers(['CUDAExecutionProvider'], [cuda_options])
|
||||
new_provider_options = sess.get_provider_options()
|
||||
self.assertEqual(
|
||||
new_provider_options.get('CUDAExecutionProvider', {}).get(option_name),
|
||||
str(option_value))
|
||||
|
||||
test_get_and_set_option_with_values(
|
||||
'arena_extend_strategy', ['kNextPowerOfTwo', 'kSameAsRequested'])
|
||||
|
||||
test_get_and_set_option_with_values(
|
||||
'cudnn_conv_algo_search', ["DEFAULT", "EXHAUSTIVE", "HEURISTIC"])
|
||||
|
||||
test_get_and_set_option_with_values(
|
||||
'do_copy_in_default_stream', [0, 1])
|
||||
|
||||
#
|
||||
# Note: Tests that throw an exception leave an empty session due to how set_providers currently works,
|
||||
|
|
@ -181,21 +192,17 @@ class TestInferenceSession(unittest.TestCase):
|
|||
|
||||
sess = onnxrt.InferenceSession(get_name("mul_1.onnx"))
|
||||
|
||||
# configure session with not legit option values and that shloud fail
|
||||
# configure session with invalid option values and that should fail
|
||||
with self.assertRaises(RuntimeError):
|
||||
option = {'device_id': num_device}
|
||||
sess.set_providers(['CUDAExecutionProvider'], [option])
|
||||
option = {'device_id': 'non_legit_value'}
|
||||
option = {'device_id': 'invalid_value'}
|
||||
sess.set_providers(['CUDAExecutionProvider'], [option])
|
||||
|
||||
# configure session with not legit option should cause no effect
|
||||
option = {'device_id': 0}
|
||||
sess.set_providers(['CUDAExecutionProvider'], [option])
|
||||
option = {'non_legit_option': num_device}
|
||||
sess.set_providers(['CUDAExecutionProvider'], [option])
|
||||
self.assertEqual(['CUDAExecutionProvider', 'CPUExecutionProvider'], sess.get_providers())
|
||||
|
||||
|
||||
# configure session with invalid option should fail
|
||||
with self.assertRaises(RuntimeError):
|
||||
option = {'invalid_option': 123}
|
||||
sess.set_providers(['CUDAExecutionProvider'], [option])
|
||||
|
||||
libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll')
|
||||
for libname in libnames:
|
||||
|
|
@ -218,8 +225,7 @@ class TestInferenceSession(unittest.TestCase):
|
|||
with self.assertRaises(ValueError) as context:
|
||||
sess = onnxrt.InferenceSession(get_name("mul_1.onnx"))
|
||||
sess.set_providers(['InvalidProvider'])
|
||||
self.assertTrue(
|
||||
'[\'InvalidProvider\'] does not contain a subset of available providers' in str(context.exception))
|
||||
self.assertTrue('\'InvalidProvider\' is unavailable' in str(context.exception))
|
||||
|
||||
def testSessionProviders(self):
|
||||
if 'CUDAExecutionProvider' in onnxrt.get_available_providers():
|
||||
|
|
@ -712,7 +718,7 @@ class TestInferenceSession(unittest.TestCase):
|
|||
sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), so, ['CPUExecutionProvider'])
|
||||
res = sess.run(["Y"], {"X": np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)})
|
||||
self.assertTrue(np.array_equal(res[0], np.array([[2.0, 2.0], [12.0, 12.0], [30.0, 30.0]], dtype=np.float32)))
|
||||
|
||||
|
||||
def testRegisterCustomOpsLibrary(self):
|
||||
if sys.platform.startswith("win"):
|
||||
shared_library = 'custom_op_library.dll'
|
||||
|
|
@ -809,17 +815,17 @@ class TestInferenceSession(unittest.TestCase):
|
|||
# 1. if there are intermittent failure in this test, something is wrong
|
||||
# 2. it's easier to repro on slower GPU (like M60, Geforce 1070)
|
||||
|
||||
# to repro #4829, uncomment the line below to run copy in a separate stream
|
||||
#onnxrt.capi._pybind_state.set_do_copy_in_default_stream(False)
|
||||
# to repro #4829, set the CUDA EP do_copy_in_default_stream option to False
|
||||
providers = [("CUDAExecutionProvider", {"do_copy_in_default_stream": True}), "CPUExecutionProvider"]
|
||||
|
||||
session = onnxrt.InferenceSession(get_name("issue4829.onnx"))
|
||||
session = onnxrt.InferenceSession(get_name("issue4829.onnx"), providers=providers)
|
||||
shape = np.array([2,2], dtype=np.int64)
|
||||
for iteration in range(100000):
|
||||
result = session.run(output_names=['output'], input_feed={'shape': shape})
|
||||
|
||||
def testSharedAllocatorUsingCreateAndRegisterAllocator(self):
|
||||
# Create and register an arena based allocator
|
||||
|
||||
|
||||
# ort_arena_cfg = onnxrt.OrtArenaCfg(0, -1, -1, -1) (create an OrtArenaCfg like this template if you want to use non-default parameters)
|
||||
ort_memory_info = onnxrt.OrtMemoryInfo("Cpu", onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR, 0, onnxrt.OrtMemType.DEFAULT)
|
||||
# Use this option if using non-default OrtArenaCfg : onnxrt.create_and_register_allocator(ort_memory_info, ort_arena_cfg)
|
||||
|
|
@ -836,5 +842,55 @@ class TestInferenceSession(unittest.TestCase):
|
|||
so2.log_severity_level = 1
|
||||
onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=so2)
|
||||
|
||||
def testCheckAndNormalizeProviderArgs(self):
|
||||
from onnxruntime.capi.onnxruntime_inference_collection import check_and_normalize_provider_args
|
||||
|
||||
valid_providers = ["a", "b", "c"]
|
||||
|
||||
def check_success(providers, provider_options, expected_providers, expected_provider_options):
|
||||
actual_providers, actual_provider_options = check_and_normalize_provider_args(
|
||||
providers, provider_options, valid_providers)
|
||||
self.assertEqual(actual_providers, expected_providers)
|
||||
self.assertEqual(actual_provider_options, expected_provider_options)
|
||||
|
||||
check_success(None, None, [], [])
|
||||
|
||||
check_success(["a"], None, ["a"], [{}])
|
||||
|
||||
check_success(["a", "b"], None, ["a", "b"], [{}, {}])
|
||||
|
||||
check_success([("a", {1: 2}), "b"], None, ["a", "b"], [{"1": "2"}, {}])
|
||||
|
||||
check_success(["a", "b"], [{1: 2}, {}], ["a", "b"], [{"1": "2"}, {}])
|
||||
|
||||
with self.assertWarns(UserWarning):
|
||||
check_success(["a", "b", "a"], [{"x": 1}, {}, {"y": 2}], ["a", "b"], [{"x": "1"}, {}])
|
||||
|
||||
def check_failure(providers, provider_options):
|
||||
with self.assertRaises(ValueError):
|
||||
check_and_normalize_provider_args(providers, provider_options, valid_providers)
|
||||
|
||||
# provider not valid
|
||||
check_failure(["d"], None)
|
||||
|
||||
# providers not sequence
|
||||
check_failure(3, None)
|
||||
|
||||
# providers value invalid
|
||||
check_failure([3], None)
|
||||
|
||||
# provider_options not sequence
|
||||
check_failure(["a"], 3)
|
||||
|
||||
# provider_options value invalid
|
||||
check_failure(["a"], ["not dict"])
|
||||
|
||||
# providers and provider_options length mismatch
|
||||
check_failure(["a", "b"], [{1: 2}])
|
||||
|
||||
# provider options unsupported mixed specification
|
||||
check_failure([("a", {1: 2})], [{3: 4}])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -25,13 +25,13 @@ def reference_gemm(a, b, c, alpha, beta, transA, transB):
|
|||
|
||||
def set_gemm_node_attrs(attrs, config):
|
||||
if config['alpha'] != 1.0:
|
||||
attrs['alpha'] = config['alpha']
|
||||
attrs['alpha'] = config['alpha']
|
||||
if config['beta'] != 1.0:
|
||||
attrs['beta'] = config['beta']
|
||||
attrs['beta'] = config['beta']
|
||||
if config['transA']:
|
||||
attrs['transA'] = 1
|
||||
attrs['transA'] = 1
|
||||
if config['transB']:
|
||||
attrs['transB'] = 1
|
||||
attrs['transB'] = 1
|
||||
|
||||
def generate_gemm_inputs_initializers(graph, config, added_inputs_initializers={}, extend=False):
|
||||
M = config['M']
|
||||
|
|
@ -323,6 +323,16 @@ def set_gemm_model_inputs(config, test_inputs, a, b, c):
|
|||
if config['withC'] and not config['initC']:
|
||||
test_inputs[config['C']] = c
|
||||
|
||||
|
||||
def make_providers(nuphar_settings):
|
||||
return [
|
||||
('NupharExecutionProvider', {
|
||||
'nuphar_settings': nuphar_settings
|
||||
}),
|
||||
'CPUExecutionProvider',
|
||||
]
|
||||
|
||||
|
||||
class TestNuphar(unittest.TestCase):
|
||||
|
||||
def test_bidaf(self):
|
||||
|
|
@ -385,8 +395,9 @@ class TestNuphar(unittest.TestCase):
|
|||
for model in [bidaf_opt_scan_model, bidaf_int8_scan_only_model]:
|
||||
nuphar_settings = 'nuphar_cache_path:{}'.format(cache_dir)
|
||||
for isa in ['avx', 'avx2', 'avx512']:
|
||||
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings + ', nuphar_codegen_target:' + isa)
|
||||
sess = onnxrt.InferenceSession(model) # JIT cache happens when initializing session
|
||||
# JIT cache happens when initializing session
|
||||
sess = onnxrt.InferenceSession(
|
||||
model, providers=make_providers(nuphar_settings + ', nuphar_codegen_target:' + isa))
|
||||
|
||||
cache_dir_content = os.listdir(cache_dir)
|
||||
assert len(cache_dir_content) == 1
|
||||
|
|
@ -400,15 +411,13 @@ class TestNuphar(unittest.TestCase):
|
|||
|
||||
nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}'.format(
|
||||
cache_dir, so_name, 'on')
|
||||
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings)
|
||||
sess = onnxrt.InferenceSession(model)
|
||||
sess = onnxrt.InferenceSession(model, providers=make_providers(nuphar_settings))
|
||||
sess.run([], feed)
|
||||
|
||||
# test avx
|
||||
nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}, nuphar_codegen_target:{}'.format(
|
||||
cache_dir, so_name, 'on', 'avx')
|
||||
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings)
|
||||
sess = onnxrt.InferenceSession(model)
|
||||
sess = onnxrt.InferenceSession(model, providers=make_providers(nuphar_settings))
|
||||
sess.run([], feed)
|
||||
|
||||
def test_bert_squad(self):
|
||||
|
|
@ -672,7 +681,7 @@ class TestNuphar(unittest.TestCase):
|
|||
|
||||
def test_loop_to_scan(self):
|
||||
loop_model_filename = get_name("nuphar_tiny_model_with_loop_shape_infered.onnx")
|
||||
scan_model_filename = "nuphar_tiny_model_with_loop_shape_infered_converted_to_scan.onnx"
|
||||
scan_model_filename = "nuphar_tiny_model_with_loop_shape_infered_converted_to_scan.onnx"
|
||||
subprocess.run([
|
||||
sys.executable, '-m', 'onnxruntime.nuphar.model_editor',
|
||||
'--input', loop_model_filename,
|
||||
|
|
@ -680,13 +689,13 @@ class TestNuphar(unittest.TestCase):
|
|||
], check=True)
|
||||
|
||||
validate_with_ort(loop_model_filename, scan_model_filename)
|
||||
|
||||
|
||||
def test_loop_to_scan_with_inconvertible_loop(self):
|
||||
# nuphar_onnx_test_loop11_inconvertible_loop.onnx contains a Loop op with dynamic loop count.
|
||||
# This Loop op cannot be converted to a Scan op.
|
||||
# Set --keep_unconvertible_loop_ops option so conversion will not fail due to unconvertible loop ops.
|
||||
loop_model_filename = get_name("nuphar_onnx_test_loop11_inconvertible_loop.onnx")
|
||||
scan_model_filename = "nuphar_onnx_test_loop11_inconvertible_loop_unchanged.onnx"
|
||||
scan_model_filename = "nuphar_onnx_test_loop11_inconvertible_loop_unchanged.onnx"
|
||||
subprocess.run([
|
||||
sys.executable, '-m', 'onnxruntime.nuphar.model_editor',
|
||||
'--input', loop_model_filename,
|
||||
|
|
@ -695,15 +704,15 @@ class TestNuphar(unittest.TestCase):
|
|||
], check=True)
|
||||
|
||||
# onnxruntime is failing with:
|
||||
# onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 :
|
||||
# FAIL : Non-zero status code returned while running Loop node. Name:''
|
||||
# onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 :
|
||||
# FAIL : Non-zero status code returned while running Loop node. Name:''
|
||||
# Status Message: Inconsistent shape in loop output for output. Expected:{1} Got:{0}
|
||||
# skip validate_with_ort for now
|
||||
# validate_with_ort(loop_model_filename, scan_model_filename)
|
||||
|
||||
def test_loop_to_scan_tool(self):
|
||||
loop_model_filename = get_name("nuphar_tiny_model_with_loop_shape_infered.onnx")
|
||||
scan_model_filename = "nuphar_tiny_model_with_loop_shape_infered_converted_to_scan.onnx"
|
||||
scan_model_filename = "nuphar_tiny_model_with_loop_shape_infered_converted_to_scan.onnx"
|
||||
subprocess.run([
|
||||
sys.executable, '-m', 'onnxruntime.nuphar.model_tools',
|
||||
'--input', loop_model_filename,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <core/common/make_unique.h>
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "core/session/onnxruntime_session_options_config_keys.h"
|
||||
#include "core/graph/constants.h"
|
||||
#include "providers.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
|
@ -15,7 +9,16 @@
|
|||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <algorithm>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/make_unique.h"
|
||||
#include "core/graph/constants.h"
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "core/session/onnxruntime_session_options_config_keys.h"
|
||||
#include "providers.h"
|
||||
#include "test_allocator.h"
|
||||
#include "test_fixture.h"
|
||||
|
||||
|
|
@ -1188,19 +1191,19 @@ TEST(CApiTest, get_available_providers) {
|
|||
int len = 0;
|
||||
char** providers;
|
||||
ASSERT_EQ(g_ort->GetAvailableProviders(&providers, &len), nullptr);
|
||||
ASSERT_TRUE(len > 0);
|
||||
ASSERT_EQ(strcmp(providers[0], "CPUExecutionProvider"), 0);
|
||||
ASSERT_GT(len, 0);
|
||||
ASSERT_STREQ(providers[len-1], "CPUExecutionProvider");
|
||||
ASSERT_EQ(g_ort->ReleaseAvailableProviders(providers, len), nullptr);
|
||||
}
|
||||
|
||||
TEST(CApiTest, get_available_providers_cpp) {
|
||||
std::vector<std::string> providers = Ort::GetAvailableProviders();
|
||||
ASSERT_TRUE(providers.size() > 0);
|
||||
ASSERT_TRUE(providers[0] == std::string("CPUExecutionProvider"));
|
||||
ASSERT_FALSE(providers.empty());
|
||||
ASSERT_EQ(providers.back(), "CPUExecutionProvider");
|
||||
|
||||
#ifdef USE_CUDA
|
||||
// CUDA EP will exist in the list but its position may vary based on other EPs included in the build
|
||||
ASSERT_TRUE(std::find(providers.begin(), providers.end(), std::string("CUDAExecutionProvider")) != providers.end());
|
||||
ASSERT_TRUE(std::find(providers.begin(), providers.end(), "CUDAExecutionProvider") != providers.end());
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -576,7 +576,7 @@ void setup_training_params(BertParameters& params) {
|
|||
if (params.cuda_mem_limit_in_gb > 0) {
|
||||
info.cuda_mem_limit = gsl::narrow<size_t>(params.cuda_mem_limit_in_gb * 1024 * 1024 * 1024);
|
||||
}
|
||||
info.cudnn_conv_algo = OrtCudnnConvAlgoSearch::EXHAUSTIVE;
|
||||
info.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::EXHAUSTIVE;
|
||||
|
||||
params.providers.emplace(kCudaExecutionProvider, CreateExecutionProviderFactory_CUDA(info));
|
||||
params.input_allocator = std::make_shared<CUDAPinnedAllocator>(info.device_id, CUDA_PINNED);
|
||||
|
|
|
|||
|
|
@ -7,11 +7,12 @@ import sys
|
|||
import os
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from onnxruntime.capi.onnxruntime_inference_collection import Session, InferenceSession, IOBinding
|
||||
from onnxruntime.capi.onnxruntime_inference_collection import (Session, InferenceSession, IOBinding,
|
||||
check_and_normalize_provider_args)
|
||||
|
||||
|
||||
class TrainingSession(InferenceSession):
|
||||
def __init__(self, path_or_bytes, parameters, sess_options=None):
|
||||
def __init__(self, path_or_bytes, parameters, sess_options=None, providers=None, provider_options=None):
|
||||
Session.__init__(self)
|
||||
|
||||
if sess_options:
|
||||
|
|
@ -19,10 +20,13 @@ class TrainingSession(InferenceSession):
|
|||
else:
|
||||
self._sess = C.TrainingSession()
|
||||
|
||||
providers, provider_options = check_and_normalize_provider_args(providers, provider_options,
|
||||
C.get_available_providers())
|
||||
|
||||
if isinstance(path_or_bytes, str):
|
||||
config_result = self._sess.load_model(path_or_bytes, parameters)
|
||||
config_result = self._sess.load_model(path_or_bytes, parameters, providers, provider_options)
|
||||
elif isinstance(path_or_bytes, bytes):
|
||||
config_result = self._sess.read_bytes(path_or_bytes, parameters)
|
||||
config_result = self._sess.read_bytes(path_or_bytes, parameters, providers, provider_options)
|
||||
else:
|
||||
raise TypeError("Unable to load from type '{0}'".format(type(path_or_bytes)))
|
||||
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
#endif
|
||||
#endif
|
||||
})
|
||||
.def("load_model", [](PyTrainingSession* sess, const std::string& path, TrainingParameters& parameters) {
|
||||
.def("load_model", [](PyTrainingSession* sess, const std::string& path, TrainingParameters& parameters, const std::vector<std::string>& provider_types, const ProviderOptionsVector& provider_options) {
|
||||
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(path));
|
||||
|
||||
#if defined(USE_MPI)
|
||||
|
|
@ -315,12 +315,11 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
#endif
|
||||
const auto config_result = ConfigureSessionForTraining(static_cast<TrainingSession*>(sess->GetSessionHandle()), parameters);
|
||||
|
||||
std::vector<std::string> provider_types = {};
|
||||
InitializeSession(sess->GetSessionHandle(), provider_types);
|
||||
InitializeSession(sess->GetSessionHandle(), provider_types, provider_options);
|
||||
|
||||
return config_result;
|
||||
})
|
||||
.def("read_bytes", [](PyTrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters) {
|
||||
.def("read_bytes", [](PyTrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters, const std::vector<std::string>& provider_types, const ProviderOptionsVector& provider_options) {
|
||||
std::istringstream buffer(serialized_model);
|
||||
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(buffer));
|
||||
|
||||
|
|
@ -331,8 +330,7 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
#endif
|
||||
const auto config_result = ConfigureSessionForTraining(static_cast<TrainingSession*>(sess->GetSessionHandle()), parameters);
|
||||
|
||||
std::vector<std::string> provider_types = {};
|
||||
InitializeSession(sess->GetSessionHandle(), provider_types);
|
||||
InitializeSession(sess->GetSessionHandle(), provider_types, provider_options);
|
||||
|
||||
return config_result;
|
||||
})
|
||||
|
|
|
|||
|
|
@ -195,13 +195,6 @@ class ORTTrainer(object):
|
|||
break
|
||||
assert dtype is not None, f"ONNX model with unknown output type ({o_desc.name})"
|
||||
|
||||
# Set GPU device and memory limit
|
||||
if 'cuda' in self.options.device.id.lower():
|
||||
mem_limit = self.options.device.mem_limit
|
||||
if mem_limit > 0:
|
||||
ort.set_cuda_mem_limit(self.options.device.mem_limit)
|
||||
ort.set_cuda_device_id(_utils.get_device_index(self.options.device.id))
|
||||
|
||||
# TODO: Remove when experimental checkpoint functions are removed.
|
||||
self._state_dict = {}
|
||||
|
||||
|
|
@ -655,10 +648,30 @@ class ORTTrainer(object):
|
|||
# old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error.
|
||||
# for example, load_state_dict will be called before returing the function, and it calls _init_session again
|
||||
del self._training_session
|
||||
|
||||
# Set provider-specific options if needed
|
||||
def get_providers():
|
||||
providers = ort.get_available_providers()
|
||||
|
||||
if 'cuda' in self.options.device.id.lower():
|
||||
cuda_ep_options = {"device_id": _utils.get_device_index(self.options.device.id)}
|
||||
if self.options.device.mem_limit > 0:
|
||||
cuda_ep_options["cuda_mem_limit"] = self.options.device.mem_limit
|
||||
|
||||
cuda_ep_name = "CUDAExecutionProvider"
|
||||
|
||||
if cuda_ep_name not in providers:
|
||||
raise RuntimeError(
|
||||
"ORTTrainer options specify a CUDA device but the {} provider is unavailable.".format(
|
||||
cuda_ep_name))
|
||||
|
||||
providers[providers.index(cuda_ep_name)] = (cuda_ep_name, cuda_ep_options)
|
||||
|
||||
return providers
|
||||
|
||||
# TrainingSession
|
||||
self._training_session = ort.TrainingSession(self._onnx_model.SerializeToString(),
|
||||
ort_parameters,
|
||||
session_options)
|
||||
self._training_session = ort.TrainingSession(self._onnx_model.SerializeToString(), ort_parameters,
|
||||
session_options, get_providers())
|
||||
|
||||
# I/O bindings
|
||||
self._train_io_binding = self._training_session.io_binding()
|
||||
|
|
|
|||
Loading…
Reference in a new issue