Adding an option for cudnn conv algorithms. (#5159)

* adding cudnn conv algorithm selection options.

* adding cudnn conv algorithm selection options.

* export the api

* adding the perf test option.

* accomodating pr comments.

* Move OrtSessionOptionsAppendExecutionProvider_CUDA to onnxruntime_c_api.h

* Accomodating PR comments.
This commit is contained in:
Du Li 2020-10-05 16:53:52 -07:00 committed by GitHub
parent a0b8218f9a
commit 323c4dfe02
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 156 additions and 57 deletions

View file

@ -7,6 +7,7 @@
#include <string.h>
#include "onnxruntime_session_options_config_keys.h"
// This value is used in structures passed to ORT so that a newer version of ORT will still work with them
#define ORT_API_VERSION 5
@ -259,6 +260,19 @@ typedef enum OrtMemType {
OrtMemTypeDefault = 0, // the default allocator for execution provider
} OrtMemType;
typedef enum OrtCudnnConvAlgoSearch {
EXHAUSTIVE, // expensive exhaustive benchmarking using cudnnFindConvolutionForwardAlgorithmEx
HEURISTIC, // lightweight heuristic based search using cudnnGetConvolutionForwardAlgorithm_v7
DEFAULT, // default algorithm using CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
} OrtCudnnConvAlgoSearch;
typedef struct OrtCUDAProviderOptions {
int device_id; // cuda device with id=0 as default device.
OrtCudnnConvAlgoSearch cudnn_conv_algo_search; // cudnn conv algo search option
size_t cuda_mem_limit; // default cuda memory limitation to maximum finite value of size_t.
int arena_extend_strategy; // default area extend strategy to KNextPowerOfTwo.
} OrtCUDAProviderOptions;
struct OrtApi;
typedef struct OrtApi OrtApi;
@ -1032,7 +1046,8 @@ struct OrtApi {
*/
ORT_API2_STATUS(SessionGetProfilingStartTimeNs, _In_ const OrtSession* sess, _Outptr_ uint64_t* out);
/**
/**
* Use this API to configure the global thread pool options to be used in the call to CreateEnvWithGlobalThreadPools.
* A value of 0 means ORT will pick the default.
* A value of 1 means the invoking thread will be used; no threads will be created in the thread pool.
@ -1072,6 +1087,14 @@ struct OrtApi {
*/
ORT_API2_STATUS(CreateEnvWithCustomLoggerAndGlobalThreadPools, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, OrtLoggingLevel logging_level,
_In_ const char* logid, _In_ const struct OrtThreadingOptions* tp_options, _Outptr_ OrtEnv** out);
#ifdef USE_CUDA
/**
* Append CUDA execution provider
*/
ORT_API2_STATUS(OrtSessionOptionsAppendExecutionProvider_CUDA,
_In_ OrtSessionOptions* options, _In_ OrtCUDAProviderOptions* cuda_options);
#endif // USE_CUDA
};
/*

View file

@ -250,8 +250,10 @@ struct SessionOptions : Base<OrtSessionOptions> {
SessionOptions& DisablePerSessionThreads();
SessionOptions& AddConfigEntry(const char* config_key, const char* config_value);
SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val);
#ifdef USE_CUDA
OrtStatus* OrtSessionOptionsAppendExecutionProvider_CUDA(OrtSessionOptions* options, OrtCUDAProviderOptions* cuda_options);
#endif
};
struct ModelMetadata : Base<OrtModelMetadata> {

View file

@ -461,6 +461,13 @@ inline SessionOptions& SessionOptions::AddInitializer(const char* name, const Or
return *this;
}
#ifdef USE_CUDA
inline OrtStatus* SessionOptions::OrtSessionOptionsAppendExecutionProvider_CUDA(OrtSessionOptions * options, OrtCUDAProviderOptions * cuda_options) {
ThrowOnError(GetApi().OrtSessionOptionsAppendExecutionProvider_CUDA(options, cuda_options));
return nullptr;
}
#endif
inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
ThrowOnError(GetApi().CreateSession(env, model_path, options, &p_));
}

View file

@ -120,7 +120,8 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in
: IExecutionProvider{onnxruntime::kCudaExecutionProvider},
device_id_(info.device_id),
cuda_mem_limit_(info.cuda_mem_limit),
arena_extend_strategy_(info.arena_extend_strategy) {
arena_extend_strategy_(info.arena_extend_strategy),
cudnn_conv_algo_(info.cudnn_conv_algo) {
CUDA_CALL_THROW(cudaSetDevice(device_id_));
// must wait GPU idle, otherwise cudaGetDeviceProperties might fail

View file

@ -24,6 +24,7 @@ struct CUDAExecutionProviderInfo {
OrtDevice::DeviceId device_id{0};
size_t cuda_mem_limit{std::numeric_limits<size_t>::max()};
ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo};
OrtCudnnConvAlgoSearch cudnn_conv_algo{OrtCudnnConvAlgoSearch::EXHAUSTIVE};
};
// Logical device representation.
@ -77,13 +78,15 @@ class CUDAExecutionProvider : public IExecutionProvider {
int GetDeviceId() const { return device_id_; }
const cudaDeviceProp& GetDeviceProp() const { return device_prop_; };
int GetCudnnConvAlgo() const { return cudnn_conv_algo_; }
void UpdateProviderOptionsInfo();
private:
private:
OrtDevice::DeviceId device_id_;
cudaDeviceProp device_prop_;
size_t cuda_mem_limit_;
ArenaExtendStrategy arena_extend_strategy_;
int cudnn_conv_algo_;
struct DeferredReleaseCPUPtrs {
bool recorded = false;

View file

@ -6,6 +6,8 @@
#include "core/graph/onnx_protobuf.h"
#include "cuda_execution_provider.h"
#include "core/session/abi_session_options_impl.h"
#include "core/session/onnxruntime_c_api.h"
#include "core/session/ort_apis.h"
#include "core/framework/bfc_arena.h"
using namespace onnxruntime;
@ -15,10 +17,12 @@ namespace onnxruntime {
struct CUDAProviderFactory : IExecutionProviderFactory {
CUDAProviderFactory(OrtDevice::DeviceId device_id,
size_t cuda_mem_limit = std::numeric_limits<size_t>::max(),
ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo)
ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo,
OrtCudnnConvAlgoSearch cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::EXHAUSTIVE)
: device_id_(device_id),
cuda_mem_limit_(cuda_mem_limit),
arena_extend_strategy_(arena_extend_strategy) {}
arena_extend_strategy_(arena_extend_strategy),
cudnn_conv_algo_search_(cudnn_conv_algo_search) {}
~CUDAProviderFactory() override {}
std::unique_ptr<IExecutionProvider> CreateProvider() override;
@ -27,6 +31,7 @@ struct CUDAProviderFactory : IExecutionProviderFactory {
OrtDevice::DeviceId device_id_;
size_t cuda_mem_limit_;
ArenaExtendStrategy arena_extend_strategy_;
OrtCudnnConvAlgoSearch cudnn_conv_algo_search_;
};
std::unique_ptr<IExecutionProvider> CUDAProviderFactory::CreateProvider() {
@ -34,18 +39,28 @@ std::unique_ptr<IExecutionProvider> CUDAProviderFactory::CreateProvider() {
info.device_id = device_id_;
info.cuda_mem_limit = cuda_mem_limit_;
info.arena_extend_strategy = arena_extend_strategy_;
info.cudnn_conv_algo = cudnn_conv_algo_search_;
return onnxruntime::make_unique<CUDAExecutionProvider>(info);
}
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CUDA(OrtDevice::DeviceId device_id,
OrtCudnnConvAlgoSearch cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::EXHAUSTIVE,
size_t cuda_mem_limit = std::numeric_limits<size_t>::max(),
ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo) {
return std::make_shared<onnxruntime::CUDAProviderFactory>(device_id, cuda_mem_limit, arena_extend_strategy);
return std::make_shared<onnxruntime::CUDAProviderFactory>(device_id, cuda_mem_limit, arena_extend_strategy, cudnn_conv_algo_search);
}
} // namespace onnxruntime
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id) {
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id){
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_CUDA(static_cast<OrtDevice::DeviceId>(device_id)));
return nullptr;
}
ORT_API_STATUS_IMPL(OrtApis::OrtSessionOptionsAppendExecutionProvider_CUDA,
_In_ OrtSessionOptions* options, _In_ OrtCUDAProviderOptions* cuda_options) {
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_CUDA(static_cast<OrtDevice::DeviceId>(cuda_options->device_id),
cuda_options->cudnn_conv_algo_search, cuda_options->cuda_mem_limit,
static_cast<onnxruntime::ArenaExtendStrategy>(cuda_options->arena_extend_strategy)));
return nullptr;
}

View file

@ -1,25 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
/**
* Configuration information for a cuda provider.
*
* Note: This struct is currently for internal use for Python API,
* not for C/C++/C#...APIs.
*/
struct CudaProviderOptions {
// use cuda device with id=0 as default device.
OrtDevice::DeviceId device_id = 0;
// set default cuda memory limitation to maximum finite value of size_t.
size_t cuda_mem_limit = std::numeric_limits<size_t>::max();
// set default area extend strategy to KNextPowerOfTwo.
onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo;
};
} // namespace onnxruntime

