pytorch/torch/_decomp
Andres Lugo-Reyes 38b8b614a2 [ROCm] Implement forward AD for miopen_batch_norm (#125069)
Implements forward automatic differentiation support for miopen_batch_norm as well as unskips the associated unit tests. Also fixes a class of functorch related unit tests that fail due to failing a contiguous tensor assertion in BatchNorm_miopen.cpp. Solution was to just limit tensors to miopen_batch_norm that have at least 3 dimensions. The exact restriction already existed in the cudnn path and is why the tests in question only failed on ROCm.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125069
Approved by: https://github.com/jeffdaily, https://github.com/andrewor14
2024-05-14 19:09:50 +00:00
..
__init__.py [ROCm] Implement forward AD for miopen_batch_norm (#125069) 2024-05-14 19:09:50 +00:00
decompositions.py [ROCm] Implement forward AD for miopen_batch_norm (#125069) 2024-05-14 19:09:50 +00:00
decompositions_for_jvp.py [ROCm] Implement forward AD for miopen_batch_norm (#125069) 2024-05-14 19:09:50 +00:00
decompositions_for_rng.py