diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index 1ef865b736..54188f2647 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -180,7 +180,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests string[] providers = ortEnvInstance.GetAvailableProviders(); Assert.True(providers.Length > 0); - Assert.Equal("CPUExecutionProvider", providers[0]); + Assert.Equal("CPUExecutionProvider", providers[providers.Length - 1]); # if USE_CUDA Assert.True(Array.Exists(providers, provider => provider == "CUDAExecutionProvider");); diff --git a/docs/python/notebooks/onnxruntime-nuphar-tutorial.ipynb b/docs/python/notebooks/onnxruntime-nuphar-tutorial.ipynb index ccc6a19044..fff760062f 100644 --- a/docs/python/notebooks/onnxruntime-nuphar-tutorial.ipynb +++ b/docs/python/notebooks/onnxruntime-nuphar-tutorial.ipynb @@ -165,8 +165,8 @@ "code_to_run = '''\n", "import onnxruntime\n", "s = 'codegen_dump_lower:verbose'\n", - "onnxruntime.capi._pybind_state.set_nuphar_settings(s)\n", - "sess = onnxruntime.InferenceSession('simple.onnx')\n", + "providers = [('NupharExecutionProvider', {'nuphar_settings': s}), 'CPUExecutionProvider']\n", + "sess = onnxruntime.InferenceSession('simple.onnx', providers=providers)\n", "'''\n", "\n", "log_file = 'simple_lower.log' \n", @@ -992,8 +992,8 @@ "bidaf_cache_dir = os.path.join(bidaf_dir, 'cache')\n", "create_cache_dir(bidaf_cache_dir)\n", "settings = 'nuphar_cache_path:{}'.format(bidaf_cache_dir)\n", - "onnxruntime.capi._pybind_state.set_nuphar_settings(settings)\n", - "sess = onnxruntime.InferenceSession(bidaf_converted)" + "providers = [('NupharExecutionProvider', {'nuphar_settings': settings}), 'CPUExecutionProvider']\n", + "sess = onnxruntime.InferenceSession(bidaf_converted, providers=providers)" ] }, { @@ -1049,10 +1049,9 @@ ], "source": [ "start_aot = timer()\n", - "# NOTE: Nuphar settings string is not sticky. It needs to be reset before creating InferenceSession\n", "settings = 'nuphar_cache_path:{}'.format(bidaf_cache_dir)\n", - "onnxruntime.capi._pybind_state.set_nuphar_settings(settings)\n", - "sess = onnxruntime.InferenceSession(bidaf_converted)\n", + "providers = [('NupharExecutionProvider', {'nuphar_settings': settings}), 'CPUExecutionProvider']\n", + "sess = onnxruntime.InferenceSession(bidaf_converted, providers=providers)\n", "end_aot = timer()\n", "print_speedup('AOT', end_jit - start_jit, end_aot - start_aot)" ] @@ -1096,8 +1095,8 @@ "settings = 'nuphar_cache_path:{}'.format(cache_dir)\n", "for isa in ['avx512', 'avx2', 'avx']:\n", " settings_with_isa = settings + ', nuphar_codegen_target:' + isa\n", - " onnxruntime.capi._pybind_state.set_nuphar_settings(settings_with_isa)\n", - " sess = onnxruntime.InferenceSession(model_name)\n", + " providers = [('NupharExecutionProvider', {'nuphar_settings': settings_with_isa}), 'CPUExecutionProvider']\n", + " sess = onnxruntime.InferenceSession(model_name, providers=providers)\n", " cache_versioned_dir = os.path.join(cache_dir, os.listdir(cache_dir)[0])\n", "\n", "# link object files to AOT dll\n", @@ -1106,15 +1105,15 @@ "# now load the model with AOT dll\n", "# NOTE: when nuphar_codegen_target is not set, it defaults to current CPU ISA\n", "settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_model_checksum:{}, nuphar_cache_force_no_jit:on'.format(cache_dir, multi_isa_so, model_checksum)\n", - "onnxruntime.capi._pybind_state.set_nuphar_settings(settings)\n", - "sess = onnxruntime.InferenceSession(model_name)\n", + "providers = [('NupharExecutionProvider', {'nuphar_settings': settings}), 'CPUExecutionProvider']\n", + "sess = onnxruntime.InferenceSession(model_name, providers=providers)\n", "\n", "# force to a different ISA which is a subset of current CPU\n", "# NOTE: if an incompatible ISA is used, exception on invalid instructions would be thrown\n", "for valid_isa in ['avx2', 'avx']:\n", " settings_with_isa = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_model_checksum:{}, nuphar_codegen_target:{}, nuphar_cache_force_no_jit:on'.format(cache_dir, multi_isa_so, model_checksum, valid_isa)\n", - " onnxruntime.capi._pybind_state.set_nuphar_settings(settings_with_isa)\n", - " sess = onnxruntime.InferenceSession(model_name)\n", + " providers = [('NupharExecutionProvider', {'nuphar_settings': settings_with_isa}), 'CPUExecutionProvider']\n", + " sess = onnxruntime.InferenceSession(model_name, providers=providers)\n", "\n", " start_nuphar = timer()\n", " for i in range(repeats):\n", @@ -1160,8 +1159,8 @@ "# use NUPHAR_PARALLEL_MIN_WORKLOADS=0 to turn off parallel schedule, using settings string\n", "# it can be set from environment variable too: os.environ['NUPHAR_PARALLEL_MIN_WORKLOADS'] = '0'\n", "settings = 'nuphar_parallel_min_workloads:0'\n", - "onnxruntime.capi._pybind_state.set_nuphar_settings(settings)\n", - "sess = onnxruntime.InferenceSession(bidaf_converted)\n", + "providers = [('NupharExecutionProvider', {'nuphar_settings': settings}), 'CPUExecutionProvider']\n", + "sess = onnxruntime.InferenceSession(bidaf_converted, providers=providers)\n", "\n", "start = timer()\n", "for i in range(bidaf_repeats):\n", @@ -1200,4 +1199,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/include/onnxruntime/core/common/common.h b/include/onnxruntime/core/common/common.h index 4b1c350849..bdca7b6938 100644 --- a/include/onnxruntime/core/common/common.h +++ b/include/onnxruntime/core/common/common.h @@ -33,9 +33,9 @@ #include "core/common/code_location.h" #include "core/common/exceptions.h" +#include "core/common/make_string.h" #include "core/common/make_unique.h" #include "core/common/status.h" -#include "core/common/string_utils.h" #ifdef USE_MIMALLOC_ARENA_ALLOCATOR #include diff --git a/include/onnxruntime/core/common/string_utils.h b/include/onnxruntime/core/common/make_string.h similarity index 68% rename from include/onnxruntime/core/common/string_utils.h rename to include/onnxruntime/core/common/make_string.h index fae0530485..b1e942aecc 100644 --- a/include/onnxruntime/core/common/string_utils.h +++ b/include/onnxruntime/core/common/make_string.h @@ -19,7 +19,6 @@ #include #include -#include namespace onnxruntime { @@ -74,41 +73,4 @@ inline std::string MakeString(const char* cstr) { return cstr; } -/** - * Tries to parse a value from an entire string. - */ -template -bool TryParse(const std::string& str, T& value) { - if (std::is_integral::value && std::is_unsigned::value) { - // if T is unsigned integral type, reject negative values which will wrap - if (!str.empty() && str[0] == '-') { - return false; - } - } - - // don't allow leading whitespace - if (!str.empty() && std::isspace(str[0], std::locale::classic())) { - return false; - } - - std::istringstream is{str}; - is.imbue(std::locale::classic()); - T parsed_value{}; - - const bool parse_successful = - is >> parsed_value && - is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters - if (!parse_successful) { - return false; - } - - value = std::move(parsed_value); - return true; -} - -inline bool TryParse(const std::string& str, std::string& value) { - value = str; - return true; -} - } // namespace onnxruntime diff --git a/include/onnxruntime/core/common/parse_string.h b/include/onnxruntime/core/common/parse_string.h new file mode 100644 index 0000000000..988ce3d3df --- /dev/null +++ b/include/onnxruntime/core/common/parse_string.h @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/common.h" + +namespace onnxruntime { + +/** + * Tries to parse a value from an entire string. + */ +template +bool TryParseString(const std::string& str, T& value) { + if (std::is_integral::value && std::is_unsigned::value) { + // if T is unsigned integral type, reject negative values which will wrap + if (!str.empty() && str[0] == '-') { + return false; + } + } + + // don't allow leading whitespace + if (!str.empty() && std::isspace(str[0], std::locale::classic())) { + return false; + } + + std::istringstream is{str}; + is.imbue(std::locale::classic()); + T parsed_value{}; + + const bool parse_successful = + is >> parsed_value && + is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters + if (!parse_successful) { + return false; + } + + value = std::move(parsed_value); + return true; +} + +inline bool TryParseString(const std::string& str, std::string& value) { + value = str; + return true; +} + +inline bool TryParseString(const std::string& str, bool& value) { + if (str == "0" || str == "False" || str == "false") { + value = false; + return true; + } + + if (str == "1" || str == "True" || str == "true") { + value = true; + return true; + } + + return false; +} + +/** + * Parses a value from an entire string. + */ +template +Status ParseString(const std::string& s, T& value) { + ORT_RETURN_IF_NOT(TryParseString(s, value), "Failed to parse value: \"", value, "\""); + return Status::OK(); +} + +/** + * Parses a value from an entire string. + */ +template +T ParseString(const std::string& s) { + T value{}; + ORT_THROW_IF_ERROR(ParseString(s, value)); + return value; +} + +} // namespace onnxruntime diff --git a/include/onnxruntime/core/framework/provider_options_utils.h b/include/onnxruntime/core/framework/provider_options_utils.h index 8804dde7ad..88fd0f9b6f 100644 --- a/include/onnxruntime/core/framework/provider_options_utils.h +++ b/include/onnxruntime/core/framework/provider_options_utils.h @@ -3,37 +3,163 @@ #pragma once +#include +#include +#include +#include +#include + #include "core/common/common.h" +#include "core/common/parse_string.h" #include "core/framework/provider_options.h" namespace onnxruntime { + +template +using EnumNameMapping = std::vector>; + /** - * Reads the named provider option. - * Returns true if the option is present, false otherwise. + * Given a mapping and an enumeration value, gets the corresponding name. */ -template -bool ReadProviderOption(const ProviderOptions& options, const std::string& key, T& value) { - auto it = options.find(key); - if (it != options.end()) { - ORT_ENFORCE( - TryParse(it->second, value), - "Failed to parse provider option \"", key, "\" with value \"", it->second, "\"."); - return true; - } - return false; +template +Status EnumToName(const EnumNameMapping& mapping, TEnum value, std::string& name) { + const auto it = std::find_if( + mapping.begin(), mapping.end(), + [&value](const std::pair& entry) { + return entry.first == value; + }); + ORT_RETURN_IF( + it == mapping.end(), + "Failed to map enum value to name: ", static_cast::type>(value)); + name = it->second; + return Status::OK(); +} + +template +std::string EnumToName(const EnumNameMapping& mapping, TEnum value) { + std::string name; + ORT_THROW_IF_ERROR(EnumToName(mapping, value, name)); + return name; } /** - * Reads the named provider option. - * Returns the value if the option is present or the specified default value otherwise. + * Given a mapping and a name, gets the corresponding enumeration value. */ -template -T ReadProviderOptionOrDefault( - const ProviderOptions& options, const std::string& key, const T& default_value) { - T value{}; - if (ReadProviderOption(options, key, value)) { - return value; - } - return default_value; +template +Status NameToEnum( + const EnumNameMapping& mapping, const std::string& name, TEnum& value) { + const auto it = std::find_if( + mapping.begin(), mapping.end(), + [&name](const std::pair& entry) { + return entry.second == name; + }); + ORT_RETURN_IF( + it == mapping.end(), + "Failed to map enum name to value: ", name); + value = it->first; + return Status::OK(); } + +template +TEnum NameToEnum(const EnumNameMapping& mapping, const std::string& name) { + TEnum value; + ORT_THROW_IF_ERROR(NameToEnum(mapping, name, value)); + return value; +} + +class ProviderOptionsParser { + public: + /** + * Adds a parser for a particular provider option value. + * + * @param name The provider option name. + * @param value_parser An object that parses the option value. + * It should be callable with the following signature and return + * whether the parsing was successful: + * Status value_parser(const std::string&) + * + * @return The current ProviderOptionsParser instance. + */ + template + ProviderOptionsParser& AddValueParser( + const std::string& name, ValueParserType value_parser) { + ORT_ENFORCE( + value_parsers_.emplace(name, ValueParser{value_parser}).second, + "Provider option \"", name, "\" already has a value parser."); + return *this; + } + + /** + * Adds a parser for a particular provider option value which converts a + * value to the right type and assigns it to the given reference. + * + * IMPORTANT: This function stores a reference to the destination variable. + * The caller must ensure that the reference is valid when Parse() is called! + * + * @param name The provider option name. + * @param dest The destination variable reference. + * + * @return The current ProviderOptionsParser instance. + */ + template + ProviderOptionsParser& AddAssignmentToReference( + const std::string& name, ValueType& dest) { + return AddValueParser( + name, + [&dest](const std::string& value_str) -> Status { + return ParseString(value_str, dest); + }); + } + + /** + * Adds a parser for a particular provider option value which maps an + * enumeration name to a value and assigns it to the given reference. + * + * IMPORTANT: This function stores references to the mapping and destination + * variables. The caller must ensure that the references are valid when + * Parse() is called! + * + * @param name The provider option name. + * @param mapping The enumeration value to name mapping. + * @param dest The destination variable reference. + * + * @return The current ProviderOptionsParser instance. + */ + template + ProviderOptionsParser& AddAssignmentToEnumReference( + const std::string& name, const EnumNameMapping& mapping, EnumType& dest) { + return AddValueParser( + name, + [&mapping, &dest](const std::string& value_str) -> Status { + return NameToEnum(mapping, value_str, dest); + }); + } + + /** + * Parses the given provider options. + */ + Status Parse(const ProviderOptions& options) const { + for (const auto& option : options) { + const auto& name = option.first; + const auto& value_str = option.second; + + const auto value_parser_it = value_parsers_.find(name); + ORT_RETURN_IF( + value_parser_it == value_parsers_.end(), + "Unknown provider option: \"", name, "\"."); + + const auto parse_status = value_parser_it->second(value_str); + ORT_RETURN_IF_NOT( + parse_status.IsOK(), + "Failed to parse provider option \"", name, "\": ", parse_status.ErrorMessage()); + } + + return Status::OK(); + } + + private: + using ValueParser = std::function; + std::unordered_map value_parsers_; +}; + } // namespace onnxruntime diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 4d56c370f3..f723a3957b 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -3,11 +3,8 @@ #pragma once -#include - -#include "core/common/common.h" - namespace onnxruntime { + constexpr const char* kNoOp = "NoOp"; constexpr const char* kConstant = "Constant"; constexpr const char* kFunctionOp = "_kFunctionOp"; @@ -22,10 +19,10 @@ constexpr const char* kMSDmlDomain = "com.microsoft.dml"; constexpr const char* kNGraphDomain = "com.intel.ai"; constexpr const char* kMIGraphXDomain = ""; constexpr const char* kVitisAIDomain = "com.xilinx"; + constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider"; constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider"; constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider"; -constexpr const char* kNGraphExecutionProvider = "NGRAPHExecutionProvider"; constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider"; constexpr const char* kNupharExecutionProvider = "NupharExecutionProvider"; constexpr const char* kVitisAIExecutionProvider = "VitisAIExecutionProvider"; @@ -38,50 +35,4 @@ constexpr const char* kAclExecutionProvider = "ACLExecutionProvider"; constexpr const char* kArmNNExecutionProvider = "ArmNNExecutionProvider"; constexpr const char* kRocmExecutionProvider = "ROCMExecutionProvider"; -constexpr const char* providers_available[] = { - kCpuExecutionProvider, -#ifdef USE_CUDA - kCudaExecutionProvider, -#endif -#ifdef USE_DNNL - kDnnlExecutionProvider, -#endif -#ifdef USE_NGRAPH - kNGraphExecutionProvider, -#endif -#ifdef USE_OPENVINO - kOpenVINOExecutionProvider, -#endif -#ifdef USE_NUPHAR - kNupharExecutionProvider, -#endif -#ifdef USE_VITISAI - kVitisAIExecutionProvider, -#endif -#ifdef USE_TENSORRT - kTensorrtExecutionProvider, -#endif -#ifdef USE_NNAPI - kNnapiExecutionProvider, -#endif -#ifdef USE_RKNPU - kRknpuExecutionProvider, -#endif -#ifdef USE_DML - kDmlExecutionProvider, -#endif -#ifdef USE_MIGRAPHX - kMIGraphXExecutionProvider, -#endif -#ifdef USE_ACL - kAclExecutionProvider, -#endif -#ifdef USE_ARMNN - kArmNNExecutionProvider, -#endif -#ifdef USE_ROCM - kRocmExecutionProvider, -#endif -}; - } // namespace onnxruntime diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index afa58cefc8..b74c899bc1 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -26,11 +26,6 @@ from onnxruntime.capi._pybind_state import get_all_providers, get_available_prov NodeArg, ModelMetadata, GraphOptimizationLevel, ExecutionMode, ExecutionOrder, OrtDevice, SessionIOBinding, \ OrtAllocatorType, OrtMemType, OrtArenaCfg, OrtMemoryInfo, create_and_register_allocator -try: - from onnxruntime.capi._pybind_state import set_cuda_mem_limit, set_cuda_device_id -except ImportError: - pass - from onnxruntime.capi.onnxruntime_inference_collection import InferenceSession, IOBinding, OrtValue from onnxruntime.capi import onnxruntime_validation diff --git a/onnxruntime/core/framework/arena_extend_strategy.h b/onnxruntime/core/framework/arena_extend_strategy.h index 0cd04c40d9..036aaa3ba1 100644 --- a/onnxruntime/core/framework/arena_extend_strategy.h +++ b/onnxruntime/core/framework/arena_extend_strategy.h @@ -3,8 +3,7 @@ #pragma once -#include -#include +#include namespace onnxruntime { @@ -13,33 +12,4 @@ enum class ArenaExtendStrategy : int32_t { kSameAsRequested, }; -inline std::istream& operator>>(std::istream& is, ArenaExtendStrategy& value) { - std::string value_str; - if (is >> value_str) { - if (value_str == "kNextPowerOfTwo") { - value = ArenaExtendStrategy::kNextPowerOfTwo; - } else if (value_str == "kSameAsRequested") { - value = ArenaExtendStrategy::kSameAsRequested; - } else { - is.setstate(std::ios_base::failbit); - } - } - return is; -} - -inline std::ostream& operator<<(std::ostream& os, ArenaExtendStrategy value) { - switch (value) { - case ArenaExtendStrategy::kNextPowerOfTwo: - os << "kNextPowerOfTwo"; - break; - case ArenaExtendStrategy::kSameAsRequested: - os << "kSameAsRequested"; - break; - default: - os << "unknown"; - break; - } - return os; -} - } // namespace onnxruntime diff --git a/onnxruntime/core/platform/env_var_utils.h b/onnxruntime/core/platform/env_var_utils.h index b4d9569857..e79dd4fff7 100644 --- a/onnxruntime/core/platform/env_var_utils.h +++ b/onnxruntime/core/platform/env_var_utils.h @@ -5,6 +5,7 @@ #include "core/common/common.h" #include "core/common/optional.h" +#include "core/common/parse_string.h" #include "core/platform/env.h" namespace onnxruntime { @@ -20,7 +21,7 @@ optional ParseEnvironmentVariable(const std::string& name) { T parsed_value; ORT_ENFORCE( - TryParse(value_str, parsed_value), + TryParseString(value_str, parsed_value), "Failed to parse environment variable - name: \"", name, "\", value: \"", value_str, "\""); return parsed_value; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 5bb2e8f8ba..888c7dcf57 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -71,7 +71,7 @@ class CUDAExecutionProvider : public IExecutionProvider { int GetDeviceId() const override { return info_.device_id; } const cudaDeviceProp& GetDeviceProp() const { return device_prop_; }; - int GetCudnnConvAlgo() const { return info_.cudnn_conv_algo; } + int GetCudnnConvAlgo() const { return info_.cudnn_conv_algo_search; } ProviderOptions GetProviderOptions() const override { return CUDAExecutionProviderInfo::ToProviderOptions(info_); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index 7fa4d5a046..d70420a3ad 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -3,58 +3,76 @@ #include "core/providers/cuda/cuda_execution_provider_info.h" -#include "core/common/string_utils.h" +#include "core/common/make_string.h" +#include "core/common/parse_string.h" #include "core/framework/provider_options_utils.h" #include "core/providers/cuda/cuda_common.h" namespace onnxruntime { +namespace cuda { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kMemLimit = "cuda_mem_limit"; constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; -constexpr const char* kCudnnConvAlgo = "cudnn_conv_algo"; +constexpr const char* kCudnnConvAlgoSearch = "cudnn_conv_algo_search"; constexpr const char* kDoCopyInDefaultStream = "do_copy_in_default_stream"; } // namespace provider_option_names +} // namespace cuda + +namespace { +const EnumNameMapping ort_cudnn_conv_algo_search_mapping{ + {OrtCudnnConvAlgoSearch::EXHAUSTIVE, "EXHAUSTIVE"}, + {OrtCudnnConvAlgoSearch::HEURISTIC, "HEURISTIC"}, + {OrtCudnnConvAlgoSearch::DEFAULT, "DEFAULT"}, +}; + +const EnumNameMapping arena_extend_strategy_mapping{ + {ArenaExtendStrategy::kNextPowerOfTwo, "kNextPowerOfTwo"}, + {ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"}, +}; +} // namespace CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { CUDAExecutionProviderInfo info{}; - if (ReadProviderOption(options, provider_option_names::kDeviceId, info.device_id)) { - int num_devices = 0; - CUDA_CALL_THROW(cudaGetDeviceCount(&num_devices)); - ORT_ENFORCE( - 0 <= info.device_id && info.device_id < num_devices, - "Invalid ", provider_option_names::kDeviceId, " value: ", info.device_id, - ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); - } - ReadProviderOption(options, provider_option_names::kMemLimit, info.cuda_mem_limit); - ReadProviderOption(options, provider_option_names::kArenaExtendStrategy, info.arena_extend_strategy); - { - int cudnn_conv_algo_val; - if (ReadProviderOption(options, provider_option_names::kCudnnConvAlgo, cudnn_conv_algo_val)) { - switch (cudnn_conv_algo_val) { - case OrtCudnnConvAlgoSearch::EXHAUSTIVE: - case OrtCudnnConvAlgoSearch::HEURISTIC: - case OrtCudnnConvAlgoSearch::DEFAULT: - break; - default: - ORT_THROW("Invalid ", provider_option_names::kCudnnConvAlgo, " value: ", cudnn_conv_algo_val); - } - info.cudnn_conv_algo = static_cast(cudnn_conv_algo_val); - } - } - ReadProviderOption(options, provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream); + ORT_THROW_IF_ERROR( + ProviderOptionsParser{} + .AddValueParser( + cuda::provider_option_names::kDeviceId, + [&info](const std::string& value_str) -> Status { + ORT_RETURN_IF_ERROR(ParseString(value_str, info.device_id)); + int num_devices{}; + ORT_RETURN_IF_NOT( + CUDA_CALL(cudaGetDeviceCount(&num_devices)), + "cudaGetDeviceCount() failed."); + ORT_RETURN_IF_NOT( + 0 <= info.device_id && info.device_id < num_devices, + "Invalid device ID: ", info.device_id, + ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); + return Status::OK(); + }) + .AddAssignmentToReference(cuda::provider_option_names::kMemLimit, info.cuda_mem_limit) + .AddAssignmentToEnumReference( + cuda::provider_option_names::kArenaExtendStrategy, + arena_extend_strategy_mapping, info.arena_extend_strategy) + .AddAssignmentToEnumReference( + cuda::provider_option_names::kCudnnConvAlgoSearch, + ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search) + .AddAssignmentToReference(cuda::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream) + .Parse(options)); return info; } ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecutionProviderInfo& info) { const ProviderOptions options{ - {provider_option_names::kDeviceId, MakeString(info.device_id)}, - {provider_option_names::kMemLimit, MakeString(info.cuda_mem_limit)}, - {provider_option_names::kArenaExtendStrategy, MakeString(info.arena_extend_strategy)}, - {provider_option_names::kCudnnConvAlgo, MakeString(static_cast(info.cudnn_conv_algo))}, - {provider_option_names::kDoCopyInDefaultStream, MakeString(info.do_copy_in_default_stream)}, + {cuda::provider_option_names::kDeviceId, MakeString(info.device_id)}, + {cuda::provider_option_names::kMemLimit, MakeString(info.cuda_mem_limit)}, + {cuda::provider_option_names::kArenaExtendStrategy, + EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, + {cuda::provider_option_names::kCudnnConvAlgoSearch, + EnumToName(ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search)}, + {cuda::provider_option_names::kDoCopyInDefaultStream, MakeString(info.do_copy_in_default_stream)}, }; return options; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h index b6c763c283..5ba2d07b9c 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h @@ -16,7 +16,7 @@ struct CUDAExecutionProviderInfo { OrtDevice::DeviceId device_id{0}; size_t cuda_mem_limit{std::numeric_limits::max()}; ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; - OrtCudnnConvAlgoSearch cudnn_conv_algo{OrtCudnnConvAlgoSearch::EXHAUSTIVE}; + OrtCudnnConvAlgoSearch cudnn_conv_algo_search{OrtCudnnConvAlgoSearch::EXHAUSTIVE}; bool do_copy_in_default_stream{true}; static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 0b1495c232..ef3c7a4269 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -55,7 +55,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_CUDA, info.device_id = gsl::narrow(cuda_options->device_id); info.cuda_mem_limit = cuda_options->cuda_mem_limit; info.arena_extend_strategy = static_cast(cuda_options->arena_extend_strategy); - info.cudnn_conv_algo = cuda_options->cudnn_conv_algo_search; + info.cudnn_conv_algo_search = cuda_options->cudnn_conv_algo_search; info.do_copy_in_default_stream = cuda_options->do_copy_in_default_stream; options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_CUDA(info)); diff --git a/onnxruntime/core/providers/get_execution_providers.cc b/onnxruntime/core/providers/get_execution_providers.cc new file mode 100644 index 0000000000..866d71feef --- /dev/null +++ b/onnxruntime/core/providers/get_execution_providers.cc @@ -0,0 +1,154 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/get_execution_providers.h" + +#include "core/graph/constants.h" + +namespace onnxruntime { + +namespace { +struct ProviderInfo { + const char* name; + bool available; +}; + +// all providers ordered by default priority from highest to lowest +// kCpuExecutionProvider should always be last +constexpr ProviderInfo kProvidersInPriorityOrder[] = + { + { + kTensorrtExecutionProvider, +#ifdef USE_TENSORRT + true, +#else + false, +#endif + }, + { + kCudaExecutionProvider, +#ifdef USE_CUDA + true, +#else + false, +#endif + }, + { + kMIGraphXExecutionProvider, +#ifdef USE_MIGRAPHX + true, +#else + false, +#endif + }, + { + kRocmExecutionProvider, +#ifdef USE_ROCM + true, +#else + false, +#endif + }, + { + kOpenVINOExecutionProvider, +#ifdef USE_OPENVINO + true, +#else + false, +#endif + }, + { + kDnnlExecutionProvider, +#ifdef USE_DNNL + true, +#else + false, +#endif + }, + { + kNupharExecutionProvider, +#ifdef USE_NUPHAR + true, +#else + false, +#endif + }, + { + kVitisAIExecutionProvider, +#ifdef USE_VITISAI + true, +#else + false, +#endif + }, + { + kNnapiExecutionProvider, +#ifdef USE_NNAPI + true, +#else + false, +#endif + }, + { + kArmNNExecutionProvider, +#ifdef USE_ARMNN + true, +#else + false, +#endif + }, + { + kAclExecutionProvider, +#ifdef USE_ACL + true, +#else + false, +#endif + }, + { + kDmlExecutionProvider, +#ifdef USE_DML + true, +#else + false, +#endif + }, + { + kRknpuExecutionProvider, +#ifdef USE_RKNPU + true, +#else + false, +#endif + }, + {kCpuExecutionProvider, true}, // kCpuExecutionProvider is always last +}; +} // namespace + +const std::vector& GetAllExecutionProviderNames() { + static const auto all_execution_providers = []() { + std::vector result{}; + for (const auto& provider : kProvidersInPriorityOrder) { + result.push_back(provider.name); + } + return result; + }(); + + return all_execution_providers; +} + +const std::vector& GetAvailableExecutionProviderNames() { + static const auto available_execution_providers = []() { + std::vector result{}; + for (const auto& provider : kProvidersInPriorityOrder) { + if (provider.available) { + result.push_back(provider.name); + } + } + return result; + }(); + + return available_execution_providers; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/get_execution_providers.h b/onnxruntime/core/providers/get_execution_providers.h new file mode 100644 index 0000000000..04d9166d3f --- /dev/null +++ b/onnxruntime/core/providers/get_execution_providers.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +namespace onnxruntime { + +/** + * Gets the names of all execution providers, in order of decreasing default + * priority. + */ +const std::vector& GetAllExecutionProviderNames(); + +/** + * Gets the names of execution providers available in this build, in order of + * decreasing default priority. + */ +const std::vector& GetAvailableExecutionProviderNames(); + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc index 46333b8109..a9a8679ec3 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc @@ -3,32 +3,47 @@ #include "core/providers/rocm/rocm_execution_provider_info.h" -#include "core/common/string_utils.h" +#include "core/common/make_string.h" #include "core/framework/provider_options_utils.h" namespace onnxruntime { +namespace rocm { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kMemLimit = "hip_mem_limit"; constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; } // namespace provider_option_names +} // namespace rocm + +namespace { +const EnumNameMapping arena_extend_strategy_mapping{ + {ArenaExtendStrategy::kNextPowerOfTwo, "kNextPowerOfTwo"}, + {ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"}, +}; +} // namespace ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { ROCMExecutionProviderInfo info{}; - // TODO validate info.device_id - ReadProviderOption(options, provider_option_names::kDeviceId, info.device_id); - ReadProviderOption(options, provider_option_names::kMemLimit, info.hip_mem_limit); - ReadProviderOption(options, provider_option_names::kArenaExtendStrategy, info.arena_extend_strategy); + ORT_THROW_IF_ERROR( + ProviderOptionsParser{} + // TODO validate info.device_id + .AddAssignmentToReference(rocm::provider_option_names::kDeviceId, info.device_id) + .AddAssignmentToReference(rocm::provider_option_names::kMemLimit, info.hip_mem_limit) + .AddAssignmentToEnumReference( + rocm::provider_option_names::kArenaExtendStrategy, + arena_extend_strategy_mapping, info.arena_extend_strategy) + .Parse(options)); return info; } ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecutionProviderInfo& info) { const ProviderOptions options{ - {provider_option_names::kDeviceId, MakeString(info.device_id)}, - {provider_option_names::kMemLimit, MakeString(info.hip_mem_limit)}, - {provider_option_names::kArenaExtendStrategy, MakeString(info.arena_extend_strategy)}, + {rocm::provider_option_names::kDeviceId, MakeString(info.device_id)}, + {rocm::provider_option_names::kMemLimit, MakeString(info.hip_mem_limit)}, + {rocm::provider_option_names::kArenaExtendStrategy, + EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, }; return options; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 6f613efb92..0a85199e9a 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -23,6 +23,7 @@ #include "core/framework/allocator.h" #include "core/framework/tensor.h" #include "core/framework/ml_value.h" +#include "core/providers/get_execution_providers.h" #include "core/session/environment.h" #include "core/framework/callback.h" #include "core/framework/tensorprotoutils.h" @@ -1713,19 +1714,20 @@ ORT_API_STATUS_IMPL(OrtApis::GetAvailableProviders, _Outptr_ char*** out_ptr, _In_ int* providers_length) { API_IMPL_BEGIN const size_t MAX_LEN = 30; - int available_count = (int)(sizeof(providers_available) / sizeof(char*)); - char** out = (char**)malloc(available_count * sizeof(char*)); + const auto& available_providers = GetAvailableExecutionProviderNames(); + const int available_count = gsl::narrow(available_providers.size()); + char** const out = (char**)malloc(available_count * sizeof(char*)); if (out) { for (int i = 0; i < available_count; i++) { out[i] = (char*)malloc((MAX_LEN + 1) * sizeof(char)); if (out[i]) { #ifdef _MSC_VER - strncpy_s(out[i], MAX_LEN, providers_available[i], MAX_LEN); + strncpy_s(out[i], MAX_LEN, available_providers[i].c_str(), MAX_LEN); out[i][MAX_LEN] = '\0'; #elif defined(__APPLE__) - strlcpy(out[i], providers_available[i], MAX_LEN); + strlcpy(out[i], available_providers[i].c_str(), MAX_LEN); #else - strncpy(out[i], providers_available[i], MAX_LEN); + strncpy(out[i], available_providers[i].c_str(), MAX_LEN); out[i][MAX_LEN] = '\0'; #endif } @@ -1734,7 +1736,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetAvailableProviders, _Outptr_ char*** out_ptr, *providers_length = available_count; *out_ptr = out; API_IMPL_END - return NULL; + return nullptr; } ORT_API_STATUS_IMPL(OrtApis::ReleaseAvailableProviders, _In_ char** ptr, diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 4a36123df9..e1603fcbe7 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -2,7 +2,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +import collections +import collections.abc import os +import warnings from onnxruntime.capi import _pybind_state as C @@ -17,6 +20,78 @@ def get_ort_device_type(device): raise Exception('Unsupported device type: ' + device) +def check_and_normalize_provider_args(providers, provider_options, available_provider_names): + """ + Validates the 'providers' and 'provider_options' arguments and returns a + normalized version. + + :param providers: Optional sequence of providers in order of decreasing + precedence. Values can either be provider names or tuples of + (provider name, options dict). + :param provider_options: Optional sequence of options dicts corresponding + to the providers listed in 'providers'. + :param available_provider_names: The available provider names. + + :return: Tuple of (normalized 'providers' sequence, normalized + 'provider_options' sequence). + + 'providers' can contain either names or names and options. When any options + are given in 'providers', 'provider_options' should not be used. + + The normalized result is a tuple of: + 1. Sequence of provider names in the same order as 'providers'. + 2. Sequence of corresponding provider options dicts with string keys and + values. Unspecified provider options yield empty dicts. + """ + if providers is None: + return [], [] + + provider_name_to_options = collections.OrderedDict() + + def set_provider_options(name, options): + if name not in available_provider_names: + raise ValueError("Specified provider '{}' is unavailable. Available providers: '{}'".format( + name, ", ".join(available_provider_names))) + + if name in provider_name_to_options: + warnings.warn("Duplicate provider '{}' encountered, ignoring.".format(name)) + return + + normalized_options = {str(key): str(value) for key, value in options.items()} + provider_name_to_options[name] = normalized_options + + if not isinstance(providers, collections.abc.Sequence): + raise ValueError("'providers' should be a sequence.") + + if provider_options is not None: + if not isinstance(provider_options, collections.abc.Sequence): + raise ValueError("'provider_options' should be a sequence.") + + if len(providers) != len(provider_options): + raise ValueError("'providers' and 'provider_options' should be the same length if both are given.") + + if not all([isinstance(provider, str) for provider in providers]): + raise ValueError("Only string values for 'providers' are supported if 'provider_options' is given.") + + if not all([isinstance(options_for_provider, dict) for options_for_provider in provider_options]): + raise ValueError("'provider_options' values must be dicts.") + + for name, options in zip(providers, provider_options): + set_provider_options(name, options) + + else: + for provider in providers: + if isinstance(provider, str): + set_provider_options(provider, dict()) + elif isinstance(provider, tuple) and len(provider) == 2 and \ + isinstance(provider[0], str) and isinstance(provider[1], dict): + set_provider_options(provider[0], provider[1]) + else: + raise ValueError("'providers' values must be either strings or (string, dict) tuples.") + + return list(provider_name_to_options.keys()), list(provider_name_to_options.values()) + + class Session: """ This is the main class used to run a model. @@ -55,34 +130,23 @@ class Session: "Return registered execution providers' configurations." return self._provider_options - def set_providers(self, providers, provider_options=None): + def set_providers(self, providers=None, provider_options=None): """ Register the input list of execution providers. The underlying session is re-created. - :param providers: list of execution providers - :param provider_options: list of provider options dict for each provider, in the same order as 'providers' + :param providers: Optional sequence of providers in order of decreasing + precedence. Values can either be provider names or tuples of + (provider name, options dict). If not provided, then all available + providers are used with the default precedence. + :param provider_options: Optional sequence of options dicts corresponding + to the providers listed in 'providers'. - The list of providers is ordered by Priority. For example ['CUDAExecutionProvider', 'CPUExecutionProvider'] + 'providers' can contain either names or names and options. When any options + are given in 'providers', 'provider_options' should not be used. + + The list of providers is ordered by precedence. For example ['CUDAExecutionProvider', 'CPUExecutionProvider'] means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider. """ - if not set(providers).issubset(C.get_available_providers()): - raise ValueError("{} does not contain a subset of available providers {}".format( - providers, C.get_available_providers())) - - if provider_options: - if not isinstance(providers, list) or not isinstance(provider_options, list): - raise ValueError("Inputs must be two python lists.") - - if len(providers) != len(provider_options): - raise ValueError("Two input lists must have same length.") - - for option in provider_options: - if not isinstance(option, dict): - raise ValueError("Provider options must be list of python dict.") - - for key, val in option.items(): - option[key] = str(val) - # recreate the underlying C.InferenceSession self._reset_session(providers, provider_options) @@ -173,8 +237,12 @@ class InferenceSession(Session): """ :param path_or_bytes: filename or serialized ONNX or ORT format model in a byte string :param sess_options: session options - :param providers: list of providers to use for session. If empty, will use all available providers. - :param provider_options: list of provider options dict for each provider, in the same order as 'providers' + :param providers: Optional sequence of providers in order of decreasing + precedence. Values can either be provider names or tuples of + (provider name, options dict). If not provided, then all available + providers are used with the default precedence. + :param provider_options: Optional sequence of options dicts corresponding + to the providers listed in 'providers'. The model type will be inferred unless explicitly set in the SessionOptions. To explicitly set: @@ -184,6 +252,12 @@ class InferenceSession(Session): A file extension of '.ort' will be inferred as an ORT format model. All other filenames are assumed to be ONNX format models. + + 'providers' can contain either names or names and options. When any options + are given in 'providers', 'provider_options' should not be used. + + The list of providers is ordered by precedence. For example ['CUDAExecutionProvider', 'CPUExecutionProvider'] + means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider. """ Session.__init__(self) @@ -208,15 +282,22 @@ class InferenceSession(Session): if self._enable_fallback: print("EP Error using {}".format(self._providers)) print("Falling back to {} and retrying.".format(self._fallback_providers)) - self._create_inference_session(self._fallback_providers) + self._create_inference_session(self._fallback_providers, None) # Fallback only once. self.disable_fallback() else: raise def _create_inference_session(self, providers, provider_options): + available_providers = C.get_available_providers() + + # validate providers and provider_options before other initialization + providers, provider_options = check_and_normalize_provider_args(providers, + provider_options, + available_providers) + # Tensorrt can fall back to CUDA. All others fall back to CPU. - if 'TensorrtExecutionProvider' in C.get_available_providers(): + if 'TensorrtExecutionProvider' in available_providers: self._fallback_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] else: self._fallback_providers = ['CPUExecutionProvider'] @@ -228,7 +309,7 @@ class InferenceSession(Session): sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model) # initialize the C++ InferenceSession - sess.initialize_session(providers or [], provider_options or []) + sess.initialize_session(providers, provider_options) self._sess = sess self._sess_options = self._sess.session_options diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 977bf02a81..08e2628598 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -15,6 +15,7 @@ #include "core/framework/arena_extend_strategy.h" #include "core/framework/data_transfer_utils.h" #include "core/framework/data_types_internal.h" +#include "core/providers/get_execution_providers.h" #include "core/framework/kernel_registry.h" #include "core/framework/provider_options_utils.h" #include "core/framework/random_seed.h" @@ -200,6 +201,7 @@ std::shared_ptr CreateExecutionProviderFactory_ACL(in std::shared_ptr CreateExecutionProviderFactory_ArmNN(int use_arena); std::shared_ptr CreateExecutionProviderFactory_DML(int device_id); std::shared_ptr CreateExecutionProviderFactory_Nnapi(uint32_t flags); +std::shared_ptr CreateExecutionProviderFactory_Rknpu(); } // namespace onnxruntime #if defined(_MSC_VER) @@ -425,27 +427,6 @@ static inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(p))); } -// ordered by default priority from highest to lowest. kCpuExecutionProvider should always be last. -static const std::vector& GetAllProviders() { - static std::vector all_providers = {kTensorrtExecutionProvider, kCudaExecutionProvider, - kMIGraphXExecutionProvider, kRocmExecutionProvider, - kOpenVINOExecutionProvider, kDnnlExecutionProvider, - kNupharExecutionProvider, kVitisAIExecutionProvider, - kNnapiExecutionProvider, - kArmNNExecutionProvider, kAclExecutionProvider, - kDmlExecutionProvider, kCpuExecutionProvider}; - return all_providers; -} - -static const std::vector& GetAvailableProviders() { - auto InitializeProviders = []() { - std::vector available_providers(std::begin(providers_available), std::end(providers_available)); - return available_providers; - }; - static std::vector available_providers = InitializeProviders(); - return available_providers; -} - #ifdef USE_CUDA static bool IsCudaDeviceIdValid(const onnxruntime::logging::Logger& logger, int id) { @@ -537,7 +518,7 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector info.device_id = cuda_device_id; info.cuda_mem_limit = cuda_mem_limit; info.arena_extend_strategy = arena_extend_strategy; - info.cudnn_conv_algo = cudnn_conv_algo_search; + info.cudnn_conv_algo_search = cudnn_conv_algo_search; info.do_copy_in_default_stream = do_copy_in_default_stream; return info; }(); @@ -605,7 +586,10 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector #if USE_NUPHAR const auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { - ReadProviderOption(it->second, "nuphar_settings", nuphar_settings); + ORT_THROW_IF_ERROR( + ProviderOptionsParser{} + .AddAssignmentToReference("nuphar_settings", nuphar_settings) + .Parse(it->second)); } RegisterExecutionProvider( @@ -638,6 +622,10 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector LOGS_DEFAULT(WARNING) << "NNAPI execution provider can only be used to generate ORT format model in this build."; #endif RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_Nnapi(0)); +#endif + } else if (type == kRknpuExecutionProvider) { +#ifdef USE_RKNPU + RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_Rknpu()); #endif } else { // unknown provider @@ -697,7 +685,7 @@ void InitializeSession(InferenceSession* sess, const std::vector& p if (provider_types.empty()) { // use default registration priority. - RegisterExecutionProviders(sess, GetAllProviders(), provider_options_map); + RegisterExecutionProviders(sess, GetAllExecutionProviderNames(), provider_options_map); } else { RegisterExecutionProviders(sess, provider_types, provider_options_map); } @@ -752,13 +740,15 @@ void addGlobalMethods(py::module& m, Environment& env) { }, "Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal"); m.def( - "get_all_providers", []() -> const std::vector& { return GetAllProviders(); }, + "get_all_providers", []() -> const std::vector& { return GetAllExecutionProviderNames(); }, "Return list of Execution Providers that this version of Onnxruntime can support. " - "The order of elements represents the default priority order of Execution Providers" - " from highest to lowest."); + "The order of elements represents the default priority order of Execution Providers " + "from highest to lowest."); m.def( - "get_available_providers", []() -> const std::vector& { return GetAvailableProviders(); }, - "Return list of available Execution Providers available in this installed version of Onnxruntime."); + "get_available_providers", []() -> const std::vector& { return GetAvailableExecutionProviderNames(); }, + "Return list of available Execution Providers available in this installed version of Onnxruntime. " + "The order of elements represents the default priority order of Execution Providers " + "from highest to lowest."); m.def( "enable_telemetry_events", []() -> void { platform_env.GetTelemetryProvider().EnableTelemetryEvents(); }, "Enables platform-specific telemetry collection where applicable."); @@ -833,7 +823,7 @@ void addGlobalMethods(py::module& m, Environment& env) { info.device_id = cuda_device_id; info.cuda_mem_limit = cuda_mem_limit; info.arena_extend_strategy = arena_extend_strategy; - info.cudnn_conv_algo = cudnn_conv_algo_search; + info.cudnn_conv_algo_search = cudnn_conv_algo_search; info.do_copy_in_default_stream = do_copy_in_default_stream; return info; }()), @@ -874,6 +864,9 @@ void addGlobalMethods(py::module& m, Environment& env) { #endif #ifdef USE_NNAPI onnxruntime::CreateExecutionProviderFactory_NNAPI(0), +#endif +#ifdef USE_RKNPU + onnxruntime::CreateExecutionProviderFactory_Rknpu(), #endif }; @@ -905,7 +898,7 @@ void addGlobalMethods(py::module& m, Environment& env) { }); // TODO remove deprecated global config m.def("set_cudnn_conv_algo_search", [](const OrtCudnnConvAlgoSearch algo) { - LogDeprecationWarning("set_cudnn_conv_algo_search", "CUDA execution provider option \"cudnn_conv_algo\""); + LogDeprecationWarning("set_cudnn_conv_algo_search", "CUDA execution provider option \"cudnn_conv_algo_search\""); #ifdef USE_ROCM ORT_UNUSED_PARAMETER(algo); ORT_THROW("set_cudnn_conv_algo_search is not supported in ROCM"); @@ -1084,11 +1077,6 @@ void addObjectMethods(py::module& m, Environment& env) { .value("CPU", OrtMemType::OrtMemTypeCPU) .value("DEFAULT", OrtMemType::OrtMemTypeDefault); - py::enum_(m, "OrtCudnnConvAlgoSearch") - .value("EXHAUSTIVE", OrtCudnnConvAlgoSearch::EXHAUSTIVE) - .value("HEURISTIC", OrtCudnnConvAlgoSearch::HEURISTIC) - .value("DEFAULT", OrtCudnnConvAlgoSearch::DEFAULT); - py::class_ device(m, "OrtDevice", R"pbdoc(ONNXRuntime device informaion.)pbdoc"); device.def(py::init()) .def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc") diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 2ffde9a672..767a401c8b 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -11,9 +11,6 @@ namespace onnxruntime { namespace python { -using namespace onnxruntime; -using namespace onnxruntime::logging; - #if !defined(ORT_MINIMAL_BUILD) struct CustomOpLibrary { CustomOpLibrary(const char* library_path, OrtSessionOptions& ort_so); diff --git a/onnxruntime/test/common/string_utils_test.cc b/onnxruntime/test/common/string_utils_test.cc index ca6688c6c5..c18bef8f0b 100644 --- a/onnxruntime/test/common/string_utils_test.cc +++ b/onnxruntime/test/common/string_utils_test.cc @@ -1,7 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/common/string_utils.h" +#include "core/common/make_string.h" +#include "core/common/parse_string.h" #include "gtest/gtest.h" @@ -12,29 +13,26 @@ namespace { template void TestSuccessfulParse(const std::string& input, const T& expected_value) { T value; - ASSERT_TRUE(TryParse(input, value)); + ASSERT_TRUE(TryParseString(input, value)); EXPECT_EQ(value, expected_value); } template void TestFailedParse(const std::string& input) { T value; - EXPECT_FALSE(TryParse(input, value)); + EXPECT_FALSE(TryParseString(input, value)); } } // namespace -TEST(StringUtilsTest, TryParse) { +TEST(StringUtilsTest, TryParseString) { TestSuccessfulParse("-1", -1); TestSuccessfulParse("42", 42u); TestSuccessfulParse("2.5", 2.5f); - TestSuccessfulParse("1", true); - TestSuccessfulParse("0", false); // out of range TestFailedParse("32768"); TestFailedParse("-1"); TestFailedParse("1e100"); - TestFailedParse("2"); // invalid representation TestFailedParse("1.2"); TestFailedParse("one"); @@ -43,12 +41,21 @@ TEST(StringUtilsTest, TryParse) { TestFailedParse("1 "); } -TEST(StringUtilsTest, TryParseString) { +TEST(StringUtilsTest, TryParseStringAsString) { // when parsing a string as a string, allow leading and trailing whitespace const std::string s = " this is a string! "; TestSuccessfulParse(s, s); } +TEST(StringUtilsTest, TryParseStringAsBool) { + TestSuccessfulParse("True", true); + TestSuccessfulParse("1", true); + TestSuccessfulParse("False", false); + TestSuccessfulParse("0", false); + + TestFailedParse("2"); +} + namespace { struct S { int i{}; @@ -73,12 +80,12 @@ std::istream& operator>>(std::istream& is, S& s) { } } // namespace -TEST(StringUtilsTest, MakeStringAndTryParseCustomType) { +TEST(StringUtilsTest, MakeStringAndTryParseStringWithCustomType) { S s; s.i = 42; const auto str = MakeString(s); S parsed_s; - ASSERT_TRUE(TryParse(str, parsed_s)); + ASSERT_TRUE(TryParseString(str, parsed_s)); ASSERT_EQ(parsed_s, s); } diff --git a/onnxruntime/test/framework/provider_options_utils_test.cc b/onnxruntime/test/framework/provider_options_utils_test.cc new file mode 100644 index 0000000000..d1e35e1a2f --- /dev/null +++ b/onnxruntime/test/framework/provider_options_utils_test.cc @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/provider_options_utils.h" + +#include "gtest/gtest.h" + +#include "asserts.h" + +namespace onnxruntime { +namespace test { + +namespace { +enum class TestEnum { + A, + Unmapped, +}; + +const EnumNameMapping test_enum_mapping{ + {TestEnum::A, "A"}, +}; +} // namespace + +TEST(ProviderOptionsUtilsTest, ProviderOptionsParser) { + int i; + bool b; + TestEnum e; + ProviderOptionsParser parser{}; + parser.AddAssignmentToReference("int", i); + parser.AddAssignmentToReference("bool", b); + parser.AddAssignmentToEnumReference("enum", test_enum_mapping, e); + + // adding same option again should throw + ASSERT_THROW(parser.AddAssignmentToReference("int", i), OnnxRuntimeException); + + ASSERT_STATUS_OK(parser.Parse({{"int", "3"}, {"bool", "true"}, {"enum", "A"}})); + EXPECT_EQ(i, 3); + EXPECT_EQ(b, true); + EXPECT_EQ(e, TestEnum::A); + + ASSERT_FALSE(parser.Parse({{"unknown option", "some value"}}).IsOK()); +} + +TEST(ProviderOptionsUtilsTest, EnumToName) { + std::string name; + ASSERT_STATUS_OK(EnumToName(test_enum_mapping, TestEnum::A, name)); + EXPECT_EQ(name, "A"); + ASSERT_FALSE(EnumToName(test_enum_mapping, TestEnum::Unmapped, name).IsOK()); +} + +TEST(ProviderOptionsUtilsTest, NameToEnum) { + TestEnum value; + ASSERT_STATUS_OK(NameToEnum(test_enum_mapping, "A", value)); + EXPECT_EQ(value, TestEnum::A); + ASSERT_FALSE(NameToEnum(test_enum_mapping, "invalid", value).IsOK()); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/onnx/tensorprotoutils.cc b/onnxruntime/test/onnx/tensorprotoutils.cc index ba081d6af1..33fc5dcfdf 100644 --- a/onnxruntime/test/onnx/tensorprotoutils.cc +++ b/onnxruntime/test/onnx/tensorprotoutils.cc @@ -11,7 +11,7 @@ #include "mem_buffer.h" #include "core/common/safeint.h" #include "core/common/status.h" -#include "core/common/string_utils.h" +#include "core/common/make_string.h" #include "core/framework/data_types.h" #include "core/framework/endian.h" #include "core/framework/allocator.h" diff --git a/onnxruntime/test/providers/get_execution_providers_test.cc b/onnxruntime/test/providers/get_execution_providers_test.cc new file mode 100644 index 0000000000..e2fe29609a --- /dev/null +++ b/onnxruntime/test/providers/get_execution_providers_test.cc @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/get_execution_providers.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +#include "core/graph/constants.h" + +namespace onnxruntime { +namespace test { + +TEST(GetExecutionProvidersTest, CpuEpAlwaysLast) { + const auto check = [](const std::vector& providers) { + ASSERT_FALSE(providers.empty()); + EXPECT_EQ(providers.back(), kCpuExecutionProvider); + }; + + check(GetAllExecutionProviderNames()); + check(GetAvailableExecutionProviderNames()); +} + +TEST(GetExecutionProvidersTest, ConsistentOrdering) { + const auto& all = GetAllExecutionProviderNames(); + const auto& available = GetAvailableExecutionProviderNames(); + std::vector available_from_all{}; + std::copy_if( + all.begin(), all.end(), + std::back_inserter(available_from_all), + [&available](const std::string& value) { + return std::find(available.begin(), available.end(), value) != available.end(); + }); + + EXPECT_EQ(available, available_from_all); +} + +TEST(GetExecutionProvidersTest, NoDuplicates) { + const auto check = [](const std::vector& providers) { + const std::unordered_set providers_set(providers.begin(), providers.end()); + EXPECT_EQ(providers.size(), providers_set.size()); + }; + + check(GetAllExecutionProviderNames()); + check(GetAvailableExecutionProviderNames()); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index ec58417e73..dabab43f14 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -81,13 +81,13 @@ class TestInferenceSession(unittest.TestCase): def runBaseTest2(): sess = onnxrt.InferenceSession(get_name("mul_1.onnx")) - self.assertTrue('CUDAExecutionProvider' in sess.get_providers()) + self.assertIn('CUDAExecutionProvider', sess.get_providers()) # test get/set of "cuda_mem_limit" configuration. options = sess.get_provider_options() - self.assertTrue('CUDAExecutionProvider' in options) + self.assertIn('CUDAExecutionProvider', options) option = options['CUDAExecutionProvider'] - self.assertTrue('cuda_mem_limit' in option) + self.assertIn('cuda_mem_limit', option) ori_mem_limit = option['cuda_mem_limit'] new_mem_limit = int(ori_mem_limit) // 2 option['cuda_mem_limit'] = new_mem_limit @@ -100,16 +100,27 @@ class TestInferenceSession(unittest.TestCase): options = sess.get_provider_options() self.assertEqual(options['CUDAExecutionProvider']['cuda_mem_limit'], ori_mem_limit) - # test get/set of "arena_extend_strategy" configuration. - options = sess.get_provider_options() - self.assertTrue('CUDAExecutionProvider' in options) - option = options['CUDAExecutionProvider'] - self.assertTrue('arena_extend_strategy' in option) - for strategy in ['kNextPowerOfTwo', 'kSameAsRequested']: - option['arena_extend_strategy'] = strategy - sess.set_providers(['CUDAExecutionProvider'], [option]) - options = sess.get_provider_options() - self.assertEqual(options['CUDAExecutionProvider']['arena_extend_strategy'], strategy) + def test_get_and_set_option_with_values(option_name, option_values): + provider_options = sess.get_provider_options() + self.assertIn('CUDAExecutionProvider', provider_options) + cuda_options = options['CUDAExecutionProvider'] + self.assertIn(option_name, cuda_options) + for option_value in option_values: + cuda_options[option_name] = option_value + sess.set_providers(['CUDAExecutionProvider'], [cuda_options]) + new_provider_options = sess.get_provider_options() + self.assertEqual( + new_provider_options.get('CUDAExecutionProvider', {}).get(option_name), + str(option_value)) + + test_get_and_set_option_with_values( + 'arena_extend_strategy', ['kNextPowerOfTwo', 'kSameAsRequested']) + + test_get_and_set_option_with_values( + 'cudnn_conv_algo_search', ["DEFAULT", "EXHAUSTIVE", "HEURISTIC"]) + + test_get_and_set_option_with_values( + 'do_copy_in_default_stream', [0, 1]) # # Note: Tests that throw an exception leave an empty session due to how set_providers currently works, @@ -181,21 +192,17 @@ class TestInferenceSession(unittest.TestCase): sess = onnxrt.InferenceSession(get_name("mul_1.onnx")) - # configure session with not legit option values and that shloud fail + # configure session with invalid option values and that should fail with self.assertRaises(RuntimeError): option = {'device_id': num_device} sess.set_providers(['CUDAExecutionProvider'], [option]) - option = {'device_id': 'non_legit_value'} + option = {'device_id': 'invalid_value'} sess.set_providers(['CUDAExecutionProvider'], [option]) - # configure session with not legit option should cause no effect - option = {'device_id': 0} - sess.set_providers(['CUDAExecutionProvider'], [option]) - option = {'non_legit_option': num_device} - sess.set_providers(['CUDAExecutionProvider'], [option]) - self.assertEqual(['CUDAExecutionProvider', 'CPUExecutionProvider'], sess.get_providers()) - - + # configure session with invalid option should fail + with self.assertRaises(RuntimeError): + option = {'invalid_option': 123} + sess.set_providers(['CUDAExecutionProvider'], [option]) libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll') for libname in libnames: @@ -218,8 +225,7 @@ class TestInferenceSession(unittest.TestCase): with self.assertRaises(ValueError) as context: sess = onnxrt.InferenceSession(get_name("mul_1.onnx")) sess.set_providers(['InvalidProvider']) - self.assertTrue( - '[\'InvalidProvider\'] does not contain a subset of available providers' in str(context.exception)) + self.assertTrue('\'InvalidProvider\' is unavailable' in str(context.exception)) def testSessionProviders(self): if 'CUDAExecutionProvider' in onnxrt.get_available_providers(): @@ -712,7 +718,7 @@ class TestInferenceSession(unittest.TestCase): sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), so, ['CPUExecutionProvider']) res = sess.run(["Y"], {"X": np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)}) self.assertTrue(np.array_equal(res[0], np.array([[2.0, 2.0], [12.0, 12.0], [30.0, 30.0]], dtype=np.float32))) - + def testRegisterCustomOpsLibrary(self): if sys.platform.startswith("win"): shared_library = 'custom_op_library.dll' @@ -809,17 +815,17 @@ class TestInferenceSession(unittest.TestCase): # 1. if there are intermittent failure in this test, something is wrong # 2. it's easier to repro on slower GPU (like M60, Geforce 1070) - # to repro #4829, uncomment the line below to run copy in a separate stream - #onnxrt.capi._pybind_state.set_do_copy_in_default_stream(False) + # to repro #4829, set the CUDA EP do_copy_in_default_stream option to False + providers = [("CUDAExecutionProvider", {"do_copy_in_default_stream": True}), "CPUExecutionProvider"] - session = onnxrt.InferenceSession(get_name("issue4829.onnx")) + session = onnxrt.InferenceSession(get_name("issue4829.onnx"), providers=providers) shape = np.array([2,2], dtype=np.int64) for iteration in range(100000): result = session.run(output_names=['output'], input_feed={'shape': shape}) def testSharedAllocatorUsingCreateAndRegisterAllocator(self): # Create and register an arena based allocator - + # ort_arena_cfg = onnxrt.OrtArenaCfg(0, -1, -1, -1) (create an OrtArenaCfg like this template if you want to use non-default parameters) ort_memory_info = onnxrt.OrtMemoryInfo("Cpu", onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR, 0, onnxrt.OrtMemType.DEFAULT) # Use this option if using non-default OrtArenaCfg : onnxrt.create_and_register_allocator(ort_memory_info, ort_arena_cfg) @@ -836,5 +842,55 @@ class TestInferenceSession(unittest.TestCase): so2.log_severity_level = 1 onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=so2) + def testCheckAndNormalizeProviderArgs(self): + from onnxruntime.capi.onnxruntime_inference_collection import check_and_normalize_provider_args + + valid_providers = ["a", "b", "c"] + + def check_success(providers, provider_options, expected_providers, expected_provider_options): + actual_providers, actual_provider_options = check_and_normalize_provider_args( + providers, provider_options, valid_providers) + self.assertEqual(actual_providers, expected_providers) + self.assertEqual(actual_provider_options, expected_provider_options) + + check_success(None, None, [], []) + + check_success(["a"], None, ["a"], [{}]) + + check_success(["a", "b"], None, ["a", "b"], [{}, {}]) + + check_success([("a", {1: 2}), "b"], None, ["a", "b"], [{"1": "2"}, {}]) + + check_success(["a", "b"], [{1: 2}, {}], ["a", "b"], [{"1": "2"}, {}]) + + with self.assertWarns(UserWarning): + check_success(["a", "b", "a"], [{"x": 1}, {}, {"y": 2}], ["a", "b"], [{"x": "1"}, {}]) + + def check_failure(providers, provider_options): + with self.assertRaises(ValueError): + check_and_normalize_provider_args(providers, provider_options, valid_providers) + + # provider not valid + check_failure(["d"], None) + + # providers not sequence + check_failure(3, None) + + # providers value invalid + check_failure([3], None) + + # provider_options not sequence + check_failure(["a"], 3) + + # provider_options value invalid + check_failure(["a"], ["not dict"]) + + # providers and provider_options length mismatch + check_failure(["a", "b"], [{1: 2}]) + + # provider options unsupported mixed specification + check_failure([("a", {1: 2})], [{3: 4}]) + + if __name__ == '__main__': unittest.main() diff --git a/onnxruntime/test/python/onnxruntime_test_python_nuphar.py b/onnxruntime/test/python/onnxruntime_test_python_nuphar.py index b493480bc9..47d862e92e 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_nuphar.py +++ b/onnxruntime/test/python/onnxruntime_test_python_nuphar.py @@ -25,13 +25,13 @@ def reference_gemm(a, b, c, alpha, beta, transA, transB): def set_gemm_node_attrs(attrs, config): if config['alpha'] != 1.0: - attrs['alpha'] = config['alpha'] + attrs['alpha'] = config['alpha'] if config['beta'] != 1.0: - attrs['beta'] = config['beta'] + attrs['beta'] = config['beta'] if config['transA']: - attrs['transA'] = 1 + attrs['transA'] = 1 if config['transB']: - attrs['transB'] = 1 + attrs['transB'] = 1 def generate_gemm_inputs_initializers(graph, config, added_inputs_initializers={}, extend=False): M = config['M'] @@ -323,6 +323,16 @@ def set_gemm_model_inputs(config, test_inputs, a, b, c): if config['withC'] and not config['initC']: test_inputs[config['C']] = c + +def make_providers(nuphar_settings): + return [ + ('NupharExecutionProvider', { + 'nuphar_settings': nuphar_settings + }), + 'CPUExecutionProvider', + ] + + class TestNuphar(unittest.TestCase): def test_bidaf(self): @@ -385,8 +395,9 @@ class TestNuphar(unittest.TestCase): for model in [bidaf_opt_scan_model, bidaf_int8_scan_only_model]: nuphar_settings = 'nuphar_cache_path:{}'.format(cache_dir) for isa in ['avx', 'avx2', 'avx512']: - onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings + ', nuphar_codegen_target:' + isa) - sess = onnxrt.InferenceSession(model) # JIT cache happens when initializing session + # JIT cache happens when initializing session + sess = onnxrt.InferenceSession( + model, providers=make_providers(nuphar_settings + ', nuphar_codegen_target:' + isa)) cache_dir_content = os.listdir(cache_dir) assert len(cache_dir_content) == 1 @@ -400,15 +411,13 @@ class TestNuphar(unittest.TestCase): nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}'.format( cache_dir, so_name, 'on') - onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings) - sess = onnxrt.InferenceSession(model) + sess = onnxrt.InferenceSession(model, providers=make_providers(nuphar_settings)) sess.run([], feed) # test avx nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}, nuphar_codegen_target:{}'.format( cache_dir, so_name, 'on', 'avx') - onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings) - sess = onnxrt.InferenceSession(model) + sess = onnxrt.InferenceSession(model, providers=make_providers(nuphar_settings)) sess.run([], feed) def test_bert_squad(self): @@ -672,7 +681,7 @@ class TestNuphar(unittest.TestCase): def test_loop_to_scan(self): loop_model_filename = get_name("nuphar_tiny_model_with_loop_shape_infered.onnx") - scan_model_filename = "nuphar_tiny_model_with_loop_shape_infered_converted_to_scan.onnx" + scan_model_filename = "nuphar_tiny_model_with_loop_shape_infered_converted_to_scan.onnx" subprocess.run([ sys.executable, '-m', 'onnxruntime.nuphar.model_editor', '--input', loop_model_filename, @@ -680,13 +689,13 @@ class TestNuphar(unittest.TestCase): ], check=True) validate_with_ort(loop_model_filename, scan_model_filename) - + def test_loop_to_scan_with_inconvertible_loop(self): # nuphar_onnx_test_loop11_inconvertible_loop.onnx contains a Loop op with dynamic loop count. # This Loop op cannot be converted to a Scan op. # Set --keep_unconvertible_loop_ops option so conversion will not fail due to unconvertible loop ops. loop_model_filename = get_name("nuphar_onnx_test_loop11_inconvertible_loop.onnx") - scan_model_filename = "nuphar_onnx_test_loop11_inconvertible_loop_unchanged.onnx" + scan_model_filename = "nuphar_onnx_test_loop11_inconvertible_loop_unchanged.onnx" subprocess.run([ sys.executable, '-m', 'onnxruntime.nuphar.model_editor', '--input', loop_model_filename, @@ -695,15 +704,15 @@ class TestNuphar(unittest.TestCase): ], check=True) # onnxruntime is failing with: - # onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : - # FAIL : Non-zero status code returned while running Loop node. Name:'' + # onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : + # FAIL : Non-zero status code returned while running Loop node. Name:'' # Status Message: Inconsistent shape in loop output for output. Expected:{1} Got:{0} # skip validate_with_ort for now # validate_with_ort(loop_model_filename, scan_model_filename) def test_loop_to_scan_tool(self): loop_model_filename = get_name("nuphar_tiny_model_with_loop_shape_infered.onnx") - scan_model_filename = "nuphar_tiny_model_with_loop_shape_infered_converted_to_scan.onnx" + scan_model_filename = "nuphar_tiny_model_with_loop_shape_infered_converted_to_scan.onnx" subprocess.run([ sys.executable, '-m', 'onnxruntime.nuphar.model_tools', '--input', loop_model_filename, diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 248059c916..4efdd116e6 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1,12 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include "core/session/onnxruntime_c_api.h" -#include "core/session/onnxruntime_cxx_api.h" -#include "core/session/onnxruntime_session_options_config_keys.h" -#include "core/graph/constants.h" -#include "providers.h" #include #include #include @@ -15,7 +9,16 @@ #include #include #include + #include + +#include "core/common/common.h" +#include "core/common/make_unique.h" +#include "core/graph/constants.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "providers.h" #include "test_allocator.h" #include "test_fixture.h" @@ -1188,19 +1191,19 @@ TEST(CApiTest, get_available_providers) { int len = 0; char** providers; ASSERT_EQ(g_ort->GetAvailableProviders(&providers, &len), nullptr); - ASSERT_TRUE(len > 0); - ASSERT_EQ(strcmp(providers[0], "CPUExecutionProvider"), 0); + ASSERT_GT(len, 0); + ASSERT_STREQ(providers[len-1], "CPUExecutionProvider"); ASSERT_EQ(g_ort->ReleaseAvailableProviders(providers, len), nullptr); } TEST(CApiTest, get_available_providers_cpp) { std::vector providers = Ort::GetAvailableProviders(); - ASSERT_TRUE(providers.size() > 0); - ASSERT_TRUE(providers[0] == std::string("CPUExecutionProvider")); + ASSERT_FALSE(providers.empty()); + ASSERT_EQ(providers.back(), "CPUExecutionProvider"); #ifdef USE_CUDA // CUDA EP will exist in the list but its position may vary based on other EPs included in the build - ASSERT_TRUE(std::find(providers.begin(), providers.end(), std::string("CUDAExecutionProvider")) != providers.end()); + ASSERT_TRUE(std::find(providers.begin(), providers.end(), "CUDAExecutionProvider") != providers.end()); #endif } diff --git a/orttraining/orttraining/models/bert/main.cc b/orttraining/orttraining/models/bert/main.cc index df960b009f..110bfc9cb3 100644 --- a/orttraining/orttraining/models/bert/main.cc +++ b/orttraining/orttraining/models/bert/main.cc @@ -576,7 +576,7 @@ void setup_training_params(BertParameters& params) { if (params.cuda_mem_limit_in_gb > 0) { info.cuda_mem_limit = gsl::narrow(params.cuda_mem_limit_in_gb * 1024 * 1024 * 1024); } - info.cudnn_conv_algo = OrtCudnnConvAlgoSearch::EXHAUSTIVE; + info.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::EXHAUSTIVE; params.providers.emplace(kCudaExecutionProvider, CreateExecutionProviderFactory_CUDA(info)); params.input_allocator = std::make_shared(info.device_id, CUDA_PINNED); diff --git a/orttraining/orttraining/python/deprecated/training_session.py b/orttraining/orttraining/python/deprecated/training_session.py index ef8bedc54c..e75edb162a 100644 --- a/orttraining/orttraining/python/deprecated/training_session.py +++ b/orttraining/orttraining/python/deprecated/training_session.py @@ -7,11 +7,12 @@ import sys import os from onnxruntime.capi import _pybind_state as C -from onnxruntime.capi.onnxruntime_inference_collection import Session, InferenceSession, IOBinding +from onnxruntime.capi.onnxruntime_inference_collection import (Session, InferenceSession, IOBinding, + check_and_normalize_provider_args) class TrainingSession(InferenceSession): - def __init__(self, path_or_bytes, parameters, sess_options=None): + def __init__(self, path_or_bytes, parameters, sess_options=None, providers=None, provider_options=None): Session.__init__(self) if sess_options: @@ -19,10 +20,13 @@ class TrainingSession(InferenceSession): else: self._sess = C.TrainingSession() + providers, provider_options = check_and_normalize_provider_args(providers, provider_options, + C.get_available_providers()) + if isinstance(path_or_bytes, str): - config_result = self._sess.load_model(path_or_bytes, parameters) + config_result = self._sess.load_model(path_or_bytes, parameters, providers, provider_options) elif isinstance(path_or_bytes, bytes): - config_result = self._sess.read_bytes(path_or_bytes, parameters) + config_result = self._sess.read_bytes(path_or_bytes, parameters, providers, provider_options) else: raise TypeError("Unable to load from type '{0}'".format(type(path_or_bytes))) diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index ac4a1074bb..c6f2f2e159 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -305,7 +305,7 @@ void addObjectMethodsForTraining(py::module& m) { #endif #endif }) - .def("load_model", [](PyTrainingSession* sess, const std::string& path, TrainingParameters& parameters) { + .def("load_model", [](PyTrainingSession* sess, const std::string& path, TrainingParameters& parameters, const std::vector& provider_types, const ProviderOptionsVector& provider_options) { OrtPybindThrowIfError(sess->GetSessionHandle()->Load(path)); #if defined(USE_MPI) @@ -315,12 +315,11 @@ void addObjectMethodsForTraining(py::module& m) { #endif const auto config_result = ConfigureSessionForTraining(static_cast(sess->GetSessionHandle()), parameters); - std::vector provider_types = {}; - InitializeSession(sess->GetSessionHandle(), provider_types); + InitializeSession(sess->GetSessionHandle(), provider_types, provider_options); return config_result; }) - .def("read_bytes", [](PyTrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters) { + .def("read_bytes", [](PyTrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters, const std::vector& provider_types, const ProviderOptionsVector& provider_options) { std::istringstream buffer(serialized_model); OrtPybindThrowIfError(sess->GetSessionHandle()->Load(buffer)); @@ -331,8 +330,7 @@ void addObjectMethodsForTraining(py::module& m) { #endif const auto config_result = ConfigureSessionForTraining(static_cast(sess->GetSessionHandle()), parameters); - std::vector provider_types = {}; - InitializeSession(sess->GetSessionHandle(), provider_types); + InitializeSession(sess->GetSessionHandle(), provider_types, provider_options); return config_result; }) diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index 6d5c55253b..4de0a9c210 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -195,13 +195,6 @@ class ORTTrainer(object): break assert dtype is not None, f"ONNX model with unknown output type ({o_desc.name})" - # Set GPU device and memory limit - if 'cuda' in self.options.device.id.lower(): - mem_limit = self.options.device.mem_limit - if mem_limit > 0: - ort.set_cuda_mem_limit(self.options.device.mem_limit) - ort.set_cuda_device_id(_utils.get_device_index(self.options.device.id)) - # TODO: Remove when experimental checkpoint functions are removed. self._state_dict = {} @@ -655,10 +648,30 @@ class ORTTrainer(object): # old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error. # for example, load_state_dict will be called before returing the function, and it calls _init_session again del self._training_session + + # Set provider-specific options if needed + def get_providers(): + providers = ort.get_available_providers() + + if 'cuda' in self.options.device.id.lower(): + cuda_ep_options = {"device_id": _utils.get_device_index(self.options.device.id)} + if self.options.device.mem_limit > 0: + cuda_ep_options["cuda_mem_limit"] = self.options.device.mem_limit + + cuda_ep_name = "CUDAExecutionProvider" + + if cuda_ep_name not in providers: + raise RuntimeError( + "ORTTrainer options specify a CUDA device but the {} provider is unavailable.".format( + cuda_ep_name)) + + providers[providers.index(cuda_ep_name)] = (cuda_ep_name, cuda_ep_options) + + return providers + # TrainingSession - self._training_session = ort.TrainingSession(self._onnx_model.SerializeToString(), - ort_parameters, - session_options) + self._training_session = ort.TrainingSession(self._onnx_model.SerializeToString(), ort_parameters, + session_options, get_providers()) # I/O bindings self._train_io_binding = self._training_session.io_binding()