pytorch/torch/distributed/tensor
Mayank Mishra e5657024b5 Fix loss_parallel with BF16 logits (#130550)
Fixes #130549

This PR uses the specific dtype for the `grad_input` buffer and fixes the error

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130550
Approved by: https://github.com/tianyu-l
2024-07-12 15:47:38 +00:00
..
parallel Fix loss_parallel with BF16 logits (#130550) 2024-07-12 15:47:38 +00:00
__init__.py