diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 7f4a84f99c..c02f6cb300 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -7,6 +7,7 @@ #include #include "onnxruntime_session_options_config_keys.h" + // This value is used in structures passed to ORT so that a newer version of ORT will still work with them #define ORT_API_VERSION 5 @@ -259,6 +260,19 @@ typedef enum OrtMemType { OrtMemTypeDefault = 0, // the default allocator for execution provider } OrtMemType; +typedef enum OrtCudnnConvAlgoSearch { + EXHAUSTIVE, // expensive exhaustive benchmarking using cudnnFindConvolutionForwardAlgorithmEx + HEURISTIC, // lightweight heuristic based search using cudnnGetConvolutionForwardAlgorithm_v7 + DEFAULT, // default algorithm using CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM +} OrtCudnnConvAlgoSearch; + +typedef struct OrtCUDAProviderOptions { + int device_id; // cuda device with id=0 as default device. + OrtCudnnConvAlgoSearch cudnn_conv_algo_search; // cudnn conv algo search option + size_t cuda_mem_limit; // default cuda memory limitation to maximum finite value of size_t. + int arena_extend_strategy; // default area extend strategy to KNextPowerOfTwo. +} OrtCUDAProviderOptions; + struct OrtApi; typedef struct OrtApi OrtApi; @@ -1032,7 +1046,8 @@ struct OrtApi { */ ORT_API2_STATUS(SessionGetProfilingStartTimeNs, _In_ const OrtSession* sess, _Outptr_ uint64_t* out); - /** + +/** * Use this API to configure the global thread pool options to be used in the call to CreateEnvWithGlobalThreadPools. * A value of 0 means ORT will pick the default. * A value of 1 means the invoking thread will be used; no threads will be created in the thread pool. @@ -1072,6 +1087,14 @@ struct OrtApi { */ ORT_API2_STATUS(CreateEnvWithCustomLoggerAndGlobalThreadPools, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, OrtLoggingLevel logging_level, _In_ const char* logid, _In_ const struct OrtThreadingOptions* tp_options, _Outptr_ OrtEnv** out); + + #ifdef USE_CUDA + /** + * Append CUDA execution provider + */ + ORT_API2_STATUS(OrtSessionOptionsAppendExecutionProvider_CUDA, + _In_ OrtSessionOptions* options, _In_ OrtCUDAProviderOptions* cuda_options); +#endif // USE_CUDA }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index e98d87a3a1..128792ec2e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -250,8 +250,10 @@ struct SessionOptions : Base { SessionOptions& DisablePerSessionThreads(); SessionOptions& AddConfigEntry(const char* config_key, const char* config_value); - SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val); +#ifdef USE_CUDA + OrtStatus* OrtSessionOptionsAppendExecutionProvider_CUDA(OrtSessionOptions* options, OrtCUDAProviderOptions* cuda_options); +#endif }; struct ModelMetadata : Base { diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 21926ab079..3c66c3cbc1 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -461,6 +461,13 @@ inline SessionOptions& SessionOptions::AddInitializer(const char* name, const Or return *this; } + #ifdef USE_CUDA +inline OrtStatus* SessionOptions::OrtSessionOptionsAppendExecutionProvider_CUDA(OrtSessionOptions * options, OrtCUDAProviderOptions * cuda_options) { + ThrowOnError(GetApi().OrtSessionOptionsAppendExecutionProvider_CUDA(options, cuda_options)); + return nullptr; +} +#endif + inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) { ThrowOnError(GetApi().CreateSession(env, model_path, options, &p_)); } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 165fb35115..2a8a92b9be 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -120,7 +120,8 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in : IExecutionProvider{onnxruntime::kCudaExecutionProvider}, device_id_(info.device_id), cuda_mem_limit_(info.cuda_mem_limit), - arena_extend_strategy_(info.arena_extend_strategy) { + arena_extend_strategy_(info.arena_extend_strategy), + cudnn_conv_algo_(info.cudnn_conv_algo) { CUDA_CALL_THROW(cudaSetDevice(device_id_)); // must wait GPU idle, otherwise cudaGetDeviceProperties might fail diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 4507d87c17..139902c731 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -24,6 +24,7 @@ struct CUDAExecutionProviderInfo { OrtDevice::DeviceId device_id{0}; size_t cuda_mem_limit{std::numeric_limits::max()}; ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; + OrtCudnnConvAlgoSearch cudnn_conv_algo{OrtCudnnConvAlgoSearch::EXHAUSTIVE}; }; // Logical device representation. @@ -77,13 +78,15 @@ class CUDAExecutionProvider : public IExecutionProvider { int GetDeviceId() const { return device_id_; } const cudaDeviceProp& GetDeviceProp() const { return device_prop_; }; + int GetCudnnConvAlgo() const { return cudnn_conv_algo_; } void UpdateProviderOptionsInfo(); - private: +private: OrtDevice::DeviceId device_id_; cudaDeviceProp device_prop_; size_t cuda_mem_limit_; ArenaExtendStrategy arena_extend_strategy_; + int cudnn_conv_algo_; struct DeferredReleaseCPUPtrs { bool recorded = false; diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index f5026387f9..7518ed09dd 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -6,6 +6,8 @@ #include "core/graph/onnx_protobuf.h" #include "cuda_execution_provider.h" #include "core/session/abi_session_options_impl.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_apis.h" #include "core/framework/bfc_arena.h" using namespace onnxruntime; @@ -15,10 +17,12 @@ namespace onnxruntime { struct CUDAProviderFactory : IExecutionProviderFactory { CUDAProviderFactory(OrtDevice::DeviceId device_id, size_t cuda_mem_limit = std::numeric_limits::max(), - ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo) + ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo, + OrtCudnnConvAlgoSearch cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::EXHAUSTIVE) : device_id_(device_id), cuda_mem_limit_(cuda_mem_limit), - arena_extend_strategy_(arena_extend_strategy) {} + arena_extend_strategy_(arena_extend_strategy), + cudnn_conv_algo_search_(cudnn_conv_algo_search) {} ~CUDAProviderFactory() override {} std::unique_ptr CreateProvider() override; @@ -27,6 +31,7 @@ struct CUDAProviderFactory : IExecutionProviderFactory { OrtDevice::DeviceId device_id_; size_t cuda_mem_limit_; ArenaExtendStrategy arena_extend_strategy_; + OrtCudnnConvAlgoSearch cudnn_conv_algo_search_; }; std::unique_ptr CUDAProviderFactory::CreateProvider() { @@ -34,18 +39,28 @@ std::unique_ptr CUDAProviderFactory::CreateProvider() { info.device_id = device_id_; info.cuda_mem_limit = cuda_mem_limit_; info.arena_extend_strategy = arena_extend_strategy_; + info.cudnn_conv_algo = cudnn_conv_algo_search_; return onnxruntime::make_unique(info); } std::shared_ptr CreateExecutionProviderFactory_CUDA(OrtDevice::DeviceId device_id, + OrtCudnnConvAlgoSearch cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::EXHAUSTIVE, size_t cuda_mem_limit = std::numeric_limits::max(), ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo) { - return std::make_shared(device_id, cuda_mem_limit, arena_extend_strategy); + return std::make_shared(device_id, cuda_mem_limit, arena_extend_strategy, cudnn_conv_algo_search); } } // namespace onnxruntime -ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id) { +ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id){ options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_CUDA(static_cast(device_id))); return nullptr; } + +ORT_API_STATUS_IMPL(OrtApis::OrtSessionOptionsAppendExecutionProvider_CUDA, + _In_ OrtSessionOptions* options, _In_ OrtCUDAProviderOptions* cuda_options) { + options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_CUDA(static_cast(cuda_options->device_id), + cuda_options->cudnn_conv_algo_search, cuda_options->cuda_mem_limit, + static_cast(cuda_options->arena_extend_strategy))); + return nullptr; +} diff --git a/onnxruntime/core/providers/cuda/cuda_provider_options.h b/onnxruntime/core/providers/cuda/cuda_provider_options.h deleted file mode 100644 index b8fb8f8a97..0000000000 --- a/onnxruntime/core/providers/cuda/cuda_provider_options.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -namespace onnxruntime { - -/** - * Configuration information for a cuda provider. - * - * Note: This struct is currently for internal use for Python API, - * not for C/C++/C#...APIs. - */ -struct CudaProviderOptions { - - // use cuda device with id=0 as default device. - OrtDevice::DeviceId device_id = 0; - - // set default cuda memory limitation to maximum finite value of size_t. - size_t cuda_mem_limit = std::numeric_limits::max(); - - // set default area extend strategy to KNextPowerOfTwo. - onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo; -}; -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index b65f3289b4..37be8f662b 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -211,20 +211,57 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { cudnnConvolutionFwdAlgoPerf_t perf; int algo_count = 1; - CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx( - CudnnHandle(), - s_.x_tensor, - x_data, - s_.filter_desc, - w_data, - s_.conv_desc, - s_.y_tensor, - y_data, - 1, - &algo_count, - &perf, - algo_search_workspace.get(), - AlgoSearchWorkspaceSize)); + const CUDAExecutionProvider* cuda_ep = static_cast(this->Info().GetExecutionProvider()); + int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); + ORT_ENFORCE(cudnn_conv_algo > -1 && cudnn_conv_algo < 3, "cudnn_conv_algo should be 0, 1 or 2, but got ", cudnn_conv_algo); + switch (cudnn_conv_algo) { + case 0: + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx( + CudnnHandle(), + s_.x_tensor, + x_data, + s_.filter_desc, + w_data, + s_.conv_desc, + s_.y_tensor, + y_data, + 1, + &algo_count, + &perf, + algo_search_workspace.get(), + AlgoSearchWorkspaceSize)); + break; + + case 1: + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7( + CudnnHandle(), + s_.x_tensor, + s_.filter_desc, + s_.conv_desc, + s_.y_tensor, + 1, + &algo_count, + &perf)); + break; + + default: + perf.algo = kDefaultConvAlgo; + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardWorkspaceSize( + CudnnHandle(), + s_.x_tensor, + s_.filter_desc, + s_.conv_desc, + s_.y_tensor, + perf.algo, + &perf.memory)); + if (std::is_same::value) { + perf.mathType = CUDNN_TENSOR_OP_MATH; + } + else { + perf.mathType = CUDNN_DEFAULT_MATH; + } + } + s_.cached_benchmark_results.insert(x_dims_cudnn, {perf.algo, perf.memory, perf.mathType}); } diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index a081285678..fc4b5d766a 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -160,8 +160,8 @@ class Conv : public CudaKernel { private: ConvAttributes conv_attrs_; - mutable CudnnConvState s_; + constexpr static auto kDefaultConvAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; }; } // namespace cuda diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index a642318738..c8fb7d0ca6 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -33,6 +33,9 @@ #include "abi_session_options_impl.h" #include "core/framework/TensorSeq.h" #include "core/platform/ort_mutex.h" +#ifdef USE_CUDA +#include "core/providers/cuda/cuda_provider_factory.h" +#endif using namespace onnxruntime::logging; using onnxruntime::BFloat16; @@ -2028,6 +2031,9 @@ static constexpr OrtApi ort_api_1_to_6 = { // Version 6 - In development, feel free to add/remove/rearrange here &OrtApis::AddInitializer, &OrtApis::CreateEnvWithCustomLoggerAndGlobalThreadPools, +#ifdef USE_CUDA + &OrtApis::OrtSessionOptionsAppendExecutionProvider_CUDA, +#endif }; // Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other) diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 84f9eb2fe3..44ed0ff948 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -237,9 +237,13 @@ ORT_API_STATUS_IMPL(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const ORT_API_STATUS_IMPL(SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection); ORT_API_STATUS_IMPL(SessionGetProfilingStartTimeNs, _In_ const OrtSession* sess, _Outptr_ uint64_t* out); + ORT_API_STATUS_IMPL(SetGlobalIntraOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int intra_op_num_threads); ORT_API_STATUS_IMPL(SetGlobalInterOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int inter_op_num_threads); ORT_API_STATUS_IMPL(SetGlobalSpinControl, _Inout_ OrtThreadingOptions* tp_options, int allow_spinning); ORT_API_STATUS_IMPL(AddInitializer, _Inout_ OrtSessionOptions* options, _In_ const char* name, _In_ const OrtValue* val); + +ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CUDA, + _In_ OrtSessionOptions* options, _In_ OrtCUDAProviderOptions* cuda_options); } // namespace OrtApis diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 0be1a573b2..a8ce6f4a13 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -138,8 +138,9 @@ struct OrtStatus { #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_factory.h" #include "core/providers/cuda/shared_inc/cuda_call.h" -#include "core/providers/cuda/cuda_provider_options.h" +#include "core/providers/cuda/cuda_execution_provider.h" OrtDevice::DeviceId cuda_device_id = 0; +OrtCudnnConvAlgoSearch cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::EXHAUSTIVE; size_t cuda_mem_limit = std::numeric_limits::max(); onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo; #endif @@ -178,6 +179,7 @@ std::string nuphar_settings; namespace onnxruntime { std::shared_ptr CreateExecutionProviderFactory_CPU(int use_arena); std::shared_ptr CreateExecutionProviderFactory_CUDA(OrtDevice::DeviceId device_id, + OrtCudnnConvAlgoSearch cudnn_conv_algo_search, size_t cuda_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy); std::shared_ptr CreateExecutionProviderFactory_Tensorrt(int device_id); @@ -446,7 +448,7 @@ bool IsCudaDeviceIdValid(InferenceSession* sess, int id) { return true; } -void UpdateCudaProviderOptions(InferenceSession* sess, onnxruntime::CudaProviderOptions& options, +void UpdateCudaProviderOptions(InferenceSession* sess, onnxruntime::CUDAExecutionProviderInfo& options, std::unordered_map options_map) { std::unordered_map::iterator it; @@ -537,16 +539,18 @@ void RegisterExecutionProviders(InferenceSession* sess, const std::vectorsecond); RegisterExecutionProvider( sess, *onnxruntime::CreateExecutionProviderFactory_CUDA(cuda_provider_options.device_id, + cuda_provider_options.cudnn_conv_algo, cuda_provider_options.cuda_mem_limit, cuda_provider_options.arena_extend_strategy)); } else { RegisterExecutionProvider( sess, *onnxruntime::CreateExecutionProviderFactory_CUDA(cuda_device_id, + cudnn_conv_algo_search, cuda_mem_limit, arena_extend_strategy)); } @@ -749,7 +753,7 @@ void addGlobalMethods(py::module& m, const Environment& env) { std::vector> factories = { onnxruntime::CreateExecutionProviderFactory_CPU(0), #ifdef USE_CUDA - onnxruntime::CreateExecutionProviderFactory_CUDA(cuda_device_id, cuda_mem_limit, arena_extend_strategy), + onnxruntime::CreateExecutionProviderFactory_CUDA(cuda_device_id, cudnn_conv_algo_search, cuda_mem_limit, arena_extend_strategy), #endif #ifdef USE_DNNL onnxruntime::CreateExecutionProviderFactory_Dnnl(1), @@ -802,6 +806,7 @@ void addGlobalMethods(py::module& m, const Environment& env) { * */ m.def("set_cuda_device_id", [](const int id) { cuda_device_id = static_cast(id); }); + m.def("set_cudnn_conv_algo_search", [](const OrtCudnnConvAlgoSearch algo) { cudnn_conv_algo_search = algo; }); m.def("set_cuda_mem_limit", [](const int64_t limit) { cuda_mem_limit = static_cast(limit); }); m.def("set_arena_extend_strategy", [](const onnxruntime::ArenaExtendStrategy strategy) { arena_extend_strategy = strategy; }); #endif diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 7be58c050c..07a8436c13 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include #include #include #include @@ -332,7 +331,13 @@ int real_main(int argc, char* argv[], Ort::Env& env) { } if (enable_cuda) { #ifdef USE_CUDA - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(sf, device_id)); + OrtCUDAProviderOptions cuda_options{ + 0, + OrtCudnnConvAlgoSearch::EXHAUSTIVE, + std::numeric_limits::max(), + 0, + }; + Ort::ThrowOnError(sf.OrtSessionOptionsAppendExecutionProvider_CUDA(sf, &cuda_options)); #else fprintf(stderr, "CUDA is not supported in this build"); return -1; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index ed139249c1..cc8daff820 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -48,12 +48,13 @@ namespace perftest { "\t-o [optimization level]: Default is 1. Valid values are 0 (disable), 1 (basic), 2 (extended), 99 (all).\n" "\t\tPlease see onnxruntime_c_api.h (enum GraphOptimizationLevel) for the full list of all optimization levels. \n" "\t-u [optimized_model_path]: Specify the optimized model path for saving.\n" + "\t-d [cudnn_conv_algorithm]: Specify CUDNN convolution algothrithms: 0(benchmark), 1(heuristic), 2(default). \n" "\t-h: help\n"); } /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("b:m:e:r:t:p:x:y:c:o:u:AMPIvhs"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("b:m:e:r:t:p:x:y:c:d:o:u:AMPIvhs"))) != -1) { switch (ch) { case 'm': if (!CompareCString(optarg, ORT_TSTR("duration"))) { @@ -177,6 +178,9 @@ namespace perftest { case 'I': test_config.run_config.generate_model_input_binding = true; break; + case 'd': + test_config.run_config.cudnn_conv_algo = static_cast(OrtStrtol(optarg, nullptr)); + break; case '?': case 'h': default: diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 941804d073..e399bc359f 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -46,7 +46,13 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #endif } else if (provider_name == onnxruntime::kCudaExecutionProvider) { #ifdef USE_CUDA - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptions cuda_options{ + 0, + OrtCudnnConvAlgoSearch::EXHAUSTIVE, + std::numeric_limits::max(), + 0, + }; + Ort::ThrowOnError(session_options.OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, &cuda_options)); #else ORT_THROW("CUDA is not supported in this build\n"); #endif diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index 9554f12486..cd5b242962 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -50,6 +50,7 @@ struct RunConfig { int inter_op_num_threads{0}; GraphOptimizationLevel optimization_level{ORT_ENABLE_ALL}; std::basic_string optimized_model_path; + int cudnn_conv_algo{0}; }; struct PerformanceTestConfig { diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 880ddd9fcc..5e61fa6a40 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -11,6 +11,7 @@ namespace onnxruntime { std::shared_ptr CreateExecutionProviderFactory_CPU(int use_arena); std::shared_ptr CreateExecutionProviderFactory_CUDA(OrtDevice::DeviceId device_id, + OrtCudnnConvAlgoSearch cudnn_conv_algo = OrtCudnnConvAlgoSearch::EXHAUSTIVE, size_t cuda_mem_limit = std::numeric_limits::max(), ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo); std::shared_ptr CreateExecutionProviderFactory_Dnnl(int use_arena); diff --git a/orttraining/orttraining/models/bert/main.cc b/orttraining/orttraining/models/bert/main.cc index 0e24c4d9a8..a1c46e244c 100644 --- a/orttraining/orttraining/models/bert/main.cc +++ b/orttraining/orttraining/models/bert/main.cc @@ -20,6 +20,7 @@ namespace onnxruntime { std::shared_ptr CreateExecutionProviderFactory_CUDA(OrtDevice::DeviceId device_id, + OrtCudnnConvAlgoSearch cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::EXHAUSTIVE, size_t cuda_mem_limit = std::numeric_limits::max(), onnxruntime::ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo); } @@ -568,7 +569,8 @@ void setup_training_params(BertParameters& params) { size_t cuda_mem_limit = std::numeric_limits::max(); if (params.cuda_mem_limit_in_gb > 0) cuda_mem_limit = static_cast(params.cuda_mem_limit_in_gb * 1024 * 1024 * 1024); - params.providers.emplace(kCudaExecutionProvider, CreateExecutionProviderFactory_CUDA(device_id, cuda_mem_limit)); + params.providers.emplace(kCudaExecutionProvider, CreateExecutionProviderFactory_CUDA(device_id, OrtCudnnConvAlgoSearch::EXHAUSTIVE, + cuda_mem_limit)); params.input_allocator = std::make_shared(device_id, CUDA_PINNED); #endif diff --git a/orttraining/orttraining/models/gpt2/main.cc b/orttraining/orttraining/models/gpt2/main.cc index 4baf3bd01e..c415bf9578 100644 --- a/orttraining/orttraining/models/gpt2/main.cc +++ b/orttraining/orttraining/models/gpt2/main.cc @@ -19,6 +19,7 @@ namespace onnxruntime { std::shared_ptr CreateExecutionProviderFactory_CUDA(OrtDevice::DeviceId device_id, + OrtCudnnConvAlgoSearch cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::EXHAUSTIVE, size_t cuda_mem_limit = std::numeric_limits::max(), onnxruntime::ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo); } diff --git a/orttraining/orttraining/models/mnist/main.cc b/orttraining/orttraining/models/mnist/main.cc index e867d0e6b7..11f0f330b5 100644 --- a/orttraining/orttraining/models/mnist/main.cc +++ b/orttraining/orttraining/models/mnist/main.cc @@ -19,6 +19,7 @@ namespace onnxruntime { std::shared_ptr CreateExecutionProviderFactory_CUDA(OrtDevice::DeviceId device_id, + OrtCudnnConvAlgoSearch cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::EXHAUSTIVE, size_t cuda_mem_limit = std::numeric_limits::max(), onnxruntime::ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo); }