Fixes https://github.com/pytorch/pytorch/issues/104871 Pull Request resolved: https://github.com/pytorch/pytorch/pull/105158 Approved by: https://github.com/SherlockNoMad
multi_margin_loss