mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
Refactor the constant _ONE in orttraining_test_ortmodule_api.py (#15128)
Follow up of https://github.com/microsoft/onnxruntime/pull/15097#discussion_r1142399537
This commit is contained in:
parent
41ddcd30a1
commit
710d095124
1 changed files with 16 additions and 19 deletions
|
|
@ -536,28 +536,25 @@ def test_forward_call_positional_and_keyword_arguments():
|
|||
prediction.backward()
|
||||
|
||||
|
||||
_ONE = torch.FloatTensor([1])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"forward_function",
|
||||
[
|
||||
lambda model: model(_ONE),
|
||||
lambda model: model(x=_ONE),
|
||||
lambda model: model(_ONE, None, None),
|
||||
lambda model: model(_ONE, None, z=None),
|
||||
lambda model: model(_ONE, None),
|
||||
lambda model: model(x=_ONE, y=_ONE),
|
||||
lambda model: model(y=_ONE, x=_ONE),
|
||||
lambda model: model(y=_ONE, z=None, x=_ONE),
|
||||
lambda model: model(_ONE, None, z=_ONE),
|
||||
lambda model: model(x=_ONE, z=_ONE),
|
||||
lambda model: model(_ONE, z=_ONE),
|
||||
lambda model: model(_ONE, z=_ONE, y=_ONE),
|
||||
lambda model: model(_ONE, _ONE, _ONE),
|
||||
lambda model: model(_ONE, None, _ONE),
|
||||
lambda model: model(z=_ONE, x=_ONE, y=_ONE),
|
||||
lambda model: model(z=_ONE, x=_ONE, y=None),
|
||||
lambda model: model(torch.tensor([1.0])),
|
||||
lambda model: model(x=torch.tensor([1.0])),
|
||||
lambda model: model(torch.tensor([1.0]), None, None),
|
||||
lambda model: model(torch.tensor([1.0]), None, z=None),
|
||||
lambda model: model(torch.tensor([1.0]), None),
|
||||
lambda model: model(x=torch.tensor([1.0]), y=torch.tensor([1.0])),
|
||||
lambda model: model(y=torch.tensor([1.0]), x=torch.tensor([1.0])),
|
||||
lambda model: model(y=torch.tensor([1.0]), z=None, x=torch.tensor([1.0])),
|
||||
lambda model: model(torch.tensor([1.0]), None, z=torch.tensor([1.0])),
|
||||
lambda model: model(x=torch.tensor([1.0]), z=torch.tensor([1.0])),
|
||||
lambda model: model(torch.tensor([1.0]), z=torch.tensor([1.0])),
|
||||
lambda model: model(torch.tensor([1.0]), z=torch.tensor([1.0]), y=torch.tensor([1.0])),
|
||||
lambda model: model(torch.tensor([1.0]), torch.tensor([1.0]), torch.tensor([1.0])),
|
||||
lambda model: model(torch.tensor([1.0]), None, torch.tensor([1.0])),
|
||||
lambda model: model(z=torch.tensor([1.0]), x=torch.tensor([1.0]), y=torch.tensor([1.0])),
|
||||
lambda model: model(z=torch.tensor([1.0]), x=torch.tensor([1.0]), y=None),
|
||||
],
|
||||
)
|
||||
def test_compare_pytorch_forward_call_positional_and_keyword_arguments(forward_function):
|
||||
|
|
|
|||
Loading…
Reference in a new issue