mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "Fix RMSNorm epsilon value type for BF16 or FP16 (#142848)"
This reverts commit 07e23653cd.
Reverted https://github.com/pytorch/pytorch/pull/142848 on behalf of https://github.com/izaitsevfb due to breaking internal tests, see D68355212 ([comment](https://github.com/pytorch/pytorch/pull/142848#issuecomment-2605734067))
This commit is contained in:
parent
bac62341eb
commit
895659cb41
1 changed files with 8 additions and 17 deletions
|
|
@ -278,27 +278,18 @@ Tensor rms_norm_symint(
|
|||
input.scalar_type(),
|
||||
"rms_norm",
|
||||
[&] {
|
||||
scalar_t eps_val;
|
||||
if (!eps.has_value()) {
|
||||
eps_val = std::numeric_limits<at::scalar_value_type<scalar_t>::type>::epsilon();
|
||||
} else {
|
||||
eps_val = eps.value();
|
||||
}
|
||||
|
||||
// upcast is needed for fp16 and bf16
|
||||
c10::ScalarType opmath_t = toOpMathType(input.scalar_type());
|
||||
Tensor upcasted_input = input.to(opmath_t);
|
||||
|
||||
Tensor rqrst_input;
|
||||
|
||||
// opmath_t would be one of [Double, Float, ComplexFloat, ComplexDouble]
|
||||
if (opmath_t == at::ScalarType::Float || opmath_t == at::ScalarType::ComplexFloat) {
|
||||
float eps_val = std::numeric_limits<float>::epsilon();
|
||||
if (eps.has_value()) {
|
||||
eps_val = eps.value();
|
||||
}
|
||||
rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keepdim=*/true).add_(eps_val));
|
||||
} else {
|
||||
double eps_val = std::numeric_limits<double>::epsilon();
|
||||
if (eps.has_value()) {
|
||||
eps_val = eps.value();
|
||||
}
|
||||
rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keepdim=*/true).add_(eps_val));
|
||||
}
|
||||
|
||||
auto rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keepdim=*/true).add_(eps_val));
|
||||
Tensor result = upcasted_input.mul(rqrst_input).type_as(input);
|
||||
|
||||
if (weight_opt.has_value()) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue