Revert "Add _foreach_clamp (#106574)"

This reverts commit 2b560d3c3a.

Reverted https://github.com/pytorch/pytorch/pull/106574 on behalf of https://github.com/kit1980 due to breaking internal windows builds ([comment](https://github.com/pytorch/pytorch/pull/106574#issuecomment-1675400335))
This commit is contained in:
PyTorch MergeBot 2023-08-11 21:05:02 +00:00
parent c9cdcb299a
commit 354484ea6d
10 changed files with 25 additions and 344 deletions

View file

@ -2,7 +2,6 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/native/ForeachUtils.h>
#include <c10/util/Optional.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
@ -20,7 +19,6 @@
#include <ATen/ops/_foreach_ceil_native.h>
#include <ATen/ops/_foreach_clamp_max_native.h>
#include <ATen/ops/_foreach_clamp_min_native.h>
#include <ATen/ops/_foreach_clamp_native.h>
#include <ATen/ops/_foreach_cos_native.h>
#include <ATen/ops/_foreach_cosh_native.h>
#include <ATen/ops/_foreach_div_native.h>
@ -410,33 +408,4 @@ std::vector<Tensor> foreach_scalar_pow_list_kernel_slow(
return result;
}
std::vector<Tensor> foreach_tensor_clamp_scalar_kernel_slow(
TensorList self,
const optional<Scalar>& min,
const optional<Scalar>& max) {
TORCH_CHECK(
min.has_value() || max.has_value(),
"Either `min` or `max` must be specified");
check_foreach_api_restrictions(self);
std::vector<Tensor> result;
result.reserve(self.size());
for (const auto& t : self) {
result.emplace_back(t.clamp(min, max));
}
return result;
}
void foreach_tensor_clamp_scalar_kernel_slow_(
TensorList self,
const optional<Scalar>& min,
const optional<Scalar>& max) {
TORCH_CHECK(
min.has_value() || max.has_value(),
"Either `min` or `max` must be specified");
check_foreach_api_restrictions(self);
for (auto& t : self) {
t.clamp_(min, max);
}
}
} // namespace at::native

View file

