lshift and rshift stop support floating types (#77146)

Fixes #74358

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77146
Approved by: https://github.com/ngimel
This commit is contained in:
Xiang Gao 2022-05-11 22:29:30 +00:00 committed by PyTorch MergeBot
parent 166a466e7f
commit cc9d0f309e
9 changed files with 85 additions and 198 deletions

View file

@ -312,26 +312,12 @@ void bitwise_xor_kernel(TensorIteratorBase& iter) {
}
void lshift_kernel(TensorIteratorBase& iter) {
if (iter.dtype() == ScalarType::Float || iter.dtype() == ScalarType::Double) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "lshift_cpu", [&]() {
auto base_vec = Vectorized<scalar_t>((scalar_t)(2));
cpu_kernel_vec(
iter,
[=](scalar_t a, scalar_t b) -> scalar_t {
return a * std::pow((scalar_t)(2), b);
},
[=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
return a * base_vec.pow(b);
});
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return static_cast<std::make_unsigned_t<scalar_t>>(a) << b;
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return static_cast<std::make_unsigned_t<scalar_t>>(a) << b;
});
});
}
});
}
void logical_and_kernel(TensorIterator& iter) {
@ -392,26 +378,12 @@ void logical_xor_kernel(TensorIterator& iter) {
}
void rshift_kernel(TensorIteratorBase& iter) {
if (iter.dtype() == ScalarType::Float || iter.dtype() == ScalarType::Double) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "rshift_cpu", [&]() {
auto base_vec = Vectorized<scalar_t>((scalar_t)(2));
cpu_kernel_vec(
iter,
[=](scalar_t a, scalar_t b) -> scalar_t {
return a / std::pow((scalar_t)(2), b);
},
[=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
return a / base_vec.pow(b);
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return a >> b;
});
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return a >> b;
});
});
}
});
}
void lt_kernel(TensorIteratorBase& iter) {

View file

@ -12,47 +12,21 @@ namespace at { namespace native {
void lshift_kernel_cuda(TensorIteratorBase& iter) {
if (iter.dtype() == ScalarType::Float ||
iter.dtype() == ScalarType::Double ||
iter.dtype() == ScalarType::Half ||
iter.dtype() == ScalarType::BFloat16) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "lshift_cuda", [&]() {
gpu_kernel_with_scalars(
iter,
[]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a * std::pow(static_cast<scalar_t>(2), b);
});
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cuda", [&]() {
gpu_kernel_with_scalars(iter,
[]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return static_cast<std::make_unsigned_t<scalar_t>>(a) << b;
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cuda", [&]() {
gpu_kernel_with_scalars(iter,
[]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return static_cast<std::make_unsigned_t<scalar_t>>(a) << b;
});
});
}
});
}
void rshift_kernel_cuda(TensorIteratorBase& iter) {
if (iter.dtype() == ScalarType::Float ||
iter.dtype() == ScalarType::Double ||
iter.dtype() == ScalarType::Half ||
iter.dtype() == ScalarType::BFloat16) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "rshift_cuda", [&]() {
gpu_kernel_with_scalars(
iter,
[]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a / std::pow(static_cast<scalar_t>(2), b);
});
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cuda", [&]() {
gpu_kernel_with_scalars(iter,
[]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a >> b;
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cuda", [&]() {
gpu_kernel_with_scalars(iter,
[]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a >> b;
});
});
}
});
}
REGISTER_DISPATCH(lshift_stub, &lshift_kernel_cuda);

View file

