diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index f794026b86..7737abd090 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -251,14 +251,18 @@ class ORTModule(torch.nn.Module): return user_outputs @staticmethod - def backward(ctx, *grad_output): + def backward(ctx, *grad_outputs): '''Performs backward pass based on grad wrt module output ''' # Use IO binding # Push user output grads to ONNX backend. backward_grad_output_ortvalue = [] - for grad_output in grad_output[:len(self._onnx_graphs_info.backward_output_grad_names)]: + for grad_output in grad_outputs[:len(self._onnx_graphs_info.backward_output_grad_names)]: + # Force torch tensors to be contiguous before converting into OrtValue + if not grad_output.is_contiguous(): + grad_output = grad_output.contiguous() + backward_grad_output_ortvalue.append(onnxruntime.OrtValue.ortvalue_from_data_ptr(list(grad_output.size()), _utils.dtype_torch_to_numpy( grad_output.dtype), grad_output.device.type, _utils.get_device_index(grad_output.device), grad_output.data_ptr())) diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py index ab2bb13b07..6b7852c328 100644 --- a/orttraining/orttraining/test/python/_test_helpers.py +++ b/orttraining/orttraining/test/python/_test_helpers.py @@ -145,3 +145,19 @@ def _get_name(name): if os.path.exists(res): return res raise FileNotFoundError("Unable to find '{0}' or '{1}' or '{2}'".format(name, rel, res)) + +def assert_gradients_match_and_reset_gradient(ort_model, pt_model, reset_gradient=True, rtol=1e-05, atol=1e-06): + ort_named_params = list(ort_model.named_parameters()) + pt_named_params = list(pt_model.named_parameters()) + assert len(ort_named_params) == len(pt_named_params) + + for ort_named_param, pt_named_param in zip(ort_named_params, pt_named_params): + ort_name, ort_param = ort_named_param + pt_name, pt_param = pt_named_param + + assert pt_name in ort_name + assert torch.allclose(ort_param.grad, pt_param.grad, rtol=rtol, atol=atol) + + if reset_gradient: + ort_param.grad = None + pt_param.grad = None diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 4994145bd7..940e269d1e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -3,6 +3,7 @@ # orttraining_test_ortmodule_api.py import math +import copy import torch from transformers import AutoConfig, BertForSequenceClassification from transformers.modeling_outputs import SequenceClassifierOutput @@ -367,143 +368,176 @@ def test_input_requires_grad_backward_creates_input_grad(device): s.backward() assert x.grad is not None -def test_multiple_forward_only_calls(): - N, D_in, H, D_out = 32, 784, 500, 10 - model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') - model = ORTModule(model) +def test_gradient_correctness(): + device = 'cuda' + N, D_in, H, D_out = 32, 128, 500, 10 + pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def run_step(model, x): + prediction = model(x) + loss = prediction.sum() + loss.backward() + for step in range(10): - x = torch.randn(N, D_in, device='cuda', requires_grad=False) - prediction1 = model(x) + x = torch.randn(N, D_in, device=device) + run_step(pt_model, x) + run_step(ort_model, x) + + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + +def test_multiple_forward_only_calls(): + device = 'cuda' + N, D_in, H, D_out = 32, 784, 500, 10 + pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + for step in range(10): + x = torch.randn(N, D_in, device=device, requires_grad=False) + pt_prediction = pt_model(x) + ort_prediction = ort_model(x) + + assert torch.allclose(ort_prediction, pt_prediction) def test_multiple_overlapping_forward_backward_calls(): + device = 'cuda' N, D_in, H, D_out = 32, 784, 500, 10 - model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') - model = ORTModule(model) + pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) - for step in range(10): - x1 = torch.randn(N, D_in, device='cuda', requires_grad=True) - x2 = torch.randn(N, D_in, device='cuda', requires_grad=True) - assert x1.grad is None and x2.grad is None - + def run_step(model, x1, x2): prediction1 = model(x1) - s1 = prediction1.sum() + loss1 = prediction1.sum() prediction2 = model(x2) - s2 = prediction2.sum() + loss2 = prediction2.sum() + + loss1.backward() + loss2.backward() - s1.backward() - s2.backward() - assert x1.grad is not None and x2.grad is not None + for step in range(10): + pt_x1 = torch.randn(N, D_in, device=device, requires_grad=True) + pt_x2 = torch.randn(N, D_in, device=device, requires_grad=True) + + ort_x1 = pt_x1.clone().detach() + ort_x2 = pt_x2.clone().detach() + ort_x1.requires_grad = True + ort_x2.requires_grad = True + + run_step(pt_model, pt_x1, pt_x2) + run_step(ort_model, ort_x1, ort_x2) + + assert torch.allclose(ort_x1.grad, pt_x1.grad) + assert torch.allclose(ort_x2.grad, pt_x2.grad) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) def test_multiple_ortmodules_training(): - N, D_in, H, D_out = 32, 784, 500, 10 - model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') - model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') - model1 = ORTModule(model1) - model2 = ORTModule(model2) - - for step in range(10): - x1 = torch.randn(N, D_in, device='cuda', requires_grad=True) - x2 = torch.randn(N, D_in, device='cuda', requires_grad=True) - assert x1.grad is None and x2.grad is None + device = 'cuda' + N, D_in, H, D_out = 32, 784, 128, 10 + pt_model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + pt_model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + ort_model1 = ORTModule(copy.deepcopy(pt_model1)) + ort_model2 = ORTModule(copy.deepcopy(pt_model2)) + def run_step(model1, model2, x1, x2): prediction1 = model1(x1) - s1 = prediction1.sum() + loss1 = prediction1.sum() + loss1.backward() prediction2 = model2(x2) - s2 = prediction2.sum() + loss2 = prediction2.sum() + loss2.backward() - s1.backward() - s2.backward() + for step in range(10): + x1 = torch.randn(N, D_in, device=device) + x2 = torch.randn(N, D_in, device=device) + run_step(pt_model1, pt_model2, x1, x2) + run_step(ort_model1, ort_model2, x1, x2) - assert x1.grad is not None and x2.grad is not None - for param in model1.parameters(): - assert param.grad is not None - param.grad = None - for param in model2.parameters(): - assert param.grad is not None - param.grad = None + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model1, pt_model1) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2) def test_multiple_ortmodules_common_backbone_training(): - N, D_in, H, D_out = 32, 64, 500, 64 - model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') - model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') - model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') + device = 'cuda' + N, D_in, H, D_out = 32, 64, 128, 64 + pt_model0 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + pt_model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + pt_model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) # model is the common backbone shared by model1 and model2 - model = ORTModule(model) - model1 = ORTModule(model1) - model2 = ORTModule(model2) + ort_model0 = ORTModule(copy.deepcopy(pt_model0)) + ort_model1 = ORTModule(copy.deepcopy(pt_model1)) + ort_model2 = ORTModule(copy.deepcopy(pt_model2)) + + def run_step(backbone_layers, task_layers, x): + prediction = task_layers(backbone_layers(x)) + loss = prediction.sum() + loss.backward() for step in range(10): - x1 = torch.randn(N, D_in, device='cuda', requires_grad=True) - x2 = torch.randn(N, D_in, device='cuda', requires_grad=True) - assert x1.grad is None and x2.grad is None + # Run task 1 + x1 = torch.randn(N, D_in, device=device) + run_step(pt_model0, pt_model1, x1) + run_step(ort_model0, ort_model1, x1) - prediction1 = model1(model(x1)) - s1 = prediction1.sum() - s1.backward() + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model0, pt_model0, reset_gradient=False) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model1, pt_model1) - prediction2 = model2(model(x2)) - s2 = prediction2.sum() - s2.backward() + # Run task 2 + x2 = torch.randn(N, D_in, device=device) + run_step(pt_model0, pt_model2, x1) + run_step(ort_model0, ort_model2, x1) - assert x1.grad is not None and x2.grad is not None - for param in model.parameters(): - assert param.grad is not None - param.grad = None - for param in model1.parameters(): - assert param.grad is not None - param.grad = None - for param in model2.parameters(): - assert param.grad is not None - param.grad = None + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model0, pt_model0, reset_gradient=True) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2) def test_multiple_chained_ortmodules_training(): + device = 'cuda' N, D_in, H, D_out = 32, 128, 500, 128 - model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') - model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') - model1 = ORTModule(model1) - model2 = ORTModule(model2) + pt_model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + pt_model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + ort_model1 = ORTModule(copy.deepcopy(pt_model1)) + ort_model2 = ORTModule(copy.deepcopy(pt_model2)) - all_params = list(model1.parameters()) + list(model2.parameters()) + def run_step(layers1, layers2, x): + prediction = layers2(layers1(x)) + loss = prediction.sum() + loss.backward() for step in range(10): - x = torch.randn(N, D_in, device='cuda', requires_grad=True) - output1 = model1(x) - output2 = model2(output1) - s = output2.sum() - s.backward() + x = torch.randn(N, D_in, device=device, requires_grad=True) + run_step(pt_model1, pt_model2, x) + run_step(ort_model1, ort_model2, x) - assert x.grad is not None - for param in all_params: - assert param.grad is not None - param.grad = None + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model1, pt_model1) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2) def test_mixed_nnmodule_ortmodules_training(): + device = 'cuda' N, D_in, H, D_out = 32, 128, 500, 128 - model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') - model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') - model3 = NeuralNetMultiplePositionalArguments(D_in, H, D_out).to('cuda') - model1 = ORTModule(model1) - # model2 is intentionally left as nn.module - model3 = ORTModule(model3) + pt_model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + pt_model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + pt_model3 = NeuralNetMultiplePositionalArguments(D_in, H, D_out).to(device) - all_params = list(model1.parameters()) + list(model2.parameters()) + list(model3.parameters()) - - for step in range(10): - x1 = torch.randn(N, D_in, device='cuda', requires_grad=True) - x2 = torch.randn(N, D_in, device='cuda', requires_grad=True) + ort_model1 = ORTModule(copy.deepcopy(pt_model1)) + ort_model2 = copy.deepcopy(pt_model2) # model2 is intentionally left as nn.module + ort_model3 = ORTModule(copy.deepcopy(pt_model3)) + def run_step(model1, model2, model3, x1, x2): a1 = model1(x1) a2 = model2(x2) a3 = model3(torch.sin(a1), torch.cos(a2)) loss = a3.sum() loss.backward() - assert x1.grad is not None and x2.grad is not None - for param in all_params: - assert param.grad is not None - param.grad = None + for step in range(10): + x1 = torch.randn(N, D_in, device=device) + x2 = torch.randn(N, D_in, device=device) + run_step(pt_model1, pt_model2, pt_model3, x1, x2) + run_step(ort_model1, ort_model2, ort_model3, x1, x2) + + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model1, pt_model1) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model3, pt_model3) @pytest.mark.parametrize("device", ['cuda', 'cpu']) def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder(device):