@ -1,9 +1,7 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/TensorCompare.h>
#include <ATen/native/cuda/ForeachFunctors.cuh>
#include <ATen/native/cuda/ForeachMinMaxFunctors.cuh>
@ -13,7 +11,6 @@
#include <ATen/ops/_foreach_add_native.h>
#include <ATen/ops/_foreach_clamp_max_native.h>
#include <ATen/ops/_foreach_clamp_min_native.h>
#include <ATen/ops/_foreach_clamp_native.h>
#include <ATen/ops/_foreach_div_native.h>
#include <ATen/ops/_foreach_mul_native.h>
#include <ATen/ops/_foreach_pow_native.h>
@ -245,163 +242,4 @@ std::vector<Tensor> foreach_tensor_sub_scalar_kernel_cuda(
FOREACH_BINARY_OP_SCALAR(all_types_half_bfloat16, clamp_max, minimum, false);
FOREACH_BINARY_OP_SCALAR(all_types_half_bfloat16, clamp_min, maximum, false);
namespace {
template <typename T>
__forceinline__ C10_DEVICE T clamp(
const T v,
const T lower,
const T upper,
const at::native::detail::ClampLimits& clamp_kind) {
// Propagate nan, which doesn't propagate automatically for ROCm
if (at::_isnan(v)) {
return v;
} else if (clamp_kind == at::native::detail::ClampLimits::Min) {
return ::max(v, lower);
} else if (clamp_kind == at::native::detail::ClampLimits::Max) {
return ::min(v, upper);
} else {
return ::min(::max(v, lower), upper);
}
}
} // namespace
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct ClampFunctor {
using opmath_t = at::opmath_type<T>;
__forceinline__ C10_DEVICE T operator()(
const int chunk_size,
TensorListMetadata<depth>& tl,
opmath_t lower,
opmath_t upper,
const at::native::detail::ClampLimits& clamp_kind) {
static_assert(depth == 1 || depth == 2, "");
static_assert(depth >= r_args_depth, "");
static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
auto n = tl.numel_for_tensor[tensor_loc];
T* args[depth];
const bool all_aligned =
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
T r_args[r_args_depth][kILP];
// to make things simple, we put aligned case in a different code path
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
for (int64_t i_start = threadIdx.x;
i_start * kILP < n && i_start * kILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_args[0], args[0], 0, i_start);
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = clamp(
static_cast<opmath_t>(r_args[0][ii]), lower, upper, clamp_kind);
}
// store
load_store(args[res_arg_index], r_args[0], i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * kILP) {
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = clamp(
static_cast<opmath_t>(r_args[0][ii]), lower, upper, clamp_kind);
}
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
};
at::native::detail::ClampLimits get_clamp_kind(
const optional<Scalar>& min,
const optional<Scalar>& max) {
if (min.has_value() && max.has_value()) {
return at::native::detail::ClampLimits::MinMax;
} else if (min.has_value()) {
return at::native::detail::ClampLimits::Min;
} else {
return at::native::detail::ClampLimits::Max;
}
}
std::vector<at::Tensor> foreach_tensor_clamp_scalar_kernel_cuda(
TensorList self,
const optional<Scalar>& min,
const optional<Scalar>& max) {
check_foreach_api_restrictions(self);
if (!can_use_fast_route(
ArrayRef<TensorList>{self}, ArrayRef<Scalar>{}, true)) {
return foreach_tensor_clamp_scalar_kernel_slow(self, min, max);
}
TORCH_CHECK(
min.has_value() || max.has_value(),
"Either `min` or `max` must be specified");
std::vector<Tensor> result;
result.reserve(self.size());
for (const auto& t : self) {
result.emplace_back(at::native::empty_like(t));
}
std::vector<std::vector<at::Tensor>> tensor_lists{self.vec(), result};
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
self[0].scalar_type(),
"foreach_tensor_clamp_scalar_kernel_cuda",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
multi_tensor_apply<2>(
tensor_lists,
ClampFunctor<
scalar_t,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1>(),
min.has_value() ? min.value().to<opmath_t>()
: max.value().to<opmath_t>(),
max.has_value() ? max.value().to<opmath_t>()
: min.value().to<opmath_t>(),
get_clamp_kind(min, max));
});
return tensor_lists[1];
}
void foreach_tensor_clamp_scalar_kernel_cuda_(
TensorList self,
const optional<Scalar>& min,
const optional<Scalar>& max) {
check_foreach_api_restrictions(self);
if (!can_use_fast_route(
ArrayRef<TensorList>{self}, ArrayRef<Scalar>{}, true)) {
return foreach_tensor_clamp_scalar_kernel_slow_(self, min, max);
}
TORCH_CHECK(
min.has_value() || max.has_value(),
"Either `min` or `max` must be specified");
std::vector<std::vector<at::Tensor>> tensor_lists{self.vec()};
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
self[0].scalar_type(),
"foreach_tensor_clamp_scalar_kernel_cuda",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
multi_tensor_apply<1>(
tensor_lists,
ClampFunctor<
scalar_t,
/* depth */ 1,
/* r_args_depth */ 1,
/* res_arg_index */ 0>(),
min.has_value() ? min.value().to<opmath_t>()
: max.value().to<opmath_t>(),
max.has_value() ? max.value().to<opmath_t>()
: min.value().to<opmath_t>(),
get_clamp_kind(min, max));
});
}
} // namespace at::native

View file

@ -10121,25 +10121,6 @@
CUDA: foreach_tensor_div_scalarlist_kernel_cuda_
autogen: _foreach_div.ScalarList_out
# TODO(crcrpar): Think of adding overloads for `TensorList`s and `ScalarList`s
# ref: https://github.com/pytorch/pytorch/issues/106931
- func: _foreach_clamp(Tensor[] self, Scalar? min=None, Scalar? max=None) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
cpp_no_default_args: ['min']
dispatch:
CPU: foreach_tensor_clamp_scalar_kernel_slow
CUDA: foreach_tensor_clamp_scalar_kernel_cuda
- func: _foreach_clamp_(Tensor(a!)[] self, Scalar? min=None, Scalar? max=None) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
cpp_no_default_args: ['min']
dispatch:
CPU: foreach_tensor_clamp_scalar_kernel_slow_
CUDA: foreach_tensor_clamp_scalar_kernel_cuda_
autogen: _foreach_clamp.out
- func: _foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function

View file

@ -153,11 +153,6 @@ def supports(o, factory_methods):
print("Skipping {} Because of Arg: {} ({}) ".format(
o['name'], arg['type'], arg['dynamic_type']))
return False
# skip _foreach_clamp(Tensor[], Scalar?, Scalar?)
if o['name'] in {"_foreach_clamp"}:
print(f"Skipping {o['name']} because it has multiple scalar arguments with default values.")
return False
return True
@ -332,5 +327,4 @@ if __name__ == '__main__':
top_env['implementations'].append(IMPLEMENTATION_TEMPLATE.substitute(env))
top_env['cases'].append(CASE_TEMPLATE.substitute(env))
key += 1
write(os.path.join(args.install_dir, args.output_prefix + "aten_op.h"), OP_TEMPLATE.substitute(top_env))

