Plug n Allocate with external CUDA allocator via PyBind. (#6679)

This commit is contained in:
M. Zeeshan Siddiqui 2021-02-17 18:59:38 -08:00 committed by GitHub
parent dd8ef4409a
commit e44ac6524f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 848 additions and 764 deletions

View file

@ -142,7 +142,6 @@ class ProviderOptionsParser {
for (const auto& option : options) {
const auto& name = option.first;
const auto& value_str = option.second;
const auto value_parser_it = value_parsers_.find(name);
ORT_RETURN_IF(
value_parser_it == value_parsers_.end(),

View file

@ -64,6 +64,23 @@ void CUDAAllocator::Free(void* p) {
cudaFree(p); // do not throw error since it's OK for cudaFree to fail during shutdown
}
void* CUDAExternalAllocator::Alloc(size_t size) {
void* p = nullptr;
if (size > 0) {
p = alloc_(size);
// review(codemzs): ORT_ENFORCE does not seem appropiate.
ORT_ENFORCE(p != nullptr);
}
return p;
}
void CUDAExternalAllocator::Free(void* p) {
free_(p);
}
FencePtr CUDAAllocator::CreateFence(const SessionState* session_state) {
return std::make_shared<CUDAFence>(GetGPUDataTransfer(session_state));
}

View file

@ -23,6 +23,25 @@ class CUDAAllocator : public IAllocator {
void SetDevice(bool throw_when_fail) const;
};
class CUDAExternalAllocator : public CUDAAllocator {
typedef void* (*ExternalAlloc)(size_t size);
typedef void (*ExternalFree)(void* p);
public:
CUDAExternalAllocator(OrtDevice::DeviceId device_id, const char* name, void* alloc, void* free)
: CUDAAllocator(device_id, name) {
alloc_ = reinterpret_cast<ExternalAlloc>(alloc);
free_ = reinterpret_cast<ExternalFree>(free);
}
void* Alloc(size_t size) override;
void Free(void* p) override;
private:
ExternalAlloc alloc_;
ExternalFree free_;
};
//TODO: add a default constructor
class CUDAPinnedAllocator : public IAllocator {
public:

File diff suppressed because it is too large Load diff

View file

@ -80,12 +80,15 @@ class CUDAExecutionProvider : public IExecutionProvider {
}
void RegisterAllocator(std::shared_ptr<AllocatorManager> allocator_manager) override;
static AllocatorPtr CreateCudaAllocator(OrtDevice::DeviceId device_id, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy,
CUDAExecutionProviderExternalAllocatorInfo external_alloc_info);
private:
CUDAExecutionProviderInfo info_;
cudaDeviceProp device_prop_;
bool external_stream_ = false;
cudaStream_t stream_ = nullptr;
struct DeferredReleaseCPUPtrs {
bool recorded = false;
std::vector<void*> cpu_ptrs;
@ -96,7 +99,8 @@ class CUDAExecutionProvider : public IExecutionProvider {
class PerThreadContext final {
public:
PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy);
PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy,
CUDAExecutionProviderExternalAllocatorInfo external_alloc_info);
~PerThreadContext();
cublasHandle_t CublasHandle() const {
@ -129,7 +133,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
}
return reinterpret_cast<const T*>(constant_ones_half_->GetBuffer(stream_, count));
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
} else if (std::is_same<T, nv_bfloat16>::value) {
} else if (std::is_same<T, nv_bfloat16>::value) {
if (!constant_ones_bfloat16_) {
constant_ones_bfloat16_ = cuda::CreateConstantOnes<nv_bfloat16>();
}

View file

@ -16,6 +16,8 @@ constexpr const char* kMemLimit = "cuda_mem_limit";
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
constexpr const char* kCudnnConvAlgoSearch = "cudnn_conv_algo_search";
constexpr const char* kDoCopyInDefaultStream = "do_copy_in_default_stream";
constexpr const char* kcudaExternalAlloc = "cuda_external_alloc";
constexpr const char* kcudaExternalFree = "cuda_external_free";
} // namespace provider_option_names
} // namespace cuda
@ -51,6 +53,22 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P
", must be between 0 (inclusive) and ", num_devices, " (exclusive).");
return Status::OK();
})
.AddValueParser(
cuda::provider_option_names::kcudaExternalAlloc,
[&info](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
info.external_allocator_info.alloc = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddValueParser(
cuda::provider_option_names::kcudaExternalFree,
[&info](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
info.external_allocator_info.free = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddAssignmentToReference(cuda::provider_option_names::kMemLimit, info.cuda_mem_limit)
.AddAssignmentToEnumReference(
cuda::provider_option_names::kArenaExtendStrategy,
@ -68,6 +86,8 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution
const ProviderOptions options{
{cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.cuda_mem_limit)},
{cuda::provider_option_names::kcudaExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.alloc))},
{cuda::provider_option_names::kcudaExternalFree, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.free))},
{cuda::provider_option_names::kArenaExtendStrategy,
EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
{cuda::provider_option_names::kCudnnConvAlgoSearch,

View file

@ -12,6 +12,20 @@
namespace onnxruntime {
// Information needed to construct CUDA execution providers.
struct CUDAExecutionProviderExternalAllocatorInfo {
void* alloc{nullptr};
void* free{nullptr};
CUDAExecutionProviderExternalAllocatorInfo() {
alloc = nullptr;
free = nullptr;
}
bool UseExternalAllocator() {
return (alloc != nullptr) && (free != nullptr);
}
};
struct CUDAExecutionProviderInfo {
OrtDevice::DeviceId device_id{0};
size_t cuda_mem_limit{std::numeric_limits<size_t>::max()};
@ -20,6 +34,7 @@ struct CUDAExecutionProviderInfo {
bool do_copy_in_default_stream{true};
bool has_user_compute_stream{false};
void* user_compute_stream{nullptr};
CUDAExecutionProviderExternalAllocatorInfo external_allocator_info{};
static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const CUDAExecutionProviderInfo& info);

View file

@ -138,11 +138,13 @@ struct OrtStatus {
#if defined(USE_CUDA) || defined(USE_ROCM)
#ifdef USE_CUDA
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/providers/cuda/cuda_execution_provider.h"
#include "core/providers/cuda/cuda_allocator.h"
// TODO remove deprecated global config
OrtCudnnConvAlgoSearch cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::EXHAUSTIVE;
// TODO remove deprecated global config
bool do_copy_in_default_stream = true;
onnxruntime::CUDAExecutionProviderExternalAllocatorInfo external_allocator_info{};
#endif
// TODO remove deprecated global config
OrtDevice::DeviceId cuda_device_id = 0;
@ -452,20 +454,7 @@ static AllocatorPtr GetCudaAllocator(OrtDevice::DeviceId id) {
static std::unordered_map<OrtDevice::DeviceId, AllocatorPtr> id_to_allocator_map;
if (id_to_allocator_map.find(id) == id_to_allocator_map.end()) {
// Use arena-based allocator
AllocatorCreationInfo default_memory_info(
[](OrtDevice::DeviceId id) {
return onnxruntime::make_unique<CUDAAllocator>(id, CUDA);
},
id,
true,
{cuda_mem_limit,
static_cast<int>(arena_extend_strategy),
-1, -1});
auto allocator = CreateAllocator(default_memory_info);
id_to_allocator_map.insert({id, allocator});
id_to_allocator_map.insert({id, CUDAExecutionProvider::CreateCudaAllocator(id, cuda_mem_limit, arena_extend_strategy, external_allocator_info)});
}
return id_to_allocator_map[id];
@ -521,9 +510,14 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
info.arena_extend_strategy = arena_extend_strategy;
info.cudnn_conv_algo_search = cudnn_conv_algo_search;
info.do_copy_in_default_stream = do_copy_in_default_stream;
info.external_allocator_info = external_allocator_info;
return info;
}();
// This variable is never initialized because the APIs by which is it should be initialized are deprecated, however they still
// exist are are in-use. Neverthless, it is used to return CUDAAllocator, hence we must try to initialize it here if we can
// since FromProviderOptions might contain external CUDA allocator.
external_allocator_info = info.external_allocator_info;
RegisterExecutionProvider(
sess, *onnxruntime::CreateExecutionProviderFactory_CUDA(info));
#endif
@ -826,6 +820,7 @@ void addGlobalMethods(py::module& m, Environment& env) {
info.arena_extend_strategy = arena_extend_strategy;
info.cudnn_conv_algo_search = cudnn_conv_algo_search;
info.do_copy_in_default_stream = do_copy_in_default_stream;
info.external_allocator_info = external_allocator_info;
return info;
}()),
#endif

View file

@ -65,7 +65,6 @@ class TestInferenceSession(unittest.TestCase):
import sys
import ctypes
CUDA_SUCCESS = 0
def runBaseTest1():
sess = onnxrt.InferenceSession(get_name("mul_1.onnx"))
self.assertTrue('CUDAExecutionProvider' in sess.get_providers())
@ -122,6 +121,12 @@ class TestInferenceSession(unittest.TestCase):
test_get_and_set_option_with_values(
'do_copy_in_default_stream', [0, 1])
option['cuda_external_alloc'] = '0'
option['cuda_external_free'] = '0'
sess.set_providers(['CUDAExecutionProvider'], [option])
options = sess.get_provider_options()
self.assertEqual(options['CUDAExecutionProvider']['cuda_external_alloc'], '0')
self.assertEqual(options['CUDAExecutionProvider']['cuda_external_free'], '0')
#
# Note: Tests that throw an exception leave an empty session due to how set_providers currently works,
# so run them last. Each set_providers call will attempt to re-create a session, so it's