mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "Add _foreach_clamp (#106574)"
This reverts commit 2b560d3c3a.
Reverted https://github.com/pytorch/pytorch/pull/106574 on behalf of https://github.com/kit1980 due to breaking internal windows builds ([comment](https://github.com/pytorch/pytorch/pull/106574#issuecomment-1675400335))
This commit is contained in:
parent
c9cdcb299a
commit
354484ea6d
10 changed files with 25 additions and 344 deletions
|
|
@ -2,7 +2,6 @@
|
|||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/ForeachUtils.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
|
|
@ -20,7 +19,6 @@
|
|||
#include <ATen/ops/_foreach_ceil_native.h>
|
||||
#include <ATen/ops/_foreach_clamp_max_native.h>
|
||||
#include <ATen/ops/_foreach_clamp_min_native.h>
|
||||
#include <ATen/ops/_foreach_clamp_native.h>
|
||||
#include <ATen/ops/_foreach_cos_native.h>
|
||||
#include <ATen/ops/_foreach_cosh_native.h>
|
||||
#include <ATen/ops/_foreach_div_native.h>
|
||||
|
|
@ -410,33 +408,4 @@ std::vector<Tensor> foreach_scalar_pow_list_kernel_slow(
|
|||
return result;
|
||||
}
|
||||
|
||||
std::vector<Tensor> foreach_tensor_clamp_scalar_kernel_slow(
|
||||
TensorList self,
|
||||
const optional<Scalar>& min,
|
||||
const optional<Scalar>& max) {
|
||||
TORCH_CHECK(
|
||||
min.has_value() || max.has_value(),
|
||||
"Either `min` or `max` must be specified");
|
||||
check_foreach_api_restrictions(self);
|
||||
std::vector<Tensor> result;
|
||||
result.reserve(self.size());
|
||||
for (const auto& t : self) {
|
||||
result.emplace_back(t.clamp(min, max));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void foreach_tensor_clamp_scalar_kernel_slow_(
|
||||
TensorList self,
|
||||
const optional<Scalar>& min,
|
||||
const optional<Scalar>& max) {
|
||||
TORCH_CHECK(
|
||||
min.has_value() || max.has_value(),
|
||||
"Either `min` or `max` must be specified");
|
||||
check_foreach_api_restrictions(self);
|
||||
for (auto& t : self) {
|
||||
t.clamp_(min, max);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <ATen/native/BinaryOps.h>
|
||||
#include <ATen/native/ForeachUtils.h>
|
||||
#include <ATen/native/TensorCompare.h>
|
||||
#include <ATen/native/cuda/ForeachFunctors.cuh>
|
||||
#include <ATen/native/cuda/ForeachMinMaxFunctors.cuh>
|
||||
|
||||
|
|
@ -13,7 +11,6 @@
|
|||
#include <ATen/ops/_foreach_add_native.h>
|
||||
#include <ATen/ops/_foreach_clamp_max_native.h>
|
||||
#include <ATen/ops/_foreach_clamp_min_native.h>
|
||||
#include <ATen/ops/_foreach_clamp_native.h>
|
||||
#include <ATen/ops/_foreach_div_native.h>
|
||||
#include <ATen/ops/_foreach_mul_native.h>
|
||||
#include <ATen/ops/_foreach_pow_native.h>
|
||||
|
|
@ -245,163 +242,4 @@ std::vector<Tensor> foreach_tensor_sub_scalar_kernel_cuda(
|
|||
FOREACH_BINARY_OP_SCALAR(all_types_half_bfloat16, clamp_max, minimum, false);
|
||||
FOREACH_BINARY_OP_SCALAR(all_types_half_bfloat16, clamp_min, maximum, false);
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
__forceinline__ C10_DEVICE T clamp(
|
||||
const T v,
|
||||
const T lower,
|
||||
const T upper,
|
||||
const at::native::detail::ClampLimits& clamp_kind) {
|
||||
// Propagate nan, which doesn't propagate automatically for ROCm
|
||||
if (at::_isnan(v)) {
|
||||
return v;
|
||||
} else if (clamp_kind == at::native::detail::ClampLimits::Min) {
|
||||
return ::max(v, lower);
|
||||
} else if (clamp_kind == at::native::detail::ClampLimits::Max) {
|
||||
return ::min(v, upper);
|
||||
} else {
|
||||
return ::min(::max(v, lower), upper);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct ClampFunctor {
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
__forceinline__ C10_DEVICE T operator()(
|
||||
const int chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
opmath_t lower,
|
||||
opmath_t upper,
|
||||
const at::native::detail::ClampLimits& clamp_kind) {
|
||||
static_assert(depth == 1 || depth == 2, "");
|
||||
static_assert(depth >= r_args_depth, "");
|
||||
static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
auto n = tl.numel_for_tensor[tensor_loc];
|
||||
|
||||
T* args[depth];
|
||||
const bool all_aligned =
|
||||
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
n -= chunk_idx * chunk_size;
|
||||
T r_args[r_args_depth][kILP];
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
||||
for (int64_t i_start = threadIdx.x;
|
||||
i_start * kILP < n && i_start * kILP < chunk_size;
|
||||
i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_args[0], args[0], 0, i_start);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < kILP; ii++) {
|
||||
r_args[0][ii] = clamp(
|
||||
static_cast<opmath_t>(r_args[0][ii]), lower, upper, clamp_kind);
|
||||
}
|
||||
// store
|
||||
load_store(args[res_arg_index], r_args[0], i_start, 0);
|
||||
}
|
||||
} else {
|
||||
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * kILP) {
|
||||
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < kILP; ii++) {
|
||||
r_args[0][ii] = clamp(
|
||||
static_cast<opmath_t>(r_args[0][ii]), lower, upper, clamp_kind);
|
||||
}
|
||||
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
at::native::detail::ClampLimits get_clamp_kind(
|
||||
const optional<Scalar>& min,
|
||||
const optional<Scalar>& max) {
|
||||
if (min.has_value() && max.has_value()) {
|
||||
return at::native::detail::ClampLimits::MinMax;
|
||||
} else if (min.has_value()) {
|
||||
return at::native::detail::ClampLimits::Min;
|
||||
} else {
|
||||
return at::native::detail::ClampLimits::Max;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> foreach_tensor_clamp_scalar_kernel_cuda(
|
||||
TensorList self,
|
||||
const optional<Scalar>& min,
|
||||
const optional<Scalar>& max) {
|
||||
check_foreach_api_restrictions(self);
|
||||
if (!can_use_fast_route(
|
||||
ArrayRef<TensorList>{self}, ArrayRef<Scalar>{}, true)) {
|
||||
return foreach_tensor_clamp_scalar_kernel_slow(self, min, max);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
min.has_value() || max.has_value(),
|
||||
"Either `min` or `max` must be specified");
|
||||
std::vector<Tensor> result;
|
||||
result.reserve(self.size());
|
||||
for (const auto& t : self) {
|
||||
result.emplace_back(at::native::empty_like(t));
|
||||
}
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists{self.vec(), result};
|
||||
AT_DISPATCH_ALL_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
self[0].scalar_type(),
|
||||
"foreach_tensor_clamp_scalar_kernel_cuda",
|
||||
[&]() {
|
||||
using opmath_t = typename at::opmath_type<scalar_t>;
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
ClampFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 1>(),
|
||||
min.has_value() ? min.value().to<opmath_t>()
|
||||
: max.value().to<opmath_t>(),
|
||||
max.has_value() ? max.value().to<opmath_t>()
|
||||
: min.value().to<opmath_t>(),
|
||||
get_clamp_kind(min, max));
|
||||
});
|
||||
return tensor_lists[1];
|
||||
}
|
||||
|
||||
void foreach_tensor_clamp_scalar_kernel_cuda_(
|
||||
TensorList self,
|
||||
const optional<Scalar>& min,
|
||||
const optional<Scalar>& max) {
|
||||
check_foreach_api_restrictions(self);
|
||||
if (!can_use_fast_route(
|
||||
ArrayRef<TensorList>{self}, ArrayRef<Scalar>{}, true)) {
|
||||
return foreach_tensor_clamp_scalar_kernel_slow_(self, min, max);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
min.has_value() || max.has_value(),
|
||||
"Either `min` or `max` must be specified");
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists{self.vec()};
|
||||
AT_DISPATCH_ALL_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
self[0].scalar_type(),
|
||||
"foreach_tensor_clamp_scalar_kernel_cuda",
|
||||
[&]() {
|
||||
using opmath_t = typename at::opmath_type<scalar_t>;
|
||||
multi_tensor_apply<1>(
|
||||
tensor_lists,
|
||||
ClampFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 1,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 0>(),
|
||||
min.has_value() ? min.value().to<opmath_t>()
|
||||
: max.value().to<opmath_t>(),
|
||||
max.has_value() ? max.value().to<opmath_t>()
|
||||
: min.value().to<opmath_t>(),
|
||||
get_clamp_kind(min, max));
|
||||
});
|
||||
}
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -10121,25 +10121,6 @@
|
|||
CUDA: foreach_tensor_div_scalarlist_kernel_cuda_
|
||||
autogen: _foreach_div.ScalarList_out
|
||||
|
||||
# TODO(crcrpar): Think of adding overloads for `TensorList`s and `ScalarList`s
|
||||
# ref: https://github.com/pytorch/pytorch/issues/106931
|
||||
- func: _foreach_clamp(Tensor[] self, Scalar? min=None, Scalar? max=None) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
variants: function
|
||||
cpp_no_default_args: ['min']
|
||||
dispatch:
|
||||
CPU: foreach_tensor_clamp_scalar_kernel_slow
|
||||
CUDA: foreach_tensor_clamp_scalar_kernel_cuda
|
||||
|
||||
- func: _foreach_clamp_(Tensor(a!)[] self, Scalar? min=None, Scalar? max=None) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
variants: function
|
||||
cpp_no_default_args: ['min']
|
||||
dispatch:
|
||||
CPU: foreach_tensor_clamp_scalar_kernel_slow_
|
||||
CUDA: foreach_tensor_clamp_scalar_kernel_cuda_
|
||||
autogen: _foreach_clamp.out
|
||||
|
||||
- func: _foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
variants: function
|
||||
|
|
|
|||
|
|
@ -153,11 +153,6 @@ def supports(o, factory_methods):
|
|||
print("Skipping {} Because of Arg: {} ({}) ".format(
|
||||
o['name'], arg['type'], arg['dynamic_type']))
|
||||
return False
|
||||
|
||||
# skip _foreach_clamp(Tensor[], Scalar?, Scalar?)
|
||||
if o['name'] in {"_foreach_clamp"}:
|
||||
print(f"Skipping {o['name']} because it has multiple scalar arguments with default values.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -332,5 +327,4 @@ if __name__ == '__main__':
|
|||
top_env['implementations'].append(IMPLEMENTATION_TEMPLATE.substitute(env))
|
||||
top_env['cases'].append(CASE_TEMPLATE.substitute(env))
|
||||
key += 1
|
||||
|
||||
write(os.path.join(args.install_dir, args.output_prefix + "aten_op.h"), OP_TEMPLATE.substitute(top_env))
|
||||
|
|
|
|||
|
|
@ -161,9 +161,6 @@ aten::_foreach_atan_
|
|||
aten::_foreach_ceil
|
||||
aten::_foreach_ceil.out
|
||||
aten::_foreach_ceil_
|
||||
aten::_foreach_clamp
|
||||
aten::_foreach_clamp.out
|
||||
aten::_foreach_clamp_
|
||||
aten::_foreach_clamp_max.List
|
||||
aten::_foreach_clamp_max.List_out
|
||||
aten::_foreach_clamp_max.Scalar
|
||||
|
|
|
|||
|
|
@ -40,14 +40,6 @@ class RegularFuncWrapper:
|
|||
if len(inputs) == 2 and isinstance(inputs[1], (Number, torch.Tensor)):
|
||||
# binary op with tensorlist and scalar.
|
||||
inputs[1] = [inputs[1] for _ in range(len(inputs[0]))]
|
||||
# path for `_foreach_clamp(TensorList, Scalar?, Scalar?)`
|
||||
if (
|
||||
len(inputs) == 3
|
||||
and (isinstance(inputs[1], Number) or inputs[1] is None)
|
||||
and (isinstance(inputs[2], Number) or inputs[2] is None)
|
||||
):
|
||||
inputs[1] = [inputs[1] for _ in range(len(inputs[0]))]
|
||||
inputs[2] = [inputs[2] for _ in range(len(inputs[0]))]
|
||||
return [self.func(*i, **kwargs) for i in zip(*inputs)]
|
||||
|
||||
|
||||
|
|
@ -144,7 +136,7 @@ class TestForeach(TestCase):
|
|||
op(inputs, self.is_cuda, is_fastpath, zero_size=zero_size)
|
||||
return
|
||||
|
||||
ref_inputs = [[t.clone().detach() for t in inputs[0]], *inputs[1:]] if is_inplace else inputs
|
||||
ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1]] if is_inplace else inputs
|
||||
try:
|
||||
with InplaceForeachVersionBumpCheck(self, inputs[0]) if op.is_inplace else nullcontext():
|
||||
actual = op(inputs, self.is_cuda, is_fastpath, zero_size=zero_size)
|
||||
|
|
@ -181,18 +173,19 @@ class TestForeach(TestCase):
|
|||
def test_binary_op(self, device, dtype, op, is_fastpath):
|
||||
scalar_self_arg_test_complete = False
|
||||
for i, sample in enumerate(op.sample_inputs(device, dtype, noncontiguous=not is_fastpath)):
|
||||
(rhs_arg,) = sample.args
|
||||
zero_size = sample.kwargs.pop("zero_size")
|
||||
kwargs = {} or sample.kwargs
|
||||
alpha = kwargs.pop("alpha", None)
|
||||
disable_fastpath = kwargs.pop("disable_fastpath") if is_fastpath else False
|
||||
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
|
||||
self._binary_test(
|
||||
dtype, wrapped_op, ref, [sample.input, *sample.args],
|
||||
dtype, wrapped_op, ref, [sample.input, rhs_arg],
|
||||
is_fastpath and not disable_fastpath, False,
|
||||
alpha=alpha, zero_size=zero_size, scalar_self_arg=False,
|
||||
)
|
||||
self._binary_test(
|
||||
dtype, inplace_op, inplace_ref, [sample.input, *sample.args],
|
||||
dtype, inplace_op, inplace_ref, [sample.input, rhs_arg],
|
||||
is_fastpath and not disable_fastpath, True,
|
||||
alpha=alpha, zero_size=zero_size, scalar_self_arg=False,
|
||||
)
|
||||
|
|
@ -200,49 +193,46 @@ class TestForeach(TestCase):
|
|||
if op.supports_autograd and dtype in floating_types() and not zero_size:
|
||||
transformed_sample = sample.transform(get_transform_func(len(sample.input), dtype, device, is_fastpath))
|
||||
tensors = transformed_sample.input
|
||||
ref_tensors, ref_rhs_arg = clone(tensors), clone(transformed_sample.args)
|
||||
(rhs_arg,) = transformed_sample.args
|
||||
ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
|
||||
try:
|
||||
sum(
|
||||
wrapped_op([tensors, *transformed_sample.args], is_cuda=False, is_fastpath=False, zero_size=zero_size)
|
||||
wrapped_op([tensors, rhs_arg], is_cuda=False, is_fastpath=False, zero_size=zero_size)
|
||||
).mean().backward()
|
||||
except RuntimeError:
|
||||
with self.assertRaises(RuntimeError):
|
||||
sum(ref([ref_tensors, *ref_rhs_arg])).mean().backward()
|
||||
sum(ref([ref_tensors, ref_rhs_arg])).mean().backward()
|
||||
else:
|
||||
sum(ref([ref_tensors, *ref_rhs_arg])).mean().backward()
|
||||
sum(ref([ref_tensors, ref_rhs_arg])).mean().backward()
|
||||
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
|
||||
if isinstance(transformed_sample.args[0], list) and isinstance(transformed_sample.args[0][0], torch.Tensor):
|
||||
self.assertEqual([t.grad for t in transformed_sample.args[0]], [t.grad for t in ref_rhs_arg[0]])
|
||||
if isinstance(rhs_arg, list) and isinstance(rhs_arg[0], torch.Tensor):
|
||||
self.assertEqual([t.grad for t in rhs_arg], [t.grad for t in ref_rhs_arg])
|
||||
tensors = [t.clone().detach().requires_grad_().clone() for t in tensors]
|
||||
ref_tensors = [t.clone().detach().requires_grad_().clone() for t in tensors]
|
||||
inplace_op([tensors, *transformed_sample.args], is_cuda=False, is_fastpath=False, zero_size=zero_size)
|
||||
inplace_op([tensors, rhs_arg], is_cuda=False, is_fastpath=False, zero_size=zero_size)
|
||||
assert_multiple_grad_fns(tensors, self)
|
||||
|
||||
# note(crcrpar): the following ops' reference torch functions don't have the overload with Scalar/ScalarList.
|
||||
is_foreach_max_min_imum_with_scalar_or_scalarlist = (
|
||||
inplace_op.func in (torch._foreach_minimum_, torch._foreach_maximum_)
|
||||
and (
|
||||
isinstance(transformed_sample.args[0], Number)
|
||||
or (
|
||||
isinstance(transformed_sample.args[0], list)
|
||||
and isinstance(transformed_sample.args[0][0], Number)
|
||||
)
|
||||
isinstance(rhs_arg, Number) or (isinstance(rhs_arg, list) and isinstance(rhs_arg[0], Number))
|
||||
)
|
||||
)
|
||||
if not is_foreach_max_min_imum_with_scalar_or_scalarlist:
|
||||
inplace_ref([ref_tensors, *ref_rhs_arg])
|
||||
inplace_ref([ref_tensors, rhs_arg])
|
||||
torch.autograd.backward(sum([t.clone() for t in tensors]).sum(), inputs=tensors)
|
||||
torch.autograd.backward(sum([t.clone() for t in ref_tensors]).sum(), inputs=ref_tensors)
|
||||
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
|
||||
if (
|
||||
op.supports_scalar_self_arg
|
||||
and isinstance(sample.args[0], Number)
|
||||
and isinstance(rhs_arg, Number)
|
||||
and not scalar_self_arg_test_complete
|
||||
and not zero_size
|
||||
):
|
||||
scalar_self_arg_test_complete = True
|
||||
self._binary_test(
|
||||
dtype, wrapped_op, ref, [sample.args[0], sample.input], is_fastpath, False,
|
||||
dtype, wrapped_op, ref, [rhs_arg, sample.input], is_fastpath, False,
|
||||
alpha=alpha, scalar_self_arg=True, zero_size=False,
|
||||
)
|
||||
if op.supports_autograd and dtype == torch.float32 and not zero_size:
|
||||
|
|
@ -566,10 +556,7 @@ class TestForeach(TestCase):
|
|||
self.assertIsNone(runtime_error)
|
||||
|
||||
@skipIfTorchDynamo("Different error msgs, TODO")
|
||||
@ops(
|
||||
filter(lambda op: op.name != "_foreach_clamp", foreach_binary_op_db),
|
||||
dtypes=OpDTypes.supported,
|
||||
)
|
||||
@ops(foreach_binary_op_db, dtypes=OpDTypes.supported)
|
||||
def test_binary_op_list_error_cases(self, device, dtype, op):
|
||||
foreach_op, foreach_op_, ref, ref_ = op.method_variant, op.inplace_variant, op.ref, op.ref_inplace
|
||||
tensors1 = []
|
||||
|
|
@ -632,7 +619,7 @@ class TestForeach(TestCase):
|
|||
foreach_op_([tensor1], [tensor2])
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
|
||||
@ops(filter(lambda op: op.name != "_foreach_clamp", foreach_binary_op_db), dtypes=OpDTypes.supported)
|
||||
@ops(foreach_binary_op_db, dtypes=OpDTypes.supported)
|
||||
def test_binary_op_list_slow_path(self, device, dtype, op):
|
||||
foreach_op, native_op, foreach_op_, native_op_ = self._get_funcs(op)
|
||||
# 0-strides
|
||||
|
|
@ -681,10 +668,7 @@ class TestForeach(TestCase):
|
|||
dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
|
||||
zero_size=False, alpha=None, scalar_self_arg=False)
|
||||
|
||||
@ops(
|
||||
filter(lambda op: op.name != "_foreach_clamp", foreach_binary_op_db),
|
||||
dtypes=floating_types_and(torch.half, torch.bfloat16),
|
||||
)
|
||||
@ops(foreach_binary_op_db, dtypes=floating_types_and(torch.half, torch.bfloat16))
|
||||
def test_binary_op_float_inf_nan(self, device, dtype, op):
|
||||
inputs = (
|
||||
[
|
||||
|
|
@ -739,7 +723,7 @@ class TestForeach(TestCase):
|
|||
self.assertEqual([torch.zeros_like(t) for t in tensors], tensors)
|
||||
|
||||
@onlyCUDA
|
||||
@ops(filter(lambda op: op.name != "_foreach_clamp", foreach_binary_op_db))
|
||||
@ops(foreach_binary_op_db)
|
||||
def test_binary_op_tensors_on_different_devices(self, device, dtype, op):
|
||||
# `tensors1`: ['cuda', 'cpu']
|
||||
# `tensors2`: ['cuda', 'cpu']
|
||||
|
|
|
|||
|
|
@ -223,7 +223,7 @@ def _register_foreach_lowering(aten_fn, decomp_fn):
|
|||
|
||||
@functools.wraps(decomp_fn)
|
||||
def wrapped(*args, **kwargs):
|
||||
assert len(args) <= 3
|
||||
assert len(args) <= 2
|
||||
out = decomp_fn(*args, **kwargs)
|
||||
validate_ir(out)
|
||||
return out
|
||||
|
|
@ -4428,8 +4428,8 @@ logical_xor = register_pointwise(
|
|||
)
|
||||
maximum = register_pointwise(aten.maximum)
|
||||
minimum = register_pointwise(aten.minimum)
|
||||
clamp_min = register_lowering(aten.clamp_min)(maximum)
|
||||
clamp_max = register_lowering(aten.clamp_max)(minimum)
|
||||
register_lowering(aten.clamp_min)(maximum)
|
||||
register_lowering(aten.clamp_max)(minimum)
|
||||
neg = register_pointwise(aten.neg)
|
||||
reciprocal = register_pointwise_numeric(aten.reciprocal)
|
||||
register_pointwise(aten.remainder)
|
||||
|
|
@ -4478,17 +4478,6 @@ register_foreach_pointwise(aten._foreach_reciprocal, reciprocal)
|
|||
register_foreach_pointwise(aten._foreach_sign, sign)
|
||||
|
||||
|
||||
def _clamp_impl(v, min=None, max=None):
|
||||
if min is not None:
|
||||
v = clamp_min(v, min)
|
||||
if max is not None:
|
||||
v = clamp_max(v, max)
|
||||
return v
|
||||
|
||||
|
||||
register_foreach_pointwise(aten._foreach_clamp, _clamp_impl)
|
||||
|
||||
|
||||
def register_inplace(aten_op, outplace_op):
|
||||
@register_lowering(aten_op, type_promotion_kind=None)
|
||||
def fn(*args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -2956,29 +2956,6 @@ def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars):
|
|||
)
|
||||
|
||||
|
||||
@register_meta([aten._foreach_clamp.default])
|
||||
def meta__foreach_clamp(self, min, max):
|
||||
torch._check(
|
||||
isinstance(self, List),
|
||||
lambda: f"self must be a tensor list but got {type(self)}",
|
||||
)
|
||||
torch._check(
|
||||
min is not None or max is not None, lambda: "`min` or `max` must be specified"
|
||||
)
|
||||
return [torch.empty_like(t) for t in self]
|
||||
|
||||
|
||||
@register_meta([aten._foreach_clamp_.default])
|
||||
def meta__foreach_clamp_(self, min, max):
|
||||
torch._check(
|
||||
isinstance(self, List),
|
||||
lambda: f"self must be a tensor list but got {type(self)}",
|
||||
)
|
||||
torch._check(
|
||||
min is not None or max is not None, lambda: "`min` or `max` must be specified"
|
||||
)
|
||||
|
||||
|
||||
@register_meta([aten._fused_adam_.default])
|
||||
def meta__fused_adam_(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -315,7 +315,8 @@ def _multi_tensor_rprop(
|
|||
|
||||
# update stepsizes with step size updates
|
||||
torch._foreach_mul_(grouped_step_sizes, signs)
|
||||
torch._foreach_clamp_(grouped_step_sizes, step_size_min, step_size_max)
|
||||
for step_size in grouped_step_sizes:
|
||||
step_size.clamp_(step_size_min, step_size_max)
|
||||
|
||||
# for dir<0, dfdx=0
|
||||
# for dir>=0 dfdx=dfdx
|
||||
|
|
|
|||
|
|
@ -8552,46 +8552,6 @@ class foreach_lerp_sample_func(foreach_inputs_sample_func):
|
|||
raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}")
|
||||
|
||||
|
||||
class foreach_clamp_sample_func(foreach_inputs_sample_func):
|
||||
|
||||
def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
|
||||
num_input_tensors_specified = "num_input_tensors" in kwargs
|
||||
num_input_tensors = kwargs.pop("num_input_tensors") if num_input_tensors_specified else foreach_num_tensors
|
||||
assert isinstance(num_input_tensors, list)
|
||||
_foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
|
||||
_foreach_inputs_kwargs["requires_grad"] = requires_grad
|
||||
|
||||
# zero_size tensor
|
||||
if dtype == torch.float32 and (not num_input_tensors_specified) and ("cuda" in device):
|
||||
zero_size_foreach_inputs_kwargs = copy.deepcopy(_foreach_inputs_kwargs)
|
||||
zero_size_foreach_inputs_kwargs["zero_size"] = True
|
||||
input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, **zero_size_foreach_inputs_kwargs)
|
||||
args = np.random.uniform(size=(2,)).tolist()
|
||||
kwargs = {
|
||||
"zero_size": True,
|
||||
"disable_fastpath": dtype in integral_types_and(torch.bool),
|
||||
}
|
||||
yield SampleInput(input, *args, **kwargs)
|
||||
|
||||
for num_tensors, args in product(
|
||||
num_input_tensors,
|
||||
(
|
||||
(-1, 1),
|
||||
(-1, None),
|
||||
(None, 1),
|
||||
(3, 1),
|
||||
),
|
||||
):
|
||||
_foreach_inputs_kwargs["zero_size"] = False
|
||||
input = sample_inputs_foreach(
|
||||
None, device, dtype, num_tensors, **_foreach_inputs_kwargs)
|
||||
kwargs = {
|
||||
"zero_size": False,
|
||||
"disable_fastpath": dtype in integral_types_and(torch.bool),
|
||||
}
|
||||
yield SampleInput(input, *args, **kwargs)
|
||||
|
||||
|
||||
class foreach_pointwise_sample_func(foreach_inputs_sample_func):
|
||||
|
||||
def __init__(
|
||||
|
|
@ -8925,15 +8885,6 @@ foreach_binary_op_db: List[OpInfo] = [
|
|||
supports_autograd=True,
|
||||
supports_forward_ad=True,
|
||||
),
|
||||
ForeachFuncInfo(
|
||||
"clamp",
|
||||
dtypes=all_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
|
||||
supports_alpha_param=False,
|
||||
sample_inputs_func=foreach_clamp_sample_func(2, True, True),
|
||||
supports_autograd=True,
|
||||
supports_forward_ad=True,
|
||||
),
|
||||
ForeachFuncInfo(
|
||||
"clamp_min",
|
||||
dtypes=all_types_and(torch.bfloat16),
|
||||
|
|
|
|||
Loading…
Reference in a new issue