Check gradient correctness in the UTs (#6803)

* Check gradient correctness in the UTs

Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
Sherlock 2021-02-25 13:31:07 -08:00 committed by GitHub
parent fa8a9015bd
commit 8a450d523f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 150 additions and 96 deletions

View file

@ -251,14 +251,18 @@ class ORTModule(torch.nn.Module):
return user_outputs
@staticmethod
def backward(ctx, *grad_output):
def backward(ctx, *grad_outputs):
'''Performs backward pass based on grad wrt module output
'''
# Use IO binding
# Push user output grads to ONNX backend.
backward_grad_output_ortvalue = []
for grad_output in grad_output[:len(self._onnx_graphs_info.backward_output_grad_names)]:
for grad_output in grad_outputs[:len(self._onnx_graphs_info.backward_output_grad_names)]:
# Force torch tensors to be contiguous before converting into OrtValue
if not grad_output.is_contiguous():
grad_output = grad_output.contiguous()
backward_grad_output_ortvalue.append(onnxruntime.OrtValue.ortvalue_from_data_ptr(list(grad_output.size()), _utils.dtype_torch_to_numpy(
grad_output.dtype), grad_output.device.type, _utils.get_device_index(grad_output.device), grad_output.data_ptr()))

View file

@ -145,3 +145,19 @@ def _get_name(name):
if os.path.exists(res):
return res
raise FileNotFoundError("Unable to find '{0}' or '{1}' or '{2}'".format(name, rel, res))
def assert_gradients_match_and_reset_gradient(ort_model, pt_model, reset_gradient=True, rtol=1e-05, atol=1e-06):
ort_named_params = list(ort_model.named_parameters())
pt_named_params = list(pt_model.named_parameters())
assert len(ort_named_params) == len(pt_named_params)
for ort_named_param, pt_named_param in zip(ort_named_params, pt_named_params):
ort_name, ort_param = ort_named_param
pt_name, pt_param = pt_named_param
assert pt_name in ort_name
assert torch.allclose(ort_param.grad, pt_param.grad, rtol=rtol, atol=atol)
if reset_gradient:
ort_param.grad = None
pt_param.grad = None

View file

@ -3,6 +3,7 @@
# orttraining_test_ortmodule_api.py
import math
import copy
import torch
from transformers import AutoConfig, BertForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput
@ -367,143 +368,176 @@ def test_input_requires_grad_backward_creates_input_grad(device):
s.backward()
assert x.grad is not None
def test_multiple_forward_only_calls():
N, D_in, H, D_out = 32, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model = ORTModule(model)
def test_gradient_correctness():
device = 'cuda'
N, D_in, H, D_out = 32, 128, 500, 10
pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
def run_step(model, x):
prediction = model(x)
loss = prediction.sum()
loss.backward()
for step in range(10):
x = torch.randn(N, D_in, device='cuda', requires_grad=False)
prediction1 = model(x)
x = torch.randn(N, D_in, device=device)
run_step(pt_model, x)
run_step(ort_model, x)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
def test_multiple_forward_only_calls():
device = 'cuda'
N, D_in, H, D_out = 32, 784, 500, 10
pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
for step in range(10):
x = torch.randn(N, D_in, device=device, requires_grad=False)
pt_prediction = pt_model(x)
ort_prediction = ort_model(x)
assert torch.allclose(ort_prediction, pt_prediction)
def test_multiple_overlapping_forward_backward_calls():
device = 'cuda'
N, D_in, H, D_out = 32, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model = ORTModule(model)
pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
for step in range(10):
x1 = torch.randn(N, D_in, device='cuda', requires_grad=True)
x2 = torch.randn(N, D_in, device='cuda', requires_grad=True)
assert x1.grad is None and x2.grad is None
def run_step(model, x1, x2):
prediction1 = model(x1)
s1 = prediction1.sum()
loss1 = prediction1.sum()
prediction2 = model(x2)
s2 = prediction2.sum()
loss2 = prediction2.sum()
loss1.backward()
loss2.backward()
s1.backward()
s2.backward()
assert x1.grad is not None and x2.grad is not None
for step in range(10):
pt_x1 = torch.randn(N, D_in, device=device, requires_grad=True)
pt_x2 = torch.randn(N, D_in, device=device, requires_grad=True)
ort_x1 = pt_x1.clone().detach()
ort_x2 = pt_x2.clone().detach()
ort_x1.requires_grad = True
ort_x2.requires_grad = True
run_step(pt_model, pt_x1, pt_x2)
run_step(ort_model, ort_x1, ort_x2)
assert torch.allclose(ort_x1.grad, pt_x1.grad)
assert torch.allclose(ort_x2.grad, pt_x2.grad)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
def test_multiple_ortmodules_training():
N, D_in, H, D_out = 32, 784, 500, 10
model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model1 = ORTModule(model1)
model2 = ORTModule(model2)
for step in range(10):
x1 = torch.randn(N, D_in, device='cuda', requires_grad=True)
x2 = torch.randn(N, D_in, device='cuda', requires_grad=True)
assert x1.grad is None and x2.grad is None
device = 'cuda'
N, D_in, H, D_out = 32, 784, 128, 10
pt_model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
pt_model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
ort_model1 = ORTModule(copy.deepcopy(pt_model1))
ort_model2 = ORTModule(copy.deepcopy(pt_model2))
def run_step(model1, model2, x1, x2):
prediction1 = model1(x1)
s1 = prediction1.sum()
loss1 = prediction1.sum()
loss1.backward()
prediction2 = model2(x2)
s2 = prediction2.sum()
loss2 = prediction2.sum()
loss2.backward()
s1.backward()
s2.backward()
for step in range(10):
x1 = torch.randn(N, D_in, device=device)
x2 = torch.randn(N, D_in, device=device)
run_step(pt_model1, pt_model2, x1, x2)
run_step(ort_model1, ort_model2, x1, x2)
assert x1.grad is not None and x2.grad is not None
for param in model1.parameters():
assert param.grad is not None
param.grad = None
for param in model2.parameters():
assert param.grad is not None
param.grad = None
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model1, pt_model1)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2)
def test_multiple_ortmodules_common_backbone_training():
N, D_in, H, D_out = 32, 64, 500, 64
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
device = 'cuda'
N, D_in, H, D_out = 32, 64, 128, 64
pt_model0 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
pt_model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
pt_model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
# model is the common backbone shared by model1 and model2
model = ORTModule(model)
model1 = ORTModule(model1)
model2 = ORTModule(model2)
ort_model0 = ORTModule(copy.deepcopy(pt_model0))
ort_model1 = ORTModule(copy.deepcopy(pt_model1))
ort_model2 = ORTModule(copy.deepcopy(pt_model2))
def run_step(backbone_layers, task_layers, x):
prediction = task_layers(backbone_layers(x))
loss = prediction.sum()
loss.backward()
for step in range(10):
x1 = torch.randn(N, D_in, device='cuda', requires_grad=True)
x2 = torch.randn(N, D_in, device='cuda', requires_grad=True)
assert x1.grad is None and x2.grad is None
# Run task 1
x1 = torch.randn(N, D_in, device=device)
run_step(pt_model0, pt_model1, x1)
run_step(ort_model0, ort_model1, x1)
prediction1 = model1(model(x1))
s1 = prediction1.sum()
s1.backward()
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model0, pt_model0, reset_gradient=False)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model1, pt_model1)
prediction2 = model2(model(x2))
s2 = prediction2.sum()
s2.backward()
# Run task 2
x2 = torch.randn(N, D_in, device=device)
run_step(pt_model0, pt_model2, x1)
run_step(ort_model0, ort_model2, x1)
assert x1.grad is not None and x2.grad is not None
for param in model.parameters():
assert param.grad is not None
param.grad = None
for param in model1.parameters():
assert param.grad is not None
param.grad = None
for param in model2.parameters():
assert param.grad is not None
param.grad = None
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model0, pt_model0, reset_gradient=True)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2)
def test_multiple_chained_ortmodules_training():
device = 'cuda'
N, D_in, H, D_out = 32, 128, 500, 128
model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model1 = ORTModule(model1)
model2 = ORTModule(model2)
pt_model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
pt_model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
ort_model1 = ORTModule(copy.deepcopy(pt_model1))
ort_model2 = ORTModule(copy.deepcopy(pt_model2))
all_params = list(model1.parameters()) + list(model2.parameters())
def run_step(layers1, layers2, x):
prediction = layers2(layers1(x))
loss = prediction.sum()
loss.backward()
for step in range(10):
x = torch.randn(N, D_in, device='cuda', requires_grad=True)
output1 = model1(x)
output2 = model2(output1)
s = output2.sum()
s.backward()
x = torch.randn(N, D_in, device=device, requires_grad=True)
run_step(pt_model1, pt_model2, x)
run_step(ort_model1, ort_model2, x)
assert x.grad is not None
for param in all_params:
assert param.grad is not None
param.grad = None
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model1, pt_model1)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2)
def test_mixed_nnmodule_ortmodules_training():
device = 'cuda'
N, D_in, H, D_out = 32, 128, 500, 128
model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model3 = NeuralNetMultiplePositionalArguments(D_in, H, D_out).to('cuda')
model1 = ORTModule(model1)
# model2 is intentionally left as nn.module
model3 = ORTModule(model3)
pt_model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
pt_model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
pt_model3 = NeuralNetMultiplePositionalArguments(D_in, H, D_out).to(device)
all_params = list(model1.parameters()) + list(model2.parameters()) + list(model3.parameters())
for step in range(10):
x1 = torch.randn(N, D_in, device='cuda', requires_grad=True)
x2 = torch.randn(N, D_in, device='cuda', requires_grad=True)
ort_model1 = ORTModule(copy.deepcopy(pt_model1))
ort_model2 = copy.deepcopy(pt_model2) # model2 is intentionally left as nn.module
ort_model3 = ORTModule(copy.deepcopy(pt_model3))
def run_step(model1, model2, model3, x1, x2):
a1 = model1(x1)
a2 = model2(x2)
a3 = model3(torch.sin(a1), torch.cos(a2))
loss = a3.sum()
loss.backward()
assert x1.grad is not None and x2.grad is not None
for param in all_params:
assert param.grad is not None
param.grad = None
for step in range(10):
x1 = torch.randn(N, D_in, device=device)
x2 = torch.randn(N, D_in, device=device)
run_step(pt_model1, pt_model2, pt_model3, x1, x2)
run_step(ort_model1, ort_model2, ort_model3, x1, x2)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model1, pt_model1)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model3, pt_model3)
@pytest.mark.parametrize("device", ['cuda', 'cpu'])
def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder(device):