View file

@ -161,9 +161,6 @@ aten::_foreach_atan_
aten::_foreach_ceil
aten::_foreach_ceil.out
aten::_foreach_ceil_
aten::_foreach_clamp
aten::_foreach_clamp.out
aten::_foreach_clamp_
aten::_foreach_clamp_max.List
aten::_foreach_clamp_max.List_out
aten::_foreach_clamp_max.Scalar

View file

@ -40,14 +40,6 @@ class RegularFuncWrapper:
if len(inputs) == 2 and isinstance(inputs[1], (Number, torch.Tensor)):
# binary op with tensorlist and scalar.
inputs[1] = [inputs[1] for _ in range(len(inputs[0]))]
# path for `_foreach_clamp(TensorList, Scalar?, Scalar?)`
if (
len(inputs) == 3
and (isinstance(inputs[1], Number) or inputs[1] is None)
and (isinstance(inputs[2], Number) or inputs[2] is None)
):
inputs[1] = [inputs[1] for _ in range(len(inputs[0]))]
inputs[2] = [inputs[2] for _ in range(len(inputs[0]))]
return [self.func(*i, **kwargs) for i in zip(*inputs)]
@ -144,7 +136,7 @@ class TestForeach(TestCase):
op(inputs, self.is_cuda, is_fastpath, zero_size=zero_size)
return
ref_inputs = [[t.clone().detach() for t in inputs[0]], *inputs[1:]] if is_inplace else inputs
ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1]] if is_inplace else inputs
try:
with InplaceForeachVersionBumpCheck(self, inputs[0]) if op.is_inplace else nullcontext():
actual = op(inputs, self.is_cuda, is_fastpath, zero_size=zero_size)
@ -181,18 +173,19 @@ class TestForeach(TestCase):
def test_binary_op(self, device, dtype, op, is_fastpath):
scalar_self_arg_test_complete = False
for i, sample in enumerate(op.sample_inputs(device, dtype, noncontiguous=not is_fastpath)):
(rhs_arg,) = sample.args
zero_size = sample.kwargs.pop("zero_size")
kwargs = {} or sample.kwargs
alpha = kwargs.pop("alpha", None)
disable_fastpath = kwargs.pop("disable_fastpath") if is_fastpath else False
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
self._binary_test(
dtype, wrapped_op, ref, [sample.input, *sample.args],
dtype, wrapped_op, ref, [sample.input, rhs_arg],
is_fastpath and not disable_fastpath, False,
alpha=alpha, zero_size=zero_size, scalar_self_arg=False,
)
self._binary_test(
dtype, inplace_op, inplace_ref, [sample.input, *sample.args],
dtype, inplace_op, inplace_ref, [sample.input, rhs_arg],
is_fastpath and not disable_fastpath, True,
alpha=alpha, zero_size=zero_size, scalar_self_arg=False,
)
@ -200,49 +193,46 @@ class TestForeach(TestCase):
if op.supports_autograd and dtype in floating_types() and not zero_size:
transformed_sample = sample.transform(get_transform_func(len(sample.input), dtype, device, is_fastpath))
tensors = transformed_sample.input
ref_tensors, ref_rhs_arg = clone(tensors), clone(transformed_sample.args)
(rhs_arg,) = transformed_sample.args
ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
try:
sum(
wrapped_op([tensors, *transformed_sample.args], is_cuda=False, is_fastpath=False, zero_size=zero_size)
wrapped_op([tensors, rhs_arg], is_cuda=False, is_fastpath=False, zero_size=zero_size)
).mean().backward()
except RuntimeError:
with self.assertRaises(RuntimeError):
sum(ref([ref_tensors, *ref_rhs_arg])).mean().backward()
sum(ref([ref_tensors, ref_rhs_arg])).mean().backward()
else:
sum(ref([ref_tensors, *ref_rhs_arg])).mean().backward()
sum(ref([ref_tensors, ref_rhs_arg])).mean().backward()
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
if isinstance(transformed_sample.args[0], list) and isinstance(transformed_sample.args[0][0], torch.Tensor):
self.assertEqual([t.grad for t in transformed_sample.args[0]], [t.grad for t in ref_rhs_arg[0]])
if isinstance(rhs_arg, list) and isinstance(rhs_arg[0], torch.Tensor):
self.assertEqual([t.grad for t in rhs_arg], [t.grad for t in ref_rhs_arg])
tensors = [t.clone().detach().requires_grad_().clone() for t in tensors]
ref_tensors = [t.clone().detach().requires_grad_().clone() for t in tensors]
inplace_op([tensors, *transformed_sample.args], is_cuda=False, is_fastpath=False, zero_size=zero_size)
inplace_op([tensors, rhs_arg], is_cuda=False, is_fastpath=False, zero_size=zero_size)
assert_multiple_grad_fns(tensors, self)
# note(crcrpar): the following ops' reference torch functions don't have the overload with Scalar/ScalarList.
is_foreach_max_min_imum_with_scalar_or_scalarlist = (
inplace_op.func in (torch._foreach_minimum_, torch._foreach_maximum_)
and (
isinstance(transformed_sample.args[0], Number)
or (
isinstance(transformed_sample.args[0], list)
and isinstance(transformed_sample.args[0][0], Number)
)
isinstance(rhs_arg, Number) or (isinstance(rhs_arg, list) and isinstance(rhs_arg[0], Number))
)
)
if not is_foreach_max_min_imum_with_scalar_or_scalarlist:
inplace_ref([ref_tensors, *ref_rhs_arg])
inplace_ref([ref_tensors, rhs_arg])
torch.autograd.backward(sum([t.clone() for t in tensors]).sum(), inputs=tensors)
torch.autograd.backward(sum([t.clone() for t in ref_tensors]).sum(), inputs=ref_tensors)
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
if (
op.supports_scalar_self_arg
and isinstance(sample.args[0], Number)
and isinstance(rhs_arg, Number)
and not scalar_self_arg_test_complete
and not zero_size
):
scalar_self_arg_test_complete = True
self._binary_test(
dtype, wrapped_op, ref, [sample.args[0], sample.input], is_fastpath, False,
dtype, wrapped_op, ref, [rhs_arg, sample.input], is_fastpath, False,
alpha=alpha, scalar_self_arg=True, zero_size=False,
)
if op.supports_autograd and dtype == torch.float32 and not zero_size:
@ -566,10 +556,7 @@ class TestForeach(TestCase):
self.assertIsNone(runtime_error)
@skipIfTorchDynamo("Different error msgs, TODO")
@ops(
filter(lambda op: op.name != "_foreach_clamp", foreach_binary_op_db),
dtypes=OpDTypes.supported,
)
@ops(foreach_binary_op_db, dtypes=OpDTypes.supported)
def test_binary_op_list_error_cases(self, device, dtype, op):
foreach_op, foreach_op_, ref, ref_ = op.method_variant, op.inplace_variant, op.ref, op.ref_inplace
tensors1 = []
@ -632,7 +619,7 @@ class TestForeach(TestCase):
foreach_op_([tensor1], [tensor2])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
@ops(filter(lambda op: op.name != "_foreach_clamp", foreach_binary_op_db), dtypes=OpDTypes.supported)
@ops(foreach_binary_op_db, dtypes=OpDTypes.supported)
def test_binary_op_list_slow_path(self, device, dtype, op):
foreach_op, native_op, foreach_op_, native_op_ = self._get_funcs(op)
# 0-strides
@ -681,10 +668,7 @@ class TestForeach(TestCase):
dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
zero_size=False, alpha=None, scalar_self_arg=False)
@ops(
filter(lambda op: op.name != "_foreach_clamp", foreach_binary_op_db),
dtypes=floating_types_and(torch.half, torch.bfloat16),
)
@ops(foreach_binary_op_db, dtypes=floating_types_and(torch.half, torch.bfloat16))
def test_binary_op_float_inf_nan(self, device, dtype, op):
inputs = (
[
@ -739,7 +723,7 @@ class TestForeach(TestCase):
self.assertEqual([torch.zeros_like(t) for t in tensors], tensors)
@onlyCUDA
@ops(filter(lambda op: op.name != "_foreach_clamp", foreach_binary_op_db))
@ops(foreach_binary_op_db)
def test_binary_op_tensors_on_different_devices(self, device, dtype, op):
# `tensors1`: ['cuda', 'cpu']
# `tensors2`: ['cuda', 'cpu']

View file

@ -223,7 +223,7 @@ def _register_foreach_lowering(aten_fn, decomp_fn):
@functools.wraps(decomp_fn)
def wrapped(*args, **kwargs):
assert len(args) <= 3
assert len(args) <= 2
out = decomp_fn(*args, **kwargs)
validate_ir(out)
return out
@ -4428,8 +4428,8 @@ logical_xor = register_pointwise(
)
maximum = register_pointwise(aten.maximum)
minimum = register_pointwise(aten.minimum)
clamp_min = register_lowering(aten.clamp_min)(maximum)
clamp_max = register_lowering(aten.clamp_max)(minimum)
register_lowering(aten.clamp_min)(maximum)
register_lowering(aten.clamp_max)(minimum)
neg = register_pointwise(aten.neg)
reciprocal = register_pointwise_numeric(aten.reciprocal)
register_pointwise(aten.remainder)
@ -4478,17 +4478,6 @@ register_foreach_pointwise(aten._foreach_reciprocal, reciprocal)
register_foreach_pointwise(aten._foreach_sign, sign)
def _clamp_impl(v, min=None, max=None):
if min is not None:
v = clamp_min(v, min)
if max is not None:
v = clamp_max(v, max)
return v
register_foreach_pointwise(aten._foreach_clamp, _clamp_impl)
def register_inplace(aten_op, outplace_op):
@register_lowering(aten_op, type_promotion_kind=None)
def fn(*args, **kwargs):

View file

@ -2956,29 +2956,6 @@ def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars):
)
@register_meta([aten._foreach_clamp.default])
def meta__foreach_clamp(self, min, max):
torch._check(
isinstance(self, List),
lambda: f"self must be a tensor list but got {type(self)}",
)
torch._check(
min is not None or max is not None, lambda: "`min` or `max` must be specified"
)
return [torch.empty_like(t) for t in self]
@register_meta([aten._foreach_clamp_.default])
def meta__foreach_clamp_(self, min, max):
torch._check(
isinstance(self, List),
lambda: f"self must be a tensor list but got {type(self)}",
)
torch._check(
min is not None or max is not None, lambda: "`min` or `max` must be specified"
)
@register_meta([aten._fused_adam_.default])
def meta__fused_adam_(
self,

View file

@ -315,7 +315,8 @@ def _multi_tensor_rprop(
# update stepsizes with step size updates
torch._foreach_mul_(grouped_step_sizes, signs)
torch._foreach_clamp_(grouped_step_sizes, step_size_min, step_size_max)
for step_size in grouped_step_sizes:
step_size.clamp_(step_size_min, step_size_max)
# for dir<0, dfdx=0
# for dir>=0 dfdx=dfdx

View file

@ -8552,46 +8552,6 @@ class foreach_lerp_sample_func(foreach_inputs_sample_func):
raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}")
class foreach_clamp_sample_func(foreach_inputs_sample_func):
def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
num_input_tensors_specified = "num_input_tensors" in kwargs
num_input_tensors = kwargs.pop("num_input_tensors") if num_input_tensors_specified else foreach_num_tensors
assert isinstance(num_input_tensors, list)
_foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
_foreach_inputs_kwargs["requires_grad"] = requires_grad
# zero_size tensor
if dtype == torch.float32 and (not num_input_tensors_specified) and ("cuda" in device):
zero_size_foreach_inputs_kwargs = copy.deepcopy(_foreach_inputs_kwargs)
zero_size_foreach_inputs_kwargs["zero_size"] = True
input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, **zero_size_foreach_inputs_kwargs)
args = np.random.uniform(size=(2,)).tolist()
kwargs = {
"zero_size": True,
"disable_fastpath": dtype in integral_types_and(torch.bool),
}
yield SampleInput(input, *args, **kwargs)
for num_tensors, args in product(
num_input_tensors,
(
(-1, 1),
(-1, None),
(None, 1),
(3, 1),
),
):
_foreach_inputs_kwargs["zero_size"] = False
input = sample_inputs_foreach(
None, device, dtype, num_tensors, **_foreach_inputs_kwargs)
kwargs = {
"zero_size": False,
"disable_fastpath": dtype in integral_types_and(torch.bool),
}
yield SampleInput(input, *args, **kwargs)
class foreach_pointwise_sample_func(foreach_inputs_sample_func):
def __init__(
@ -8925,15 +8885,6 @@ foreach_binary_op_db: List[OpInfo] = [
supports_autograd=True,
supports_forward_ad=True,
),
ForeachFuncInfo(
"clamp",
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
supports_alpha_param=False,
sample_inputs_func=foreach_clamp_sample_func(2, True, True),
supports_autograd=True,
supports_forward_ad=True,
),
ForeachFuncInfo(
"clamp_min",
dtypes=all_types_and(torch.bfloat16),