Deprecate Python global configuration functions [Part 2] (#6171)

Update Python API to allow more flexibility for setting providers and provider options.

The providers argument (InferenceSession/TrainingSession constructors, InferenceSession.set_providers()) now also accepts a tuple of (name, options dict).
Fix get_available_providers() API (and the corresponding function in the C API) to return the providers in default priority order. Now it can be used as a starting point for the providers argument and maintain the default priority order.
Convert some usages of the deprecated global configuration functions to use EP-specific options instead.

Update some EP-specific option parsing to fail on unknown options.

Other clean up.
This commit is contained in:
Edward Chen 2021-01-07 10:10:55 -08:00 committed by GitHub
parent bbc9ed908a
commit d761571afc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
32 changed files with 937 additions and 370 deletions

View file

@ -180,7 +180,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string[] providers = ortEnvInstance.GetAvailableProviders();
Assert.True(providers.Length > 0);
Assert.Equal("CPUExecutionProvider", providers[0]);
Assert.Equal("CPUExecutionProvider", providers[providers.Length - 1]);
# if USE_CUDA
Assert.True(Array.Exists(providers, provider => provider == "CUDAExecutionProvider"););

View file

@ -165,8 +165,8 @@
"code_to_run = '''\n",
"import onnxruntime\n",
"s = 'codegen_dump_lower:verbose'\n",
"onnxruntime.capi._pybind_state.set_nuphar_settings(s)\n",
"sess = onnxruntime.InferenceSession('simple.onnx')\n",
"providers = [('NupharExecutionProvider', {'nuphar_settings': s}), 'CPUExecutionProvider']\n",
"sess = onnxruntime.InferenceSession('simple.onnx', providers=providers)\n",
"'''\n",
"\n",
"log_file = 'simple_lower.log' \n",
@ -992,8 +992,8 @@
"bidaf_cache_dir = os.path.join(bidaf_dir, 'cache')\n",
"create_cache_dir(bidaf_cache_dir)\n",
"settings = 'nuphar_cache_path:{}'.format(bidaf_cache_dir)\n",
"onnxruntime.capi._pybind_state.set_nuphar_settings(settings)\n",
"sess = onnxruntime.InferenceSession(bidaf_converted)"
"providers = [('NupharExecutionProvider', {'nuphar_settings': settings}), 'CPUExecutionProvider']\n",
"sess = onnxruntime.InferenceSession(bidaf_converted, providers=providers)"
]
},
{
@ -1049,10 +1049,9 @@
],
"source": [
"start_aot = timer()\n",
"# NOTE: Nuphar settings string is not sticky. It needs to be reset before creating InferenceSession\n",
"settings = 'nuphar_cache_path:{}'.format(bidaf_cache_dir)\n",
"onnxruntime.capi._pybind_state.set_nuphar_settings(settings)\n",
"sess = onnxruntime.InferenceSession(bidaf_converted)\n",
"providers = [('NupharExecutionProvider', {'nuphar_settings': settings}), 'CPUExecutionProvider']\n",
"sess = onnxruntime.InferenceSession(bidaf_converted, providers=providers)\n",
"end_aot = timer()\n",
"print_speedup('AOT', end_jit - start_jit, end_aot - start_aot)"
]
@ -1096,8 +1095,8 @@
"settings = 'nuphar_cache_path:{}'.format(cache_dir)\n",
"for isa in ['avx512', 'avx2', 'avx']:\n",
" settings_with_isa = settings + ', nuphar_codegen_target:' + isa\n",
" onnxruntime.capi._pybind_state.set_nuphar_settings(settings_with_isa)\n",
" sess = onnxruntime.InferenceSession(model_name)\n",
" providers = [('NupharExecutionProvider', {'nuphar_settings': settings_with_isa}), 'CPUExecutionProvider']\n",
" sess = onnxruntime.InferenceSession(model_name, providers=providers)\n",
" cache_versioned_dir = os.path.join(cache_dir, os.listdir(cache_dir)[0])\n",
"\n",
"# link object files to AOT dll\n",
@ -1106,15 +1105,15 @@
"# now load the model with AOT dll\n",
"# NOTE: when nuphar_codegen_target is not set, it defaults to current CPU ISA\n",
"settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_model_checksum:{}, nuphar_cache_force_no_jit:on'.format(cache_dir, multi_isa_so, model_checksum)\n",
"onnxruntime.capi._pybind_state.set_nuphar_settings(settings)\n",
"sess = onnxruntime.InferenceSession(model_name)\n",
"providers = [('NupharExecutionProvider', {'nuphar_settings': settings}), 'CPUExecutionProvider']\n",
"sess = onnxruntime.InferenceSession(model_name, providers=providers)\n",
"\n",
"# force to a different ISA which is a subset of current CPU\n",
"# NOTE: if an incompatible ISA is used, exception on invalid instructions would be thrown\n",
"for valid_isa in ['avx2', 'avx']:\n",
" settings_with_isa = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_model_checksum:{}, nuphar_codegen_target:{}, nuphar_cache_force_no_jit:on'.format(cache_dir, multi_isa_so, model_checksum, valid_isa)\n",
" onnxruntime.capi._pybind_state.set_nuphar_settings(settings_with_isa)\n",
" sess = onnxruntime.InferenceSession(model_name)\n",
" providers = [('NupharExecutionProvider', {'nuphar_settings': settings_with_isa}), 'CPUExecutionProvider']\n",
" sess = onnxruntime.InferenceSession(model_name, providers=providers)\n",
"\n",
" start_nuphar = timer()\n",
" for i in range(repeats):\n",
@ -1160,8 +1159,8 @@
"# use NUPHAR_PARALLEL_MIN_WORKLOADS=0 to turn off parallel schedule, using settings string\n",
"# it can be set from environment variable too: os.environ['NUPHAR_PARALLEL_MIN_WORKLOADS'] = '0'\n",
"settings = 'nuphar_parallel_min_workloads:0'\n",
"onnxruntime.capi._pybind_state.set_nuphar_settings(settings)\n",
"sess = onnxruntime.InferenceSession(bidaf_converted)\n",
"providers = [('NupharExecutionProvider', {'nuphar_settings': settings}), 'CPUExecutionProvider']\n",
"sess = onnxruntime.InferenceSession(bidaf_converted, providers=providers)\n",
"\n",
"start = timer()\n",
"for i in range(bidaf_repeats):\n",
@ -1200,4 +1199,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View file

@ -33,9 +33,9 @@
#include "core/common/code_location.h"
#include "core/common/exceptions.h"
#include "core/common/make_string.h"
#include "core/common/make_unique.h"
#include "core/common/status.h"
#include "core/common/string_utils.h"
#ifdef USE_MIMALLOC_ARENA_ALLOCATOR
#include <mimalloc.h>

View file

@ -19,7 +19,6 @@
#include <locale>
#include <sstream>
#include <type_traits>
namespace onnxruntime {
@ -74,41 +73,4 @@ inline std::string MakeString(const char* cstr) {
return cstr;
}
/**
* Tries to parse a value from an entire string.
*/
template <typename T>
bool TryParse(const std::string& str, T& value) {
if (std::is_integral<T>::value && std::is_unsigned<T>::value) {
// if T is unsigned integral type, reject negative values which will wrap
if (!str.empty() && str[0] == '-') {
return false;
}
}
// don't allow leading whitespace
if (!str.empty() && std::isspace(str[0], std::locale::classic())) {
return false;
}
std::istringstream is{str};
is.imbue(std::locale::classic());
T parsed_value{};
const bool parse_successful =
is >> parsed_value &&
is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters
if (!parse_successful) {
return false;
}
value = std::move(parsed_value);
return true;
}
inline bool TryParse(const std::string& str, std::string& value) {
value = str;
return true;
}
} // namespace onnxruntime

View file

@ -0,0 +1,84 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <locale>
#include <sstream>
#include <type_traits>
#include "core/common/common.h"
namespace onnxruntime {
/**
* Tries to parse a value from an entire string.
*/
template <typename T>
bool TryParseString(const std::string& str, T& value) {
if (std::is_integral<T>::value && std::is_unsigned<T>::value) {
// if T is unsigned integral type, reject negative values which will wrap
if (!str.empty() && str[0] == '-') {
return false;
}
}
// don't allow leading whitespace
if (!str.empty() && std::isspace(str[0], std::locale::classic())) {
return false;
}
std::istringstream is{str};
is.imbue(std::locale::classic());
T parsed_value{};
const bool parse_successful =
is >> parsed_value &&
is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters
if (!parse_successful) {
return false;
}
value = std::move(parsed_value);
return true;
}
inline bool TryParseString(const std::string& str, std::string& value) {
value = str;
return true;
}
inline bool TryParseString(const std::string& str, bool& value) {
if (str == "0" || str == "False" || str == "false") {
value = false;
return true;
}
if (str == "1" || str == "True" || str == "true") {
value = true;
return true;
}
return false;
}
/**
* Parses a value from an entire string.
*/
template <typename T>
Status ParseString(const std::string& s, T& value) {
ORT_RETURN_IF_NOT(TryParseString(s, value), "Failed to parse value: \"", value, "\"");
return Status::OK();
}
/**
* Parses a value from an entire string.
*/
template <typename T>
T ParseString(const std::string& s) {
T value{};
ORT_THROW_IF_ERROR(ParseString(s, value));
return value;
}
} // namespace onnxruntime

