don't clear grad_fns & add test (#10671)

This commit is contained in:
pengwa 2022-03-11 14:31:54 +08:00 committed by GitHub
parent 1a62306db7
commit d478a53d43
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 119 additions and 58 deletions

View file

@ -16,7 +16,6 @@ from .debug_options import DebugOptions
from ._fallback import (ORTModuleFallbackException,
_FallbackPolicy,
_FallbackManager)
from .torch_cpp_extensions.cpu.torch_interop_utils import clear_all_grad_fns
from onnxruntime.capi import _pybind_state as C
from onnxruntime.capi.onnxruntime_inference_collection import get_ort_device_type
@ -40,10 +39,6 @@ class TrainingManager(GraphExecutionManager):
def execution_session_run_forward(execution_session, onnx_model, device, gradient_accumulation_manager, *inputs):
"""Runs the forward graph on execution_session with given model inputs and device"""
# Clear all gradient functions, to avoid a deadlock issue.
# Check the called function for more detailed comments.
clear_all_grad_fns()
# TODO: Try to reuse the output buffers as some of the output tensors are same sizes,
# especially the backward graph outputs.
# REVIEW(codemzs): Consolidate Training Agent with InferenceAgent on C++ side to not

View file

@ -114,16 +114,27 @@ void unregister_grad_fn(size_t ctx_address)
PyNodeSharedPointerPool::GetInstance().UnRegisterGradFunc(ctx_address);
}
// Supposed to be cleared on python program exit or before every forward run to resolve following issues:
// 1. When training program exits, PyNodeSharedPointerPool destructor is called, if grad_fns_ is not empty,
// Supposed to be cleared on python program exit to resolve following issue:
// When training program exits, PyNodeSharedPointerPool destructor is called, if grad_fns_ is not empty,
// PyNode::release_variables() will be called.
// (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L168)
// The other hand, there is known issue when acquiring GIL in pybind11 destructors, there will be probabbly deadlock issue.
// (https://github.com/pybind/pybind11/issues/1446)
// The resolution here, we remove all maintained states before program exits.
// 2. When forward functions is called repeated without corresponding backward calls, grad functions keeps accumulating without releasing
// (happening in backward)
void clear_all_grad_fns(){
// A known existing issue: when forward functions is called repeatedly without corresponding backward calls,
// grad functions keeps accumulating without releasing, there might be memory (bound to those gradient function) leaks.
// Ideally this usually won't happen in real training case, so it should be fine.
// We CANNOT explictly clear grad functions before each forward pass to mitigate the known issue above.
// For example:
// loss1 = forward_run(inputs1)
// loss2 = forward_run(inputs2)
// loss = loss1 + loss2
// loss.backward()
// If we clear grad functions in the beggining of the second `forward_run`, when `loss.backward()` runs,
// the backward path of `loss1` will fail to run PythonOpGrad ops (if there is any).
void clear_all_grad_fns() {
PyNodeSharedPointerPool::GetInstance().ClearAll();
}

View file

@ -215,88 +215,84 @@ def assert_values_are_close(input, other, rtol=1e-05, atol=1e-06):
def enable_custom_autograd_function(module):
enable_custom_autograd_support()
def run_with_pytorch_on_device(device, model, input_list, label_input, is_eval_mode=False):
with torch.no_grad():
model = copy.deepcopy(model).to(device)
def _run_model_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False):
if is_eval_mode:
model.eval()
else:
model.train()
with torch.no_grad():
inputs_on_device = [input_.to(device) for input_ in input_list]
for i, val in enumerate(input_list):
if val.requires_grad:
inputs_on_device[i].requires_grad_()
target = label_input.to(device)
def generate_inputs(input_list_, label_input_):
with torch.no_grad():
inputs_on_device = [input_.to(device) for input_ in input_list_]
for i, val in enumerate(input_list_):
if val.requires_grad:
inputs_on_device[i].requires_grad_()
with torch.no_grad():
target = label_input_.to(device)
return inputs_on_device, target
output = model(*inputs_on_device)
forward_outputs = [output]
inputs_on_device1, target1 = generate_inputs(input_list, label_input)
if run_forward_twice is True:
inputs_on_device2, target2 = generate_inputs(input_list, label_input)
output1 = model(*inputs_on_device1)
if run_forward_twice is True:
output2 = model(*inputs_on_device2)
forward_outputs = [output1]
grad_outputs = []
if not is_eval_mode:
criterion = torch.nn.MSELoss()
loss = criterion(output, target)
loss = criterion(output1, target1)
if run_forward_twice is True:
loss += criterion(output2, target2)
loss.backward()
for name, param in model.named_parameters():
if param.requires_grad:
grad_outputs.append(param.grad)
return forward_outputs, grad_outputs
def run_with_ort_on_device(device, model, input_list, label_input, is_eval_mode=False):
def run_with_pytorch_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False):
with torch.no_grad():
model = copy.deepcopy(model).to(device)
return _run_model_on_device(device, model, input_list, label_input, is_eval_mode, run_forward_twice)
def run_with_ort_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False):
with torch.no_grad():
model = copy.deepcopy(model)
model.to(device)
enable_custom_autograd_function(model)
model = ORTModule(model)
if is_eval_mode:
model.eval()
else:
model.train()
with torch.no_grad():
inputs_on_device = [input_.to(device) for input_ in input_list]
for i, val in enumerate(input_list):
if val.requires_grad:
inputs_on_device[i].requires_grad_()
target = label_input.to(device)
output = model(*inputs_on_device)
forward_outputs = [output]
grad_outputs = []
if not is_eval_mode:
criterion = torch.nn.MSELoss()
loss = criterion(output, target)
loss.backward()
for name, param in model.named_parameters():
if param.requires_grad:
grad_outputs.append(param.grad)
return forward_outputs, grad_outputs
return _run_model_on_device(device, model, input_list, label_input, is_eval_mode, run_forward_twice)
def compare_tensor_list(val_list_a, val_list_b):
for val_a, val_b in zip(val_list_a, val_list_b):
assert_values_are_close(val_a, val_b, atol=1e-7, rtol=1e-6)
def run_training_test_and_compare(pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input,
ignore_grad_compare=False, expected_outputs=[], expected_grads=[]):
run_forward_twice=False, ignore_grad_compare=False, expected_outputs=[], expected_grads=[]):
cpu = torch.device("cpu")
def cpu_barrier_func():
pass
run_training_test_on_device_and_compare(
cpu, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cpu_barrier_func,
ignore_grad_compare, expected_outputs, expected_grads)
run_forward_twice, ignore_grad_compare, expected_outputs, expected_grads)
def cuda_barrier_func():
torch.cuda.synchronize()
cuda = torch.device('cuda:0')
run_training_test_on_device_and_compare(
cuda, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cuda_barrier_func,
ignore_grad_compare, expected_outputs, expected_grads)
run_forward_twice, ignore_grad_compare, expected_outputs, expected_grads)
def run_training_test_on_device_and_compare(device, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, barrier_func,
ignore_grad_compare=False, expected_outputs=[], expected_grads=[]):
run_forward_twice=False, ignore_grad_compare=False, expected_outputs=[], expected_grads=[]):
repeats = 16
for i in range(repeats):
m = pt_model_builder_func()
@ -307,11 +303,11 @@ def run_training_test_on_device_and_compare(device, pt_model_builder_func, pt_mo
x_ort = copy.deepcopy(x)
outputs, grads = run_with_pytorch_on_device(
device, m, [x], pt_model_label_input)
device, m, [x], pt_model_label_input, run_forward_twice=run_forward_twice)
barrier_func()
outputs_ort, grads_ort = run_with_ort_on_device(
device, m_ort, [x_ort], pt_model_label_input)
device, m_ort, [x_ort], pt_model_label_input, run_forward_twice=run_forward_twice)
barrier_func()
val_list_a = [o.detach().cpu() for o in outputs if o is not None]
@ -330,14 +326,16 @@ def run_training_test_on_device_and_compare(device, pt_model_builder_func, pt_mo
if len(expected_grads) > 0:
compare_tensor_list(val_list_a, expected_grads)
def run_evaluate_test_and_compare(pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input):
def run_evaluate_test_and_compare(pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input,
run_forward_twice=False):
cpu = torch.device("cpu")
def cpu_barrier_func():
pass
run_evaluate_test_on_device_and_compare(
cpu, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cpu_barrier_func)
cpu, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input,
cpu_barrier_func, run_forward_twice=run_forward_twice)
def cuda_barrier_func():
torch.cuda.synchronize()
@ -345,9 +343,11 @@ def run_evaluate_test_and_compare(pt_model_builder_func, pt_model_inputs_generat
cuda = torch.device('cuda:0')
run_evaluate_test_on_device_and_compare(
cuda, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cuda_barrier_func)
cuda, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input,
cuda_barrier_func, run_forward_twice=run_forward_twice)
def run_evaluate_test_on_device_and_compare(device, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, barrier_func):
def run_evaluate_test_on_device_and_compare(device, pt_model_builder_func, pt_model_inputs_generator,
pt_model_label_input, barrier_func, run_forward_twice=False):
repeats = 16
for i in range(repeats):
m = pt_model_builder_func()
@ -357,11 +357,11 @@ def run_evaluate_test_on_device_and_compare(device, pt_model_builder_func, pt_mo
x_ort = copy.deepcopy(x)
outputs, grads = run_with_pytorch_on_device(
device, m, [x], pt_model_label_input, is_eval_mode=True)
device, m, [x], pt_model_label_input, is_eval_mode=True, run_forward_twice=run_forward_twice)
barrier_func()
outputs_ort, grads_ort = run_with_ort_on_device(
device, m_ort, [x_ort], pt_model_label_input, is_eval_mode=True)
device, m_ort, [x_ort], pt_model_label_input, is_eval_mode=True, run_forward_twice=run_forward_twice)
barrier_func()
val_list_a = [o.detach().cpu() for o in outputs if o is not None]

View file

@ -138,6 +138,61 @@ def test_GeLU_custom_func_rets_not_as_module_output():
run_training_test_and_compare(model_builder, input_generator, label_input)
def test_GeLU_multiple_forward_runs():
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
@torch.jit.script
def bias_gelu_backward(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 +
0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction3(torch.autograd.Function):
@staticmethod
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_backward(grad_output, bias, input)
return tmp, tmp
class GeLUModel(torch.nn.Module):
def __init__(self, output_size):
super(GeLUModel, self).__init__()
self.relu = GeLUFunction3.apply
self.bias = Parameter(torch.empty(
output_size,
device=torch.cuda.current_device(),
dtype=torch.float))
with torch.no_grad():
self.bias.uniform_()
def forward(self, model_input):
out = self.relu(model_input, self.bias)
return out
output_size = 1024
def model_builder():
return GeLUModel(output_size)
def input_generator():
return torch.randn(output_size, dtype=torch.float)
# generate a label that have same shape as forward output.
label_input = torch.ones([output_size])
run_training_test_and_compare(model_builder, input_generator, label_input, run_forward_twice=True)
def test_MegatronF():
# MegatronGFunction is tested in distributed test files.
class MegatronFFunction(torch.autograd.Function):