implement torch._foreach_rsqrt (#134574)

Related:
- #133367 c

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134574
Approved by: https://github.com/eqy, https://github.com/janeyx99
This commit is contained in:
Masaki Kozuki 2024-11-12 15:34:35 +00:00 committed by PyTorch MergeBot
parent 8cb0b932a1
commit 71d8bb7ede
10 changed files with 80 additions and 2 deletions

View file

@ -44,6 +44,7 @@
#include <ATen/ops/_foreach_pow_native.h>
#include <ATen/ops/_foreach_reciprocal_native.h>
#include <ATen/ops/_foreach_round_native.h>
#include <ATen/ops/_foreach_rsqrt_native.h>
#include <ATen/ops/_foreach_sigmoid_native.h>
#include <ATen/ops/_foreach_sign_native.h>
#include <ATen/ops/_foreach_sin_native.h>
@ -393,6 +394,7 @@ FOREACH_UNARY_OP(tanh)
FOREACH_UNARY_OP(sin)
FOREACH_UNARY_OP(sinh)
FOREACH_UNARY_OP(round)
FOREACH_UNARY_OP(rsqrt)
FOREACH_UNARY_OP(lgamma)
FOREACH_UNARY_OP(frac)
FOREACH_UNARY_OP(trunc)

View file

@ -1,6 +1,7 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/native/ForeachUtils.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <c10/util/TypeSafeSignMath.h>
#include <ATen/native/cuda/ForeachFunctors.cuh>
@ -28,6 +29,7 @@
#include <ATen/ops/_foreach_neg_native.h>
#include <ATen/ops/_foreach_reciprocal_native.h>
#include <ATen/ops/_foreach_round_native.h>
#include <ATen/ops/_foreach_rsqrt_native.h>
#include <ATen/ops/_foreach_sigmoid_native.h>
#include <ATen/ops/_foreach_sign_native.h>
#include <ATen/ops/_foreach_sin_native.h>
@ -304,11 +306,35 @@ struct Sign {
}
};
template <typename T>
struct Rsqrt {
C10_DEVICE T operator()(T t) const {
return c10::cuda::compat::rsqrt(t);
}
};
template <>
struct Rsqrt<c10::complex<float>> {
C10_DEVICE c10::complex<float> operator()(c10::complex<float> t) const {
const auto one = c10::complex<float>(1.0, 0);
return one / std::sqrt(t);
}
};
template <>
struct Rsqrt<c10::complex<double>> {
C10_DEVICE c10::complex<double> operator()(c10::complex<double> t) const {
const auto one = c10::complex<double>(1.0, 0);
return one / std::sqrt(t);
}
};
OP_CUSTOM_FUNCTOR(floating_complex_half_bfloat16, sigmoid, Sigmoid)
OP_CUSTOM_FUNCTOR(floating_half_bfloat16, round, Round)
OP_CUSTOM_FUNCTOR(floating_half_bfloat16, frac, Trunc)
OP_CUSTOM_FUNCTOR(floating_complex_half_bfloat16, reciprocal, Reciprocal)
OP_CUSTOM_FUNCTOR(floating_half_bfloat16, sign, Sign)
OP_CUSTOM_FUNCTOR(floating_complex_half_bfloat16, rsqrt, Rsqrt)
// note(mkozuki): tensor dtype checks of `neg` kernels.
// Since `check_foreach_api_restrictions` don't require all the tensors to have

View file

@ -11293,6 +11293,21 @@
CUDA: foreach_tensor_round_cuda_
autogen: _foreach_round.out
- func: _foreach_rsqrt(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
CompositeExplicitAutograd: foreach_tensor_rsqrt_slow
CUDA: foreach_tensor_rsqrt_cuda
- func: _foreach_rsqrt_(Tensor(a!)[] self) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
CompositeExplicitAutograd: foreach_tensor_rsqrt_slow_
CUDA: foreach_tensor_rsqrt_cuda_
autogen: _foreach_rsqrt.out
- func: _foreach_sigmoid(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function

View file

@ -303,6 +303,9 @@ aten::_foreach_reciprocal_
aten::_foreach_round
aten::_foreach_round.out
aten::_foreach_round_
aten::_foreach_rsqrt
aten::_foreach_rsqrt.out
aten::_foreach_rsqrt_
aten::_foreach_sigmoid
aten::_foreach_sigmoid.out
aten::_foreach_sigmoid_

View file

@ -56,6 +56,7 @@ un_ops_under_test = [
torch._foreach_sign,
torch._foreach_abs,
torch._foreach_sqrt,
torch._foreach_rsqrt,
]
compose_ops = [torch._foreach_addcdiv, torch._foreach_addcmul]
all_ops = parametrize(

View file

@ -1389,6 +1389,7 @@ class TestForeach(TestCase):
"_foreach_log",
"_foreach_pow",
"_foreach_sqrt",
"_foreach_rsqrt",
)
):
value_range = {"low": 0.5, "high": 1.0}

View file

@ -1452,6 +1452,8 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._foreach_round",
"torch._foreach_sigmoid_",
"torch._foreach_sigmoid",
"torch._foreach_rsqrt_",
"torch._foreach_rsqrt",
"torch._foreach_sign_",
"torch._foreach_sign",
"torch._foreach_sin_",

View file

@ -6128,6 +6128,7 @@ foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div)
register_foreach_pointwise(aten._foreach_div.Tensor, div)
foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div)
register_foreach_pointwise(aten._foreach_sqrt, sqrt)
register_foreach_pointwise(aten._foreach_rsqrt, rsqrt)
register_foreach_pointwise(aten._foreach_maximum.List, maximum)
register_foreach_pointwise(aten._foreach_maximum.Scalar, maximum)
register_foreach_pointwise(aten._foreach_minimum.List, minimum)

View file

@ -584,8 +584,7 @@ def _multi_tensor_adafactor(
# square the eps1 as we sqrt after to keep eps1's magnitude
torch._foreach_clamp_min_(var_estimates, eps1 * eps1)
torch._foreach_sqrt_(var_estimates)
torch._foreach_reciprocal_(var_estimates)
torch._foreach_rsqrt_(var_estimates)
torch._foreach_mul_(var_estimates, device_grads)
updates = var_estimates

View file

@ -10141,6 +10141,34 @@ foreach_unary_op_db: List[OpInfo] = [
),
),
),
ForeachFuncInfo(
'rsqrt',
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
supports_inplace_autograd=True,
supports_forward_ad=True,
backward_requires_result=True,
decorators=(
DecorateInfo(
unittest.expectedFailure,
"TestMeta",
"test_dispatch_meta_inplace",
dtypes=integral_types_and(torch.bool),
),
DecorateInfo(
unittest.expectedFailure,
"TestMeta",
"test_dispatch_symbolic_meta_inplace",
dtypes=integral_types_and(torch.bool),
),
DecorateInfo(
unittest.expectedFailure,
"TestMeta",
"test_meta_inplace",
dtypes=integral_types_and(torch.bool),
),
),
),
ForeachFuncInfo(
'ceil',
sample_inputs_func=foreach_inputs_sample_func(1, False, False),