diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp index 480ea9706e2..ec129a93e0d 100644 --- a/aten/src/ATen/native/layer_norm.cpp +++ b/aten/src/ATen/native/layer_norm.cpp @@ -278,27 +278,18 @@ Tensor rms_norm_symint( input.scalar_type(), "rms_norm", [&] { + scalar_t eps_val; + if (!eps.has_value()) { + eps_val = std::numeric_limits::type>::epsilon(); + } else { + eps_val = eps.value(); + } + // upcast is needed for fp16 and bf16 c10::ScalarType opmath_t = toOpMathType(input.scalar_type()); Tensor upcasted_input = input.to(opmath_t); - Tensor rqrst_input; - - // opmath_t would be one of [Double, Float, ComplexFloat, ComplexDouble] - if (opmath_t == at::ScalarType::Float || opmath_t == at::ScalarType::ComplexFloat) { - float eps_val = std::numeric_limits::epsilon(); - if (eps.has_value()) { - eps_val = eps.value(); - } - rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keepdim=*/true).add_(eps_val)); - } else { - double eps_val = std::numeric_limits::epsilon(); - if (eps.has_value()) { - eps_val = eps.value(); - } - rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keepdim=*/true).add_(eps_val)); - } - + auto rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keepdim=*/true).add_(eps_val)); Tensor result = upcasted_input.mul(rqrst_input).type_as(input); if (weight_opt.has_value()) {