diff --git a/docs/source/amp.rst b/docs/source/amp.rst index ef50caaa976..0d8e0045e47 100644 --- a/docs/source/amp.rst +++ b/docs/source/amp.rst @@ -74,8 +74,8 @@ but ``a.addmm_(b, c)`` and ``a.addmm(b, c, out=d)`` cannot. For best performance and stability, prefer out-of-place ops in autocast-enabled regions. -Ops called with an explicit `dtype=...` argument are not eligible, -and will produce output that respects the `dtype` argument. +Ops called with an explicit ``dtype=...`` argument are not eligible, +and will produce output that respects the ``dtype`` argument. Op-Specific Behavior -------------------- diff --git a/docs/source/notes/amp_examples.rst b/docs/source/notes/amp_examples.rst index c2edbf6f7ca..9dead217056 100644 --- a/docs/source/notes/amp_examples.rst +++ b/docs/source/notes/amp_examples.rst @@ -44,7 +44,7 @@ Typical Mixed Precision Training # Scales loss. Calls backward() on scaled loss to create scaled gradients. # Backward passes under autocast are not recommended. - # Backward ops run in the same precision that autocast used for corresponding forward ops. + # Backward ops run in the same dtype autocast chose for corresponding forward ops. scaler.scale(loss).backward() # scaler.step() first unscales the gradients of the optimizer's assigned params. @@ -368,8 +368,8 @@ the relevant case below. Functions with multiple inputs or autocastable ops -------------------------------------------------- -Apply :func:`custom_fwd` and :func:`custom_bwd` (with no arguments) to ``forward`` and ``backward`` -respectively. These ensure ``forward`` executes with the current autocast state and ``backward`` +Apply :func:`custom_fwd` and :func:`custom_bwd` (with no arguments) to ``forward`` and +``backward`` respectively. These ensure ``forward`` executes with the current autocast state and ``backward`` executes with the same autocast state as ``forward`` (which can prevent type mismatch errors):: class MyMM(torch.autograd.Function): @@ -391,13 +391,14 @@ Now ``MyMM`` can be invoked anywhere, without disabling autocast or manually cas with autocast(): output = mymm(input1, input2) -Functions that need a particular dtype --------------------------------------- +Functions that need a particular ``dtype`` +------------------------------------------ +Consider a custom function that requires ``torch.float32`` inputs. Apply :func:`custom_fwd(cast_inputs=torch.float32)` to ``forward`` and :func:`custom_bwd` (with no arguments) to ``backward``. -With ``cast_inputs=torch.float32``, these disable autocast in ``forward`` and ``backward``, -and ``forward`` casts incoming floating-point CUDA Tensors to ``float32``:: +If ``forward`` runs in an autocast-enabled region, the decorators cast floating-point CUDA Tensor +inputs to ``float32``, and locally disable autocast during ``forward`` and ``backward``:: class MyFloat32Func(torch.autograd.Function): @staticmethod @@ -411,7 +412,7 @@ and ``forward`` casts incoming floating-point CUDA Tensors to ``float32``:: def backward(ctx, grad): ... -Now ``MyFloat32Func`` can be invoked anywhere, without disabling autocast or manually casting inputs:: +Now ``MyFloat32Func`` can be invoked anywhere, without manually disabling autocast or casting inputs:: func = MyFloat32Func.apply diff --git a/test/test_cuda.py b/test/test_cuda.py index e9f4af63573..3bfe747b3cd 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -64,6 +64,7 @@ types = [ torch.HalfTensor, ] + def make_sparse_tensor(t, n, *sizes): assert t.is_sparse tensor = t() @@ -76,6 +77,7 @@ def make_sparse_tensor(t, n, *sizes): _cycles_per_ms = None + def get_cycles_per_ms(): """Approximate number of cycles per millisecond for torch.cuda._sleep""" global _cycles_per_ms @@ -2764,14 +2766,14 @@ t2.start() loss = output.sum() loss.backward() - def test_autocast_custom_disabled(self): + def test_autocast_custom_cast_inputs(self): class MyMM(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) - def forward(ctx, a, container): + def forward(ctx, a, container, expect_type): b = container[1][0] - self.assertTrue(a.dtype is torch.float32) - self.assertTrue(b.dtype is torch.float32) + self.assertTrue(a.dtype is expect_type) + self.assertTrue(b.dtype is expect_type) self.assertFalse(torch.is_autocast_enabled()) ctx.save_for_backward(a, b) return a.mm(b) @@ -2781,7 +2783,7 @@ t2.start() def backward(ctx, grad): self.assertFalse(torch.is_autocast_enabled()) a, b = ctx.saved_tensors - return grad.mm(b.t()), None + return grad.mm(b.t()), None, None mymm = MyMM.apply @@ -2792,11 +2794,17 @@ t2.start() y = (0, {0: torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=False)}) with torch.cuda.amp.autocast(): - output = mymm(x, y) + output = mymm(x, y, torch.float32) self.assertTrue(output.dtype is torch.float32) loss = output.sum() loss.backward() + # Tests if custom_fwd becomes a no-op when mymm runs outside an autocast-enabled region. + output = mymm(x, y, torch.float16) + self.assertTrue(output.dtype is torch.float16) + loss = output.sum() + loss.backward() + def test_autocast_cat_jit(self): # Reported at https://github.com/pytorch/pytorch/issues/38958 diff --git a/torch/cuda/amp/autocast_mode.py b/torch/cuda/amp/autocast_mode.py index 49fbc69b7cf..8bac02fc39f 100644 --- a/torch/cuda/amp/autocast_mode.py +++ b/torch/cuda/amp/autocast_mode.py @@ -140,7 +140,8 @@ class autocast(object): # may be falsely detected as "Iterables." def _cast(value, dtype): if isinstance(value, torch.Tensor): - return value.to(dtype) if (value.is_floating_point() and value.is_cuda) else value + is_eligible = (value.is_floating_point() and value.is_cuda and (value.dtype is not torch.float64)) + return value.to(dtype) if is_eligible else value elif isinstance(value, string_classes): return value elif isinstance(value, np.ndarray): @@ -169,10 +170,15 @@ def custom_fwd(fwd=None, **kwargs): :class:`torch.autograd.Function`). See the :ref:`example page` for more detail. Arguments: - cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``, casts incoming - floating-point Tensors to the target dtype (non-floating-point Tensors are not affected), - and causes ``forward`` to execute with autocast disabled. + cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``, + when ``forward`` runs in an autocast-enabled region, casts incoming + floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected), + then executes ``forward`` with autocast disabled. If ``None``, ``forward``'s internal ops execute with the current autocast state. + + .. note:: + If the decorated ``forward`` is called outside an autocast-enabled region, + :func:`custom_fwd` is a no-op and ``cast_inputs`` has no effect. """ if fwd is None: if len(kwargs) == 0: @@ -194,9 +200,13 @@ def custom_fwd(fwd=None, **kwargs): args[0]._fwd_used_autocast = torch.is_autocast_enabled() return fwd(*args, **kwargs) else: + autocast_context = torch.is_autocast_enabled() args[0]._fwd_used_autocast = False - with autocast(enabled=False): - return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs)) + if autocast_context: + with autocast(enabled=False): + return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs)) + else: + return fwd(*args, **kwargs) return decorate_fwd