From cc9d0f309ec72cc53ffbbca3044ddb1a61c8ecf3 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 11 May 2022 22:29:30 +0000 Subject: [PATCH] lshift and rshift stop support floating types (#77146) Fixes #74358 Pull Request resolved: https://github.com/pytorch/pytorch/pull/77146 Approved by: https://github.com/ngimel --- aten/src/ATen/native/cpu/BinaryOpsKernel.cpp | 48 +++-------- .../ATen/native/cuda/BinaryShiftOpsKernels.cu | 46 +++------- test/cpp/lazy/test_lazy_ops.cpp | 38 ++++++--- .../expect/TestOperators.test_bitshift.expect | 85 +++++-------------- test/onnx/test_operators.py | 9 +- test/onnx/test_pytorch_onnx_onnxruntime.py | 13 ++- test/test_binary_ufuncs.py | 24 ------ torch/_torch_docs.py | 10 ++- .../_internal/common_methods_invocations.py | 10 +-- 9 files changed, 85 insertions(+), 198 deletions(-) diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index f1731417eb4..a3772675052 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -312,26 +312,12 @@ void bitwise_xor_kernel(TensorIteratorBase& iter) { } void lshift_kernel(TensorIteratorBase& iter) { - if (iter.dtype() == ScalarType::Float || iter.dtype() == ScalarType::Double) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "lshift_cpu", [&]() { - auto base_vec = Vectorized((scalar_t)(2)); - cpu_kernel_vec( - iter, - [=](scalar_t a, scalar_t b) -> scalar_t { - return a * std::pow((scalar_t)(2), b); - }, - [=](Vectorized a, Vectorized b) { - return a * base_vec.pow(b); - }); + AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cpu", [&]() { + cpu_kernel(iter, + [](scalar_t a, scalar_t b) -> scalar_t { + return static_cast>(a) << b; }); - } else { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cpu", [&]() { - cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> scalar_t { - return static_cast>(a) << b; - }); - }); - } + }); } void logical_and_kernel(TensorIterator& iter) { @@ -392,26 +378,12 @@ void logical_xor_kernel(TensorIterator& iter) { } void rshift_kernel(TensorIteratorBase& iter) { - if (iter.dtype() == ScalarType::Float || iter.dtype() == ScalarType::Double) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "rshift_cpu", [&]() { - auto base_vec = Vectorized((scalar_t)(2)); - cpu_kernel_vec( - iter, - [=](scalar_t a, scalar_t b) -> scalar_t { - return a / std::pow((scalar_t)(2), b); - }, - [=](Vectorized a, Vectorized b) { - return a / base_vec.pow(b); + AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cpu", [&]() { + cpu_kernel(iter, + [](scalar_t a, scalar_t b) -> scalar_t { + return a >> b; }); - }); - } else { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cpu", [&]() { - cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> scalar_t { - return a >> b; - }); - }); - } + }); } void lt_kernel(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu index 7f22ace666f..d6bd145c4f5 100644 --- a/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu @@ -12,47 +12,21 @@ namespace at { namespace native { void lshift_kernel_cuda(TensorIteratorBase& iter) { - if (iter.dtype() == ScalarType::Float || - iter.dtype() == ScalarType::Double || - iter.dtype() == ScalarType::Half || - iter.dtype() == ScalarType::BFloat16) { - AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "lshift_cuda", [&]() { - gpu_kernel_with_scalars( - iter, - []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return a * std::pow(static_cast(2), b); - }); + AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cuda", [&]() { + gpu_kernel_with_scalars(iter, + []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return static_cast>(a) << b; }); - } else { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cuda", [&]() { - gpu_kernel_with_scalars(iter, - []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return static_cast>(a) << b; - }); - }); - } + }); } void rshift_kernel_cuda(TensorIteratorBase& iter) { - if (iter.dtype() == ScalarType::Float || - iter.dtype() == ScalarType::Double || - iter.dtype() == ScalarType::Half || - iter.dtype() == ScalarType::BFloat16) { - AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "rshift_cuda", [&]() { - gpu_kernel_with_scalars( - iter, - []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return a / std::pow(static_cast(2), b); - }); + AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cuda", [&]() { + gpu_kernel_with_scalars(iter, + []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a >> b; }); - } else { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cuda", [&]() { - gpu_kernel_with_scalars(iter, - []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return a >> b; - }); - }); - } + }); } REGISTER_DISPATCH(lshift_stub, &lshift_kernel_cuda); diff --git a/test/cpp/lazy/test_lazy_ops.cpp b/test/cpp/lazy/test_lazy_ops.cpp index 6ac65c90105..f12d357760e 100644 --- a/test/cpp/lazy/test_lazy_ops.cpp +++ b/test/cpp/lazy/test_lazy_ops.cpp @@ -8986,25 +8986,30 @@ TEST_F(LazyOpsTest, TestBitwiseXorScalarInPlace) { TEST_F(LazyOpsTest, TestLshift) { torch::Tensor input = torch::ones( - {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice())); torch::Tensor shift_amount = torch::randint( - 16, input.sizes(), torch::TensorOptions().device(DefaultDevice())); + 16, + input.sizes(), + torch::TensorOptions(torch::kInt32).device(DefaultDevice())); torch::Tensor result = torch::__lshift__(input, shift_amount); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor lazy_shift_amount = CopyToDevice(shift_amount, device); - torch::Tensor lazy_result = torch::__lshift__(lazy_input, lazy_shift_amount); + torch::Tensor lazy_result = + torch::__lshift__(lazy_input, lazy_shift_amount); AllClose(result, lazy_result); }); } TEST_F(LazyOpsTest, TestLshiftInPlace) { torch::Tensor input = torch::ones( - {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice())); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor shift_amount = torch::randint( - 16, input.sizes(), torch::TensorOptions().device(DefaultDevice())); + 16, + input.sizes(), + torch::TensorOptions(torch::kInt32).device(DefaultDevice())); torch::Tensor result = input.__ilshift__(shift_amount); torch::Tensor lazy_shift_amount = CopyToDevice(shift_amount, device); torch::Tensor lazy_result = lazy_input.__ilshift__(lazy_shift_amount); @@ -9015,7 +9020,7 @@ TEST_F(LazyOpsTest, TestLshiftInPlace) { TEST_F(LazyOpsTest, TestLshiftScalar) { torch::Tensor input = torch::ones( - {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice())); torch::Scalar shift_amount = 3; torch::Tensor result = torch::__lshift__(input, shift_amount); ForEachDevice([&](const torch::Device& device) { @@ -9027,7 +9032,7 @@ TEST_F(LazyOpsTest, TestLshiftScalar) { TEST_F(LazyOpsTest, TestLshiftScalarInPlace) { torch::Tensor input = torch::ones( - {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice())); torch::Scalar shift_amount = 3; ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -9040,25 +9045,30 @@ TEST_F(LazyOpsTest, TestLshiftScalarInPlace) { TEST_F(LazyOpsTest, TestRshift) { torch::Tensor input = torch::ones( - {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice())); torch::Tensor shift_amount = torch::randint( - 16, input.sizes(), torch::TensorOptions().device(DefaultDevice())); + 16, + input.sizes(), + torch::TensorOptions(torch::kInt32).device(DefaultDevice())); torch::Tensor result = torch::__rshift__(input, shift_amount); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor lazy_shift_amount = CopyToDevice(shift_amount, device); - torch::Tensor lazy_result = torch::__rshift__(lazy_input, lazy_shift_amount); + torch::Tensor lazy_result = + torch::__rshift__(lazy_input, lazy_shift_amount); AllClose(result, lazy_result); }); } TEST_F(LazyOpsTest, TestRshiftInPlace) { torch::Tensor input = torch::ones( - {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice())); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor shift_amount = torch::randint( - 16, input.sizes(), torch::TensorOptions().device(DefaultDevice())); + 16, + input.sizes(), + torch::TensorOptions(torch::kInt32).device(DefaultDevice())); torch::Tensor result = input.__irshift__(shift_amount); torch::Tensor lazy_shift_amount = CopyToDevice(shift_amount, device); torch::Tensor lazy_result = lazy_input.__irshift__(lazy_shift_amount); @@ -9069,7 +9079,7 @@ TEST_F(LazyOpsTest, TestRshiftInPlace) { TEST_F(LazyOpsTest, TestRshiftScalar) { torch::Tensor input = torch::ones( - {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice())); torch::Scalar shift_amount = 3; torch::Tensor result = torch::__rshift__(input, shift_amount); ForEachDevice([&](const torch::Device& device) { @@ -9081,7 +9091,7 @@ TEST_F(LazyOpsTest, TestRshiftScalar) { TEST_F(LazyOpsTest, TestRshiftScalarInPlace) { torch::Tensor input = torch::ones( - {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice())); torch::Scalar shift_amount = 3; ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); diff --git a/test/onnx/expect/TestOperators.test_bitshift.expect b/test/onnx/expect/TestOperators.test_bitshift.expect index 9332dafb8fa..10199d03efc 100644 --- a/test/onnx/expect/TestOperators.test_bitshift.expect +++ b/test/onnx/expect/TestOperators.test_bitshift.expect @@ -3,48 +3,22 @@ producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { node { - output: "onnx::Pow_4" - name: "Constant_0" - op_type: "Constant" + input: "onnx::BitShift_0" + input: "onnx::BitShift_7" + output: "3" + name: "BitShift_0" + op_type: "BitShift" attribute { - name: "value" - t { - data_type: 1 - raw_data: "\000\000\000@" - } - type: TENSOR + name: "direction" + s: "RIGHT" + type: STRING } } node { - input: "onnx::Pow_4" - input: "onnx::Pow_11" - output: "onnx::Cast_5" - name: "Pow_1" - op_type: "Pow" - } - node { - input: "onnx::Cast_5" - output: "onnx::Div_6" - name: "Cast_2" - op_type: "Cast" - attribute { - name: "to" - i: 1 - type: INT - } - } - node { - input: "onnx::Div_0" - input: "onnx::Div_6" - output: "7" - name: "Div_3" - op_type: "Div" - } - node { - input: "onnx::BitShift_1" - input: "onnx::BitShift_12" - output: "10" - name: "BitShift_4" + input: "onnx::BitShift_0" + input: "onnx::BitShift_8" + output: "6" + name: "BitShift_1" op_type: "BitShift" attribute { name: "direction" @@ -54,36 +28,17 @@ graph { } name: "torch_jit" initializer { - data_type: 1 - name: "onnx::Pow_11" - raw_data: "\000\000\200?" + data_type: 2 + name: "onnx::BitShift_7" + raw_data: "\001" } initializer { data_type: 2 - name: "onnx::BitShift_12" + name: "onnx::BitShift_8" raw_data: "\002" } input { - name: "onnx::Div_0" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 3 - } - dim { - dim_value: 4 - } - dim { - dim_value: 2 - } - } - } - } - } - input { - name: "onnx::BitShift_1" + name: "onnx::BitShift_0" type { tensor_type { elem_type: 2 @@ -102,10 +57,10 @@ graph { } } output { - name: "7" + name: "3" type { tensor_type { - elem_type: 1 + elem_type: 2 shape { dim { dim_value: 3 @@ -121,7 +76,7 @@ graph { } } output { - name: "10" + name: "6" type { tensor_type { elem_type: 2 diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 2564f5b9a48..1cded5a9b9f 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -921,12 +921,11 @@ class TestOperators(TestCase): def test_bitshift(self): class BitshiftModel(torch.nn.Module): - def forward(self, input, input2): - return input >> 1, input2 >> 2 + def forward(self, input): + return input >> 1, input >> 2 - input = torch.arange(24, dtype=torch.float32).reshape(3, 4, 2) - input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2) - self.assertONNX(BitshiftModel(), (input, input2), opset_version=11) + input = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2) + self.assertONNX(BitshiftModel(), input, opset_version=11) @skipIfCaffe2 def test_layer_norm_aten(self): diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index de256088fea..e03566453f8 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -3936,17 +3936,16 @@ class _TestONNXRuntime: def test_bitshift(self): class BitshiftModel(torch.nn.Module): - def forward(self, input, input2): + def forward(self, input): return ( input >> 1, - input << 3.1, - input2 >> torch.tensor([1, 2]), - input2 << 4.2, + input << 3, + input >> torch.tensor([1, 2]), + input << 4, ) - input = torch.arange(24, dtype=torch.float32).reshape(3, 4, 2) - input2 = torch.arange(24, dtype=torch.int64).reshape(3, 4, 2) - self.run_test(BitshiftModel(), (input, input2)) + input = torch.arange(24, dtype=torch.int64).reshape(3, 4, 2) + self.run_test(BitshiftModel(), input) def test_bitshift_other_fp(self): class BitshiftModel(torch.nn.Module): diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 2f9d7c3dd3b..12a814ddff7 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -3375,30 +3375,6 @@ class TestBinaryUfuncs(TestCase): torch_op(a, 2), torch.tensor(numpy_op(a_np, 2), device=device) ) - def test_bitwise_shift_float(self, device): - ops = [ - (torch.bitwise_left_shift, lambda x, y: x * 2.0**y), - (operator.lshift, lambda x, y: x * 2.0**y), - (torch.bitwise_right_shift, lambda x, y: x / 2.0**y), - (operator.rshift, lambda x, y: x / 2.0**y), - ] - for torch_op, expected_op in ops: - # int tensor x float - a = torch.tensor([19, -20, -21, 22], dtype=torch.int64, device=device) - self.assertEqual( - torch_op(a, 1.8), torch.floor(expected_op(a, 1)).to(a.dtype) - ) - # float tensor x int scalar - a = torch.tensor( - [19.1, -20.2, -21.3, 22.4], dtype=torch.float32, device=device - ) - self.assertEqual(torch_op(a, 2), expected_op(a, 2)) - # float tensor x float scalar - a = torch.tensor( - [19.1, -20.2, -21.3, 22.4], dtype=torch.float32, device=device - ) - self.assertEqual(torch_op(a, 2.2), expected_op(a, 2.2)) - @onlyNativeDeviceTypes @dtypes( *list( diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 0ea07355fd8..10200cd022b 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -1371,12 +1371,13 @@ add_docstr(torch.bitwise_left_shift, bitwise_left_shift(input, other, *, out=None) -> Tensor Computes the left arithmetic shift of :attr:`input` by :attr:`other` bits. -The result will have the same dtype as :attr:`input`. +The result will have the same dtype as :attr:`input`. The input tensor must +be of integral types. The operation applied is: .. math:: - \text{{out}}_i = \text{{input}}_i \times 2 ^ {{\text{{other}}_i}} + \text{{out}}_i = \text{{input}}_i << \text{{other}}_i Args: input (Tensor or Scalar): the first input tensor @@ -1396,12 +1397,13 @@ add_docstr(torch.bitwise_right_shift, bitwise_right_shift(input, other, *, out=None) -> Tensor Computes the right arithmetic shift of :attr:`input` by :attr:`other` bits. -The result will have the same dtype as :attr:`input`. +The result will have the same dtype as :attr:`input`. The input tensor must +be of integral types. The operation applied is: .. math:: - \text{{out}}_i = \text{{input}}_i / 2 ^ {{\text{{other}}_i}} + \text{{out}}_i = \text{{input}}_i >> \text{{other}}_i Args: input (Tensor or Scalar): the first input tensor diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 1791e712ceb..d45966bfc54 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -21,7 +21,7 @@ from torch.testing import make_tensor from torch.testing._internal.common_dtype import ( _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types, floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, - all_types, double_types, empty_types, complex_types_and + all_types, double_types, empty_types, complex_types_and, integral_types ) from torch.testing._internal.common_device_type import \ (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, @@ -9988,8 +9988,8 @@ op_db: List[OpInfo] = [ supports_autograd=False), BinaryUfuncInfo('bitwise_left_shift', op=torch.bitwise_left_shift, - dtypes=all_types(), - dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypes=integral_types(), + dtypesIfCUDA=integral_types(), supports_autograd=False, supports_one_python_scalar=True, rhs_make_tensor_kwargs=dict(low=0), @@ -9998,8 +9998,8 @@ op_db: List[OpInfo] = [ )), BinaryUfuncInfo('bitwise_right_shift', op=torch.bitwise_right_shift, - dtypes=all_types(), - dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypes=integral_types(), + dtypesIfCUDA=integral_types(), supports_autograd=False, supports_one_python_scalar=True, rhs_make_tensor_kwargs=dict(low=0),