mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-30 23:18:20 +00:00
Allow None As Autograd Context (#9315)
* Allow none ctx * Update orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py Co-authored-by: pengwa <pengwa@microsoft.com> * Address a comment Co-authored-by: pengwa <pengwa@microsoft.com>
This commit is contained in:
parent
b64b2d48f3
commit
d2d480a0db
2 changed files with 78 additions and 15 deletions
|
|
@ -93,10 +93,25 @@ def call_python_forward_function(
|
|||
ctx = arg.grad_fn
|
||||
first_tensor_output = arg
|
||||
break
|
||||
if training_mode_flag:
|
||||
# Must extract one valid context from result tensors.
|
||||
assert ctx is not None
|
||||
|
||||
# Context can be None because not all autograd.Function's are differentiable. The function
|
||||
# https://github.com/pytorch/pytorch/blob/d701357d921ef167d42c125e65b6f7da6be3ad0f/torch/csrc/autograd/custom_function.cpp#L209?
|
||||
# means if all output of forward function are not differentiable, then grad_fn will be None (not be set).
|
||||
# For example,
|
||||
# class Bar(torch.autograd.Function):
|
||||
# # A non-differentiable autograd Function whose forard output
|
||||
# # doesn't have grad_fn attribute.
|
||||
# @staticmethod
|
||||
# def forward(ctx, x):
|
||||
# y = torch.ones_like(x)
|
||||
# return y
|
||||
|
||||
# @staticmethod
|
||||
# def backward(ctx, dy):
|
||||
# dx = torch.zeros_like(dy)
|
||||
# return dx
|
||||
|
||||
if training_mode_flag and ctx:
|
||||
# FORWARD BACKWARD FUNCTION CONNECTIONS
|
||||
# input_1 (leaf, constructed by from_dlpack) <----reference---- AccumulateGrad gradient function
|
||||
# ↓ ↑
|
||||
|
|
@ -115,9 +130,6 @@ def call_python_forward_function(
|
|||
saved_tensors = [t for t in ctx.saved_tensors if t is not None]
|
||||
torch_interop_utils.clear_grad_fns_for_next_edges(first_tensor_output, saved_tensors)
|
||||
torch_interop_utils.register_grad_fn(id(ctx), first_tensor_output)
|
||||
else:
|
||||
# Context must not present under non-training mode.
|
||||
assert ctx is None
|
||||
return ctx
|
||||
|
||||
if isinstance(result, torch.Tensor):
|
||||
|
|
|
|||
|
|
@ -15,9 +15,11 @@ from onnxruntime.training.ortmodule import ORTModule
|
|||
torch.manual_seed(1)
|
||||
onnxruntime.set_seed(1)
|
||||
|
||||
|
||||
def torch_version_lower_than(v):
|
||||
return LooseVersion(torch.__version__) < LooseVersion(v)
|
||||
|
||||
|
||||
def test_GeLU():
|
||||
@torch.jit.script
|
||||
def bias_gelu(bias, y):
|
||||
|
|
@ -424,6 +426,7 @@ def test_InplaceUpdateInputAsOutputNotRequireGradWithMarkDirty():
|
|||
|
||||
run_training_test_and_compare(model_builder, input_generator, label_input)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...).")
|
||||
def test_InplaceUpdateInputAsOutputRequireGrad():
|
||||
class InplaceUpdateInputAsOutputRequireGradFunction(torch.autograd.Function):
|
||||
|
|
@ -477,6 +480,7 @@ def test_InplaceUpdateInputAsOutputRequireGrad():
|
|||
run_training_test_and_compare(
|
||||
model_builder, input_generator, label_input, ignore_grad_compare=True)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...).")
|
||||
def test_InplaceUpdateInputNotAsOutputRequireGrad():
|
||||
class InplaceUpdateInputNotAsOutputRequireGradFunction(torch.autograd.Function):
|
||||
|
|
@ -629,7 +633,7 @@ def test_EvalTest():
|
|||
|
||||
|
||||
@pytest.mark.skipif(torch_version_lower_than("1.10.0"),
|
||||
reason='PyTorch older than 1.10.0 has bugs for exporting multiple output custom function')
|
||||
reason='PyTorch older than 1.10.0 has bugs for exporting multiple output custom function')
|
||||
def test_TwoOutputFunction():
|
||||
class TwoOutputFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
|
@ -766,7 +770,7 @@ def test_InnerModuleCall():
|
|||
|
||||
|
||||
@pytest.mark.skipif(torch_version_lower_than("1.10.0"),
|
||||
reason='PyTorch older than 1.10.0 has bugs for exporting multiple output custom function')
|
||||
reason='PyTorch older than 1.10.0 has bugs for exporting multiple output custom function')
|
||||
def test_Share_Input():
|
||||
class TwoOutputFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
|
@ -818,7 +822,8 @@ def test_Share_Input():
|
|||
# Test multi-input and multi-output custom function.
|
||||
run_training_test_and_compare(model_builder, input_generator, label_input)
|
||||
|
||||
run_training_test_and_compare(model_builder, input_generator_with_requires_grad, label_input)
|
||||
run_training_test_and_compare(
|
||||
model_builder, input_generator_with_requires_grad, label_input)
|
||||
|
||||
|
||||
def test_MultipleStream_InForwardFunction():
|
||||
|
|
@ -833,7 +838,7 @@ def test_MultipleStream_InForwardFunction():
|
|||
# on different stream
|
||||
with torch.cuda.stream(stream):
|
||||
stream.wait_stream(default_stream)
|
||||
input= input * 2
|
||||
input = input * 2
|
||||
default_stream.wait_stream(stream)
|
||||
return input
|
||||
|
||||
|
|
@ -860,7 +865,6 @@ def test_MultipleStream_InForwardFunction():
|
|||
def input_generator():
|
||||
return torch.tensor([2.8, 3.4], requires_grad=True)
|
||||
|
||||
|
||||
# generate a label that have same shape as forward output.
|
||||
label_input = torch.ones([output_size])
|
||||
|
||||
|
|
@ -868,6 +872,7 @@ def test_MultipleStream_InForwardFunction():
|
|||
run_training_test_and_compare(model_builder, input_generator, label_input,
|
||||
expected_outputs=[torch.tensor([0.224, 0.272])])
|
||||
|
||||
|
||||
def test_NonDefaultStream_InForwardFunction1():
|
||||
class MultipleStreamFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
|
@ -907,13 +912,12 @@ def test_NonDefaultStream_InForwardFunction1():
|
|||
def input_generator():
|
||||
return torch.tensor([2.8, 3.4], requires_grad=True)
|
||||
|
||||
|
||||
# generate a label that have same shape as forward output.
|
||||
label_input = torch.ones([output_size])
|
||||
|
||||
# Test multi-input and multi-output custom function.
|
||||
run_training_test_and_compare(model_builder, input_generator, label_input,
|
||||
expected_outputs=[torch.tensor([0.224, 0.272])])
|
||||
expected_outputs=[torch.tensor([0.224, 0.272])])
|
||||
|
||||
|
||||
def test_NonDefaultStream_InForwardFunction2():
|
||||
|
|
@ -954,7 +958,6 @@ def test_NonDefaultStream_InForwardFunction2():
|
|||
def input_generator():
|
||||
return torch.tensor([2.8, 3.4], requires_grad=True)
|
||||
|
||||
|
||||
# generate a label that have same shape as forward output.
|
||||
label_input = torch.ones([output_size])
|
||||
|
||||
|
|
@ -1003,10 +1006,58 @@ def test_NonDefaultStreamInplaceUpdate_InForwardFunction():
|
|||
def input_generator():
|
||||
return torch.tensor([2.8, 3.4], requires_grad=True)
|
||||
|
||||
|
||||
# generate a label that have same shape as forward output.
|
||||
label_input = torch.ones([output_size])
|
||||
|
||||
# Test multi-input and multi-output custom function.
|
||||
run_training_test_and_compare(model_builder, input_generator, label_input,
|
||||
expected_outputs=[torch.tensor([0.224, 0.272])])
|
||||
|
||||
|
||||
def test_non_differentiable_autograd_function():
|
||||
class Bar(torch.autograd.Function):
|
||||
# A non-differentiable autograd Function whose forard output
|
||||
# doesn't have grad_fn attribute.
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
y = torch.ones_like(x)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
raise NotImplementedError()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
# Module calling non-differentiable function.
|
||||
def __init__(self):
|
||||
super(Foo, self).__init__()
|
||||
self._linear = torch.nn.Linear(2, 3)
|
||||
|
||||
def forward(self, x):
|
||||
y = Bar.apply(x)
|
||||
z = self._linear(y)
|
||||
return z
|
||||
|
||||
def run():
|
||||
m = Foo().to('cuda')
|
||||
x = torch.rand((2, 2), dtype=torch.float).to('cuda')
|
||||
|
||||
# Baseline.
|
||||
y_ref = m(x)
|
||||
print('Ref:')
|
||||
print(y_ref)
|
||||
|
||||
m = ORTModule(m)
|
||||
|
||||
# Inferene mode.
|
||||
y_infer = m(x)
|
||||
print(y_infer)
|
||||
assert torch.allclose(y_ref, y_infer)
|
||||
|
||||
# Training mode.
|
||||
m.train()
|
||||
y_train = m(x)
|
||||
print('Train:')
|
||||
assert torch.allclose(y_ref, y_train)
|
||||
|
||||
run()
|
||||
|
|
|
|||
Loading…
Reference in a new issue