mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
don't clear grad_fns & add test (#10671)
This commit is contained in:
parent
1a62306db7
commit
d478a53d43
4 changed files with 119 additions and 58 deletions
|
|
@ -16,7 +16,6 @@ from .debug_options import DebugOptions
|
|||
from ._fallback import (ORTModuleFallbackException,
|
||||
_FallbackPolicy,
|
||||
_FallbackManager)
|
||||
from .torch_cpp_extensions.cpu.torch_interop_utils import clear_all_grad_fns
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from onnxruntime.capi.onnxruntime_inference_collection import get_ort_device_type
|
||||
|
|
@ -40,10 +39,6 @@ class TrainingManager(GraphExecutionManager):
|
|||
def execution_session_run_forward(execution_session, onnx_model, device, gradient_accumulation_manager, *inputs):
|
||||
"""Runs the forward graph on execution_session with given model inputs and device"""
|
||||
|
||||
# Clear all gradient functions, to avoid a deadlock issue.
|
||||
# Check the called function for more detailed comments.
|
||||
clear_all_grad_fns()
|
||||
|
||||
# TODO: Try to reuse the output buffers as some of the output tensors are same sizes,
|
||||
# especially the backward graph outputs.
|
||||
# REVIEW(codemzs): Consolidate Training Agent with InferenceAgent on C++ side to not
|
||||
|
|
|
|||
|
|
@ -114,16 +114,27 @@ void unregister_grad_fn(size_t ctx_address)
|
|||
PyNodeSharedPointerPool::GetInstance().UnRegisterGradFunc(ctx_address);
|
||||
}
|
||||
|
||||
// Supposed to be cleared on python program exit or before every forward run to resolve following issues:
|
||||
// 1. When training program exits, PyNodeSharedPointerPool destructor is called, if grad_fns_ is not empty,
|
||||
// Supposed to be cleared on python program exit to resolve following issue:
|
||||
// When training program exits, PyNodeSharedPointerPool destructor is called, if grad_fns_ is not empty,
|
||||
// PyNode::release_variables() will be called.
|
||||
// (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L168)
|
||||
// The other hand, there is known issue when acquiring GIL in pybind11 destructors, there will be probabbly deadlock issue.
|
||||
// (https://github.com/pybind/pybind11/issues/1446)
|
||||
// The resolution here, we remove all maintained states before program exits.
|
||||
// 2. When forward functions is called repeated without corresponding backward calls, grad functions keeps accumulating without releasing
|
||||
// (happening in backward)
|
||||
void clear_all_grad_fns(){
|
||||
|
||||
// A known existing issue: when forward functions is called repeatedly without corresponding backward calls,
|
||||
// grad functions keeps accumulating without releasing, there might be memory (bound to those gradient function) leaks.
|
||||
// Ideally this usually won't happen in real training case, so it should be fine.
|
||||
|
||||
// We CANNOT explictly clear grad functions before each forward pass to mitigate the known issue above.
|
||||
// For example:
|
||||
// loss1 = forward_run(inputs1)
|
||||
// loss2 = forward_run(inputs2)
|
||||
// loss = loss1 + loss2
|
||||
// loss.backward()
|
||||
// If we clear grad functions in the beggining of the second `forward_run`, when `loss.backward()` runs,
|
||||
// the backward path of `loss1` will fail to run PythonOpGrad ops (if there is any).
|
||||
void clear_all_grad_fns() {
|
||||
PyNodeSharedPointerPool::GetInstance().ClearAll();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -215,88 +215,84 @@ def assert_values_are_close(input, other, rtol=1e-05, atol=1e-06):
|
|||
def enable_custom_autograd_function(module):
|
||||
enable_custom_autograd_support()
|
||||
|
||||
def run_with_pytorch_on_device(device, model, input_list, label_input, is_eval_mode=False):
|
||||
with torch.no_grad():
|
||||
model = copy.deepcopy(model).to(device)
|
||||
def _run_model_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False):
|
||||
if is_eval_mode:
|
||||
model.eval()
|
||||
else:
|
||||
model.train()
|
||||
|
||||
with torch.no_grad():
|
||||
inputs_on_device = [input_.to(device) for input_ in input_list]
|
||||
for i, val in enumerate(input_list):
|
||||
if val.requires_grad:
|
||||
inputs_on_device[i].requires_grad_()
|
||||
target = label_input.to(device)
|
||||
def generate_inputs(input_list_, label_input_):
|
||||
with torch.no_grad():
|
||||
inputs_on_device = [input_.to(device) for input_ in input_list_]
|
||||
for i, val in enumerate(input_list_):
|
||||
if val.requires_grad:
|
||||
inputs_on_device[i].requires_grad_()
|
||||
with torch.no_grad():
|
||||
target = label_input_.to(device)
|
||||
return inputs_on_device, target
|
||||
|
||||
output = model(*inputs_on_device)
|
||||
forward_outputs = [output]
|
||||
inputs_on_device1, target1 = generate_inputs(input_list, label_input)
|
||||
if run_forward_twice is True:
|
||||
inputs_on_device2, target2 = generate_inputs(input_list, label_input)
|
||||
|
||||
output1 = model(*inputs_on_device1)
|
||||
if run_forward_twice is True:
|
||||
output2 = model(*inputs_on_device2)
|
||||
|
||||
forward_outputs = [output1]
|
||||
grad_outputs = []
|
||||
|
||||
if not is_eval_mode:
|
||||
criterion = torch.nn.MSELoss()
|
||||
loss = criterion(output, target)
|
||||
loss = criterion(output1, target1)
|
||||
|
||||
if run_forward_twice is True:
|
||||
loss += criterion(output2, target2)
|
||||
|
||||
loss.backward()
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
grad_outputs.append(param.grad)
|
||||
return forward_outputs, grad_outputs
|
||||
|
||||
def run_with_ort_on_device(device, model, input_list, label_input, is_eval_mode=False):
|
||||
def run_with_pytorch_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False):
|
||||
with torch.no_grad():
|
||||
model = copy.deepcopy(model).to(device)
|
||||
|
||||
return _run_model_on_device(device, model, input_list, label_input, is_eval_mode, run_forward_twice)
|
||||
|
||||
def run_with_ort_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False):
|
||||
with torch.no_grad():
|
||||
model = copy.deepcopy(model)
|
||||
model.to(device)
|
||||
enable_custom_autograd_function(model)
|
||||
model = ORTModule(model)
|
||||
if is_eval_mode:
|
||||
model.eval()
|
||||
else:
|
||||
model.train()
|
||||
|
||||
with torch.no_grad():
|
||||
inputs_on_device = [input_.to(device) for input_ in input_list]
|
||||
for i, val in enumerate(input_list):
|
||||
if val.requires_grad:
|
||||
inputs_on_device[i].requires_grad_()
|
||||
|
||||
target = label_input.to(device)
|
||||
output = model(*inputs_on_device)
|
||||
forward_outputs = [output]
|
||||
grad_outputs = []
|
||||
|
||||
if not is_eval_mode:
|
||||
criterion = torch.nn.MSELoss()
|
||||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
grad_outputs.append(param.grad)
|
||||
return forward_outputs, grad_outputs
|
||||
return _run_model_on_device(device, model, input_list, label_input, is_eval_mode, run_forward_twice)
|
||||
|
||||
def compare_tensor_list(val_list_a, val_list_b):
|
||||
for val_a, val_b in zip(val_list_a, val_list_b):
|
||||
assert_values_are_close(val_a, val_b, atol=1e-7, rtol=1e-6)
|
||||
|
||||
def run_training_test_and_compare(pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input,
|
||||
ignore_grad_compare=False, expected_outputs=[], expected_grads=[]):
|
||||
run_forward_twice=False, ignore_grad_compare=False, expected_outputs=[], expected_grads=[]):
|
||||
cpu = torch.device("cpu")
|
||||
|
||||
def cpu_barrier_func():
|
||||
pass
|
||||
run_training_test_on_device_and_compare(
|
||||
cpu, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cpu_barrier_func,
|
||||
ignore_grad_compare, expected_outputs, expected_grads)
|
||||
run_forward_twice, ignore_grad_compare, expected_outputs, expected_grads)
|
||||
|
||||
def cuda_barrier_func():
|
||||
torch.cuda.synchronize()
|
||||
cuda = torch.device('cuda:0')
|
||||
run_training_test_on_device_and_compare(
|
||||
cuda, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cuda_barrier_func,
|
||||
ignore_grad_compare, expected_outputs, expected_grads)
|
||||
run_forward_twice, ignore_grad_compare, expected_outputs, expected_grads)
|
||||
|
||||
def run_training_test_on_device_and_compare(device, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, barrier_func,
|
||||
ignore_grad_compare=False, expected_outputs=[], expected_grads=[]):
|
||||
run_forward_twice=False, ignore_grad_compare=False, expected_outputs=[], expected_grads=[]):
|
||||
repeats = 16
|
||||
for i in range(repeats):
|
||||
m = pt_model_builder_func()
|
||||
|
|
@ -307,11 +303,11 @@ def run_training_test_on_device_and_compare(device, pt_model_builder_func, pt_mo
|
|||
x_ort = copy.deepcopy(x)
|
||||
|
||||
outputs, grads = run_with_pytorch_on_device(
|
||||
device, m, [x], pt_model_label_input)
|
||||
device, m, [x], pt_model_label_input, run_forward_twice=run_forward_twice)
|
||||
barrier_func()
|
||||
|
||||
outputs_ort, grads_ort = run_with_ort_on_device(
|
||||
device, m_ort, [x_ort], pt_model_label_input)
|
||||
device, m_ort, [x_ort], pt_model_label_input, run_forward_twice=run_forward_twice)
|
||||
barrier_func()
|
||||
|
||||
val_list_a = [o.detach().cpu() for o in outputs if o is not None]
|
||||
|
|
@ -330,14 +326,16 @@ def run_training_test_on_device_and_compare(device, pt_model_builder_func, pt_mo
|
|||
if len(expected_grads) > 0:
|
||||
compare_tensor_list(val_list_a, expected_grads)
|
||||
|
||||
def run_evaluate_test_and_compare(pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input):
|
||||
def run_evaluate_test_and_compare(pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input,
|
||||
run_forward_twice=False):
|
||||
cpu = torch.device("cpu")
|
||||
|
||||
def cpu_barrier_func():
|
||||
pass
|
||||
|
||||
run_evaluate_test_on_device_and_compare(
|
||||
cpu, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cpu_barrier_func)
|
||||
cpu, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input,
|
||||
cpu_barrier_func, run_forward_twice=run_forward_twice)
|
||||
|
||||
def cuda_barrier_func():
|
||||
torch.cuda.synchronize()
|
||||
|
|
@ -345,9 +343,11 @@ def run_evaluate_test_and_compare(pt_model_builder_func, pt_model_inputs_generat
|
|||
|
||||
cuda = torch.device('cuda:0')
|
||||
run_evaluate_test_on_device_and_compare(
|
||||
cuda, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cuda_barrier_func)
|
||||
cuda, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input,
|
||||
cuda_barrier_func, run_forward_twice=run_forward_twice)
|
||||
|
||||
def run_evaluate_test_on_device_and_compare(device, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, barrier_func):
|
||||
def run_evaluate_test_on_device_and_compare(device, pt_model_builder_func, pt_model_inputs_generator,
|
||||
pt_model_label_input, barrier_func, run_forward_twice=False):
|
||||
repeats = 16
|
||||
for i in range(repeats):
|
||||
m = pt_model_builder_func()
|
||||
|
|
@ -357,11 +357,11 @@ def run_evaluate_test_on_device_and_compare(device, pt_model_builder_func, pt_mo
|
|||
x_ort = copy.deepcopy(x)
|
||||
|
||||
outputs, grads = run_with_pytorch_on_device(
|
||||
device, m, [x], pt_model_label_input, is_eval_mode=True)
|
||||
device, m, [x], pt_model_label_input, is_eval_mode=True, run_forward_twice=run_forward_twice)
|
||||
barrier_func()
|
||||
|
||||
outputs_ort, grads_ort = run_with_ort_on_device(
|
||||
device, m_ort, [x_ort], pt_model_label_input, is_eval_mode=True)
|
||||
device, m_ort, [x_ort], pt_model_label_input, is_eval_mode=True, run_forward_twice=run_forward_twice)
|
||||
barrier_func()
|
||||
|
||||
val_list_a = [o.detach().cpu() for o in outputs if o is not None]
|
||||
|
|
|
|||
|
|
@ -138,6 +138,61 @@ def test_GeLU_custom_func_rets_not_as_module_output():
|
|||
run_training_test_and_compare(model_builder, input_generator, label_input)
|
||||
|
||||
|
||||
def test_GeLU_multiple_forward_runs():
|
||||
@torch.jit.script
|
||||
def bias_gelu(bias, y):
|
||||
x = bias + y
|
||||
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
||||
|
||||
@torch.jit.script
|
||||
def bias_gelu_backward(g, bias, y):
|
||||
x = bias + y
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 +
|
||||
0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
||||
return ff*g
|
||||
|
||||
class GeLUFunction3(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, bias):
|
||||
ctx.save_for_backward(input, bias)
|
||||
return bias_gelu(bias, input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, bias = ctx.saved_tensors
|
||||
tmp = bias_gelu_backward(grad_output, bias, input)
|
||||
return tmp, tmp
|
||||
|
||||
class GeLUModel(torch.nn.Module):
|
||||
def __init__(self, output_size):
|
||||
super(GeLUModel, self).__init__()
|
||||
self.relu = GeLUFunction3.apply
|
||||
self.bias = Parameter(torch.empty(
|
||||
output_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.float))
|
||||
|
||||
with torch.no_grad():
|
||||
self.bias.uniform_()
|
||||
|
||||
def forward(self, model_input):
|
||||
out = self.relu(model_input, self.bias)
|
||||
return out
|
||||
|
||||
output_size = 1024
|
||||
|
||||
def model_builder():
|
||||
return GeLUModel(output_size)
|
||||
|
||||
def input_generator():
|
||||
return torch.randn(output_size, dtype=torch.float)
|
||||
|
||||
# generate a label that have same shape as forward output.
|
||||
label_input = torch.ones([output_size])
|
||||
|
||||
run_training_test_and_compare(model_builder, input_generator, label_input, run_forward_twice=True)
|
||||
|
||||
def test_MegatronF():
|
||||
# MegatronGFunction is tested in distributed test files.
|
||||
class MegatronFFunction(torch.autograd.Function):
|
||||
|
|
|
|||
Loading…
Reference in a new issue