fix t5 assert error (#8501)

Co-authored-by: Ethan Tao <ettao@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
ytaous 2021-07-27 09:04:01 -07:00 committed by GitHub
parent b4baac888c
commit 1ae32655b3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 0 deletions

View file

@ -137,6 +137,11 @@ class TrainingManager(GraphExecutionManager):
for idx in self._graph_info.module_output_indices_requires_save_for_backward:
ctx.save_for_backward(user_outputs[idx])
# Mark the outputs tensors non-differentiable if requires_grad is False in _graph_info
# This will return torch the output tensors with correct requires_grad settings
for idx in self._graph_info.output_grad_indices_non_differentiable:
ctx.mark_non_differentiable(user_outputs[idx])
return user_outputs
@staticmethod

View file

@ -184,6 +184,22 @@ class NeuralNetNonDifferentiableOutput(torch.nn.Module):
return out1, mask1, out2, mask2 # intentionally place the non-differentiable output in the middle
class NeuralNetChainedLayersWithNonDifferentiableOutput(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNetChainedLayersWithNonDifferentiableOutput, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
def forward(self, input1, mask1):
out = self.fc1(input1)
out1 = self.relu(out)
out2 = self.fc2(out1)
# this will trigger torch to set requires_grad = True for mask tensor
mask = mask1
return out2, mask
class NeuralNetPartialNoGradModel(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNetPartialNoGradModel, self).__init__()
@ -750,6 +766,34 @@ def test_module_with_non_differential_output():
_test_helpers.assert_values_are_close(ort_mask2, pt_mask2)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
def test_multiple_chained_ortmodules_with_non_differential_output():
device = 'cuda'
N, D_in, H, D_out = 32, 128, 64, 10
pt_model = NeuralNetChainedLayersWithNonDifferentiableOutput(D_in, H, D_out).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
pt_model2 = NeuralNetChainedLayersWithNonDifferentiableOutput(D_in, H, D_out).to(device)
ort_model2 = ORTModule(copy.deepcopy(pt_model2))
def run_step(layer1, layer2, x, mask1):
prediction, mask = layer1(x, mask1)
prediction, mask = layer2(x, mask)
loss = prediction.sum()
loss.backward()
return prediction, mask
x = torch.randn(N, D_in, device=device)
mask1 = torch.zeros(1, device=device)
pt_prediction, pt_mask = run_step(pt_model, pt_model2, x, mask1)
# ensure no AssertionError message for chained ortmodules, e.g.:
# ORT found the 1-th module output 'output-1' is non-differentiable according to the onnx graph.
# However, the gradient value is still provided by PyTorch's autograd engine.
ort_prediction, ort_mask = run_step(ort_model, ort_model2, x, mask1)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2)
@pytest.mark.parametrize("loss_with_duplicated_output", [False, True])
def test_duplicated_output(loss_with_duplicated_output):
class NeuralNet(torch.nn.Module):