Use the same allocator following Pytorch (#9697)

* Use the same allocator following Pytorch

* Polish

* Fix AMD build
This commit is contained in:
Wei-Sheng Chin 2021-11-09 11:25:16 -08:00 committed by GitHub
parent 229c9a4e1c
commit bdc279a7ed
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,19 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <torch/extension.h>
#include <c10/___gpu_identifier___/___gpu_allocator_header___.h>
#include <torch/extension.h>
void* delegate_raw_alloc(size_t nbytes) {
auto allocator = c10::___gpu_identifier___::___gpu_allocator_header___::get();
return allocator->raw_allocate(nbytes);
}
void delegate_raw_delete(void* ptr) {
auto allocator = c10::___gpu_identifier___::___gpu_allocator_header___::get();
allocator->raw_deallocate(ptr);
}
size_t gpu_caching_allocator_raw_alloc_address() {
return reinterpret_cast<size_t>(&c10::___gpu_identifier___::___gpu_allocator_header___::raw_alloc);
return reinterpret_cast<size_t>(&delegate_raw_alloc);
}
size_t gpu_caching_allocator_raw_delete_address() {
return reinterpret_cast<size_t>(&c10::___gpu_identifier___::___gpu_allocator_header___::raw_delete);
return reinterpret_cast<size_t>(&delegate_raw_delete);
}
size_t gpu_caching_allocator_empty_cache_address() {
return reinterpret_cast<size_t>(&c10::___gpu_identifier___::___gpu_allocator_header___::emptyCache);
// This is useful only if PYTORCH_NO_CUDA_MEMORY_CACHING=1 is not set.
return reinterpret_cast<size_t>(&c10::cuda::CUDACachingAllocator::emptyCache);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {