diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index acf4a253af..8395e87ed4 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -137,6 +137,11 @@ class TrainingManager(GraphExecutionManager): for idx in self._graph_info.module_output_indices_requires_save_for_backward: ctx.save_for_backward(user_outputs[idx]) + # Mark the outputs tensors non-differentiable if requires_grad is False in _graph_info + # This will return torch the output tensors with correct requires_grad settings + for idx in self._graph_info.output_grad_indices_non_differentiable: + ctx.mark_non_differentiable(user_outputs[idx]) + return user_outputs @staticmethod diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 020ff632be..b3b4da772c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -184,6 +184,22 @@ class NeuralNetNonDifferentiableOutput(torch.nn.Module): return out1, mask1, out2, mask2 # intentionally place the non-differentiable output in the middle +class NeuralNetChainedLayersWithNonDifferentiableOutput(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(NeuralNetChainedLayersWithNonDifferentiableOutput, self).__init__() + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, num_classes) + + def forward(self, input1, mask1): + out = self.fc1(input1) + out1 = self.relu(out) + out2 = self.fc2(out1) + # this will trigger torch to set requires_grad = True for mask tensor + mask = mask1 + + return out2, mask + class NeuralNetPartialNoGradModel(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNetPartialNoGradModel, self).__init__() @@ -750,6 +766,34 @@ def test_module_with_non_differential_output(): _test_helpers.assert_values_are_close(ort_mask2, pt_mask2) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) +def test_multiple_chained_ortmodules_with_non_differential_output(): + device = 'cuda' + N, D_in, H, D_out = 32, 128, 64, 10 + pt_model = NeuralNetChainedLayersWithNonDifferentiableOutput(D_in, H, D_out).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + pt_model2 = NeuralNetChainedLayersWithNonDifferentiableOutput(D_in, H, D_out).to(device) + ort_model2 = ORTModule(copy.deepcopy(pt_model2)) + + def run_step(layer1, layer2, x, mask1): + prediction, mask = layer1(x, mask1) + prediction, mask = layer2(x, mask) + loss = prediction.sum() + loss.backward() + return prediction, mask + + x = torch.randn(N, D_in, device=device) + mask1 = torch.zeros(1, device=device) + + pt_prediction, pt_mask = run_step(pt_model, pt_model2, x, mask1) + # ensure no AssertionError message for chained ortmodules, e.g.: + # ORT found the 1-th module output 'output-1' is non-differentiable according to the onnx graph. + # However, the gradient value is still provided by PyTorch's autograd engine. + ort_prediction, ort_mask = run_step(ort_model, ort_model2, x, mask1) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2) + @pytest.mark.parametrize("loss_with_duplicated_output", [False, True]) def test_duplicated_output(loss_with_duplicated_output): class NeuralNet(torch.nn.Module):