Make custom_fwd a no-op when not executed under autocast (#36171)

Summary:
Currently, a custom autograd function written with
```
torch.cuda.amp.custom_fwd(cast_inputs=dtype)
def forward(ctx, *args):
    ...
```
casts incoming floating-point CUDA tensors to `dtype` unconditionally, regardless of whether the function executes in an autocast-enabled region.  I think I had the wrong idea there.  Autocast-disabled regions should give the user control of input types.  Also, `custom_fwd(cast_inputs=dtype)`-decorated functions' behavior should align with native fp32list/fp16list functions.  C++-side casting wrappers have no effect when autocast is disabled, and  `custom_fwd`'s casting should behave the same way.

The present PR changes `custom_fwd` so it only casts in autocast-enabled regions (also updates custom_fwd to ignore fp64 inputs, like the C++ wrappers).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36171

Differential Revision: D22179511

Pulled By: ngimel

fbshipit-source-id: 5a93d070179a43206066bce19da0a5a19ecaabbd
This commit is contained in:
Michael Carilli 2020-06-23 10:21:21 -07:00 committed by Facebook GitHub Bot
parent f652abc1dd
commit 3b040c478a
4 changed files with 41 additions and 22 deletions

View file

@ -74,8 +74,8 @@ but ``a.addmm_(b, c)`` and ``a.addmm(b, c, out=d)`` cannot.
For best performance and stability, prefer out-of-place ops in autocast-enabled
regions.
Ops called with an explicit `dtype=...` argument are not eligible,
and will produce output that respects the `dtype` argument.
Ops called with an explicit ``dtype=...`` argument are not eligible,
and will produce output that respects the ``dtype`` argument.
Op-Specific Behavior
--------------------

View file

@ -44,7 +44,7 @@ Typical Mixed Precision Training
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same precision that autocast used for corresponding forward ops.
# Backward ops run in the same dtype autocast chose for corresponding forward ops.
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
@ -368,8 +368,8 @@ the relevant case below.
Functions with multiple inputs or autocastable ops
--------------------------------------------------
Apply :func:`custom_fwd` and :func:`custom_bwd` (with no arguments) to ``forward`` and ``backward``
respectively. These ensure ``forward`` executes with the current autocast state and ``backward``
Apply :func:`custom_fwd<custom_fwd>` and :func:`custom_bwd<custom_bwd>` (with no arguments) to ``forward`` and
``backward`` respectively. These ensure ``forward`` executes with the current autocast state and ``backward``
executes with the same autocast state as ``forward`` (which can prevent type mismatch errors)::
class MyMM(torch.autograd.Function):
@ -391,13 +391,14 @@ Now ``MyMM`` can be invoked anywhere, without disabling autocast or manually cas
with autocast():
output = mymm(input1, input2)
Functions that need a particular dtype
--------------------------------------
Functions that need a particular ``dtype``
------------------------------------------
Consider a custom function that requires ``torch.float32`` inputs.
Apply :func:`custom_fwd(cast_inputs=torch.float32)<custom_fwd>` to ``forward``
and :func:`custom_bwd<custom_bwd>` (with no arguments) to ``backward``.
With ``cast_inputs=torch.float32``, these disable autocast in ``forward`` and ``backward``,
and ``forward`` casts incoming floating-point CUDA Tensors to ``float32``::
If ``forward`` runs in an autocast-enabled region, the decorators cast floating-point CUDA Tensor
inputs to ``float32``, and locally disable autocast during ``forward`` and ``backward``::
class MyFloat32Func(torch.autograd.Function):
@staticmethod
@ -411,7 +412,7 @@ and ``forward`` casts incoming floating-point CUDA Tensors to ``float32``::
def backward(ctx, grad):
...
Now ``MyFloat32Func`` can be invoked anywhere, without disabling autocast or manually casting inputs::
Now ``MyFloat32Func`` can be invoked anywhere, without manually disabling autocast or casting inputs::
func = MyFloat32Func.apply

View file

@ -64,6 +64,7 @@ types = [
torch.HalfTensor,
]
def make_sparse_tensor(t, n, *sizes):
assert t.is_sparse
tensor = t()
@ -76,6 +77,7 @@ def make_sparse_tensor(t, n, *sizes):
_cycles_per_ms = None
def get_cycles_per_ms():
"""Approximate number of cycles per millisecond for torch.cuda._sleep"""
global _cycles_per_ms
@ -2764,14 +2766,14 @@ t2.start()
loss = output.sum()
loss.backward()
def test_autocast_custom_disabled(self):
def test_autocast_custom_cast_inputs(self):
class MyMM(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(ctx, a, container):
def forward(ctx, a, container, expect_type):
b = container[1][0]
self.assertTrue(a.dtype is torch.float32)
self.assertTrue(b.dtype is torch.float32)
self.assertTrue(a.dtype is expect_type)
self.assertTrue(b.dtype is expect_type)
self.assertFalse(torch.is_autocast_enabled())
ctx.save_for_backward(a, b)
return a.mm(b)
@ -2781,7 +2783,7 @@ t2.start()
def backward(ctx, grad):
self.assertFalse(torch.is_autocast_enabled())
a, b = ctx.saved_tensors
return grad.mm(b.t()), None
return grad.mm(b.t()), None, None
mymm = MyMM.apply
@ -2792,11 +2794,17 @@ t2.start()
y = (0, {0: torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=False)})
with torch.cuda.amp.autocast():
output = mymm(x, y)
output = mymm(x, y, torch.float32)
self.assertTrue(output.dtype is torch.float32)
loss = output.sum()
loss.backward()
# Tests if custom_fwd becomes a no-op when mymm runs outside an autocast-enabled region.
output = mymm(x, y, torch.float16)
self.assertTrue(output.dtype is torch.float16)
loss = output.sum()
loss.backward()
def test_autocast_cat_jit(self):
# Reported at https://github.com/pytorch/pytorch/issues/38958

View file

@ -140,7 +140,8 @@ class autocast(object):
# may be falsely detected as "Iterables."
def _cast(value, dtype):
if isinstance(value, torch.Tensor):
return value.to(dtype) if (value.is_floating_point() and value.is_cuda) else value
is_eligible = (value.is_floating_point() and value.is_cuda and (value.dtype is not torch.float64))
return value.to(dtype) if is_eligible else value
elif isinstance(value, string_classes):
return value
elif isinstance(value, np.ndarray):
@ -169,10 +170,15 @@ def custom_fwd(fwd=None, **kwargs):
:class:`torch.autograd.Function`). See the :ref:`example page<amp-custom-examples>` for more detail.
Arguments:
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``, casts incoming
floating-point Tensors to the target dtype (non-floating-point Tensors are not affected),
and causes ``forward`` to execute with autocast disabled.
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
when ``forward`` runs in an autocast-enabled region, casts incoming
floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected),
then executes ``forward`` with autocast disabled.
If ``None``, ``forward``'s internal ops execute with the current autocast state.
.. note::
If the decorated ``forward`` is called outside an autocast-enabled region,
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
"""
if fwd is None:
if len(kwargs) == 0:
@ -194,9 +200,13 @@ def custom_fwd(fwd=None, **kwargs):
args[0]._fwd_used_autocast = torch.is_autocast_enabled()
return fwd(*args, **kwargs)
else:
autocast_context = torch.is_autocast_enabled()
args[0]._fwd_used_autocast = False
with autocast(enabled=False):
return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
if autocast_context:
with autocast(enabled=False):
return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
else:
return fwd(*args, **kwargs)
return decorate_fwd