@ -8986,25 +8986,30 @@ TEST_F(LazyOpsTest, TestBitwiseXorScalarInPlace) {
TEST_F(LazyOpsTest, TestLshift) {
torch::Tensor input = torch::ones(
{4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
{4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
torch::Tensor shift_amount = torch::randint(
16, input.sizes(), torch::TensorOptions().device(DefaultDevice()));
16,
input.sizes(),
torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
torch::Tensor result = torch::__lshift__(input, shift_amount);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_input = CopyToDevice(input, device);
torch::Tensor lazy_shift_amount = CopyToDevice(shift_amount, device);
torch::Tensor lazy_result = torch::__lshift__(lazy_input, lazy_shift_amount);
torch::Tensor lazy_result =
torch::__lshift__(lazy_input, lazy_shift_amount);
AllClose(result, lazy_result);
});
}
TEST_F(LazyOpsTest, TestLshiftInPlace) {
torch::Tensor input = torch::ones(
{4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
{4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_input = CopyToDevice(input, device);
torch::Tensor shift_amount = torch::randint(
16, input.sizes(), torch::TensorOptions().device(DefaultDevice()));
16,
input.sizes(),
torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
torch::Tensor result = input.__ilshift__(shift_amount);
torch::Tensor lazy_shift_amount = CopyToDevice(shift_amount, device);
torch::Tensor lazy_result = lazy_input.__ilshift__(lazy_shift_amount);
@ -9015,7 +9020,7 @@ TEST_F(LazyOpsTest, TestLshiftInPlace) {
TEST_F(LazyOpsTest, TestLshiftScalar) {
torch::Tensor input = torch::ones(
{4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
{4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
torch::Scalar shift_amount = 3;
torch::Tensor result = torch::__lshift__(input, shift_amount);
ForEachDevice([&](const torch::Device& device) {
@ -9027,7 +9032,7 @@ TEST_F(LazyOpsTest, TestLshiftScalar) {
TEST_F(LazyOpsTest, TestLshiftScalarInPlace) {
torch::Tensor input = torch::ones(
{4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
{4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
torch::Scalar shift_amount = 3;
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_input = CopyToDevice(input, device);
@ -9040,25 +9045,30 @@ TEST_F(LazyOpsTest, TestLshiftScalarInPlace) {
TEST_F(LazyOpsTest, TestRshift) {
torch::Tensor input = torch::ones(
{4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
{4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
torch::Tensor shift_amount = torch::randint(
16, input.sizes(), torch::TensorOptions().device(DefaultDevice()));
16,
input.sizes(),
torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
torch::Tensor result = torch::__rshift__(input, shift_amount);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_input = CopyToDevice(input, device);
torch::Tensor lazy_shift_amount = CopyToDevice(shift_amount, device);
torch::Tensor lazy_result = torch::__rshift__(lazy_input, lazy_shift_amount);
torch::Tensor lazy_result =
torch::__rshift__(lazy_input, lazy_shift_amount);
AllClose(result, lazy_result);
});
}
TEST_F(LazyOpsTest, TestRshiftInPlace) {
torch::Tensor input = torch::ones(
{4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
{4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_input = CopyToDevice(input, device);
torch::Tensor shift_amount = torch::randint(
16, input.sizes(), torch::TensorOptions().device(DefaultDevice()));
16,
input.sizes(),
torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
torch::Tensor result = input.__irshift__(shift_amount);
torch::Tensor lazy_shift_amount = CopyToDevice(shift_amount, device);
torch::Tensor lazy_result = lazy_input.__irshift__(lazy_shift_amount);
@ -9069,7 +9079,7 @@ TEST_F(LazyOpsTest, TestRshiftInPlace) {
TEST_F(LazyOpsTest, TestRshiftScalar) {
torch::Tensor input = torch::ones(
{4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
{4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
torch::Scalar shift_amount = 3;
torch::Tensor result = torch::__rshift__(input, shift_amount);
ForEachDevice([&](const torch::Device& device) {
@ -9081,7 +9091,7 @@ TEST_F(LazyOpsTest, TestRshiftScalar) {
TEST_F(LazyOpsTest, TestRshiftScalarInPlace) {
torch::Tensor input = torch::ones(
{4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
{4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
torch::Scalar shift_amount = 3;
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_input = CopyToDevice(input, device);

View file

@ -3,48 +3,22 @@ producer_name: "pytorch"
producer_version: "CURRENT_VERSION"
graph {
node {
output: "onnx::Pow_4"
name: "Constant_0"
op_type: "Constant"
input: "onnx::BitShift_0"
input: "onnx::BitShift_7"
output: "3"
name: "BitShift_0"
op_type: "BitShift"
attribute {
name: "value"
t {
data_type: 1
raw_data: "\000\000\000@"
}
type: TENSOR
name: "direction"
s: "RIGHT"
type: STRING
}
}
node {
input: "onnx::Pow_4"
input: "onnx::Pow_11"
output: "onnx::Cast_5"
name: "Pow_1"
op_type: "Pow"
}
node {
input: "onnx::Cast_5"
output: "onnx::Div_6"
name: "Cast_2"
op_type: "Cast"
attribute {
name: "to"
i: 1
type: INT
}
}
node {
input: "onnx::Div_0"
input: "onnx::Div_6"
output: "7"
name: "Div_3"
op_type: "Div"
}
node {
input: "onnx::BitShift_1"
input: "onnx::BitShift_12"
output: "10"
name: "BitShift_4"
input: "onnx::BitShift_0"
input: "onnx::BitShift_8"
output: "6"
name: "BitShift_1"
op_type: "BitShift"
attribute {
name: "direction"
@ -54,36 +28,17 @@ graph {
}
name: "torch_jit"
initializer {
data_type: 1
name: "onnx::Pow_11"
raw_data: "\000\000\200?"
data_type: 2
name: "onnx::BitShift_7"
raw_data: "\001"
}
initializer {
data_type: 2
name: "onnx::BitShift_12"
name: "onnx::BitShift_8"
raw_data: "\002"
}
input {
name: "onnx::Div_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "onnx::BitShift_1"
name: "onnx::BitShift_0"
type {
tensor_type {
elem_type: 2
@ -102,10 +57,10 @@ graph {
}
}
output {
name: "7"
name: "3"
type {
tensor_type {
elem_type: 1
elem_type: 2
shape {
dim {
dim_value: 3
@ -121,7 +76,7 @@ graph {
}
}
output {
name: "10"
name: "6"
type {
tensor_type {
elem_type: 2

View file

@ -921,12 +921,11 @@ class TestOperators(TestCase):
def test_bitshift(self):
class BitshiftModel(torch.nn.Module):
def forward(self, input, input2):
return input >> 1, input2 >> 2
def forward(self, input):
return input >> 1, input >> 2
input = torch.arange(24, dtype=torch.float32).reshape(3, 4, 2)
input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
self.assertONNX(BitshiftModel(), (input, input2), opset_version=11)
input = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
self.assertONNX(BitshiftModel(), input, opset_version=11)
@skipIfCaffe2
def test_layer_norm_aten(self):

View file

@ -3936,17 +3936,16 @@ class _TestONNXRuntime:
def test_bitshift(self):
class BitshiftModel(torch.nn.Module):
def forward(self, input, input2):
def forward(self, input):
return (
input >> 1,
input << 3.1,
input2 >> torch.tensor([1, 2]),
input2 << 4.2,
input << 3,
input >> torch.tensor([1, 2]),
input << 4,
)
input = torch.arange(24, dtype=torch.float32).reshape(3, 4, 2)
input2 = torch.arange(24, dtype=torch.int64).reshape(3, 4, 2)
self.run_test(BitshiftModel(), (input, input2))
input = torch.arange(24, dtype=torch.int64).reshape(3, 4, 2)
self.run_test(BitshiftModel(), input)
def test_bitshift_other_fp(self):
class BitshiftModel(torch.nn.Module):

View file

@ -3375,30 +3375,6 @@ class TestBinaryUfuncs(TestCase):
torch_op(a, 2), torch.tensor(numpy_op(a_np, 2), device=device)
)
def test_bitwise_shift_float(self, device):
ops = [
(torch.bitwise_left_shift, lambda x, y: x * 2.0**y),
(operator.lshift, lambda x, y: x * 2.0**y),
(torch.bitwise_right_shift, lambda x, y: x / 2.0**y),
(operator.rshift, lambda x, y: x / 2.0**y),
]
for torch_op, expected_op in ops:
# int tensor x float
a = torch.tensor([19, -20, -21, 22], dtype=torch.int64, device=device)
self.assertEqual(
torch_op(a, 1.8), torch.floor(expected_op(a, 1)).to(a.dtype)
)
# float tensor x int scalar
a = torch.tensor(
[19.1, -20.2, -21.3, 22.4], dtype=torch.float32, device=device
)
self.assertEqual(torch_op(a, 2), expected_op(a, 2))
# float tensor x float scalar
a = torch.tensor(
[19.1, -20.2, -21.3, 22.4], dtype=torch.float32, device=device
)
self.assertEqual(torch_op(a, 2.2), expected_op(a, 2.2))
@onlyNativeDeviceTypes
@dtypes(
*list(

View file

@ -1371,12 +1371,13 @@ add_docstr(torch.bitwise_left_shift,
bitwise_left_shift(input, other, *, out=None) -> Tensor
Computes the left arithmetic shift of :attr:`input` by :attr:`other` bits.
The result will have the same dtype as :attr:`input`.
The result will have the same dtype as :attr:`input`. The input tensor must
be of integral types.
The operation applied is:
.. math::
\text{{out}}_i = \text{{input}}_i \times 2 ^ {{\text{{other}}_i}}
\text{{out}}_i = \text{{input}}_i << \text{{other}}_i
Args:
input (Tensor or Scalar): the first input tensor
@ -1396,12 +1397,13 @@ add_docstr(torch.bitwise_right_shift,
bitwise_right_shift(input, other, *, out=None) -> Tensor
Computes the right arithmetic shift of :attr:`input` by :attr:`other` bits.
The result will have the same dtype as :attr:`input`.
The result will have the same dtype as :attr:`input`. The input tensor must
be of integral types.
The operation applied is:
.. math::
\text{{out}}_i = \text{{input}}_i / 2 ^ {{\text{{other}}_i}}
\text{{out}}_i = \text{{input}}_i >> \text{{other}}_i
Args:
input (Tensor or Scalar): the first input tensor

View file

@ -21,7 +21,7 @@ from torch.testing import make_tensor
from torch.testing._internal.common_dtype import (
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
all_types, double_types, empty_types, complex_types_and
all_types, double_types, empty_types, complex_types_and, integral_types
)
from torch.testing._internal.common_device_type import \
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
@ -9988,8 +9988,8 @@ op_db: List[OpInfo] = [
supports_autograd=False),
BinaryUfuncInfo('bitwise_left_shift',
op=torch.bitwise_left_shift,
dtypes=all_types(),
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
dtypes=integral_types(),
dtypesIfCUDA=integral_types(),
supports_autograd=False,
supports_one_python_scalar=True,
rhs_make_tensor_kwargs=dict(low=0),
@ -9998,8 +9998,8 @@ op_db: List[OpInfo] = [
)),
BinaryUfuncInfo('bitwise_right_shift',
op=torch.bitwise_right_shift,
dtypes=all_types(),
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
dtypes=integral_types(),
dtypesIfCUDA=integral_types(),
supports_autograd=False,
supports_one_python_scalar=True,
rhs_make_tensor_kwargs=dict(low=0),