From 215b14530ae4ff33b718cdee5f1e30f8dfea6317 Mon Sep 17 00:00:00 2001 From: "Jiang, Yanbing" Date: Sat, 17 Aug 2024 15:20:39 +0000 Subject: [PATCH] Add Half for sparse.mm reduce (#133672) This PR is to add Half support for sparse.mm reduce in CPU backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133672 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/cpu/SpmmReduceKernel.cpp | 12 ++++++------ aten/src/ATen/native/cpu/utils.h | 11 ++++++++++- test/inductor/test_torchinductor_opinfo.py | 2 +- test/test_sparse_csr.py | 8 ++++---- .../testing/_internal/common_methods_invocations.py | 2 +- 5 files changed, 22 insertions(+), 13 deletions(-) diff --git a/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp b/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp index cf9749a60d4..b620985ba13 100644 --- a/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp +++ b/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp @@ -434,7 +434,7 @@ void spmm_reduce_kernel( const Tensor& values, const Tensor& other, ReductionType reduce_op) { - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_kernel", [&]() { AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() { AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() { spmm_reduce_kernel_impl( @@ -452,7 +452,7 @@ void spmm_reduce_arg_kernel( const Tensor& values, const Tensor& other, ReductionType reduce_op) { - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_kernel", [&]() { AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() { AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() { spmm_reduce_arg_kernel_impl( @@ -471,7 +471,7 @@ void spmm_reduce_backward_input_kernel( const Tensor& row_indices, ReductionType reduce_op) { TORCH_CHECK(reduce_op == ReductionType::SUM || reduce_op == ReductionType::MEAN); - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, other.scalar_type(), "spmm_reduce_backward_input_kernel", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, other.scalar_type(), "spmm_reduce_backward_input_kernel", [&]() { AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_indices", [&]() { AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() { spmm_reduce_backward_input_kernel_impl( @@ -489,7 +489,7 @@ void spmm_reduce_backward_input_arg_kernel( const Tensor& arg_out, ReductionType reduce_op) { TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN); - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, other.scalar_type(), "spmm_reduce_backward_input_arg_kernel", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, other.scalar_type(), "spmm_reduce_backward_input_arg_kernel", [&]() { AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_arg_indices", [&]() { spmm_reduce_backward_input_arg_kernel_impl( grad_self, grad_out, col_indices, other, arg_out); @@ -502,7 +502,7 @@ void spmm_reduce_normalize_values_kernel( const Tensor& values, const Tensor& crow_indices, const Tensor& row_indices) { - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_normalize_values_kernel", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_normalize_values_kernel", [&]() { AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "spmm_reduce_normalize_values_indices", [&]() { spmm_reduce_normalize_values_kernel_impl( normalized_values, values, crow_indices, row_indices); @@ -545,7 +545,7 @@ void spmm_reduce_backward_other_arg_kernel( const Tensor& arg_out, ReductionType reduce_op) { TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN); - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_backward_other_arg_kernel", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_backward_other_arg_kernel", [&]() { AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_other_arg_indices", [&]() { spmm_reduce_backward_other_arg_kernel_impl( grad_other, grad_out, col_indices, values, arg_out); diff --git a/aten/src/ATen/native/cpu/utils.h b/aten/src/ATen/native/cpu/utils.h index 641ac0cd061..bf6af9a1247 100644 --- a/aten/src/ATen/native/cpu/utils.h +++ b/aten/src/ATen/native/cpu/utils.h @@ -53,7 +53,7 @@ inline bool data_index_step(T& x, const T& X, Args&&... args) { return false; } -// Helper struct for bfloat16 vectorization +// Helper struct for bfloat16/float16 vectorization // Useful when you need float as immediate dtype or accumulate dtype using namespace vec; struct Vec2 { @@ -64,6 +64,10 @@ struct Vec2 { auto [v0, v1] = convert_bfloat16_float(Vectorized::loadu(ptr)); return {v0, v1}; } + static Vec2 loadu(const Half* ptr) { + auto [v0, v1] = convert_half_float(Vectorized::loadu(ptr)); + return {v0, v1}; + } static Vec2 loadu(const float* ptr) { return {Vectorized::loadu(ptr), Vectorized::loadu(ptr + Vectorized::size())}; } @@ -71,6 +75,10 @@ struct Vec2 { Vectorized val = convert_float_bfloat16(val0, val1); val.store(ptr); } + void store(Half* ptr) const { + Vectorized val = convert_float_half(val0, val1); + val.store(ptr); + } void store(float* ptr) const { val0.store(ptr); val1.store(ptr + Vectorized::size()); @@ -85,6 +93,7 @@ inline Vec2 minimum(const Vec2& a, const Vec2& b) { return {vec::minimum(a.val0, template struct VectorizedType { using type = Vectorized; }; template <> struct VectorizedType { using type = Vec2; }; +template <> struct VectorizedType { using type = Vec2; }; template using VecType = typename VectorizedType::type; // Helper for mixed data type parameter Vec::load diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 25cfadb37cb..fdac0cc9cbe 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -226,7 +226,7 @@ inductor_expected_failures_single_sample["cpu"] = { ("normal", "in_place"): {f16, f32, f64}, ("normal", "number_mean"): {f16, f32, f64}, "normal": {f16, f32, f64}, - ("sparse.mm", "reduce"): {f32, f64}, + ("sparse.mm", "reduce"): {f32, f64, f16}, "sparse.sampled_addmm": {f32, f64}, "to_sparse": { f32, diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 14e8a85b571..8ab03c1424f 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -2575,7 +2575,7 @@ class TestSparseCSR(TestCase): torch.sparse.sampled_addmm(a_sparse, a, a_sparse) @onlyCPU - @dtypes(torch.float32, torch.float64, torch.bfloat16) + @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16) @precisionOverride({torch.bfloat16: 0.01}) def test_sparse_mm_reduce_sum(self, device, dtype): def run_test(m, n, k, nnz, train): @@ -2613,8 +2613,8 @@ class TestSparseCSR(TestCase): @skipIfTorchDynamo() @onlyCPU - @dtypes(torch.float32, torch.float64, torch.bfloat16) - @precisionOverride({torch.bfloat16: 0.01}) + @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16) + @precisionOverride({torch.bfloat16: 0.01, torch.float16: 0.01}) def test_sparse_mm_reduce(self, device, dtype): def run_test(m, n, k, nnz, reduce_type, index_dtype, train): csr = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype) @@ -2649,7 +2649,7 @@ class TestSparseCSR(TestCase): out = torch.sparse.mm(csr, mat, reduce_type) self.assertEqual(out, ref_out) - if train and dtype is not torch.bfloat16: + if train and dtype not in (torch.bfloat16, torch.float16): ref_out.sum().backward() out.sum().backward() diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index e2e5469a183..78f99a9012f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -13576,7 +13576,7 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'), )), OpInfo('sparse.mm', - dtypes=floating_types_and(torch.bfloat16), + dtypes=floating_types_and(torch.bfloat16, torch.float16), variant_test_name='reduce', supports_autograd=True, supports_out=False,