Add FP8 support for eye (#139974)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139974
Approved by: https://github.com/jgong5, https://github.com/malfet
This commit is contained in:
Jiang, Yanbing 2024-12-24 05:22:31 +00:00 committed by PyTorch MergeBot
parent 060ee14753
commit 783065637e
4 changed files with 22 additions and 10 deletions

View file

@ -663,8 +663,10 @@ Tensor& eye_out_cpu(int64_t n, int64_t m, Tensor& result) {
result.zero_();
int64_t sz = std::min<int64_t>(n, m);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBFloat16, kHalf, kBool, result.scalar_type(), "eye", [&]() -> void {
AT_DISPATCH_V2(
result.scalar_type(),
"eye",
[&]() -> void {
scalar_t* result_data = result.data_ptr<scalar_t>();
at::parallel_for(
0, sz, internal::GRAIN_SIZE, [&](int64_t p_begin, int64_t p_end) {
@ -672,8 +674,12 @@ Tensor& eye_out_cpu(int64_t n, int64_t m, Tensor& result) {
result_data[i * (result.strides()[0] + result.strides()[1])] =
1;
});
});
},
kBFloat16,
kHalf,
kBool,
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES));
return result;
}

View file

@ -213,14 +213,15 @@ static void aminmax_kernel(
}
static void where_kernel_impl(TensorIterator &iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool,
AT_DISPATCH_V2(
iter.dtype(), "where_cpu", [&] {
cpu_kernel(
iter,
[=](bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
return cond_val ? self_val : other_val;
});
});
},
kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_FLOAT8_TYPES));
}
static void isposinf_kernel_impl(TensorIteratorBase& iter) {

View file

@ -1,6 +1,7 @@
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/NumericUtils.h>
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorCompare.h>
#include <ATen/native/cuda/Loops.cuh>
@ -12,13 +13,14 @@ namespace at::native {
namespace {
void where_kernel_impl(TensorIterator &iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_cuda", [&] {
AT_DISPATCH_V2(iter.dtype(), "where_cuda", [&] {
gpu_kernel(
iter,
[=] GPU_LAMBDA (bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
return cond_val ? self_val : other_val;
});
});
},
kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_FLOAT8_TYPES));
}
void isposinf_kernel_impl(TensorIteratorBase &iter) {

View file

@ -21,7 +21,7 @@ from torch.testing import make_tensor
from torch.testing._internal.common_dtype import (
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
empty_types, complex_types_and, integral_types, custom_types,
empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and,
)
from torch.testing._internal.common_device_type import \
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
@ -19154,7 +19154,7 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
)),
OpInfo('eye',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_complex_float8_and(torch.bool, torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_eye,
error_inputs_func=error_inputs_eye,
supports_out=True,
@ -19176,6 +19176,9 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
# UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
# "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn'
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)),
)),
OpInfo('empty_permuted',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),