Revert "Fix RMSNorm epsilon value type for BF16 or FP16 (#142848)"

This reverts commit 07e23653cd.

Reverted https://github.com/pytorch/pytorch/pull/142848 on behalf of https://github.com/izaitsevfb due to breaking internal tests, see D68355212 ([comment](https://github.com/pytorch/pytorch/pull/142848#issuecomment-2605734067))
This commit is contained in:
PyTorch MergeBot 2025-01-21 21:04:43 +00:00
parent bac62341eb
commit 895659cb41

View file

@ -278,27 +278,18 @@ Tensor rms_norm_symint(
input.scalar_type(),
"rms_norm",
[&] {
scalar_t eps_val;
if (!eps.has_value()) {
eps_val = std::numeric_limits<at::scalar_value_type<scalar_t>::type>::epsilon();
} else {
eps_val = eps.value();
}
// upcast is needed for fp16 and bf16
c10::ScalarType opmath_t = toOpMathType(input.scalar_type());
Tensor upcasted_input = input.to(opmath_t);
Tensor rqrst_input;
// opmath_t would be one of [Double, Float, ComplexFloat, ComplexDouble]
if (opmath_t == at::ScalarType::Float || opmath_t == at::ScalarType::ComplexFloat) {
float eps_val = std::numeric_limits<float>::epsilon();
if (eps.has_value()) {
eps_val = eps.value();
}
rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keepdim=*/true).add_(eps_val));
} else {
double eps_val = std::numeric_limits<double>::epsilon();
if (eps.has_value()) {
eps_val = eps.value();
}
rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keepdim=*/true).add_(eps_val));
}
auto rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keepdim=*/true).add_(eps_val));
Tensor result = upcasted_input.mul(rqrst_input).type_as(input);
if (weight_opt.has_value()) {