Allow None As Autograd Context (#9315)

* Allow none ctx

* Update orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py

Co-authored-by: pengwa <pengwa@microsoft.com>

* Address a comment

Co-authored-by: pengwa <pengwa@microsoft.com>
This commit is contained in:
Wei-Sheng Chin 2021-10-21 20:37:36 -07:00 committed by GitHub
parent b64b2d48f3
commit d2d480a0db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 78 additions and 15 deletions

View file

@ -93,10 +93,25 @@ def call_python_forward_function(
ctx = arg.grad_fn
first_tensor_output = arg
break
if training_mode_flag:
# Must extract one valid context from result tensors.
assert ctx is not None
# Context can be None because not all autograd.Function's are differentiable. The function
# https://github.com/pytorch/pytorch/blob/d701357d921ef167d42c125e65b6f7da6be3ad0f/torch/csrc/autograd/custom_function.cpp#L209?
# means if all output of forward function are not differentiable, then grad_fn will be None (not be set).
# For example,
# class Bar(torch.autograd.Function):
# # A non-differentiable autograd Function whose forard output
# # doesn't have grad_fn attribute.
# @staticmethod
# def forward(ctx, x):
# y = torch.ones_like(x)
# return y
# @staticmethod
# def backward(ctx, dy):
# dx = torch.zeros_like(dy)
# return dx
if training_mode_flag and ctx:
# FORWARD BACKWARD FUNCTION CONNECTIONS
# input_1 (leaf, constructed by from_dlpack) <----reference---- AccumulateGrad gradient function
# ↓ ↑
@ -115,9 +130,6 @@ def call_python_forward_function(
saved_tensors = [t for t in ctx.saved_tensors if t is not None]
torch_interop_utils.clear_grad_fns_for_next_edges(first_tensor_output, saved_tensors)
torch_interop_utils.register_grad_fn(id(ctx), first_tensor_output)
else:
# Context must not present under non-training mode.
assert ctx is None
return ctx
if isinstance(result, torch.Tensor):

View file

@ -15,9 +15,11 @@ from onnxruntime.training.ortmodule import ORTModule
torch.manual_seed(1)
onnxruntime.set_seed(1)
def torch_version_lower_than(v):
return LooseVersion(torch.__version__) < LooseVersion(v)
def test_GeLU():
@torch.jit.script
def bias_gelu(bias, y):
@ -424,6 +426,7 @@ def test_InplaceUpdateInputAsOutputNotRequireGradWithMarkDirty():
run_training_test_and_compare(model_builder, input_generator, label_input)
@pytest.mark.skip(reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...).")
def test_InplaceUpdateInputAsOutputRequireGrad():
class InplaceUpdateInputAsOutputRequireGradFunction(torch.autograd.Function):
@ -477,6 +480,7 @@ def test_InplaceUpdateInputAsOutputRequireGrad():
run_training_test_and_compare(
model_builder, input_generator, label_input, ignore_grad_compare=True)
@pytest.mark.skip(reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...).")
def test_InplaceUpdateInputNotAsOutputRequireGrad():
class InplaceUpdateInputNotAsOutputRequireGradFunction(torch.autograd.Function):
@ -629,7 +633,7 @@ def test_EvalTest():
@pytest.mark.skipif(torch_version_lower_than("1.10.0"),
reason='PyTorch older than 1.10.0 has bugs for exporting multiple output custom function')
reason='PyTorch older than 1.10.0 has bugs for exporting multiple output custom function')
def test_TwoOutputFunction():
class TwoOutputFunction(torch.autograd.Function):
@staticmethod
@ -766,7 +770,7 @@ def test_InnerModuleCall():
@pytest.mark.skipif(torch_version_lower_than("1.10.0"),
reason='PyTorch older than 1.10.0 has bugs for exporting multiple output custom function')
reason='PyTorch older than 1.10.0 has bugs for exporting multiple output custom function')
def test_Share_Input():
class TwoOutputFunction(torch.autograd.Function):
@staticmethod
@ -818,7 +822,8 @@ def test_Share_Input():
# Test multi-input and multi-output custom function.
run_training_test_and_compare(model_builder, input_generator, label_input)
run_training_test_and_compare(model_builder, input_generator_with_requires_grad, label_input)
run_training_test_and_compare(
model_builder, input_generator_with_requires_grad, label_input)
def test_MultipleStream_InForwardFunction():
@ -833,7 +838,7 @@ def test_MultipleStream_InForwardFunction():
# on different stream
with torch.cuda.stream(stream):
stream.wait_stream(default_stream)
input= input * 2
input = input * 2
default_stream.wait_stream(stream)
return input
@ -860,7 +865,6 @@ def test_MultipleStream_InForwardFunction():
def input_generator():
return torch.tensor([2.8, 3.4], requires_grad=True)
# generate a label that have same shape as forward output.
label_input = torch.ones([output_size])
@ -868,6 +872,7 @@ def test_MultipleStream_InForwardFunction():
run_training_test_and_compare(model_builder, input_generator, label_input,
expected_outputs=[torch.tensor([0.224, 0.272])])
def test_NonDefaultStream_InForwardFunction1():
class MultipleStreamFunction(torch.autograd.Function):
@staticmethod
@ -907,13 +912,12 @@ def test_NonDefaultStream_InForwardFunction1():
def input_generator():
return torch.tensor([2.8, 3.4], requires_grad=True)
# generate a label that have same shape as forward output.
label_input = torch.ones([output_size])
# Test multi-input and multi-output custom function.
run_training_test_and_compare(model_builder, input_generator, label_input,
expected_outputs=[torch.tensor([0.224, 0.272])])
expected_outputs=[torch.tensor([0.224, 0.272])])
def test_NonDefaultStream_InForwardFunction2():
@ -954,7 +958,6 @@ def test_NonDefaultStream_InForwardFunction2():
def input_generator():
return torch.tensor([2.8, 3.4], requires_grad=True)
# generate a label that have same shape as forward output.
label_input = torch.ones([output_size])
@ -1003,10 +1006,58 @@ def test_NonDefaultStreamInplaceUpdate_InForwardFunction():
def input_generator():
return torch.tensor([2.8, 3.4], requires_grad=True)
# generate a label that have same shape as forward output.
label_input = torch.ones([output_size])
# Test multi-input and multi-output custom function.
run_training_test_and_compare(model_builder, input_generator, label_input,
expected_outputs=[torch.tensor([0.224, 0.272])])
def test_non_differentiable_autograd_function():
class Bar(torch.autograd.Function):
# A non-differentiable autograd Function whose forard output
# doesn't have grad_fn attribute.
@staticmethod
def forward(ctx, x):
y = torch.ones_like(x)
return y
@staticmethod
def backward(ctx, dy):
raise NotImplementedError()
class Foo(torch.nn.Module):
# Module calling non-differentiable function.
def __init__(self):
super(Foo, self).__init__()
self._linear = torch.nn.Linear(2, 3)
def forward(self, x):
y = Bar.apply(x)
z = self._linear(y)
return z
def run():
m = Foo().to('cuda')
x = torch.rand((2, 2), dtype=torch.float).to('cuda')
# Baseline.
y_ref = m(x)
print('Ref:')
print(y_ref)
m = ORTModule(m)
# Inferene mode.
y_infer = m(x)
print(y_infer)
assert torch.allclose(y_ref, y_infer)
# Training mode.
m.train()
y_train = m(x)
print('Train:')
assert torch.allclose(y_ref, y_train)
run()