mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
Check gradient correctness in the UTs (#6803)
* Check gradient correctness in the UTs Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
fa8a9015bd
commit
8a450d523f
3 changed files with 150 additions and 96 deletions
|
|
@ -251,14 +251,18 @@ class ORTModule(torch.nn.Module):
|
|||
return user_outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_output):
|
||||
def backward(ctx, *grad_outputs):
|
||||
'''Performs backward pass based on grad wrt module output
|
||||
'''
|
||||
|
||||
# Use IO binding
|
||||
# Push user output grads to ONNX backend.
|
||||
backward_grad_output_ortvalue = []
|
||||
for grad_output in grad_output[:len(self._onnx_graphs_info.backward_output_grad_names)]:
|
||||
for grad_output in grad_outputs[:len(self._onnx_graphs_info.backward_output_grad_names)]:
|
||||
# Force torch tensors to be contiguous before converting into OrtValue
|
||||
if not grad_output.is_contiguous():
|
||||
grad_output = grad_output.contiguous()
|
||||
|
||||
backward_grad_output_ortvalue.append(onnxruntime.OrtValue.ortvalue_from_data_ptr(list(grad_output.size()), _utils.dtype_torch_to_numpy(
|
||||
grad_output.dtype), grad_output.device.type, _utils.get_device_index(grad_output.device), grad_output.data_ptr()))
|
||||
|
||||
|
|
|
|||
|
|
@ -145,3 +145,19 @@ def _get_name(name):
|
|||
if os.path.exists(res):
|
||||
return res
|
||||
raise FileNotFoundError("Unable to find '{0}' or '{1}' or '{2}'".format(name, rel, res))
|
||||
|
||||
def assert_gradients_match_and_reset_gradient(ort_model, pt_model, reset_gradient=True, rtol=1e-05, atol=1e-06):
|
||||
ort_named_params = list(ort_model.named_parameters())
|
||||
pt_named_params = list(pt_model.named_parameters())
|
||||
assert len(ort_named_params) == len(pt_named_params)
|
||||
|
||||
for ort_named_param, pt_named_param in zip(ort_named_params, pt_named_params):
|
||||
ort_name, ort_param = ort_named_param
|
||||
pt_name, pt_param = pt_named_param
|
||||
|
||||
assert pt_name in ort_name
|
||||
assert torch.allclose(ort_param.grad, pt_param.grad, rtol=rtol, atol=atol)
|
||||
|
||||
if reset_gradient:
|
||||
ort_param.grad = None
|
||||
pt_param.grad = None
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
# orttraining_test_ortmodule_api.py
|
||||
|
||||
import math
|
||||
import copy
|
||||
import torch
|
||||
from transformers import AutoConfig, BertForSequenceClassification
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
|
|
@ -367,143 +368,176 @@ def test_input_requires_grad_backward_creates_input_grad(device):
|
|||
s.backward()
|
||||
assert x.grad is not None
|
||||
|
||||
def test_multiple_forward_only_calls():
|
||||
N, D_in, H, D_out = 32, 784, 500, 10
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
|
||||
model = ORTModule(model)
|
||||
def test_gradient_correctness():
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 32, 128, 500, 10
|
||||
pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
def run_step(model, x):
|
||||
prediction = model(x)
|
||||
loss = prediction.sum()
|
||||
loss.backward()
|
||||
|
||||
for step in range(10):
|
||||
x = torch.randn(N, D_in, device='cuda', requires_grad=False)
|
||||
prediction1 = model(x)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
run_step(pt_model, x)
|
||||
run_step(ort_model, x)
|
||||
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
|
||||
def test_multiple_forward_only_calls():
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 32, 784, 500, 10
|
||||
pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
for step in range(10):
|
||||
x = torch.randn(N, D_in, device=device, requires_grad=False)
|
||||
pt_prediction = pt_model(x)
|
||||
ort_prediction = ort_model(x)
|
||||
|
||||
assert torch.allclose(ort_prediction, pt_prediction)
|
||||
|
||||
def test_multiple_overlapping_forward_backward_calls():
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 32, 784, 500, 10
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
|
||||
model = ORTModule(model)
|
||||
pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
for step in range(10):
|
||||
x1 = torch.randn(N, D_in, device='cuda', requires_grad=True)
|
||||
x2 = torch.randn(N, D_in, device='cuda', requires_grad=True)
|
||||
assert x1.grad is None and x2.grad is None
|
||||
|
||||
def run_step(model, x1, x2):
|
||||
prediction1 = model(x1)
|
||||
s1 = prediction1.sum()
|
||||
loss1 = prediction1.sum()
|
||||
|
||||
prediction2 = model(x2)
|
||||
s2 = prediction2.sum()
|
||||
loss2 = prediction2.sum()
|
||||
|
||||
loss1.backward()
|
||||
loss2.backward()
|
||||
|
||||
s1.backward()
|
||||
s2.backward()
|
||||
assert x1.grad is not None and x2.grad is not None
|
||||
for step in range(10):
|
||||
pt_x1 = torch.randn(N, D_in, device=device, requires_grad=True)
|
||||
pt_x2 = torch.randn(N, D_in, device=device, requires_grad=True)
|
||||
|
||||
ort_x1 = pt_x1.clone().detach()
|
||||
ort_x2 = pt_x2.clone().detach()
|
||||
ort_x1.requires_grad = True
|
||||
ort_x2.requires_grad = True
|
||||
|
||||
run_step(pt_model, pt_x1, pt_x2)
|
||||
run_step(ort_model, ort_x1, ort_x2)
|
||||
|
||||
assert torch.allclose(ort_x1.grad, pt_x1.grad)
|
||||
assert torch.allclose(ort_x2.grad, pt_x2.grad)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
|
||||
def test_multiple_ortmodules_training():
|
||||
N, D_in, H, D_out = 32, 784, 500, 10
|
||||
model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
|
||||
model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
|
||||
model1 = ORTModule(model1)
|
||||
model2 = ORTModule(model2)
|
||||
|
||||
for step in range(10):
|
||||
x1 = torch.randn(N, D_in, device='cuda', requires_grad=True)
|
||||
x2 = torch.randn(N, D_in, device='cuda', requires_grad=True)
|
||||
assert x1.grad is None and x2.grad is None
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 32, 784, 128, 10
|
||||
pt_model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
pt_model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
ort_model1 = ORTModule(copy.deepcopy(pt_model1))
|
||||
ort_model2 = ORTModule(copy.deepcopy(pt_model2))
|
||||
|
||||
def run_step(model1, model2, x1, x2):
|
||||
prediction1 = model1(x1)
|
||||
s1 = prediction1.sum()
|
||||
loss1 = prediction1.sum()
|
||||
loss1.backward()
|
||||
|
||||
prediction2 = model2(x2)
|
||||
s2 = prediction2.sum()
|
||||
loss2 = prediction2.sum()
|
||||
loss2.backward()
|
||||
|
||||
s1.backward()
|
||||
s2.backward()
|
||||
for step in range(10):
|
||||
x1 = torch.randn(N, D_in, device=device)
|
||||
x2 = torch.randn(N, D_in, device=device)
|
||||
run_step(pt_model1, pt_model2, x1, x2)
|
||||
run_step(ort_model1, ort_model2, x1, x2)
|
||||
|
||||
assert x1.grad is not None and x2.grad is not None
|
||||
for param in model1.parameters():
|
||||
assert param.grad is not None
|
||||
param.grad = None
|
||||
for param in model2.parameters():
|
||||
assert param.grad is not None
|
||||
param.grad = None
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model1, pt_model1)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2)
|
||||
|
||||
def test_multiple_ortmodules_common_backbone_training():
|
||||
N, D_in, H, D_out = 32, 64, 500, 64
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
|
||||
model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
|
||||
model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 32, 64, 128, 64
|
||||
pt_model0 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
pt_model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
pt_model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
# model is the common backbone shared by model1 and model2
|
||||
model = ORTModule(model)
|
||||
model1 = ORTModule(model1)
|
||||
model2 = ORTModule(model2)
|
||||
ort_model0 = ORTModule(copy.deepcopy(pt_model0))
|
||||
ort_model1 = ORTModule(copy.deepcopy(pt_model1))
|
||||
ort_model2 = ORTModule(copy.deepcopy(pt_model2))
|
||||
|
||||
def run_step(backbone_layers, task_layers, x):
|
||||
prediction = task_layers(backbone_layers(x))
|
||||
loss = prediction.sum()
|
||||
loss.backward()
|
||||
|
||||
for step in range(10):
|
||||
x1 = torch.randn(N, D_in, device='cuda', requires_grad=True)
|
||||
x2 = torch.randn(N, D_in, device='cuda', requires_grad=True)
|
||||
assert x1.grad is None and x2.grad is None
|
||||
# Run task 1
|
||||
x1 = torch.randn(N, D_in, device=device)
|
||||
run_step(pt_model0, pt_model1, x1)
|
||||
run_step(ort_model0, ort_model1, x1)
|
||||
|
||||
prediction1 = model1(model(x1))
|
||||
s1 = prediction1.sum()
|
||||
s1.backward()
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model0, pt_model0, reset_gradient=False)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model1, pt_model1)
|
||||
|
||||
prediction2 = model2(model(x2))
|
||||
s2 = prediction2.sum()
|
||||
s2.backward()
|
||||
# Run task 2
|
||||
x2 = torch.randn(N, D_in, device=device)
|
||||
run_step(pt_model0, pt_model2, x1)
|
||||
run_step(ort_model0, ort_model2, x1)
|
||||
|
||||
assert x1.grad is not None and x2.grad is not None
|
||||
for param in model.parameters():
|
||||
assert param.grad is not None
|
||||
param.grad = None
|
||||
for param in model1.parameters():
|
||||
assert param.grad is not None
|
||||
param.grad = None
|
||||
for param in model2.parameters():
|
||||
assert param.grad is not None
|
||||
param.grad = None
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model0, pt_model0, reset_gradient=True)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2)
|
||||
|
||||
def test_multiple_chained_ortmodules_training():
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 32, 128, 500, 128
|
||||
model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
|
||||
model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
|
||||
model1 = ORTModule(model1)
|
||||
model2 = ORTModule(model2)
|
||||
pt_model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
pt_model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
ort_model1 = ORTModule(copy.deepcopy(pt_model1))
|
||||
ort_model2 = ORTModule(copy.deepcopy(pt_model2))
|
||||
|
||||
all_params = list(model1.parameters()) + list(model2.parameters())
|
||||
def run_step(layers1, layers2, x):
|
||||
prediction = layers2(layers1(x))
|
||||
loss = prediction.sum()
|
||||
loss.backward()
|
||||
|
||||
for step in range(10):
|
||||
x = torch.randn(N, D_in, device='cuda', requires_grad=True)
|
||||
output1 = model1(x)
|
||||
output2 = model2(output1)
|
||||
s = output2.sum()
|
||||
s.backward()
|
||||
x = torch.randn(N, D_in, device=device, requires_grad=True)
|
||||
run_step(pt_model1, pt_model2, x)
|
||||
run_step(ort_model1, ort_model2, x)
|
||||
|
||||
assert x.grad is not None
|
||||
for param in all_params:
|
||||
assert param.grad is not None
|
||||
param.grad = None
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model1, pt_model1)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2)
|
||||
|
||||
def test_mixed_nnmodule_ortmodules_training():
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 32, 128, 500, 128
|
||||
model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
|
||||
model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
|
||||
model3 = NeuralNetMultiplePositionalArguments(D_in, H, D_out).to('cuda')
|
||||
model1 = ORTModule(model1)
|
||||
# model2 is intentionally left as nn.module
|
||||
model3 = ORTModule(model3)
|
||||
pt_model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
pt_model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
pt_model3 = NeuralNetMultiplePositionalArguments(D_in, H, D_out).to(device)
|
||||
|
||||
all_params = list(model1.parameters()) + list(model2.parameters()) + list(model3.parameters())
|
||||
|
||||
for step in range(10):
|
||||
x1 = torch.randn(N, D_in, device='cuda', requires_grad=True)
|
||||
x2 = torch.randn(N, D_in, device='cuda', requires_grad=True)
|
||||
ort_model1 = ORTModule(copy.deepcopy(pt_model1))
|
||||
ort_model2 = copy.deepcopy(pt_model2) # model2 is intentionally left as nn.module
|
||||
ort_model3 = ORTModule(copy.deepcopy(pt_model3))
|
||||
|
||||
def run_step(model1, model2, model3, x1, x2):
|
||||
a1 = model1(x1)
|
||||
a2 = model2(x2)
|
||||
a3 = model3(torch.sin(a1), torch.cos(a2))
|
||||
loss = a3.sum()
|
||||
loss.backward()
|
||||
|
||||
assert x1.grad is not None and x2.grad is not None
|
||||
for param in all_params:
|
||||
assert param.grad is not None
|
||||
param.grad = None
|
||||
for step in range(10):
|
||||
x1 = torch.randn(N, D_in, device=device)
|
||||
x2 = torch.randn(N, D_in, device=device)
|
||||
run_step(pt_model1, pt_model2, pt_model3, x1, x2)
|
||||
run_step(ort_model1, ort_model2, ort_model3, x1, x2)
|
||||
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model1, pt_model1)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model3, pt_model3)
|
||||
|
||||
@pytest.mark.parametrize("device", ['cuda', 'cpu'])
|
||||
def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder(device):
|
||||
|
|
|
|||
Loading…
Reference in a new issue