View file

@ -3,37 +3,163 @@
#pragma once
#include <algorithm>
#include <functional>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "core/common/common.h"
#include "core/common/parse_string.h"
#include "core/framework/provider_options.h"
namespace onnxruntime {
template <typename TEnum>
using EnumNameMapping = std::vector<std::pair<TEnum, std::string>>;
/**
* Reads the named provider option.
* Returns true if the option is present, false otherwise.
* Given a mapping and an enumeration value, gets the corresponding name.
*/
template <typename T>
bool ReadProviderOption(const ProviderOptions& options, const std::string& key, T& value) {
auto it = options.find(key);
if (it != options.end()) {
ORT_ENFORCE(
TryParse(it->second, value),
"Failed to parse provider option \"", key, "\" with value \"", it->second, "\".");
return true;
}
return false;
template <typename TEnum>
Status EnumToName(const EnumNameMapping<TEnum>& mapping, TEnum value, std::string& name) {
const auto it = std::find_if(
mapping.begin(), mapping.end(),
[&value](const std::pair<TEnum, std::string>& entry) {
return entry.first == value;
});
ORT_RETURN_IF(
it == mapping.end(),
"Failed to map enum value to name: ", static_cast<typename std::underlying_type<TEnum>::type>(value));
name = it->second;
return Status::OK();
}
template <typename TEnum>
std::string EnumToName(const EnumNameMapping<TEnum>& mapping, TEnum value) {
std::string name;
ORT_THROW_IF_ERROR(EnumToName(mapping, value, name));
return name;
}
/**
* Reads the named provider option.
* Returns the value if the option is present or the specified default value otherwise.
* Given a mapping and a name, gets the corresponding enumeration value.
*/
template <typename T>
T ReadProviderOptionOrDefault(
const ProviderOptions& options, const std::string& key, const T& default_value) {
T value{};
if (ReadProviderOption(options, key, value)) {
return value;
}
return default_value;
template <typename TEnum>
Status NameToEnum(
const EnumNameMapping<TEnum>& mapping, const std::string& name, TEnum& value) {
const auto it = std::find_if(
mapping.begin(), mapping.end(),
[&name](const std::pair<TEnum, std::string>& entry) {
return entry.second == name;
});
ORT_RETURN_IF(
it == mapping.end(),
"Failed to map enum name to value: ", name);
value = it->first;
return Status::OK();
}
template <typename TEnum>
TEnum NameToEnum(const EnumNameMapping<TEnum>& mapping, const std::string& name) {
TEnum value;
ORT_THROW_IF_ERROR(NameToEnum(mapping, name, value));
return value;
}
class ProviderOptionsParser {
public:
/**
* Adds a parser for a particular provider option value.
*
* @param name The provider option name.
* @param value_parser An object that parses the option value.
* It should be callable with the following signature and return
* whether the parsing was successful:
* Status value_parser(const std::string&)
*
* @return The current ProviderOptionsParser instance.
*/
template <typename ValueParserType>
ProviderOptionsParser& AddValueParser(
const std::string& name, ValueParserType value_parser) {
ORT_ENFORCE(
value_parsers_.emplace(name, ValueParser{value_parser}).second,
"Provider option \"", name, "\" already has a value parser.");
return *this;
}
/**
* Adds a parser for a particular provider option value which converts a
* value to the right type and assigns it to the given reference.
*
* IMPORTANT: This function stores a reference to the destination variable.
* The caller must ensure that the reference is valid when Parse() is called!
*
* @param name The provider option name.
* @param dest The destination variable reference.
*
* @return The current ProviderOptionsParser instance.
*/
template <typename ValueType>
ProviderOptionsParser& AddAssignmentToReference(
const std::string& name, ValueType& dest) {
return AddValueParser(
name,
[&dest](const std::string& value_str) -> Status {
return ParseString(value_str, dest);
});
}
/**
* Adds a parser for a particular provider option value which maps an
* enumeration name to a value and assigns it to the given reference.
*
* IMPORTANT: This function stores references to the mapping and destination
* variables. The caller must ensure that the references are valid when
* Parse() is called!
*
* @param name The provider option name.
* @param mapping The enumeration value to name mapping.
* @param dest The destination variable reference.
*
* @return The current ProviderOptionsParser instance.
*/
template <typename EnumType>
ProviderOptionsParser& AddAssignmentToEnumReference(
const std::string& name, const EnumNameMapping<EnumType>& mapping, EnumType& dest) {
return AddValueParser(
name,
[&mapping, &dest](const std::string& value_str) -> Status {
return NameToEnum(mapping, value_str, dest);
});
}
/**
* Parses the given provider options.
*/
Status Parse(const ProviderOptions& options) const {
for (const auto& option : options) {
const auto& name = option.first;
const auto& value_str = option.second;
const auto value_parser_it = value_parsers_.find(name);
ORT_RETURN_IF(
value_parser_it == value_parsers_.end(),
"Unknown provider option: \"", name, "\".");
const auto parse_status = value_parser_it->second(value_str);
ORT_RETURN_IF_NOT(
parse_status.IsOK(),
"Failed to parse provider option \"", name, "\": ", parse_status.ErrorMessage());
}
return Status::OK();
}
private:
using ValueParser = std::function<Status(const std::string&)>;
std::unordered_map<std::string, ValueParser> value_parsers_;
};
} // namespace onnxruntime

View file

@ -3,11 +3,8 @@
#pragma once
#include <vector>
#include "core/common/common.h"
namespace onnxruntime {
constexpr const char* kNoOp = "NoOp";
constexpr const char* kConstant = "Constant";
constexpr const char* kFunctionOp = "_kFunctionOp";
@ -22,10 +19,10 @@ constexpr const char* kMSDmlDomain = "com.microsoft.dml";
constexpr const char* kNGraphDomain = "com.intel.ai";
constexpr const char* kMIGraphXDomain = "";
constexpr const char* kVitisAIDomain = "com.xilinx";
constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider";
constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider";
constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider";
constexpr const char* kNGraphExecutionProvider = "NGRAPHExecutionProvider";
constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider";
constexpr const char* kNupharExecutionProvider = "NupharExecutionProvider";
constexpr const char* kVitisAIExecutionProvider = "VitisAIExecutionProvider";
@ -38,50 +35,4 @@ constexpr const char* kAclExecutionProvider = "ACLExecutionProvider";
constexpr const char* kArmNNExecutionProvider = "ArmNNExecutionProvider";
constexpr const char* kRocmExecutionProvider = "ROCMExecutionProvider";
constexpr const char* providers_available[] = {
kCpuExecutionProvider,
#ifdef USE_CUDA
kCudaExecutionProvider,
#endif
#ifdef USE_DNNL
kDnnlExecutionProvider,
#endif
#ifdef USE_NGRAPH
kNGraphExecutionProvider,
#endif
#ifdef USE_OPENVINO
kOpenVINOExecutionProvider,
#endif
#ifdef USE_NUPHAR
kNupharExecutionProvider,
#endif
#ifdef USE_VITISAI
kVitisAIExecutionProvider,
#endif
#ifdef USE_TENSORRT
kTensorrtExecutionProvider,
#endif
#ifdef USE_NNAPI
kNnapiExecutionProvider,
#endif
#ifdef USE_RKNPU
kRknpuExecutionProvider,
#endif
#ifdef USE_DML
kDmlExecutionProvider,
#endif
#ifdef USE_MIGRAPHX
kMIGraphXExecutionProvider,
#endif
#ifdef USE_ACL
kAclExecutionProvider,
#endif
#ifdef USE_ARMNN
kArmNNExecutionProvider,
#endif
#ifdef USE_ROCM
kRocmExecutionProvider,
#endif
};
} // namespace onnxruntime

View file

@ -26,11 +26,6 @@ from onnxruntime.capi._pybind_state import get_all_providers, get_available_prov
NodeArg, ModelMetadata, GraphOptimizationLevel, ExecutionMode, ExecutionOrder, OrtDevice, SessionIOBinding, \
OrtAllocatorType, OrtMemType, OrtArenaCfg, OrtMemoryInfo, create_and_register_allocator
try:
from onnxruntime.capi._pybind_state import set_cuda_mem_limit, set_cuda_device_id
except ImportError:
pass
from onnxruntime.capi.onnxruntime_inference_collection import InferenceSession, IOBinding, OrtValue
from onnxruntime.capi import onnxruntime_validation

View file

@ -3,8 +3,7 @@
#pragma once
#include <iostream>
#include <string>
#include <cstdint>
namespace onnxruntime {
@ -13,33 +12,4 @@ enum class ArenaExtendStrategy : int32_t {
kSameAsRequested,
};
inline std::istream& operator>>(std::istream& is, ArenaExtendStrategy& value) {
std::string value_str;
if (is >> value_str) {
if (value_str == "kNextPowerOfTwo") {
value = ArenaExtendStrategy::kNextPowerOfTwo;
} else if (value_str == "kSameAsRequested") {
value = ArenaExtendStrategy::kSameAsRequested;
} else {
is.setstate(std::ios_base::failbit);
}
}
return is;
}
inline std::ostream& operator<<(std::ostream& os, ArenaExtendStrategy value) {
switch (value) {
case ArenaExtendStrategy::kNextPowerOfTwo:
os << "kNextPowerOfTwo";
break;
case ArenaExtendStrategy::kSameAsRequested:
os << "kSameAsRequested";
break;
default:
os << "unknown";
break;
}
return os;
}
} // namespace onnxruntime

View file

@ -5,6 +5,7 @@
#include "core/common/common.h"
#include "core/common/optional.h"
#include "core/common/parse_string.h"
#include "core/platform/env.h"
namespace onnxruntime {
@ -20,7 +21,7 @@ optional<T> ParseEnvironmentVariable(const std::string& name) {
T parsed_value;
ORT_ENFORCE(
TryParse(value_str, parsed_value),
TryParseString(value_str, parsed_value),
"Failed to parse environment variable - name: \"", name, "\", value: \"", value_str, "\"");
return parsed_value;

View file

@ -71,7 +71,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
int GetDeviceId() const override { return info_.device_id; }
const cudaDeviceProp& GetDeviceProp() const { return device_prop_; };
int GetCudnnConvAlgo() const { return info_.cudnn_conv_algo; }
int GetCudnnConvAlgo() const { return info_.cudnn_conv_algo_search; }
ProviderOptions GetProviderOptions() const override {
return CUDAExecutionProviderInfo::ToProviderOptions(info_);

View file

@ -3,58 +3,76 @@
#include "core/providers/cuda/cuda_execution_provider_info.h"
#include "core/common/string_utils.h"
#include "core/common/make_string.h"
#include "core/common/parse_string.h"
#include "core/framework/provider_options_utils.h"
#include "core/providers/cuda/cuda_common.h"
namespace onnxruntime {
namespace cuda {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kMemLimit = "cuda_mem_limit";
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
constexpr const char* kCudnnConvAlgo = "cudnn_conv_algo";
constexpr const char* kCudnnConvAlgoSearch = "cudnn_conv_algo_search";
constexpr const char* kDoCopyInDefaultStream = "do_copy_in_default_stream";
} // namespace provider_option_names
} // namespace cuda
namespace {
const EnumNameMapping<OrtCudnnConvAlgoSearch> ort_cudnn_conv_algo_search_mapping{
{OrtCudnnConvAlgoSearch::EXHAUSTIVE, "EXHAUSTIVE"},
{OrtCudnnConvAlgoSearch::HEURISTIC, "HEURISTIC"},
{OrtCudnnConvAlgoSearch::DEFAULT, "DEFAULT"},
};
const EnumNameMapping<ArenaExtendStrategy> arena_extend_strategy_mapping{
{ArenaExtendStrategy::kNextPowerOfTwo, "kNextPowerOfTwo"},
{ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"},
};
} // namespace
CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) {
CUDAExecutionProviderInfo info{};
if (ReadProviderOption(options, provider_option_names::kDeviceId, info.device_id)) {
int num_devices = 0;
CUDA_CALL_THROW(cudaGetDeviceCount(&num_devices));
ORT_ENFORCE(
0 <= info.device_id && info.device_id < num_devices,
"Invalid ", provider_option_names::kDeviceId, " value: ", info.device_id,
", must be between 0 (inclusive) and ", num_devices, " (exclusive).");
}
ReadProviderOption(options, provider_option_names::kMemLimit, info.cuda_mem_limit);
ReadProviderOption(options, provider_option_names::kArenaExtendStrategy, info.arena_extend_strategy);
{
int cudnn_conv_algo_val;
if (ReadProviderOption(options, provider_option_names::kCudnnConvAlgo, cudnn_conv_algo_val)) {
switch (cudnn_conv_algo_val) {
case OrtCudnnConvAlgoSearch::EXHAUSTIVE:
case OrtCudnnConvAlgoSearch::HEURISTIC:
case OrtCudnnConvAlgoSearch::DEFAULT:
break;
default:
ORT_THROW("Invalid ", provider_option_names::kCudnnConvAlgo, " value: ", cudnn_conv_algo_val);
}
info.cudnn_conv_algo = static_cast<OrtCudnnConvAlgoSearch>(cudnn_conv_algo_val);
}
}
ReadProviderOption(options, provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream);
ORT_THROW_IF_ERROR(
ProviderOptionsParser{}
.AddValueParser(
cuda::provider_option_names::kDeviceId,
[&info](const std::string& value_str) -> Status {
ORT_RETURN_IF_ERROR(ParseString(value_str, info.device_id));
int num_devices{};
ORT_RETURN_IF_NOT(
CUDA_CALL(cudaGetDeviceCount(&num_devices)),
"cudaGetDeviceCount() failed.");
ORT_RETURN_IF_NOT(
0 <= info.device_id && info.device_id < num_devices,
"Invalid device ID: ", info.device_id,
", must be between 0 (inclusive) and ", num_devices, " (exclusive).");
return Status::OK();
})
.AddAssignmentToReference(cuda::provider_option_names::kMemLimit, info.cuda_mem_limit)
.AddAssignmentToEnumReference(
cuda::provider_option_names::kArenaExtendStrategy,
arena_extend_strategy_mapping, info.arena_extend_strategy)
.AddAssignmentToEnumReference(
cuda::provider_option_names::kCudnnConvAlgoSearch,
ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search)
.AddAssignmentToReference(cuda::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream)
.Parse(options));
return info;
}
ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecutionProviderInfo& info) {
const ProviderOptions options{
{provider_option_names::kDeviceId, MakeString(info.device_id)},
{provider_option_names::kMemLimit, MakeString(info.cuda_mem_limit)},
{provider_option_names::kArenaExtendStrategy, MakeString(info.arena_extend_strategy)},
{provider_option_names::kCudnnConvAlgo, MakeString(static_cast<int>(info.cudnn_conv_algo))},
{provider_option_names::kDoCopyInDefaultStream, MakeString(info.do_copy_in_default_stream)},
{cuda::provider_option_names::kDeviceId, MakeString(info.device_id)},
{cuda::provider_option_names::kMemLimit, MakeString(info.cuda_mem_limit)},
{cuda::provider_option_names::kArenaExtendStrategy,
EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
{cuda::provider_option_names::kCudnnConvAlgoSearch,
EnumToName(ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search)},
{cuda::provider_option_names::kDoCopyInDefaultStream, MakeString(info.do_copy_in_default_stream)},
};
return options;

View file

@ -16,7 +16,7 @@ struct CUDAExecutionProviderInfo {
OrtDevice::DeviceId device_id{0};
size_t cuda_mem_limit{std::numeric_limits<size_t>::max()};
ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo};
OrtCudnnConvAlgoSearch cudnn_conv_algo{OrtCudnnConvAlgoSearch::EXHAUSTIVE};
OrtCudnnConvAlgoSearch cudnn_conv_algo_search{OrtCudnnConvAlgoSearch::EXHAUSTIVE};
bool do_copy_in_default_stream{true};
static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);

View file

@ -55,7 +55,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_CUDA,
info.device_id = gsl::narrow<OrtDevice::DeviceId>(cuda_options->device_id);
info.cuda_mem_limit = cuda_options->cuda_mem_limit;
info.arena_extend_strategy = static_cast<onnxruntime::ArenaExtendStrategy>(cuda_options->arena_extend_strategy);
info.cudnn_conv_algo = cuda_options->cudnn_conv_algo_search;
info.cudnn_conv_algo_search = cuda_options->cudnn_conv_algo_search;
info.do_copy_in_default_stream = cuda_options->do_copy_in_default_stream;
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_CUDA(info));

View file

@ -0,0 +1,154 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/get_execution_providers.h"
#include "core/graph/constants.h"
namespace onnxruntime {
namespace {
struct ProviderInfo {
const char* name;
bool available;
};
// all providers ordered by default priority from highest to lowest
// kCpuExecutionProvider should always be last
constexpr ProviderInfo kProvidersInPriorityOrder[] =
{
{
kTensorrtExecutionProvider,
#ifdef USE_TENSORRT
true,
#else
false,
#endif
},
{
kCudaExecutionProvider,
#ifdef USE_CUDA
true,
#else
false,
#endif
},
{
kMIGraphXExecutionProvider,
#ifdef USE_MIGRAPHX
true,
#else
false,
#endif
},
{
kRocmExecutionProvider,
#ifdef USE_ROCM
true,
#else
false,
#endif
},
{
kOpenVINOExecutionProvider,
#ifdef USE_OPENVINO
true,
#else
false,
#endif
},
{
kDnnlExecutionProvider,
#ifdef USE_DNNL
true,
#else
false,
#endif
},
{
kNupharExecutionProvider,
#ifdef USE_NUPHAR
true,
#else
false,
#endif
},
{
kVitisAIExecutionProvider,
#ifdef USE_VITISAI
true,
#else
false,
#endif
},
{
kNnapiExecutionProvider,
#ifdef USE_NNAPI
true,
#else
false,
#endif
},
{
kArmNNExecutionProvider,
#ifdef USE_ARMNN
true,
#else
false,
#endif
},
{
kAclExecutionProvider,
#ifdef USE_ACL
true,
#else
false,
#endif
},
{
kDmlExecutionProvider,
#ifdef USE_DML
true,
#else
false,
#endif
},
{
kRknpuExecutionProvider,
#ifdef USE_RKNPU
true,
#else
false,
#endif
},
{kCpuExecutionProvider, true}, // kCpuExecutionProvider is always last
};
} // namespace
const std::vector<std::string>& GetAllExecutionProviderNames() {
static const auto all_execution_providers = []() {
std::vector<std::string> result{};
for (const auto& provider : kProvidersInPriorityOrder) {
result.push_back(provider.name);
}
return result;
}();
return all_execution_providers;
}
const std::vector<std::string>& GetAvailableExecutionProviderNames() {
static const auto available_execution_providers = []() {
std::vector<std::string> result{};
for (const auto& provider : kProvidersInPriorityOrder) {
if (provider.available) {
result.push_back(provider.name);
}
}
return result;
}();
return available_execution_providers;
}
} // namespace onnxruntime

View file

@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include <vector>
namespace onnxruntime {
/**
* Gets the names of all execution providers, in order of decreasing default
* priority.
*/
const std::vector<std::string>& GetAllExecutionProviderNames();
/**
* Gets the names of execution providers available in this build, in order of
* decreasing default priority.
*/
const std::vector<std::string>& GetAvailableExecutionProviderNames();
} // namespace onnxruntime

View file

@ -3,32 +3,47 @@
#include "core/providers/rocm/rocm_execution_provider_info.h"
#include "core/common/string_utils.h"
#include "core/common/make_string.h"
#include "core/framework/provider_options_utils.h"
namespace onnxruntime {
namespace rocm {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kMemLimit = "hip_mem_limit";
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
} // namespace provider_option_names
} // namespace rocm
namespace {
const EnumNameMapping<ArenaExtendStrategy> arena_extend_strategy_mapping{
{ArenaExtendStrategy::kNextPowerOfTwo, "kNextPowerOfTwo"},
{ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"},
};
} // namespace
ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) {
ROCMExecutionProviderInfo info{};
// TODO validate info.device_id
ReadProviderOption(options, provider_option_names::kDeviceId, info.device_id);
ReadProviderOption(options, provider_option_names::kMemLimit, info.hip_mem_limit);
ReadProviderOption(options, provider_option_names::kArenaExtendStrategy, info.arena_extend_strategy);
ORT_THROW_IF_ERROR(
ProviderOptionsParser{}
// TODO validate info.device_id
.AddAssignmentToReference(rocm::provider_option_names::kDeviceId, info.device_id)
.AddAssignmentToReference(rocm::provider_option_names::kMemLimit, info.hip_mem_limit)
.AddAssignmentToEnumReference(
rocm::provider_option_names::kArenaExtendStrategy,
arena_extend_strategy_mapping, info.arena_extend_strategy)
.Parse(options));
return info;
}
ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecutionProviderInfo& info) {
const ProviderOptions options{
{provider_option_names::kDeviceId, MakeString(info.device_id)},
{provider_option_names::kMemLimit, MakeString(info.hip_mem_limit)},
{provider_option_names::kArenaExtendStrategy, MakeString(info.arena_extend_strategy)},
{rocm::provider_option_names::kDeviceId, MakeString(info.device_id)},
{rocm::provider_option_names::kMemLimit, MakeString(info.hip_mem_limit)},
{rocm::provider_option_names::kArenaExtendStrategy,
EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
};
return options;

View file

@ -23,6 +23,7 @@
#include "core/framework/allocator.h"
#include "core/framework/tensor.h"
#include "core/framework/ml_value.h"
#include "core/providers/get_execution_providers.h"
#include "core/session/environment.h"
#include "core/framework/callback.h"
#include "core/framework/tensorprotoutils.h"
@ -1713,19 +1714,20 @@ ORT_API_STATUS_IMPL(OrtApis::GetAvailableProviders, _Outptr_ char*** out_ptr,
_In_ int* providers_length) {
API_IMPL_BEGIN
const size_t MAX_LEN = 30;
int available_count = (int)(sizeof(providers_available) / sizeof(char*));
char** out = (char**)malloc(available_count * sizeof(char*));
const auto& available_providers = GetAvailableExecutionProviderNames();
const int available_count = gsl::narrow<int>(available_providers.size());
char** const out = (char**)malloc(available_count * sizeof(char*));
if (out) {
for (int i = 0; i < available_count; i++) {
out[i] = (char*)malloc((MAX_LEN + 1) * sizeof(char));
if (out[i]) {
#ifdef _MSC_VER
strncpy_s(out[i], MAX_LEN, providers_available[i], MAX_LEN);
strncpy_s(out[i], MAX_LEN, available_providers[i].c_str(), MAX_LEN);
out[i][MAX_LEN] = '\0';
#elif defined(__APPLE__)
strlcpy(out[i], providers_available[i], MAX_LEN);
strlcpy(out[i], available_providers[i].c_str(), MAX_LEN);
#else
strncpy(out[i], providers_available[i], MAX_LEN);
strncpy(out[i], available_providers[i].c_str(), MAX_LEN);
out[i][MAX_LEN] = '\0';
#endif
}
@ -1734,7 +1736,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetAvailableProviders, _Outptr_ char*** out_ptr,
*providers_length = available_count;
*out_ptr = out;
API_IMPL_END
return NULL;
return nullptr;
}
ORT_API_STATUS_IMPL(OrtApis::ReleaseAvailableProviders, _In_ char** ptr,

View file

@ -2,7 +2,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import collections
import collections.abc
import os
import warnings
from onnxruntime.capi import _pybind_state as C
@ -17,6 +20,78 @@ def get_ort_device_type(device):
raise Exception('Unsupported device type: ' + device)
def check_and_normalize_provider_args(providers, provider_options, available_provider_names):
"""
Validates the 'providers' and 'provider_options' arguments and returns a
normalized version.
:param providers: Optional sequence of providers in order of decreasing
precedence. Values can either be provider names or tuples of
(provider name, options dict).
:param provider_options: Optional sequence of options dicts corresponding
to the providers listed in 'providers'.
:param available_provider_names: The available provider names.
:return: Tuple of (normalized 'providers' sequence, normalized
'provider_options' sequence).
'providers' can contain either names or names and options. When any options
are given in 'providers', 'provider_options' should not be used.
The normalized result is a tuple of:
1. Sequence of provider names in the same order as 'providers'.
2. Sequence of corresponding provider options dicts with string keys and
values. Unspecified provider options yield empty dicts.
"""
if providers is None:
return [], []
provider_name_to_options = collections.OrderedDict()
def set_provider_options(name, options):
if name not in available_provider_names:
raise ValueError("Specified provider '{}' is unavailable. Available providers: '{}'".format(
name, ", ".join(available_provider_names)))
if name in provider_name_to_options:
warnings.warn("Duplicate provider '{}' encountered, ignoring.".format(name))
return
normalized_options = {str(key): str(value) for key, value in options.items()}
provider_name_to_options[name] = normalized_options
if not isinstance(providers, collections.abc.Sequence):
raise ValueError("'providers' should be a sequence.")
if provider_options is not None:
if not isinstance(provider_options, collections.abc.Sequence):
raise ValueError("'provider_options' should be a sequence.")
if len(providers) != len(provider_options):
raise ValueError("'providers' and 'provider_options' should be the same length if both are given.")
if not all([isinstance(provider, str) for provider in providers]):
raise ValueError("Only string values for 'providers' are supported if 'provider_options' is given.")
if not all([isinstance(options_for_provider, dict) for options_for_provider in provider_options]):
raise ValueError("'provider_options' values must be dicts.")
for name, options in zip(providers, provider_options):
set_provider_options(name, options)
else:
for provider in providers:
if isinstance(provider, str):
set_provider_options(provider, dict())
elif isinstance(provider, tuple) and len(provider) == 2 and \
isinstance(provider[0], str) and isinstance(provider[1], dict):
set_provider_options(provider[0], provider[1])
else:
raise ValueError("'providers' values must be either strings or (string, dict) tuples.")
return list(provider_name_to_options.keys()), list(provider_name_to_options.values())
class Session:
"""
This is the main class used to run a model.
@ -55,34 +130,23 @@ class Session:
"Return registered execution providers' configurations."
return self._provider_options
def set_providers(self, providers, provider_options=None):
def set_providers(self, providers=None, provider_options=None):
"""
Register the input list of execution providers. The underlying session is re-created.
:param providers: list of execution providers
:param provider_options: list of provider options dict for each provider, in the same order as 'providers'
:param providers: Optional sequence of providers in order of decreasing
precedence. Values can either be provider names or tuples of
(provider name, options dict). If not provided, then all available
providers are used with the default precedence.
:param provider_options: Optional sequence of options dicts corresponding
to the providers listed in 'providers'.
The list of providers is ordered by Priority. For example ['CUDAExecutionProvider', 'CPUExecutionProvider']
'providers' can contain either names or names and options. When any options
are given in 'providers', 'provider_options' should not be used.
The list of providers is ordered by precedence. For example ['CUDAExecutionProvider', 'CPUExecutionProvider']
means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider.
"""
if not set(providers).issubset(C.get_available_providers()):
raise ValueError("{} does not contain a subset of available providers {}".format(
providers, C.get_available_providers()))
if provider_options:
if not isinstance(providers, list) or not isinstance(provider_options, list):
raise ValueError("Inputs must be two python lists.")
if len(providers) != len(provider_options):
raise ValueError("Two input lists must have same length.")
for option in provider_options:
if not isinstance(option, dict):
raise ValueError("Provider options must be list of python dict.")
for key, val in option.items():
option[key] = str(val)
# recreate the underlying C.InferenceSession
self._reset_session(providers, provider_options)
@ -173,8 +237,12 @@ class InferenceSession(Session):
"""
:param path_or_bytes: filename or serialized ONNX or ORT format model in a byte string
:param sess_options: session options
:param providers: list of providers to use for session. If empty, will use all available providers.
:param provider_options: list of provider options dict for each provider, in the same order as 'providers'
:param providers: Optional sequence of providers in order of decreasing
precedence. Values can either be provider names or tuples of
(provider name, options dict). If not provided, then all available
providers are used with the default precedence.
:param provider_options: Optional sequence of options dicts corresponding
to the providers listed in 'providers'.
The model type will be inferred unless explicitly set in the SessionOptions.
To explicitly set:
@ -184,6 +252,12 @@ class InferenceSession(Session):
A file extension of '.ort' will be inferred as an ORT format model.
All other filenames are assumed to be ONNX format models.
'providers' can contain either names or names and options. When any options
are given in 'providers', 'provider_options' should not be used.
The list of providers is ordered by precedence. For example ['CUDAExecutionProvider', 'CPUExecutionProvider']
means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider.
"""
Session.__init__(self)
@ -208,15 +282,22 @@ class InferenceSession(Session):
if self._enable_fallback:
print("EP Error using {}".format(self._providers))
print("Falling back to {} and retrying.".format(self._fallback_providers))
self._create_inference_session(self._fallback_providers)
self._create_inference_session(self._fallback_providers, None)
# Fallback only once.
self.disable_fallback()
else:
raise
def _create_inference_session(self, providers, provider_options):
available_providers = C.get_available_providers()
# validate providers and provider_options before other initialization
providers, provider_options = check_and_normalize_provider_args(providers,
provider_options,
available_providers)
# Tensorrt can fall back to CUDA. All others fall back to CPU.
if 'TensorrtExecutionProvider' in C.get_available_providers():
if 'TensorrtExecutionProvider' in available_providers:
self._fallback_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
else:
self._fallback_providers = ['CPUExecutionProvider']
@ -228,7 +309,7 @@ class InferenceSession(Session):
sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model)
# initialize the C++ InferenceSession
sess.initialize_session(providers or [], provider_options or [])
sess.initialize_session(providers, provider_options)
self._sess = sess
self._sess_options = self._sess.session_options

View file

@ -15,6 +15,7 @@
#include "core/framework/arena_extend_strategy.h"
#include "core/framework/data_transfer_utils.h"
#include "core/framework/data_types_internal.h"
#include "core/providers/get_execution_providers.h"
#include "core/framework/kernel_registry.h"
#include "core/framework/provider_options_utils.h"
#include "core/framework/random_seed.h"
@ -200,6 +201,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ACL(in
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ArmNN(int use_arena);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(int device_id);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi(uint32_t flags);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Rknpu();
} // namespace onnxruntime
#if defined(_MSC_VER)
@ -425,27 +427,6 @@ static inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime
OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(p)));
}
// ordered by default priority from highest to lowest. kCpuExecutionProvider should always be last.
static const std::vector<std::string>& GetAllProviders() {
static std::vector<std::string> all_providers = {kTensorrtExecutionProvider, kCudaExecutionProvider,
kMIGraphXExecutionProvider, kRocmExecutionProvider,
kOpenVINOExecutionProvider, kDnnlExecutionProvider,
kNupharExecutionProvider, kVitisAIExecutionProvider,
kNnapiExecutionProvider,
kArmNNExecutionProvider, kAclExecutionProvider,
kDmlExecutionProvider, kCpuExecutionProvider};
return all_providers;
}
static const std::vector<std::string>& GetAvailableProviders() {
auto InitializeProviders = []() {
std::vector<std::string> available_providers(std::begin(providers_available), std::end(providers_available));
return available_providers;
};
static std::vector<std::string> available_providers = InitializeProviders();
return available_providers;
}
#ifdef USE_CUDA
static bool IsCudaDeviceIdValid(const onnxruntime::logging::Logger& logger, int id) {
@ -537,7 +518,7 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
info.device_id = cuda_device_id;
info.cuda_mem_limit = cuda_mem_limit;
info.arena_extend_strategy = arena_extend_strategy;
info.cudnn_conv_algo = cudnn_conv_algo_search;
info.cudnn_conv_algo_search = cudnn_conv_algo_search;
info.do_copy_in_default_stream = do_copy_in_default_stream;
return info;
}();
@ -605,7 +586,10 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
#if USE_NUPHAR
const auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
ReadProviderOption(it->second, "nuphar_settings", nuphar_settings);
ORT_THROW_IF_ERROR(
ProviderOptionsParser{}
.AddAssignmentToReference("nuphar_settings", nuphar_settings)
.Parse(it->second));
}
RegisterExecutionProvider(
@ -638,6 +622,10 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
LOGS_DEFAULT(WARNING) << "NNAPI execution provider can only be used to generate ORT format model in this build.";
#endif
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_Nnapi(0));
#endif
} else if (type == kRknpuExecutionProvider) {
#ifdef USE_RKNPU
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_Rknpu());
#endif
} else {
// unknown provider
@ -697,7 +685,7 @@ void InitializeSession(InferenceSession* sess, const std::vector<std::string>& p
if (provider_types.empty()) {
// use default registration priority.
RegisterExecutionProviders(sess, GetAllProviders(), provider_options_map);
RegisterExecutionProviders(sess, GetAllExecutionProviderNames(), provider_options_map);
} else {
RegisterExecutionProviders(sess, provider_types, provider_options_map);
}
@ -752,13 +740,15 @@ void addGlobalMethods(py::module& m, Environment& env) {
},
"Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal");
m.def(
"get_all_providers", []() -> const std::vector<std::string>& { return GetAllProviders(); },
"get_all_providers", []() -> const std::vector<std::string>& { return GetAllExecutionProviderNames(); },
"Return list of Execution Providers that this version of Onnxruntime can support. "
"The order of elements represents the default priority order of Execution Providers"
" from highest to lowest.");
"The order of elements represents the default priority order of Execution Providers "
"from highest to lowest.");
m.def(
"get_available_providers", []() -> const std::vector<std::string>& { return GetAvailableProviders(); },
"Return list of available Execution Providers available in this installed version of Onnxruntime.");
"get_available_providers", []() -> const std::vector<std::string>& { return GetAvailableExecutionProviderNames(); },
"Return list of available Execution Providers available in this installed version of Onnxruntime. "
"The order of elements represents the default priority order of Execution Providers "
"from highest to lowest.");
m.def(
"enable_telemetry_events", []() -> void { platform_env.GetTelemetryProvider().EnableTelemetryEvents(); },
"Enables platform-specific telemetry collection where applicable.");
@ -833,7 +823,7 @@ void addGlobalMethods(py::module& m, Environment& env) {
info.device_id = cuda_device_id;
info.cuda_mem_limit = cuda_mem_limit;
info.arena_extend_strategy = arena_extend_strategy;
info.cudnn_conv_algo = cudnn_conv_algo_search;
info.cudnn_conv_algo_search = cudnn_conv_algo_search;
info.do_copy_in_default_stream = do_copy_in_default_stream;
return info;
}()),
@ -874,6 +864,9 @@ void addGlobalMethods(py::module& m, Environment& env) {
#endif
#ifdef USE_NNAPI
onnxruntime::CreateExecutionProviderFactory_NNAPI(0),
#endif
#ifdef USE_RKNPU
onnxruntime::CreateExecutionProviderFactory_Rknpu(),
#endif
};
@ -905,7 +898,7 @@ void addGlobalMethods(py::module& m, Environment& env) {
});
// TODO remove deprecated global config
m.def("set_cudnn_conv_algo_search", [](const OrtCudnnConvAlgoSearch algo) {
LogDeprecationWarning("set_cudnn_conv_algo_search", "CUDA execution provider option \"cudnn_conv_algo\"");
LogDeprecationWarning("set_cudnn_conv_algo_search", "CUDA execution provider option \"cudnn_conv_algo_search\"");
#ifdef USE_ROCM
ORT_UNUSED_PARAMETER(algo);
ORT_THROW("set_cudnn_conv_algo_search is not supported in ROCM");
@ -1084,11 +1077,6 @@ void addObjectMethods(py::module& m, Environment& env) {
.value("CPU", OrtMemType::OrtMemTypeCPU)
.value("DEFAULT", OrtMemType::OrtMemTypeDefault);
py::enum_<OrtCudnnConvAlgoSearch>(m, "OrtCudnnConvAlgoSearch")
.value("EXHAUSTIVE", OrtCudnnConvAlgoSearch::EXHAUSTIVE)
.value("HEURISTIC", OrtCudnnConvAlgoSearch::HEURISTIC)
.value("DEFAULT", OrtCudnnConvAlgoSearch::DEFAULT);
py::class_<OrtDevice> device(m, "OrtDevice", R"pbdoc(ONNXRuntime device informaion.)pbdoc");
device.def(py::init<OrtDevice::DeviceType, OrtDevice::MemoryType, OrtDevice::DeviceId>())
.def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc")

View file

@ -11,9 +11,6 @@
namespace onnxruntime {
namespace python {
using namespace onnxruntime;
using namespace onnxruntime::logging;
#if !defined(ORT_MINIMAL_BUILD)
struct CustomOpLibrary {
CustomOpLibrary(const char* library_path, OrtSessionOptions& ort_so);

View file

@ -1,7 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/common/string_utils.h"
#include "core/common/make_string.h"
#include "core/common/parse_string.h"
#include "gtest/gtest.h"
@ -12,29 +13,26 @@ namespace {
template <typename T>
void TestSuccessfulParse(const std::string& input, const T& expected_value) {
T value;
ASSERT_TRUE(TryParse(input, value));
ASSERT_TRUE(TryParseString(input, value));
EXPECT_EQ(value, expected_value);
}
template <typename T>
void TestFailedParse(const std::string& input) {
T value;
EXPECT_FALSE(TryParse(input, value));
EXPECT_FALSE(TryParseString(input, value));
}
} // namespace
TEST(StringUtilsTest, TryParse) {
TEST(StringUtilsTest, TryParseString) {
TestSuccessfulParse("-1", -1);
TestSuccessfulParse("42", 42u);
TestSuccessfulParse("2.5", 2.5f);
TestSuccessfulParse("1", true);
TestSuccessfulParse("0", false);
// out of range
TestFailedParse<int16_t>("32768");
TestFailedParse<uint32_t>("-1");
TestFailedParse<float>("1e100");
TestFailedParse<bool>("2");
// invalid representation
TestFailedParse<int32_t>("1.2");
TestFailedParse<int32_t>("one");
@ -43,12 +41,21 @@ TEST(StringUtilsTest, TryParse) {
TestFailedParse<int32_t>("1 ");
}
TEST(StringUtilsTest, TryParseString) {
TEST(StringUtilsTest, TryParseStringAsString) {
// when parsing a string as a string, allow leading and trailing whitespace
const std::string s = " this is a string! ";
TestSuccessfulParse(s, s);
}
TEST(StringUtilsTest, TryParseStringAsBool) {
TestSuccessfulParse("True", true);
TestSuccessfulParse("1", true);
TestSuccessfulParse("False", false);
TestSuccessfulParse("0", false);
TestFailedParse<bool>("2");
}
namespace {
struct S {
int i{};
@ -73,12 +80,12 @@ std::istream& operator>>(std::istream& is, S& s) {
}
} // namespace
TEST(StringUtilsTest, MakeStringAndTryParseCustomType) {
TEST(StringUtilsTest, MakeStringAndTryParseStringWithCustomType) {
S s;
s.i = 42;
const auto str = MakeString(s);
S parsed_s;
ASSERT_TRUE(TryParse(str, parsed_s));
ASSERT_TRUE(TryParseString(str, parsed_s));
ASSERT_EQ(parsed_s, s);
}

View file

@ -0,0 +1,59 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/provider_options_utils.h"
#include "gtest/gtest.h"
#include "asserts.h"
namespace onnxruntime {
namespace test {
namespace {
enum class TestEnum {
A,
Unmapped,
};
const EnumNameMapping<TestEnum> test_enum_mapping{
{TestEnum::A, "A"},
};
} // namespace
TEST(ProviderOptionsUtilsTest, ProviderOptionsParser) {
int i;
bool b;
TestEnum e;
ProviderOptionsParser parser{};
parser.AddAssignmentToReference("int", i);
parser.AddAssignmentToReference("bool", b);
parser.AddAssignmentToEnumReference("enum", test_enum_mapping, e);
// adding same option again should throw
ASSERT_THROW(parser.AddAssignmentToReference("int", i), OnnxRuntimeException);
ASSERT_STATUS_OK(parser.Parse({{"int", "3"}, {"bool", "true"}, {"enum", "A"}}));
EXPECT_EQ(i, 3);
EXPECT_EQ(b, true);
EXPECT_EQ(e, TestEnum::A);
ASSERT_FALSE(parser.Parse({{"unknown option", "some value"}}).IsOK());
}
TEST(ProviderOptionsUtilsTest, EnumToName) {
std::string name;
ASSERT_STATUS_OK(EnumToName(test_enum_mapping, TestEnum::A, name));
EXPECT_EQ(name, "A");
ASSERT_FALSE(EnumToName(test_enum_mapping, TestEnum::Unmapped, name).IsOK());
}
TEST(ProviderOptionsUtilsTest, NameToEnum) {
TestEnum value;
ASSERT_STATUS_OK(NameToEnum(test_enum_mapping, "A", value));
EXPECT_EQ(value, TestEnum::A);
ASSERT_FALSE(NameToEnum(test_enum_mapping, "invalid", value).IsOK());
}
} // namespace test
} // namespace onnxruntime

View file

@ -11,7 +11,7 @@
#include "mem_buffer.h"
#include "core/common/safeint.h"
#include "core/common/status.h"
#include "core/common/string_utils.h"
#include "core/common/make_string.h"
#include "core/framework/data_types.h"
#include "core/framework/endian.h"
#include "core/framework/allocator.h"

View file

@ -0,0 +1,52 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/get_execution_providers.h"
#include <algorithm>
#include <iterator>
#include <unordered_set>
#include "gtest/gtest.h"
#include "core/graph/constants.h"
namespace onnxruntime {
namespace test {
TEST(GetExecutionProvidersTest, CpuEpAlwaysLast) {
const auto check = [](const std::vector<std::string>& providers) {
ASSERT_FALSE(providers.empty());
EXPECT_EQ(providers.back(), kCpuExecutionProvider);
};
check(GetAllExecutionProviderNames());
check(GetAvailableExecutionProviderNames());
}
TEST(GetExecutionProvidersTest, ConsistentOrdering) {
const auto& all = GetAllExecutionProviderNames();
const auto& available = GetAvailableExecutionProviderNames();
std::vector<std::string> available_from_all{};
std::copy_if(
all.begin(), all.end(),
std::back_inserter(available_from_all),
[&available](const std::string& value) {
return std::find(available.begin(), available.end(), value) != available.end();
});
EXPECT_EQ(available, available_from_all);
}
TEST(GetExecutionProvidersTest, NoDuplicates) {
const auto check = [](const std::vector<std::string>& providers) {
const std::unordered_set<std::string> providers_set(providers.begin(), providers.end());
EXPECT_EQ(providers.size(), providers_set.size());
};
check(GetAllExecutionProviderNames());
check(GetAvailableExecutionProviderNames());
}
} // namespace test
} // namespace onnxruntime

View file

@ -81,13 +81,13 @@ class TestInferenceSession(unittest.TestCase):
def runBaseTest2():
sess = onnxrt.InferenceSession(get_name("mul_1.onnx"))
self.assertTrue('CUDAExecutionProvider' in sess.get_providers())
self.assertIn('CUDAExecutionProvider', sess.get_providers())
# test get/set of "cuda_mem_limit" configuration.
options = sess.get_provider_options()
self.assertTrue('CUDAExecutionProvider' in options)
self.assertIn('CUDAExecutionProvider', options)
option = options['CUDAExecutionProvider']
self.assertTrue('cuda_mem_limit' in option)
self.assertIn('cuda_mem_limit', option)
ori_mem_limit = option['cuda_mem_limit']
new_mem_limit = int(ori_mem_limit) // 2
option['cuda_mem_limit'] = new_mem_limit
@ -100,16 +100,27 @@ class TestInferenceSession(unittest.TestCase):
options = sess.get_provider_options()
self.assertEqual(options['CUDAExecutionProvider']['cuda_mem_limit'], ori_mem_limit)
# test get/set of "arena_extend_strategy" configuration.
options = sess.get_provider_options()
self.assertTrue('CUDAExecutionProvider' in options)
option = options['CUDAExecutionProvider']
self.assertTrue('arena_extend_strategy' in option)
for strategy in ['kNextPowerOfTwo', 'kSameAsRequested']:
option['arena_extend_strategy'] = strategy
sess.set_providers(['CUDAExecutionProvider'], [option])
options = sess.get_provider_options()
self.assertEqual(options['CUDAExecutionProvider']['arena_extend_strategy'], strategy)
def test_get_and_set_option_with_values(option_name, option_values):
provider_options = sess.get_provider_options()
self.assertIn('CUDAExecutionProvider', provider_options)
cuda_options = options['CUDAExecutionProvider']
self.assertIn(option_name, cuda_options)
for option_value in option_values:
cuda_options[option_name] = option_value
sess.set_providers(['CUDAExecutionProvider'], [cuda_options])
new_provider_options = sess.get_provider_options()
self.assertEqual(
new_provider_options.get('CUDAExecutionProvider', {}).get(option_name),
str(option_value))
test_get_and_set_option_with_values(
'arena_extend_strategy', ['kNextPowerOfTwo', 'kSameAsRequested'])
test_get_and_set_option_with_values(
'cudnn_conv_algo_search', ["DEFAULT", "EXHAUSTIVE", "HEURISTIC"])
test_get_and_set_option_with_values(
'do_copy_in_default_stream', [0, 1])
#
# Note: Tests that throw an exception leave an empty session due to how set_providers currently works,
@ -181,21 +192,17 @@ class TestInferenceSession(unittest.TestCase):
sess = onnxrt.InferenceSession(get_name("mul_1.onnx"))
# configure session with not legit option values and that shloud fail
# configure session with invalid option values and that should fail
with self.assertRaises(RuntimeError):
option = {'device_id': num_device}
sess.set_providers(['CUDAExecutionProvider'], [option])
option = {'device_id': 'non_legit_value'}
option = {'device_id': 'invalid_value'}
sess.set_providers(['CUDAExecutionProvider'], [option])
# configure session with not legit option should cause no effect
option = {'device_id': 0}
sess.set_providers(['CUDAExecutionProvider'], [option])
option = {'non_legit_option': num_device}
sess.set_providers(['CUDAExecutionProvider'], [option])
self.assertEqual(['CUDAExecutionProvider', 'CPUExecutionProvider'], sess.get_providers())
# configure session with invalid option should fail
with self.assertRaises(RuntimeError):
option = {'invalid_option': 123}
sess.set_providers(['CUDAExecutionProvider'], [option])
libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll')
for libname in libnames:
@ -218,8 +225,7 @@ class TestInferenceSession(unittest.TestCase):
with self.assertRaises(ValueError) as context:
sess = onnxrt.InferenceSession(get_name("mul_1.onnx"))
sess.set_providers(['InvalidProvider'])
self.assertTrue(
'[\'InvalidProvider\'] does not contain a subset of available providers' in str(context.exception))
self.assertTrue('\'InvalidProvider\' is unavailable' in str(context.exception))
def testSessionProviders(self):
if 'CUDAExecutionProvider' in onnxrt.get_available_providers():
@ -712,7 +718,7 @@ class TestInferenceSession(unittest.TestCase):
sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), so, ['CPUExecutionProvider'])
res = sess.run(["Y"], {"X": np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)})
self.assertTrue(np.array_equal(res[0], np.array([[2.0, 2.0], [12.0, 12.0], [30.0, 30.0]], dtype=np.float32)))
def testRegisterCustomOpsLibrary(self):
if sys.platform.startswith("win"):
shared_library = 'custom_op_library.dll'
@ -809,17 +815,17 @@ class TestInferenceSession(unittest.TestCase):
# 1. if there are intermittent failure in this test, something is wrong
# 2. it's easier to repro on slower GPU (like M60, Geforce 1070)
# to repro #4829, uncomment the line below to run copy in a separate stream
#onnxrt.capi._pybind_state.set_do_copy_in_default_stream(False)
# to repro #4829, set the CUDA EP do_copy_in_default_stream option to False
providers = [("CUDAExecutionProvider", {"do_copy_in_default_stream": True}), "CPUExecutionProvider"]
session = onnxrt.InferenceSession(get_name("issue4829.onnx"))
session = onnxrt.InferenceSession(get_name("issue4829.onnx"), providers=providers)
shape = np.array([2,2], dtype=np.int64)
for iteration in range(100000):
result = session.run(output_names=['output'], input_feed={'shape': shape})
def testSharedAllocatorUsingCreateAndRegisterAllocator(self):
# Create and register an arena based allocator
# ort_arena_cfg = onnxrt.OrtArenaCfg(0, -1, -1, -1) (create an OrtArenaCfg like this template if you want to use non-default parameters)
ort_memory_info = onnxrt.OrtMemoryInfo("Cpu", onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR, 0, onnxrt.OrtMemType.DEFAULT)
# Use this option if using non-default OrtArenaCfg : onnxrt.create_and_register_allocator(ort_memory_info, ort_arena_cfg)
@ -836,5 +842,55 @@ class TestInferenceSession(unittest.TestCase):
so2.log_severity_level = 1
onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=so2)
def testCheckAndNormalizeProviderArgs(self):
from onnxruntime.capi.onnxruntime_inference_collection import check_and_normalize_provider_args
valid_providers = ["a", "b", "c"]
def check_success(providers, provider_options, expected_providers, expected_provider_options):
actual_providers, actual_provider_options = check_and_normalize_provider_args(
providers, provider_options, valid_providers)
self.assertEqual(actual_providers, expected_providers)
self.assertEqual(actual_provider_options, expected_provider_options)
check_success(None, None, [], [])
check_success(["a"], None, ["a"], [{}])
check_success(["a", "b"], None, ["a", "b"], [{}, {}])
check_success([("a", {1: 2}), "b"], None, ["a", "b"], [{"1": "2"}, {}])
check_success(["a", "b"], [{1: 2}, {}], ["a", "b"], [{"1": "2"}, {}])
with self.assertWarns(UserWarning):
check_success(["a", "b", "a"], [{"x": 1}, {}, {"y": 2}], ["a", "b"], [{"x": "1"}, {}])
def check_failure(providers, provider_options):
with self.assertRaises(ValueError):
check_and_normalize_provider_args(providers, provider_options, valid_providers)
# provider not valid
check_failure(["d"], None)
# providers not sequence
check_failure(3, None)
# providers value invalid
check_failure([3], None)
# provider_options not sequence
check_failure(["a"], 3)
# provider_options value invalid
check_failure(["a"], ["not dict"])
# providers and provider_options length mismatch
check_failure(["a", "b"], [{1: 2}])
# provider options unsupported mixed specification
check_failure([("a", {1: 2})], [{3: 4}])
if __name__ == '__main__':
unittest.main()

View file

@ -25,13 +25,13 @@ def reference_gemm(a, b, c, alpha, beta, transA, transB):
def set_gemm_node_attrs(attrs, config):
if config['alpha'] != 1.0:
attrs['alpha'] = config['alpha']
attrs['alpha'] = config['alpha']
if config['beta'] != 1.0:
attrs['beta'] = config['beta']
attrs['beta'] = config['beta']
if config['transA']:
attrs['transA'] = 1
attrs['transA'] = 1
if config['transB']:
attrs['transB'] = 1
attrs['transB'] = 1
def generate_gemm_inputs_initializers(graph, config, added_inputs_initializers={}, extend=False):
M = config['M']
@ -323,6 +323,16 @@ def set_gemm_model_inputs(config, test_inputs, a, b, c):
if config['withC'] and not config['initC']:
test_inputs[config['C']] = c
def make_providers(nuphar_settings):
return [
('NupharExecutionProvider', {
'nuphar_settings': nuphar_settings
}),
'CPUExecutionProvider',
]
class TestNuphar(unittest.TestCase):
def test_bidaf(self):
@ -385,8 +395,9 @@ class TestNuphar(unittest.TestCase):
for model in [bidaf_opt_scan_model, bidaf_int8_scan_only_model]:
nuphar_settings = 'nuphar_cache_path:{}'.format(cache_dir)
for isa in ['avx', 'avx2', 'avx512']:
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings + ', nuphar_codegen_target:' + isa)
sess = onnxrt.InferenceSession(model) # JIT cache happens when initializing session
# JIT cache happens when initializing session
sess = onnxrt.InferenceSession(
model, providers=make_providers(nuphar_settings + ', nuphar_codegen_target:' + isa))
cache_dir_content = os.listdir(cache_dir)
assert len(cache_dir_content) == 1
@ -400,15 +411,13 @@ class TestNuphar(unittest.TestCase):
nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}'.format(
cache_dir, so_name, 'on')
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings)
sess = onnxrt.InferenceSession(model)
sess = onnxrt.InferenceSession(model, providers=make_providers(nuphar_settings))
sess.run([], feed)
# test avx
nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}, nuphar_codegen_target:{}'.format(
cache_dir, so_name, 'on', 'avx')
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings)
sess = onnxrt.InferenceSession(model)
sess = onnxrt.InferenceSession(model, providers=make_providers(nuphar_settings))
sess.run([], feed)
def test_bert_squad(self):
@ -672,7 +681,7 @@ class TestNuphar(unittest.TestCase):
def test_loop_to_scan(self):
loop_model_filename = get_name("nuphar_tiny_model_with_loop_shape_infered.onnx")
scan_model_filename = "nuphar_tiny_model_with_loop_shape_infered_converted_to_scan.onnx"
scan_model_filename = "nuphar_tiny_model_with_loop_shape_infered_converted_to_scan.onnx"
subprocess.run([
sys.executable, '-m', 'onnxruntime.nuphar.model_editor',
'--input', loop_model_filename,
@ -680,13 +689,13 @@ class TestNuphar(unittest.TestCase):
], check=True)
validate_with_ort(loop_model_filename, scan_model_filename)
def test_loop_to_scan_with_inconvertible_loop(self):
# nuphar_onnx_test_loop11_inconvertible_loop.onnx contains a Loop op with dynamic loop count.
# This Loop op cannot be converted to a Scan op.
# Set --keep_unconvertible_loop_ops option so conversion will not fail due to unconvertible loop ops.
loop_model_filename = get_name("nuphar_onnx_test_loop11_inconvertible_loop.onnx")
scan_model_filename = "nuphar_onnx_test_loop11_inconvertible_loop_unchanged.onnx"
scan_model_filename = "nuphar_onnx_test_loop11_inconvertible_loop_unchanged.onnx"
subprocess.run([
sys.executable, '-m', 'onnxruntime.nuphar.model_editor',
'--input', loop_model_filename,
@ -695,15 +704,15 @@ class TestNuphar(unittest.TestCase):
], check=True)
# onnxruntime is failing with:
# onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 :
# FAIL : Non-zero status code returned while running Loop node. Name:''
# onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 :
# FAIL : Non-zero status code returned while running Loop node. Name:''
# Status Message: Inconsistent shape in loop output for output. Expected:{1} Got:{0}
# skip validate_with_ort for now
# validate_with_ort(loop_model_filename, scan_model_filename)
def test_loop_to_scan_tool(self):
loop_model_filename = get_name("nuphar_tiny_model_with_loop_shape_infered.onnx")
scan_model_filename = "nuphar_tiny_model_with_loop_shape_infered_converted_to_scan.onnx"
scan_model_filename = "nuphar_tiny_model_with_loop_shape_infered_converted_to_scan.onnx"
subprocess.run([
sys.executable, '-m', 'onnxruntime.nuphar.model_tools',
'--input', loop_model_filename,

View file

@ -1,12 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <core/common/make_unique.h>
#include "core/session/onnxruntime_c_api.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/graph/constants.h"
#include "providers.h"
#include <memory>
#include <vector>
#include <iostream>
@ -15,7 +9,16 @@
#include <atomic>
#include <mutex>
#include <algorithm>
#include <gtest/gtest.h>
#include "core/common/common.h"
#include "core/common/make_unique.h"
#include "core/graph/constants.h"
#include "core/session/onnxruntime_c_api.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "providers.h"
#include "test_allocator.h"
#include "test_fixture.h"
@ -1188,19 +1191,19 @@ TEST(CApiTest, get_available_providers) {
int len = 0;
char** providers;
ASSERT_EQ(g_ort->GetAvailableProviders(&providers, &len), nullptr);
ASSERT_TRUE(len > 0);
ASSERT_EQ(strcmp(providers[0], "CPUExecutionProvider"), 0);
ASSERT_GT(len, 0);
ASSERT_STREQ(providers[len-1], "CPUExecutionProvider");
ASSERT_EQ(g_ort->ReleaseAvailableProviders(providers, len), nullptr);
}
TEST(CApiTest, get_available_providers_cpp) {
std::vector<std::string> providers = Ort::GetAvailableProviders();
ASSERT_TRUE(providers.size() > 0);
ASSERT_TRUE(providers[0] == std::string("CPUExecutionProvider"));
ASSERT_FALSE(providers.empty());
ASSERT_EQ(providers.back(), "CPUExecutionProvider");
#ifdef USE_CUDA
// CUDA EP will exist in the list but its position may vary based on other EPs included in the build
ASSERT_TRUE(std::find(providers.begin(), providers.end(), std::string("CUDAExecutionProvider")) != providers.end());
ASSERT_TRUE(std::find(providers.begin(), providers.end(), "CUDAExecutionProvider") != providers.end());
#endif
}

View file

@ -576,7 +576,7 @@ void setup_training_params(BertParameters& params) {
if (params.cuda_mem_limit_in_gb > 0) {
info.cuda_mem_limit = gsl::narrow<size_t>(params.cuda_mem_limit_in_gb * 1024 * 1024 * 1024);
}
info.cudnn_conv_algo = OrtCudnnConvAlgoSearch::EXHAUSTIVE;
info.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::EXHAUSTIVE;
params.providers.emplace(kCudaExecutionProvider, CreateExecutionProviderFactory_CUDA(info));
params.input_allocator = std::make_shared<CUDAPinnedAllocator>(info.device_id, CUDA_PINNED);

View file

@ -7,11 +7,12 @@ import sys
import os
from onnxruntime.capi import _pybind_state as C
from onnxruntime.capi.onnxruntime_inference_collection import Session, InferenceSession, IOBinding
from onnxruntime.capi.onnxruntime_inference_collection import (Session, InferenceSession, IOBinding,
check_and_normalize_provider_args)
class TrainingSession(InferenceSession):
def __init__(self, path_or_bytes, parameters, sess_options=None):
def __init__(self, path_or_bytes, parameters, sess_options=None, providers=None, provider_options=None):
Session.__init__(self)
if sess_options:
@ -19,10 +20,13 @@ class TrainingSession(InferenceSession):
else:
self._sess = C.TrainingSession()
providers, provider_options = check_and_normalize_provider_args(providers, provider_options,
C.get_available_providers())
if isinstance(path_or_bytes, str):
config_result = self._sess.load_model(path_or_bytes, parameters)
config_result = self._sess.load_model(path_or_bytes, parameters, providers, provider_options)
elif isinstance(path_or_bytes, bytes):
config_result = self._sess.read_bytes(path_or_bytes, parameters)
config_result = self._sess.read_bytes(path_or_bytes, parameters, providers, provider_options)
else:
raise TypeError("Unable to load from type '{0}'".format(type(path_or_bytes)))

View file

@ -305,7 +305,7 @@ void addObjectMethodsForTraining(py::module& m) {
#endif
#endif
})
.def("load_model", [](PyTrainingSession* sess, const std::string& path, TrainingParameters& parameters) {
.def("load_model", [](PyTrainingSession* sess, const std::string& path, TrainingParameters& parameters, const std::vector<std::string>& provider_types, const ProviderOptionsVector& provider_options) {
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(path));
#if defined(USE_MPI)
@ -315,12 +315,11 @@ void addObjectMethodsForTraining(py::module& m) {
#endif
const auto config_result = ConfigureSessionForTraining(static_cast<TrainingSession*>(sess->GetSessionHandle()), parameters);
std::vector<std::string> provider_types = {};
InitializeSession(sess->GetSessionHandle(), provider_types);
InitializeSession(sess->GetSessionHandle(), provider_types, provider_options);
return config_result;
})
.def("read_bytes", [](PyTrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters) {
.def("read_bytes", [](PyTrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters, const std::vector<std::string>& provider_types, const ProviderOptionsVector& provider_options) {
std::istringstream buffer(serialized_model);
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(buffer));
@ -331,8 +330,7 @@ void addObjectMethodsForTraining(py::module& m) {
#endif
const auto config_result = ConfigureSessionForTraining(static_cast<TrainingSession*>(sess->GetSessionHandle()), parameters);
std::vector<std::string> provider_types = {};
InitializeSession(sess->GetSessionHandle(), provider_types);
InitializeSession(sess->GetSessionHandle(), provider_types, provider_options);
return config_result;
})

View file

@ -195,13 +195,6 @@ class ORTTrainer(object):
break
assert dtype is not None, f"ONNX model with unknown output type ({o_desc.name})"
# Set GPU device and memory limit
if 'cuda' in self.options.device.id.lower():
mem_limit = self.options.device.mem_limit
if mem_limit > 0:
ort.set_cuda_mem_limit(self.options.device.mem_limit)
ort.set_cuda_device_id(_utils.get_device_index(self.options.device.id))
# TODO: Remove when experimental checkpoint functions are removed.
self._state_dict = {}
@ -655,10 +648,30 @@ class ORTTrainer(object):
# old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error.
# for example, load_state_dict will be called before returing the function, and it calls _init_session again
del self._training_session
# Set provider-specific options if needed
def get_providers():
providers = ort.get_available_providers()
if 'cuda' in self.options.device.id.lower():
cuda_ep_options = {"device_id": _utils.get_device_index(self.options.device.id)}
if self.options.device.mem_limit > 0:
cuda_ep_options["cuda_mem_limit"] = self.options.device.mem_limit
cuda_ep_name = "CUDAExecutionProvider"
if cuda_ep_name not in providers:
raise RuntimeError(
"ORTTrainer options specify a CUDA device but the {} provider is unavailable.".format(
cuda_ep_name))
providers[providers.index(cuda_ep_name)] = (cuda_ep_name, cuda_ep_options)
return providers
# TrainingSession
self._training_session = ort.TrainingSession(self._onnx_model.SerializeToString(),
ort_parameters,
session_options)
self._training_session = ort.TrainingSession(self._onnx_model.SerializeToString(), ort_parameters,
session_options, get_providers())
# I/O bindings
self._train_io_binding = self._training_session.io_binding()