diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 1bf9e47140..fe5f0d5e16 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -536,28 +536,25 @@ def test_forward_call_positional_and_keyword_arguments(): prediction.backward() -_ONE = torch.FloatTensor([1]) - - @pytest.mark.parametrize( "forward_function", [ - lambda model: model(_ONE), - lambda model: model(x=_ONE), - lambda model: model(_ONE, None, None), - lambda model: model(_ONE, None, z=None), - lambda model: model(_ONE, None), - lambda model: model(x=_ONE, y=_ONE), - lambda model: model(y=_ONE, x=_ONE), - lambda model: model(y=_ONE, z=None, x=_ONE), - lambda model: model(_ONE, None, z=_ONE), - lambda model: model(x=_ONE, z=_ONE), - lambda model: model(_ONE, z=_ONE), - lambda model: model(_ONE, z=_ONE, y=_ONE), - lambda model: model(_ONE, _ONE, _ONE), - lambda model: model(_ONE, None, _ONE), - lambda model: model(z=_ONE, x=_ONE, y=_ONE), - lambda model: model(z=_ONE, x=_ONE, y=None), + lambda model: model(torch.tensor([1.0])), + lambda model: model(x=torch.tensor([1.0])), + lambda model: model(torch.tensor([1.0]), None, None), + lambda model: model(torch.tensor([1.0]), None, z=None), + lambda model: model(torch.tensor([1.0]), None), + lambda model: model(x=torch.tensor([1.0]), y=torch.tensor([1.0])), + lambda model: model(y=torch.tensor([1.0]), x=torch.tensor([1.0])), + lambda model: model(y=torch.tensor([1.0]), z=None, x=torch.tensor([1.0])), + lambda model: model(torch.tensor([1.0]), None, z=torch.tensor([1.0])), + lambda model: model(x=torch.tensor([1.0]), z=torch.tensor([1.0])), + lambda model: model(torch.tensor([1.0]), z=torch.tensor([1.0])), + lambda model: model(torch.tensor([1.0]), z=torch.tensor([1.0]), y=torch.tensor([1.0])), + lambda model: model(torch.tensor([1.0]), torch.tensor([1.0]), torch.tensor([1.0])), + lambda model: model(torch.tensor([1.0]), None, torch.tensor([1.0])), + lambda model: model(z=torch.tensor([1.0]), x=torch.tensor([1.0]), y=torch.tensor([1.0])), + lambda model: model(z=torch.tensor([1.0]), x=torch.tensor([1.0]), y=None), ], ) def test_compare_pytorch_forward_call_positional_and_keyword_arguments(forward_function):