diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index 0d4784614c..d24357c288 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -10,6 +10,7 @@ import numpy as np from inspect import signature from torch.utils.dlpack import from_dlpack +from torch.utils.cpp_extension import load_inline from collections import abc # Needed to re-implement PyTorch's cpu,cuda,to methods @@ -157,6 +158,22 @@ def _ort_output_to_torch_tensor(ort_output): tensor = from_dlpack(ort_output.to_dlpack()) return tensor.to(torch.bool) if tensor.dtype == torch.uint8 else tensor +def _load_torch_allocator_cpp_extension(): + torch_cuda_allocator_addresses_cpp_source = """ + #include + #include + size_t cuda_caching_allocator_raw_alloc_address() { + return reinterpret_cast(&c10::cuda::CUDACachingAllocator::raw_alloc); + } + size_t cuda_caching_allocator_raw_delete_address() { + return reinterpret_cast(&c10::cuda::CUDACachingAllocator::raw_delete); + } + """ + + return load_inline(name='inline_extension', cpp_sources=[torch_cuda_allocator_addresses_cpp_source], + functions=['cuda_caching_allocator_raw_alloc_address', 'cuda_caching_allocator_raw_delete_address'], + verbose=True, with_cuda=True) + class ORTModule(torch.nn.Module): def __init__(self, module): @@ -194,6 +211,13 @@ class ORTModule(torch.nn.Module): self._save_onnx = False self._save_onnx_prefix = '' + # CPP extension to get torch CUDA allocator's alloc and free function addresses + self._use_external_cuda_allocator = True + if self._use_external_cuda_allocator: + self._torch_cuda_allocator = _load_torch_allocator_cpp_extension() + self._torch_alloc = self._torch_cuda_allocator.cuda_caching_allocator_raw_alloc_address() + self._torch_free = self._torch_cuda_allocator.cuda_caching_allocator_raw_delete_address() + def _initialize_module_gradient_graph_builder(self): # TODO: PyTorch exporter bug: changes the initializer order initializer_names = [p[0] for p in self._original_module.named_parameters()] @@ -219,7 +243,10 @@ class ORTModule(torch.nn.Module): if self._device.type == 'cuda': # Configure the InferenceSessions to use the specific GPU on which the model is placed. providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] - provider_options = [{"device_id": str(self._device.index)}, {}] + if self._use_external_cuda_allocator: + provider_options = [{"device_id": str(self._device.index), "cuda_external_alloc": str(self._torch_alloc), "cuda_external_free": str(self._torch_free)}, {}] + else: + provider_options = [{"device_id": str(self._device.index)}, {}] elif self._device.type == 'cpu': providers = ["CPUExecutionProvider"] provider_options = [{}] @@ -461,8 +488,4 @@ class ORTModule(torch.nn.Module): except RuntimeError as e: raise RuntimeError('There was an error while exporting the PyTorch model to ONNX: {}'.format(e)) - # TODO: this step might not be needed when we use the torch external allocator - # clear cache after model export - torch.cuda.empty_cache() - return onnx.load_model_from_string(f.getvalue()) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 850097df8a..5d786e6aa8 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -337,22 +337,18 @@ def test_gpu_reserved_memory_with_torch_no_grad(): model_with_no_grad(x, y, None, None, None, None, z) mem_reserved_after_export_with_torch_no_grad = torch.cuda.memory_reserved(device) del model_with_no_grad - torch.cuda.empty_cache() mem_reserved_after_cache_empty = torch.cuda.memory_reserved(device) - assert mem_reserved_before_export == mem_reserved_after_cache_empty - # Create another model and get the memory_reserved when torch.no_grad and torch.cuda.empty_cache - # has not been enabled after export + # Create another model and get the memory_reserved when torch.no_grad has not been enabled after export. model_without_no_grad = _get_bert_for_sequence_classification_model(device) model_without_no_grad = ORTModule(model_without_no_grad) mem_reserved_after_export_without_torch_no_grad = 0 - with patch('torch.no_grad'), patch('torch.cuda.empty_cache'): + with patch('torch.no_grad'): model_without_no_grad(x, y, None, None, None, None, z) mem_reserved_after_export_without_torch_no_grad = torch.cuda.memory_reserved(device) assert mem_reserved_after_export_with_torch_no_grad < mem_reserved_after_export_without_torch_no_grad - assert mem_reserved_before_export == mem_reserved_after_export_with_torch_no_grad @pytest.mark.parametrize("return_type, device", [ (dict, 'cpu'),