View file

@ -211,20 +211,57 @@ Status Conv<T>::ComputeInternal(OpKernelContext* context) const {
cudnnConvolutionFwdAlgoPerf_t perf;
int algo_count = 1;
CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx(
CudnnHandle(),
s_.x_tensor,
x_data,
s_.filter_desc,
w_data,
s_.conv_desc,
s_.y_tensor,
y_data,
1,
&algo_count,
&perf,
algo_search_workspace.get(),
AlgoSearchWorkspaceSize));
const CUDAExecutionProvider* cuda_ep = static_cast<const CUDAExecutionProvider*>(this->Info().GetExecutionProvider());
int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo();
ORT_ENFORCE(cudnn_conv_algo > -1 && cudnn_conv_algo < 3, "cudnn_conv_algo should be 0, 1 or 2, but got ", cudnn_conv_algo);
switch (cudnn_conv_algo) {
case 0:
CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx(
CudnnHandle(),
s_.x_tensor,
x_data,
s_.filter_desc,
w_data,
s_.conv_desc,
s_.y_tensor,
y_data,
1,
&algo_count,
&perf,
algo_search_workspace.get(),
AlgoSearchWorkspaceSize));
break;
case 1:
CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(
CudnnHandle(),
s_.x_tensor,
s_.filter_desc,
s_.conv_desc,
s_.y_tensor,
1,
&algo_count,
&perf));
break;
default:
perf.algo = kDefaultConvAlgo;
CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
CudnnHandle(),
s_.x_tensor,
s_.filter_desc,
s_.conv_desc,
s_.y_tensor,
perf.algo,
&perf.memory));
if (std::is_same<T, MLFloat16>::value) {
perf.mathType = CUDNN_TENSOR_OP_MATH;
}
else {
perf.mathType = CUDNN_DEFAULT_MATH;
}
}
s_.cached_benchmark_results.insert(x_dims_cudnn, {perf.algo, perf.memory, perf.mathType});
}

View file

@ -160,8 +160,8 @@ class Conv : public CudaKernel {
private:
ConvAttributes conv_attrs_;
mutable CudnnConvState<cudnnConvolutionFwdAlgoPerf_t> s_;
constexpr static auto kDefaultConvAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
};
} // namespace cuda

View file

@ -33,6 +33,9 @@
#include "abi_session_options_impl.h"
#include "core/framework/TensorSeq.h"
#include "core/platform/ort_mutex.h"
#ifdef USE_CUDA
#include "core/providers/cuda/cuda_provider_factory.h"
#endif
using namespace onnxruntime::logging;
using onnxruntime::BFloat16;
@ -2028,6 +2031,9 @@ static constexpr OrtApi ort_api_1_to_6 = {
// Version 6 - In development, feel free to add/remove/rearrange here
&OrtApis::AddInitializer,
&OrtApis::CreateEnvWithCustomLoggerAndGlobalThreadPools,
#ifdef USE_CUDA
&OrtApis::OrtSessionOptionsAppendExecutionProvider_CUDA,
#endif
};
// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)

View file

@ -237,9 +237,13 @@ ORT_API_STATUS_IMPL(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const
ORT_API_STATUS_IMPL(SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection);
ORT_API_STATUS_IMPL(SessionGetProfilingStartTimeNs, _In_ const OrtSession* sess, _Outptr_ uint64_t* out);
ORT_API_STATUS_IMPL(SetGlobalIntraOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int intra_op_num_threads);
ORT_API_STATUS_IMPL(SetGlobalInterOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int inter_op_num_threads);
ORT_API_STATUS_IMPL(SetGlobalSpinControl, _Inout_ OrtThreadingOptions* tp_options, int allow_spinning);
ORT_API_STATUS_IMPL(AddInitializer, _Inout_ OrtSessionOptions* options, _In_ const char* name,
_In_ const OrtValue* val);
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CUDA,
_In_ OrtSessionOptions* options, _In_ OrtCUDAProviderOptions* cuda_options);
} // namespace OrtApis

View file

@ -138,8 +138,9 @@ struct OrtStatus {
#ifdef USE_CUDA
#include "core/providers/cuda/cuda_provider_factory.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/providers/cuda/cuda_provider_options.h"
#include "core/providers/cuda/cuda_execution_provider.h"
OrtDevice::DeviceId cuda_device_id = 0;
OrtCudnnConvAlgoSearch cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::EXHAUSTIVE;
size_t cuda_mem_limit = std::numeric_limits<size_t>::max();
onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo;
#endif
@ -178,6 +179,7 @@ std::string nuphar_settings;
namespace onnxruntime {
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CPU(int use_arena);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CUDA(OrtDevice::DeviceId device_id,
OrtCudnnConvAlgoSearch cudnn_conv_algo_search,
size_t cuda_mem_limit,
onnxruntime::ArenaExtendStrategy arena_extend_strategy);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(int device_id);
@ -446,7 +448,7 @@ bool IsCudaDeviceIdValid(InferenceSession* sess, int id) {
return true;
}
void UpdateCudaProviderOptions(InferenceSession* sess, onnxruntime::CudaProviderOptions& options,
void UpdateCudaProviderOptions(InferenceSession* sess, onnxruntime::CUDAExecutionProviderInfo& options,
std::unordered_map<std::string, std::string> options_map) {
std::unordered_map<std::string, std::string>::iterator it;
@ -537,16 +539,18 @@ void RegisterExecutionProviders(InferenceSession* sess, const std::vector<std::s
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
onnxruntime::CudaProviderOptions cuda_provider_options;
onnxruntime::CUDAExecutionProviderInfo cuda_provider_options;
UpdateCudaProviderOptions(sess, cuda_provider_options, it->second);
RegisterExecutionProvider(
sess, *onnxruntime::CreateExecutionProviderFactory_CUDA(cuda_provider_options.device_id,
cuda_provider_options.cudnn_conv_algo,
cuda_provider_options.cuda_mem_limit,
cuda_provider_options.arena_extend_strategy));
} else {
RegisterExecutionProvider(
sess, *onnxruntime::CreateExecutionProviderFactory_CUDA(cuda_device_id,
cudnn_conv_algo_search,
cuda_mem_limit,
arena_extend_strategy));
}
@ -749,7 +753,7 @@ void addGlobalMethods(py::module& m, const Environment& env) {
std::vector<std::shared_ptr<onnxruntime::IExecutionProviderFactory>> factories = {
onnxruntime::CreateExecutionProviderFactory_CPU(0),
#ifdef USE_CUDA
onnxruntime::CreateExecutionProviderFactory_CUDA(cuda_device_id, cuda_mem_limit, arena_extend_strategy),
onnxruntime::CreateExecutionProviderFactory_CUDA(cuda_device_id, cudnn_conv_algo_search, cuda_mem_limit, arena_extend_strategy),
#endif
#ifdef USE_DNNL
onnxruntime::CreateExecutionProviderFactory_Dnnl(1),
@ -802,6 +806,7 @@ void addGlobalMethods(py::module& m, const Environment& env) {
*
*/
m.def("set_cuda_device_id", [](const int id) { cuda_device_id = static_cast<OrtDevice::DeviceId>(id); });
m.def("set_cudnn_conv_algo_search", [](const OrtCudnnConvAlgoSearch algo) { cudnn_conv_algo_search = algo; });
m.def("set_cuda_mem_limit", [](const int64_t limit) { cuda_mem_limit = static_cast<size_t>(limit); });
m.def("set_arena_extend_strategy", [](const onnxruntime::ArenaExtendStrategy strategy) { arena_extend_strategy = strategy; });
#endif

View file

@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <core/session/onnxruntime_cxx_api.h>
#include <set>
#include <iostream>
#include <fstream>
@ -332,7 +331,13 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
}
if (enable_cuda) {
#ifdef USE_CUDA
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(sf, device_id));
OrtCUDAProviderOptions cuda_options{
0,
OrtCudnnConvAlgoSearch::EXHAUSTIVE,
std::numeric_limits<size_t>::max(),
0,
};
Ort::ThrowOnError(sf.OrtSessionOptionsAppendExecutionProvider_CUDA(sf, &cuda_options));
#else
fprintf(stderr, "CUDA is not supported in this build");
return -1;

View file

@ -48,12 +48,13 @@ namespace perftest {
"\t-o [optimization level]: Default is 1. Valid values are 0 (disable), 1 (basic), 2 (extended), 99 (all).\n"
"\t\tPlease see onnxruntime_c_api.h (enum GraphOptimizationLevel) for the full list of all optimization levels. \n"
"\t-u [optimized_model_path]: Specify the optimized model path for saving.\n"
"\t-d [cudnn_conv_algorithm]: Specify CUDNN convolution algothrithms: 0(benchmark), 1(heuristic), 2(default). \n"
"\t-h: help\n");
}
/*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) {
int ch;
while ((ch = getopt(argc, argv, ORT_TSTR("b:m:e:r:t:p:x:y:c:o:u:AMPIvhs"))) != -1) {
while ((ch = getopt(argc, argv, ORT_TSTR("b:m:e:r:t:p:x:y:c:d:o:u:AMPIvhs"))) != -1) {
switch (ch) {
case 'm':
if (!CompareCString(optarg, ORT_TSTR("duration"))) {
@ -177,6 +178,9 @@ namespace perftest {
case 'I':
test_config.run_config.generate_model_input_binding = true;
break;
case 'd':
test_config.run_config.cudnn_conv_algo = static_cast<int>(OrtStrtol<PATH_CHAR_TYPE>(optarg, nullptr));
break;
case '?':
case 'h':
default:

View file

@ -46,7 +46,13 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
#endif
} else if (provider_name == onnxruntime::kCudaExecutionProvider) {
#ifdef USE_CUDA
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
OrtCUDAProviderOptions cuda_options{
0,
OrtCudnnConvAlgoSearch::EXHAUSTIVE,
std::numeric_limits<size_t>::max(),
0,
};
Ort::ThrowOnError(session_options.OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, &cuda_options));
#else
ORT_THROW("CUDA is not supported in this build\n");
#endif

View file

@ -50,6 +50,7 @@ struct RunConfig {
int inter_op_num_threads{0};
GraphOptimizationLevel optimization_level{ORT_ENABLE_ALL};
std::basic_string<ORTCHAR_T> optimized_model_path;
int cudnn_conv_algo{0};
};
struct PerformanceTestConfig {

View file

@ -11,6 +11,7 @@ namespace onnxruntime {
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CPU(int use_arena);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CUDA(OrtDevice::DeviceId device_id,
OrtCudnnConvAlgoSearch cudnn_conv_algo = OrtCudnnConvAlgoSearch::EXHAUSTIVE,
size_t cuda_mem_limit = std::numeric_limits<size_t>::max(),
ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Dnnl(int use_arena);

View file

@ -20,6 +20,7 @@
namespace onnxruntime {
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CUDA(OrtDevice::DeviceId device_id,
OrtCudnnConvAlgoSearch cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::EXHAUSTIVE,
size_t cuda_mem_limit = std::numeric_limits<size_t>::max(),
onnxruntime::ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo);
}
@ -568,7 +569,8 @@ void setup_training_params(BertParameters& params) {
size_t cuda_mem_limit = std::numeric_limits<size_t>::max();
if (params.cuda_mem_limit_in_gb > 0)
cuda_mem_limit = static_cast<size_t>(params.cuda_mem_limit_in_gb * 1024 * 1024 * 1024);
params.providers.emplace(kCudaExecutionProvider, CreateExecutionProviderFactory_CUDA(device_id, cuda_mem_limit));
params.providers.emplace(kCudaExecutionProvider, CreateExecutionProviderFactory_CUDA(device_id, OrtCudnnConvAlgoSearch::EXHAUSTIVE,
cuda_mem_limit));
params.input_allocator = std::make_shared<CUDAPinnedAllocator>(device_id, CUDA_PINNED);
#endif

View file

@ -19,6 +19,7 @@
namespace onnxruntime {
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CUDA(OrtDevice::DeviceId device_id,
OrtCudnnConvAlgoSearch cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::EXHAUSTIVE,
size_t cuda_mem_limit = std::numeric_limits<size_t>::max(),
onnxruntime::ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo);
}

View file

@ -19,6 +19,7 @@
namespace onnxruntime {
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CUDA(OrtDevice::DeviceId device_id,
OrtCudnnConvAlgoSearch cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::EXHAUSTIVE,
size_t cuda_mem_limit = std::numeric_limits<size_t>::max(),
onnxruntime::ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo);
}