diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index ae32b0304c..940f29edbb 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -93,10 +93,25 @@ def call_python_forward_function( ctx = arg.grad_fn first_tensor_output = arg break - if training_mode_flag: - # Must extract one valid context from result tensors. - assert ctx is not None + # Context can be None because not all autograd.Function's are differentiable. The function + # https://github.com/pytorch/pytorch/blob/d701357d921ef167d42c125e65b6f7da6be3ad0f/torch/csrc/autograd/custom_function.cpp#L209? + # means if all output of forward function are not differentiable, then grad_fn will be None (not be set). + # For example, + # class Bar(torch.autograd.Function): + # # A non-differentiable autograd Function whose forard output + # # doesn't have grad_fn attribute. + # @staticmethod + # def forward(ctx, x): + # y = torch.ones_like(x) + # return y + + # @staticmethod + # def backward(ctx, dy): + # dx = torch.zeros_like(dy) + # return dx + + if training_mode_flag and ctx: # FORWARD BACKWARD FUNCTION CONNECTIONS # input_1 (leaf, constructed by from_dlpack) <----reference---- AccumulateGrad gradient function # ↓ ↑ @@ -115,9 +130,6 @@ def call_python_forward_function( saved_tensors = [t for t in ctx.saved_tensors if t is not None] torch_interop_utils.clear_grad_fns_for_next_edges(first_tensor_output, saved_tensors) torch_interop_utils.register_grad_fn(id(ctx), first_tensor_output) - else: - # Context must not present under non-training mode. - assert ctx is None return ctx if isinstance(result, torch.Tensor): diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index a71fd330f7..e868fea506 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -15,9 +15,11 @@ from onnxruntime.training.ortmodule import ORTModule torch.manual_seed(1) onnxruntime.set_seed(1) + def torch_version_lower_than(v): return LooseVersion(torch.__version__) < LooseVersion(v) + def test_GeLU(): @torch.jit.script def bias_gelu(bias, y): @@ -424,6 +426,7 @@ def test_InplaceUpdateInputAsOutputNotRequireGradWithMarkDirty(): run_training_test_and_compare(model_builder, input_generator, label_input) + @pytest.mark.skip(reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...).") def test_InplaceUpdateInputAsOutputRequireGrad(): class InplaceUpdateInputAsOutputRequireGradFunction(torch.autograd.Function): @@ -477,6 +480,7 @@ def test_InplaceUpdateInputAsOutputRequireGrad(): run_training_test_and_compare( model_builder, input_generator, label_input, ignore_grad_compare=True) + @pytest.mark.skip(reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...).") def test_InplaceUpdateInputNotAsOutputRequireGrad(): class InplaceUpdateInputNotAsOutputRequireGradFunction(torch.autograd.Function): @@ -629,7 +633,7 @@ def test_EvalTest(): @pytest.mark.skipif(torch_version_lower_than("1.10.0"), - reason='PyTorch older than 1.10.0 has bugs for exporting multiple output custom function') + reason='PyTorch older than 1.10.0 has bugs for exporting multiple output custom function') def test_TwoOutputFunction(): class TwoOutputFunction(torch.autograd.Function): @staticmethod @@ -766,7 +770,7 @@ def test_InnerModuleCall(): @pytest.mark.skipif(torch_version_lower_than("1.10.0"), - reason='PyTorch older than 1.10.0 has bugs for exporting multiple output custom function') + reason='PyTorch older than 1.10.0 has bugs for exporting multiple output custom function') def test_Share_Input(): class TwoOutputFunction(torch.autograd.Function): @staticmethod @@ -818,7 +822,8 @@ def test_Share_Input(): # Test multi-input and multi-output custom function. run_training_test_and_compare(model_builder, input_generator, label_input) - run_training_test_and_compare(model_builder, input_generator_with_requires_grad, label_input) + run_training_test_and_compare( + model_builder, input_generator_with_requires_grad, label_input) def test_MultipleStream_InForwardFunction(): @@ -833,7 +838,7 @@ def test_MultipleStream_InForwardFunction(): # on different stream with torch.cuda.stream(stream): stream.wait_stream(default_stream) - input= input * 2 + input = input * 2 default_stream.wait_stream(stream) return input @@ -860,7 +865,6 @@ def test_MultipleStream_InForwardFunction(): def input_generator(): return torch.tensor([2.8, 3.4], requires_grad=True) - # generate a label that have same shape as forward output. label_input = torch.ones([output_size]) @@ -868,6 +872,7 @@ def test_MultipleStream_InForwardFunction(): run_training_test_and_compare(model_builder, input_generator, label_input, expected_outputs=[torch.tensor([0.224, 0.272])]) + def test_NonDefaultStream_InForwardFunction1(): class MultipleStreamFunction(torch.autograd.Function): @staticmethod @@ -907,13 +912,12 @@ def test_NonDefaultStream_InForwardFunction1(): def input_generator(): return torch.tensor([2.8, 3.4], requires_grad=True) - # generate a label that have same shape as forward output. label_input = torch.ones([output_size]) # Test multi-input and multi-output custom function. run_training_test_and_compare(model_builder, input_generator, label_input, - expected_outputs=[torch.tensor([0.224, 0.272])]) + expected_outputs=[torch.tensor([0.224, 0.272])]) def test_NonDefaultStream_InForwardFunction2(): @@ -954,7 +958,6 @@ def test_NonDefaultStream_InForwardFunction2(): def input_generator(): return torch.tensor([2.8, 3.4], requires_grad=True) - # generate a label that have same shape as forward output. label_input = torch.ones([output_size]) @@ -1003,10 +1006,58 @@ def test_NonDefaultStreamInplaceUpdate_InForwardFunction(): def input_generator(): return torch.tensor([2.8, 3.4], requires_grad=True) - # generate a label that have same shape as forward output. label_input = torch.ones([output_size]) # Test multi-input and multi-output custom function. run_training_test_and_compare(model_builder, input_generator, label_input, expected_outputs=[torch.tensor([0.224, 0.272])]) + + +def test_non_differentiable_autograd_function(): + class Bar(torch.autograd.Function): + # A non-differentiable autograd Function whose forard output + # doesn't have grad_fn attribute. + @staticmethod + def forward(ctx, x): + y = torch.ones_like(x) + return y + + @staticmethod + def backward(ctx, dy): + raise NotImplementedError() + + class Foo(torch.nn.Module): + # Module calling non-differentiable function. + def __init__(self): + super(Foo, self).__init__() + self._linear = torch.nn.Linear(2, 3) + + def forward(self, x): + y = Bar.apply(x) + z = self._linear(y) + return z + + def run(): + m = Foo().to('cuda') + x = torch.rand((2, 2), dtype=torch.float).to('cuda') + + # Baseline. + y_ref = m(x) + print('Ref:') + print(y_ref) + + m = ORTModule(m) + + # Inferene mode. + y_infer = m(x) + print(y_infer) + assert torch.allclose(y_ref, y_infer) + + # Training mode. + m.train() + y_train = m(x) + print('Train:') + assert torch.allclose(y_ref, y_train) + + run()