mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
implement torch._foreach_rsqrt (#134574)
Related: - #133367 c Pull Request resolved: https://github.com/pytorch/pytorch/pull/134574 Approved by: https://github.com/eqy, https://github.com/janeyx99
This commit is contained in:
parent
8cb0b932a1
commit
71d8bb7ede
10 changed files with 80 additions and 2 deletions
|
|
@ -44,6 +44,7 @@
|
|||
#include <ATen/ops/_foreach_pow_native.h>
|
||||
#include <ATen/ops/_foreach_reciprocal_native.h>
|
||||
#include <ATen/ops/_foreach_round_native.h>
|
||||
#include <ATen/ops/_foreach_rsqrt_native.h>
|
||||
#include <ATen/ops/_foreach_sigmoid_native.h>
|
||||
#include <ATen/ops/_foreach_sign_native.h>
|
||||
#include <ATen/ops/_foreach_sin_native.h>
|
||||
|
|
@ -393,6 +394,7 @@ FOREACH_UNARY_OP(tanh)
|
|||
FOREACH_UNARY_OP(sin)
|
||||
FOREACH_UNARY_OP(sinh)
|
||||
FOREACH_UNARY_OP(round)
|
||||
FOREACH_UNARY_OP(rsqrt)
|
||||
FOREACH_UNARY_OP(lgamma)
|
||||
FOREACH_UNARY_OP(frac)
|
||||
FOREACH_UNARY_OP(trunc)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/ForeachUtils.h>
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#include <c10/util/TypeSafeSignMath.h>
|
||||
#include <ATen/native/cuda/ForeachFunctors.cuh>
|
||||
|
||||
|
|
@ -28,6 +29,7 @@
|
|||
#include <ATen/ops/_foreach_neg_native.h>
|
||||
#include <ATen/ops/_foreach_reciprocal_native.h>
|
||||
#include <ATen/ops/_foreach_round_native.h>
|
||||
#include <ATen/ops/_foreach_rsqrt_native.h>
|
||||
#include <ATen/ops/_foreach_sigmoid_native.h>
|
||||
#include <ATen/ops/_foreach_sign_native.h>
|
||||
#include <ATen/ops/_foreach_sin_native.h>
|
||||
|
|
@ -304,11 +306,35 @@ struct Sign {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Rsqrt {
|
||||
C10_DEVICE T operator()(T t) const {
|
||||
return c10::cuda::compat::rsqrt(t);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Rsqrt<c10::complex<float>> {
|
||||
C10_DEVICE c10::complex<float> operator()(c10::complex<float> t) const {
|
||||
const auto one = c10::complex<float>(1.0, 0);
|
||||
return one / std::sqrt(t);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Rsqrt<c10::complex<double>> {
|
||||
C10_DEVICE c10::complex<double> operator()(c10::complex<double> t) const {
|
||||
const auto one = c10::complex<double>(1.0, 0);
|
||||
return one / std::sqrt(t);
|
||||
}
|
||||
};
|
||||
|
||||
OP_CUSTOM_FUNCTOR(floating_complex_half_bfloat16, sigmoid, Sigmoid)
|
||||
OP_CUSTOM_FUNCTOR(floating_half_bfloat16, round, Round)
|
||||
OP_CUSTOM_FUNCTOR(floating_half_bfloat16, frac, Trunc)
|
||||
OP_CUSTOM_FUNCTOR(floating_complex_half_bfloat16, reciprocal, Reciprocal)
|
||||
OP_CUSTOM_FUNCTOR(floating_half_bfloat16, sign, Sign)
|
||||
OP_CUSTOM_FUNCTOR(floating_complex_half_bfloat16, rsqrt, Rsqrt)
|
||||
|
||||
// note(mkozuki): tensor dtype checks of `neg` kernels.
|
||||
// Since `check_foreach_api_restrictions` don't require all the tensors to have
|
||||
|
|
|
|||
|
|
@ -11293,6 +11293,21 @@
|
|||
CUDA: foreach_tensor_round_cuda_
|
||||
autogen: _foreach_round.out
|
||||
|
||||
- func: _foreach_rsqrt(Tensor[] self) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_rsqrt_slow
|
||||
CUDA: foreach_tensor_rsqrt_cuda
|
||||
|
||||
- func: _foreach_rsqrt_(Tensor(a!)[] self) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_rsqrt_slow_
|
||||
CUDA: foreach_tensor_rsqrt_cuda_
|
||||
autogen: _foreach_rsqrt.out
|
||||
|
||||
- func: _foreach_sigmoid(Tensor[] self) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
variants: function
|
||||
|
|
|
|||
|
|
@ -303,6 +303,9 @@ aten::_foreach_reciprocal_
|
|||
aten::_foreach_round
|
||||
aten::_foreach_round.out
|
||||
aten::_foreach_round_
|
||||
aten::_foreach_rsqrt
|
||||
aten::_foreach_rsqrt.out
|
||||
aten::_foreach_rsqrt_
|
||||
aten::_foreach_sigmoid
|
||||
aten::_foreach_sigmoid.out
|
||||
aten::_foreach_sigmoid_
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ un_ops_under_test = [
|
|||
torch._foreach_sign,
|
||||
torch._foreach_abs,
|
||||
torch._foreach_sqrt,
|
||||
torch._foreach_rsqrt,
|
||||
]
|
||||
compose_ops = [torch._foreach_addcdiv, torch._foreach_addcmul]
|
||||
all_ops = parametrize(
|
||||
|
|
|
|||
|
|
@ -1389,6 +1389,7 @@ class TestForeach(TestCase):
|
|||
"_foreach_log",
|
||||
"_foreach_pow",
|
||||
"_foreach_sqrt",
|
||||
"_foreach_rsqrt",
|
||||
)
|
||||
):
|
||||
value_range = {"low": 0.5, "high": 1.0}
|
||||
|
|
|
|||
|
|
@ -1452,6 +1452,8 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
|||
"torch._foreach_round",
|
||||
"torch._foreach_sigmoid_",
|
||||
"torch._foreach_sigmoid",
|
||||
"torch._foreach_rsqrt_",
|
||||
"torch._foreach_rsqrt",
|
||||
"torch._foreach_sign_",
|
||||
"torch._foreach_sign",
|
||||
"torch._foreach_sin_",
|
||||
|
|
|
|||
|
|
@ -6128,6 +6128,7 @@ foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div)
|
|||
register_foreach_pointwise(aten._foreach_div.Tensor, div)
|
||||
foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div)
|
||||
register_foreach_pointwise(aten._foreach_sqrt, sqrt)
|
||||
register_foreach_pointwise(aten._foreach_rsqrt, rsqrt)
|
||||
register_foreach_pointwise(aten._foreach_maximum.List, maximum)
|
||||
register_foreach_pointwise(aten._foreach_maximum.Scalar, maximum)
|
||||
register_foreach_pointwise(aten._foreach_minimum.List, minimum)
|
||||
|
|
|
|||
|
|
@ -584,8 +584,7 @@ def _multi_tensor_adafactor(
|
|||
|
||||
# square the eps1 as we sqrt after to keep eps1's magnitude
|
||||
torch._foreach_clamp_min_(var_estimates, eps1 * eps1)
|
||||
torch._foreach_sqrt_(var_estimates)
|
||||
torch._foreach_reciprocal_(var_estimates)
|
||||
torch._foreach_rsqrt_(var_estimates)
|
||||
torch._foreach_mul_(var_estimates, device_grads)
|
||||
updates = var_estimates
|
||||
|
||||
|
|
|
|||
|
|
@ -10141,6 +10141,34 @@ foreach_unary_op_db: List[OpInfo] = [
|
|||
),
|
||||
),
|
||||
),
|
||||
ForeachFuncInfo(
|
||||
'rsqrt',
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
supports_autograd=True,
|
||||
supports_inplace_autograd=True,
|
||||
supports_forward_ad=True,
|
||||
backward_requires_result=True,
|
||||
decorators=(
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestMeta",
|
||||
"test_dispatch_meta_inplace",
|
||||
dtypes=integral_types_and(torch.bool),
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestMeta",
|
||||
"test_dispatch_symbolic_meta_inplace",
|
||||
dtypes=integral_types_and(torch.bool),
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestMeta",
|
||||
"test_meta_inplace",
|
||||
dtypes=integral_types_and(torch.bool),
|
||||
),
|
||||
),
|
||||
),
|
||||
ForeachFuncInfo(
|
||||
'ceil',
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
|
|
|
|||
Loading…
Reference in a new issue