diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index eb8d1ec3a9..b617d57446 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -16,7 +16,6 @@ from .debug_options import DebugOptions from ._fallback import (ORTModuleFallbackException, _FallbackPolicy, _FallbackManager) -from .torch_cpp_extensions.cpu.torch_interop_utils import clear_all_grad_fns from onnxruntime.capi import _pybind_state as C from onnxruntime.capi.onnxruntime_inference_collection import get_ort_device_type @@ -40,10 +39,6 @@ class TrainingManager(GraphExecutionManager): def execution_session_run_forward(execution_session, onnx_model, device, gradient_accumulation_manager, *inputs): """Runs the forward graph on execution_session with given model inputs and device""" - # Clear all gradient functions, to avoid a deadlock issue. - # Check the called function for more detailed comments. - clear_all_grad_fns() - # TODO: Try to reuse the output buffers as some of the output tensors are same sizes, # especially the backward graph outputs. # REVIEW(codemzs): Consolidate Training Agent with InferenceAgent on C++ side to not diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc index bc930899f4..a8445bf64f 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc @@ -114,16 +114,27 @@ void unregister_grad_fn(size_t ctx_address) PyNodeSharedPointerPool::GetInstance().UnRegisterGradFunc(ctx_address); } -// Supposed to be cleared on python program exit or before every forward run to resolve following issues: -// 1. When training program exits, PyNodeSharedPointerPool destructor is called, if grad_fns_ is not empty, +// Supposed to be cleared on python program exit to resolve following issue: +// When training program exits, PyNodeSharedPointerPool destructor is called, if grad_fns_ is not empty, // PyNode::release_variables() will be called. // (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L168) // The other hand, there is known issue when acquiring GIL in pybind11 destructors, there will be probabbly deadlock issue. // (https://github.com/pybind/pybind11/issues/1446) // The resolution here, we remove all maintained states before program exits. -// 2. When forward functions is called repeated without corresponding backward calls, grad functions keeps accumulating without releasing -// (happening in backward) -void clear_all_grad_fns(){ + +// A known existing issue: when forward functions is called repeatedly without corresponding backward calls, +// grad functions keeps accumulating without releasing, there might be memory (bound to those gradient function) leaks. +// Ideally this usually won't happen in real training case, so it should be fine. + +// We CANNOT explictly clear grad functions before each forward pass to mitigate the known issue above. +// For example: +// loss1 = forward_run(inputs1) +// loss2 = forward_run(inputs2) +// loss = loss1 + loss2 +// loss.backward() +// If we clear grad functions in the beggining of the second `forward_run`, when `loss.backward()` runs, +// the backward path of `loss1` will fail to run PythonOpGrad ops (if there is any). +void clear_all_grad_fns() { PyNodeSharedPointerPool::GetInstance().ClearAll(); } diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py index b190a219e6..aea0ed2fef 100644 --- a/orttraining/orttraining/test/python/_test_helpers.py +++ b/orttraining/orttraining/test/python/_test_helpers.py @@ -215,88 +215,84 @@ def assert_values_are_close(input, other, rtol=1e-05, atol=1e-06): def enable_custom_autograd_function(module): enable_custom_autograd_support() -def run_with_pytorch_on_device(device, model, input_list, label_input, is_eval_mode=False): - with torch.no_grad(): - model = copy.deepcopy(model).to(device) +def _run_model_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False): if is_eval_mode: model.eval() else: model.train() - with torch.no_grad(): - inputs_on_device = [input_.to(device) for input_ in input_list] - for i, val in enumerate(input_list): - if val.requires_grad: - inputs_on_device[i].requires_grad_() - target = label_input.to(device) + def generate_inputs(input_list_, label_input_): + with torch.no_grad(): + inputs_on_device = [input_.to(device) for input_ in input_list_] + for i, val in enumerate(input_list_): + if val.requires_grad: + inputs_on_device[i].requires_grad_() + with torch.no_grad(): + target = label_input_.to(device) + return inputs_on_device, target - output = model(*inputs_on_device) - forward_outputs = [output] + inputs_on_device1, target1 = generate_inputs(input_list, label_input) + if run_forward_twice is True: + inputs_on_device2, target2 = generate_inputs(input_list, label_input) + + output1 = model(*inputs_on_device1) + if run_forward_twice is True: + output2 = model(*inputs_on_device2) + + forward_outputs = [output1] grad_outputs = [] if not is_eval_mode: criterion = torch.nn.MSELoss() - loss = criterion(output, target) + loss = criterion(output1, target1) + + if run_forward_twice is True: + loss += criterion(output2, target2) + loss.backward() for name, param in model.named_parameters(): if param.requires_grad: grad_outputs.append(param.grad) return forward_outputs, grad_outputs -def run_with_ort_on_device(device, model, input_list, label_input, is_eval_mode=False): +def run_with_pytorch_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False): + with torch.no_grad(): + model = copy.deepcopy(model).to(device) + + return _run_model_on_device(device, model, input_list, label_input, is_eval_mode, run_forward_twice) + +def run_with_ort_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False): with torch.no_grad(): model = copy.deepcopy(model) model.to(device) enable_custom_autograd_function(model) model = ORTModule(model) - if is_eval_mode: - model.eval() - else: - model.train() - with torch.no_grad(): - inputs_on_device = [input_.to(device) for input_ in input_list] - for i, val in enumerate(input_list): - if val.requires_grad: - inputs_on_device[i].requires_grad_() - - target = label_input.to(device) - output = model(*inputs_on_device) - forward_outputs = [output] - grad_outputs = [] - - if not is_eval_mode: - criterion = torch.nn.MSELoss() - loss = criterion(output, target) - loss.backward() - for name, param in model.named_parameters(): - if param.requires_grad: - grad_outputs.append(param.grad) - return forward_outputs, grad_outputs + return _run_model_on_device(device, model, input_list, label_input, is_eval_mode, run_forward_twice) def compare_tensor_list(val_list_a, val_list_b): for val_a, val_b in zip(val_list_a, val_list_b): assert_values_are_close(val_a, val_b, atol=1e-7, rtol=1e-6) def run_training_test_and_compare(pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, - ignore_grad_compare=False, expected_outputs=[], expected_grads=[]): + run_forward_twice=False, ignore_grad_compare=False, expected_outputs=[], expected_grads=[]): cpu = torch.device("cpu") def cpu_barrier_func(): pass run_training_test_on_device_and_compare( cpu, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cpu_barrier_func, - ignore_grad_compare, expected_outputs, expected_grads) + run_forward_twice, ignore_grad_compare, expected_outputs, expected_grads) def cuda_barrier_func(): torch.cuda.synchronize() cuda = torch.device('cuda:0') run_training_test_on_device_and_compare( cuda, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cuda_barrier_func, - ignore_grad_compare, expected_outputs, expected_grads) + run_forward_twice, ignore_grad_compare, expected_outputs, expected_grads) def run_training_test_on_device_and_compare(device, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, barrier_func, - ignore_grad_compare=False, expected_outputs=[], expected_grads=[]): + run_forward_twice=False, ignore_grad_compare=False, expected_outputs=[], expected_grads=[]): repeats = 16 for i in range(repeats): m = pt_model_builder_func() @@ -307,11 +303,11 @@ def run_training_test_on_device_and_compare(device, pt_model_builder_func, pt_mo x_ort = copy.deepcopy(x) outputs, grads = run_with_pytorch_on_device( - device, m, [x], pt_model_label_input) + device, m, [x], pt_model_label_input, run_forward_twice=run_forward_twice) barrier_func() outputs_ort, grads_ort = run_with_ort_on_device( - device, m_ort, [x_ort], pt_model_label_input) + device, m_ort, [x_ort], pt_model_label_input, run_forward_twice=run_forward_twice) barrier_func() val_list_a = [o.detach().cpu() for o in outputs if o is not None] @@ -330,14 +326,16 @@ def run_training_test_on_device_and_compare(device, pt_model_builder_func, pt_mo if len(expected_grads) > 0: compare_tensor_list(val_list_a, expected_grads) -def run_evaluate_test_and_compare(pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input): +def run_evaluate_test_and_compare(pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, + run_forward_twice=False): cpu = torch.device("cpu") def cpu_barrier_func(): pass run_evaluate_test_on_device_and_compare( - cpu, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cpu_barrier_func) + cpu, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, + cpu_barrier_func, run_forward_twice=run_forward_twice) def cuda_barrier_func(): torch.cuda.synchronize() @@ -345,9 +343,11 @@ def run_evaluate_test_and_compare(pt_model_builder_func, pt_model_inputs_generat cuda = torch.device('cuda:0') run_evaluate_test_on_device_and_compare( - cuda, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cuda_barrier_func) + cuda, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, + cuda_barrier_func, run_forward_twice=run_forward_twice) -def run_evaluate_test_on_device_and_compare(device, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, barrier_func): +def run_evaluate_test_on_device_and_compare(device, pt_model_builder_func, pt_model_inputs_generator, + pt_model_label_input, barrier_func, run_forward_twice=False): repeats = 16 for i in range(repeats): m = pt_model_builder_func() @@ -357,11 +357,11 @@ def run_evaluate_test_on_device_and_compare(device, pt_model_builder_func, pt_mo x_ort = copy.deepcopy(x) outputs, grads = run_with_pytorch_on_device( - device, m, [x], pt_model_label_input, is_eval_mode=True) + device, m, [x], pt_model_label_input, is_eval_mode=True, run_forward_twice=run_forward_twice) barrier_func() outputs_ort, grads_ort = run_with_ort_on_device( - device, m_ort, [x_ort], pt_model_label_input, is_eval_mode=True) + device, m_ort, [x_ort], pt_model_label_input, is_eval_mode=True, run_forward_twice=run_forward_twice) barrier_func() val_list_a = [o.detach().cpu() for o in outputs if o is not None] diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index 3b2e6bc6a3..a3f118380c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -138,6 +138,61 @@ def test_GeLU_custom_func_rets_not_as_module_output(): run_training_test_and_compare(model_builder, input_generator, label_input) +def test_GeLU_multiple_forward_runs(): + @torch.jit.script + def bias_gelu(bias, y): + x = bias + y + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + + @torch.jit.script + def bias_gelu_backward(g, bias, y): + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff*g + + class GeLUFunction3(torch.autograd.Function): + @staticmethod + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(bias, input) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_backward(grad_output, bias, input) + return tmp, tmp + + class GeLUModel(torch.nn.Module): + def __init__(self, output_size): + super(GeLUModel, self).__init__() + self.relu = GeLUFunction3.apply + self.bias = Parameter(torch.empty( + output_size, + device=torch.cuda.current_device(), + dtype=torch.float)) + + with torch.no_grad(): + self.bias.uniform_() + + def forward(self, model_input): + out = self.relu(model_input, self.bias) + return out + + output_size = 1024 + + def model_builder(): + return GeLUModel(output_size) + + def input_generator(): + return torch.randn(output_size, dtype=torch.float) + + # generate a label that have same shape as forward output. + label_input = torch.ones([output_size]) + + run_training_test_and_compare(model_builder, input_generator, label_input, run_forward_twice=True) + def test_MegatronF(): # MegatronGFunction is tested in distributed test files. class MegatronFFunction(torch.autograd.Function):