Refactor the constant _ONE in orttraining_test_ortmodule_api.py (#15128)

Follow up of
https://github.com/microsoft/onnxruntime/pull/15097#discussion_r1142399537
This commit is contained in:
Justin Chu 2023-03-28 08:59:51 -07:00 committed by GitHub
parent 41ddcd30a1
commit 710d095124
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -536,28 +536,25 @@ def test_forward_call_positional_and_keyword_arguments():
prediction.backward()
_ONE = torch.FloatTensor([1])
@pytest.mark.parametrize(
"forward_function",
[
lambda model: model(_ONE),
lambda model: model(x=_ONE),
lambda model: model(_ONE, None, None),
lambda model: model(_ONE, None, z=None),
lambda model: model(_ONE, None),
lambda model: model(x=_ONE, y=_ONE),
lambda model: model(y=_ONE, x=_ONE),
lambda model: model(y=_ONE, z=None, x=_ONE),
lambda model: model(_ONE, None, z=_ONE),
lambda model: model(x=_ONE, z=_ONE),
lambda model: model(_ONE, z=_ONE),
lambda model: model(_ONE, z=_ONE, y=_ONE),
lambda model: model(_ONE, _ONE, _ONE),
lambda model: model(_ONE, None, _ONE),
lambda model: model(z=_ONE, x=_ONE, y=_ONE),
lambda model: model(z=_ONE, x=_ONE, y=None),
lambda model: model(torch.tensor([1.0])),
lambda model: model(x=torch.tensor([1.0])),
lambda model: model(torch.tensor([1.0]), None, None),
lambda model: model(torch.tensor([1.0]), None, z=None),
lambda model: model(torch.tensor([1.0]), None),
lambda model: model(x=torch.tensor([1.0]), y=torch.tensor([1.0])),
lambda model: model(y=torch.tensor([1.0]), x=torch.tensor([1.0])),
lambda model: model(y=torch.tensor([1.0]), z=None, x=torch.tensor([1.0])),
lambda model: model(torch.tensor([1.0]), None, z=torch.tensor([1.0])),
lambda model: model(x=torch.tensor([1.0]), z=torch.tensor([1.0])),
lambda model: model(torch.tensor([1.0]), z=torch.tensor([1.0])),
lambda model: model(torch.tensor([1.0]), z=torch.tensor([1.0]), y=torch.tensor([1.0])),
lambda model: model(torch.tensor([1.0]), torch.tensor([1.0]), torch.tensor([1.0])),
lambda model: model(torch.tensor([1.0]), None, torch.tensor([1.0])),
lambda model: model(z=torch.tensor([1.0]), x=torch.tensor([1.0]), y=torch.tensor([1.0])),
lambda model: model(z=torch.tensor([1.0]), x=torch.tensor([1.0]), y=None),
],
)
def test_compare_pytorch_forward_call_positional_and_keyword_arguments(forward_function):