mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Make custom_fwd a no-op when not executed under autocast (#36171)
Summary:
Currently, a custom autograd function written with
```
torch.cuda.amp.custom_fwd(cast_inputs=dtype)
def forward(ctx, *args):
...
```
casts incoming floating-point CUDA tensors to `dtype` unconditionally, regardless of whether the function executes in an autocast-enabled region. I think I had the wrong idea there. Autocast-disabled regions should give the user control of input types. Also, `custom_fwd(cast_inputs=dtype)`-decorated functions' behavior should align with native fp32list/fp16list functions. C++-side casting wrappers have no effect when autocast is disabled, and `custom_fwd`'s casting should behave the same way.
The present PR changes `custom_fwd` so it only casts in autocast-enabled regions (also updates custom_fwd to ignore fp64 inputs, like the C++ wrappers).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36171
Differential Revision: D22179511
Pulled By: ngimel
fbshipit-source-id: 5a93d070179a43206066bce19da0a5a19ecaabbd
This commit is contained in:
parent
f652abc1dd
commit
3b040c478a
4 changed files with 41 additions and 22 deletions
|
|
@ -74,8 +74,8 @@ but ``a.addmm_(b, c)`` and ``a.addmm(b, c, out=d)`` cannot.
|
|||
For best performance and stability, prefer out-of-place ops in autocast-enabled
|
||||
regions.
|
||||
|
||||
Ops called with an explicit `dtype=...` argument are not eligible,
|
||||
and will produce output that respects the `dtype` argument.
|
||||
Ops called with an explicit ``dtype=...`` argument are not eligible,
|
||||
and will produce output that respects the ``dtype`` argument.
|
||||
|
||||
Op-Specific Behavior
|
||||
--------------------
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ Typical Mixed Precision Training
|
|||
|
||||
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
|
||||
# Backward passes under autocast are not recommended.
|
||||
# Backward ops run in the same precision that autocast used for corresponding forward ops.
|
||||
# Backward ops run in the same dtype autocast chose for corresponding forward ops.
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# scaler.step() first unscales the gradients of the optimizer's assigned params.
|
||||
|
|
@ -368,8 +368,8 @@ the relevant case below.
|
|||
Functions with multiple inputs or autocastable ops
|
||||
--------------------------------------------------
|
||||
|
||||
Apply :func:`custom_fwd` and :func:`custom_bwd` (with no arguments) to ``forward`` and ``backward``
|
||||
respectively. These ensure ``forward`` executes with the current autocast state and ``backward``
|
||||
Apply :func:`custom_fwd<custom_fwd>` and :func:`custom_bwd<custom_bwd>` (with no arguments) to ``forward`` and
|
||||
``backward`` respectively. These ensure ``forward`` executes with the current autocast state and ``backward``
|
||||
executes with the same autocast state as ``forward`` (which can prevent type mismatch errors)::
|
||||
|
||||
class MyMM(torch.autograd.Function):
|
||||
|
|
@ -391,13 +391,14 @@ Now ``MyMM`` can be invoked anywhere, without disabling autocast or manually cas
|
|||
with autocast():
|
||||
output = mymm(input1, input2)
|
||||
|
||||
Functions that need a particular dtype
|
||||
--------------------------------------
|
||||
Functions that need a particular ``dtype``
|
||||
------------------------------------------
|
||||
|
||||
Consider a custom function that requires ``torch.float32`` inputs.
|
||||
Apply :func:`custom_fwd(cast_inputs=torch.float32)<custom_fwd>` to ``forward``
|
||||
and :func:`custom_bwd<custom_bwd>` (with no arguments) to ``backward``.
|
||||
With ``cast_inputs=torch.float32``, these disable autocast in ``forward`` and ``backward``,
|
||||
and ``forward`` casts incoming floating-point CUDA Tensors to ``float32``::
|
||||
If ``forward`` runs in an autocast-enabled region, the decorators cast floating-point CUDA Tensor
|
||||
inputs to ``float32``, and locally disable autocast during ``forward`` and ``backward``::
|
||||
|
||||
class MyFloat32Func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
|
@ -411,7 +412,7 @@ and ``forward`` casts incoming floating-point CUDA Tensors to ``float32``::
|
|||
def backward(ctx, grad):
|
||||
...
|
||||
|
||||
Now ``MyFloat32Func`` can be invoked anywhere, without disabling autocast or manually casting inputs::
|
||||
Now ``MyFloat32Func`` can be invoked anywhere, without manually disabling autocast or casting inputs::
|
||||
|
||||
func = MyFloat32Func.apply
|
||||
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ types = [
|
|||
torch.HalfTensor,
|
||||
]
|
||||
|
||||
|
||||
def make_sparse_tensor(t, n, *sizes):
|
||||
assert t.is_sparse
|
||||
tensor = t()
|
||||
|
|
@ -76,6 +77,7 @@ def make_sparse_tensor(t, n, *sizes):
|
|||
|
||||
_cycles_per_ms = None
|
||||
|
||||
|
||||
def get_cycles_per_ms():
|
||||
"""Approximate number of cycles per millisecond for torch.cuda._sleep"""
|
||||
global _cycles_per_ms
|
||||
|
|
@ -2764,14 +2766,14 @@ t2.start()
|
|||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
def test_autocast_custom_disabled(self):
|
||||
def test_autocast_custom_cast_inputs(self):
|
||||
class MyMM(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, a, container):
|
||||
def forward(ctx, a, container, expect_type):
|
||||
b = container[1][0]
|
||||
self.assertTrue(a.dtype is torch.float32)
|
||||
self.assertTrue(b.dtype is torch.float32)
|
||||
self.assertTrue(a.dtype is expect_type)
|
||||
self.assertTrue(b.dtype is expect_type)
|
||||
self.assertFalse(torch.is_autocast_enabled())
|
||||
ctx.save_for_backward(a, b)
|
||||
return a.mm(b)
|
||||
|
|
@ -2781,7 +2783,7 @@ t2.start()
|
|||
def backward(ctx, grad):
|
||||
self.assertFalse(torch.is_autocast_enabled())
|
||||
a, b = ctx.saved_tensors
|
||||
return grad.mm(b.t()), None
|
||||
return grad.mm(b.t()), None, None
|
||||
|
||||
mymm = MyMM.apply
|
||||
|
||||
|
|
@ -2792,11 +2794,17 @@ t2.start()
|
|||
y = (0, {0: torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=False)})
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
output = mymm(x, y)
|
||||
output = mymm(x, y, torch.float32)
|
||||
self.assertTrue(output.dtype is torch.float32)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
# Tests if custom_fwd becomes a no-op when mymm runs outside an autocast-enabled region.
|
||||
output = mymm(x, y, torch.float16)
|
||||
self.assertTrue(output.dtype is torch.float16)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
def test_autocast_cat_jit(self):
|
||||
# Reported at https://github.com/pytorch/pytorch/issues/38958
|
||||
|
||||
|
|
|
|||
|
|
@ -140,7 +140,8 @@ class autocast(object):
|
|||
# may be falsely detected as "Iterables."
|
||||
def _cast(value, dtype):
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.to(dtype) if (value.is_floating_point() and value.is_cuda) else value
|
||||
is_eligible = (value.is_floating_point() and value.is_cuda and (value.dtype is not torch.float64))
|
||||
return value.to(dtype) if is_eligible else value
|
||||
elif isinstance(value, string_classes):
|
||||
return value
|
||||
elif isinstance(value, np.ndarray):
|
||||
|
|
@ -169,10 +170,15 @@ def custom_fwd(fwd=None, **kwargs):
|
|||
:class:`torch.autograd.Function`). See the :ref:`example page<amp-custom-examples>` for more detail.
|
||||
|
||||
Arguments:
|
||||
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``, casts incoming
|
||||
floating-point Tensors to the target dtype (non-floating-point Tensors are not affected),
|
||||
and causes ``forward`` to execute with autocast disabled.
|
||||
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
|
||||
when ``forward`` runs in an autocast-enabled region, casts incoming
|
||||
floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected),
|
||||
then executes ``forward`` with autocast disabled.
|
||||
If ``None``, ``forward``'s internal ops execute with the current autocast state.
|
||||
|
||||
.. note::
|
||||
If the decorated ``forward`` is called outside an autocast-enabled region,
|
||||
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
|
||||
"""
|
||||
if fwd is None:
|
||||
if len(kwargs) == 0:
|
||||
|
|
@ -194,9 +200,13 @@ def custom_fwd(fwd=None, **kwargs):
|
|||
args[0]._fwd_used_autocast = torch.is_autocast_enabled()
|
||||
return fwd(*args, **kwargs)
|
||||
else:
|
||||
autocast_context = torch.is_autocast_enabled()
|
||||
args[0]._fwd_used_autocast = False
|
||||
with autocast(enabled=False):
|
||||
return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
|
||||
if autocast_context:
|
||||
with autocast(enabled=False):
|
||||
return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
|
||||
else:
|
||||
return fwd(*args, **kwargs)
|
||||
return decorate_fwd
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue