mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
fix t5 assert error (#8501)
Co-authored-by: Ethan Tao <ettao@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
b4baac888c
commit
1ae32655b3
2 changed files with 49 additions and 0 deletions
|
|
@ -137,6 +137,11 @@ class TrainingManager(GraphExecutionManager):
|
|||
for idx in self._graph_info.module_output_indices_requires_save_for_backward:
|
||||
ctx.save_for_backward(user_outputs[idx])
|
||||
|
||||
# Mark the outputs tensors non-differentiable if requires_grad is False in _graph_info
|
||||
# This will return torch the output tensors with correct requires_grad settings
|
||||
for idx in self._graph_info.output_grad_indices_non_differentiable:
|
||||
ctx.mark_non_differentiable(user_outputs[idx])
|
||||
|
||||
return user_outputs
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -184,6 +184,22 @@ class NeuralNetNonDifferentiableOutput(torch.nn.Module):
|
|||
|
||||
return out1, mask1, out2, mask2 # intentionally place the non-differentiable output in the middle
|
||||
|
||||
class NeuralNetChainedLayersWithNonDifferentiableOutput(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(NeuralNetChainedLayersWithNonDifferentiableOutput, self).__init__()
|
||||
self.fc1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, input1, mask1):
|
||||
out = self.fc1(input1)
|
||||
out1 = self.relu(out)
|
||||
out2 = self.fc2(out1)
|
||||
# this will trigger torch to set requires_grad = True for mask tensor
|
||||
mask = mask1
|
||||
|
||||
return out2, mask
|
||||
|
||||
class NeuralNetPartialNoGradModel(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(NeuralNetPartialNoGradModel, self).__init__()
|
||||
|
|
@ -750,6 +766,34 @@ def test_module_with_non_differential_output():
|
|||
_test_helpers.assert_values_are_close(ort_mask2, pt_mask2)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
|
||||
def test_multiple_chained_ortmodules_with_non_differential_output():
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 32, 128, 64, 10
|
||||
pt_model = NeuralNetChainedLayersWithNonDifferentiableOutput(D_in, H, D_out).to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
pt_model2 = NeuralNetChainedLayersWithNonDifferentiableOutput(D_in, H, D_out).to(device)
|
||||
ort_model2 = ORTModule(copy.deepcopy(pt_model2))
|
||||
|
||||
def run_step(layer1, layer2, x, mask1):
|
||||
prediction, mask = layer1(x, mask1)
|
||||
prediction, mask = layer2(x, mask)
|
||||
loss = prediction.sum()
|
||||
loss.backward()
|
||||
return prediction, mask
|
||||
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
mask1 = torch.zeros(1, device=device)
|
||||
|
||||
pt_prediction, pt_mask = run_step(pt_model, pt_model2, x, mask1)
|
||||
# ensure no AssertionError message for chained ortmodules, e.g.:
|
||||
# ORT found the 1-th module output 'output-1' is non-differentiable according to the onnx graph.
|
||||
# However, the gradient value is still provided by PyTorch's autograd engine.
|
||||
ort_prediction, ort_mask = run_step(ort_model, ort_model2, x, mask1)
|
||||
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2)
|
||||
|
||||
@pytest.mark.parametrize("loss_with_duplicated_output", [False, True])
|
||||
def test_duplicated_output(loss_with_duplicated_output):
|
||||
class NeuralNet(torch.nn.Module):
|
||||
|
|
|
|||
Loading…
Reference in a new issue