mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary:
Currently, a custom autograd function written with
```
torch.cuda.amp.custom_fwd(cast_inputs=dtype)
def forward(ctx, *args):
...
```
casts incoming floating-point CUDA tensors to `dtype` unconditionally, regardless of whether the function executes in an autocast-enabled region. I think I had the wrong idea there. Autocast-disabled regions should give the user control of input types. Also, `custom_fwd(cast_inputs=dtype)`-decorated functions' behavior should align with native fp32list/fp16list functions. C++-side casting wrappers have no effect when autocast is disabled, and `custom_fwd`'s casting should behave the same way.
The present PR changes `custom_fwd` so it only casts in autocast-enabled regions (also updates custom_fwd to ignore fp64 inputs, like the C++ wrappers).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36171
Differential Revision: D22179511
Pulled By: ngimel
fbshipit-source-id: 5a93d070179a43206066bce19da0a5a19ecaabbd
|
||
|---|---|---|
| .. | ||
| _static | ||
| _templates | ||
| _templates-stable | ||
| community | ||
| notes | ||
| rpc | ||
| scripts | ||
| __config__.rst | ||
| amp.rst | ||
| autograd.rst | ||
| bottleneck.rst | ||
| checkpoint.rst | ||
| conf.py | ||
| cpp_extension.rst | ||
| cpp_index.rst | ||
| cuda.rst | ||
| cudnn_persistent_rnn.rst | ||
| data.rst | ||
| distributed.rst | ||
| distributions.rst | ||
| dlpack.rst | ||
| docutils.conf | ||
| futures.rst | ||
| hub.rst | ||
| index.rst | ||
| jit.rst | ||
| jit_builtin_functions.rst | ||
| jit_language_reference.rst | ||
| jit_python_reference.rst | ||
| jit_unsupported.rst | ||
| math-quantizer-equation.png | ||
| model_zoo.rst | ||
| multiprocessing.rst | ||
| name_inference.rst | ||
| named_tensor.rst | ||
| nn.functional.rst | ||
| nn.init.rst | ||
| nn.rst | ||
| onnx.rst | ||
| optim.rst | ||
| packages.rst | ||
| quantization.rst | ||
| random.rst | ||
| rpc.rst | ||
| sparse.rst | ||
| storage.rst | ||
| tensor_attributes.rst | ||
| tensor_view.rst | ||
| tensorboard.rst | ||
| tensors.rst | ||
| torch.rst | ||
| type_info.rst | ||