From 783065637eabfc6f176af795aba822dd28bdc762 Mon Sep 17 00:00:00 2001 From: "Jiang, Yanbing" Date: Tue, 24 Dec 2024 05:22:31 +0000 Subject: [PATCH] Add FP8 support for eye (#139974) Pull Request resolved: https://github.com/pytorch/pytorch/pull/139974 Approved by: https://github.com/jgong5, https://github.com/malfet --- aten/src/ATen/native/TensorFactories.cpp | 14 ++++++++++---- aten/src/ATen/native/cpu/TensorCompareKernel.cpp | 5 +++-- aten/src/ATen/native/cuda/TensorCompare.cu | 6 ++++-- .../_internal/common_methods_invocations.py | 7 +++++-- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 5168c66cb1f..bbd3672412b 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -663,8 +663,10 @@ Tensor& eye_out_cpu(int64_t n, int64_t m, Tensor& result) { result.zero_(); int64_t sz = std::min(n, m); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( - kBFloat16, kHalf, kBool, result.scalar_type(), "eye", [&]() -> void { + AT_DISPATCH_V2( + result.scalar_type(), + "eye", + [&]() -> void { scalar_t* result_data = result.data_ptr(); at::parallel_for( 0, sz, internal::GRAIN_SIZE, [&](int64_t p_begin, int64_t p_end) { @@ -672,8 +674,12 @@ Tensor& eye_out_cpu(int64_t n, int64_t m, Tensor& result) { result_data[i * (result.strides()[0] + result.strides()[1])] = 1; }); - }); - + }, + kBFloat16, + kHalf, + kBool, + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + AT_EXPAND(AT_FLOAT8_TYPES)); return result; } diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index 4f5040085c3..2c52a61fc55 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -213,14 +213,15 @@ static void aminmax_kernel( } static void where_kernel_impl(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool, + AT_DISPATCH_V2( iter.dtype(), "where_cpu", [&] { cpu_kernel( iter, [=](bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t { return cond_val ? self_val : other_val; }); - }); + }, + kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_FLOAT8_TYPES)); } static void isposinf_kernel_impl(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu index 1751bc95a38..ab38c1975d1 100644 --- a/aten/src/ATen/native/cuda/TensorCompare.cu +++ b/aten/src/ATen/native/cuda/TensorCompare.cu @@ -1,6 +1,7 @@ #define TORCH_ASSERT_NO_OPERATORS #include #include +#include #include #include #include @@ -12,13 +13,14 @@ namespace at::native { namespace { void where_kernel_impl(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_cuda", [&] { + AT_DISPATCH_V2(iter.dtype(), "where_cuda", [&] { gpu_kernel( iter, [=] GPU_LAMBDA (bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t { return cond_val ? self_val : other_val; }); - }); + }, + kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_FLOAT8_TYPES)); } void isposinf_kernel_impl(TensorIteratorBase &iter) { diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 4063c2dc1fe..969d465a23b 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -21,7 +21,7 @@ from torch.testing import make_tensor from torch.testing._internal.common_dtype import ( _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types, floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, - empty_types, complex_types_and, integral_types, custom_types, + empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, ) from torch.testing._internal.common_device_type import \ (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, @@ -19154,7 +19154,7 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), )), OpInfo('eye', - dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypes=all_types_complex_float8_and(torch.bool, torch.half, torch.bfloat16), sample_inputs_func=sample_inputs_eye, error_inputs_func=error_inputs_eye, supports_out=True, @@ -19176,6 +19176,9 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn' + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)), )), OpInfo('empty_permuted', dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),