Fix typo in _normalize ref (#137079)

I think this should basically make no difference numerically, but it does have some ramifications on things like CSE.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137079
Approved by: https://github.com/Skylion007
ghstack dependencies: #136826, #137043, #137049, #137065
This commit is contained in:
chilli 2024-10-01 16:07:54 -07:00 committed by PyTorch MergeBot
parent 6374a19a6e
commit 2b329d3bf1

View file

@ -3140,7 +3140,7 @@ def _normalize(
a_acc, dim=norm_dims, unbiased=False, keepdim=True
)
rstd = torch.rsqrt(biased_var + eps)
out = (a - mean) * rstd
out = (a_acc - mean) * rstd
return out, mean, rstd