diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator/torch_gpu_allocator.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator/torch_gpu_allocator.cc index 3799eb09b4..9f8edb4b75 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator/torch_gpu_allocator.cc +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator/torch_gpu_allocator.cc @@ -1,19 +1,30 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include #include +#include + +void* delegate_raw_alloc(size_t nbytes) { + auto allocator = c10::___gpu_identifier___::___gpu_allocator_header___::get(); + return allocator->raw_allocate(nbytes); +} + +void delegate_raw_delete(void* ptr) { + auto allocator = c10::___gpu_identifier___::___gpu_allocator_header___::get(); + allocator->raw_deallocate(ptr); +} size_t gpu_caching_allocator_raw_alloc_address() { - return reinterpret_cast(&c10::___gpu_identifier___::___gpu_allocator_header___::raw_alloc); + return reinterpret_cast(&delegate_raw_alloc); } size_t gpu_caching_allocator_raw_delete_address() { - return reinterpret_cast(&c10::___gpu_identifier___::___gpu_allocator_header___::raw_delete); + return reinterpret_cast(&delegate_raw_delete); } size_t gpu_caching_allocator_empty_cache_address() { - return reinterpret_cast(&c10::___gpu_identifier___::___gpu_allocator_header___::emptyCache); + // This is useful only if PYTORCH_NO_CUDA_MEMORY_CACHING=1 is not set. + return reinterpret_cast(&c10::cuda::CUDACachingAllocator::emptyCache); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {