diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index 27fa214fba1..675c700c7cd 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -205,7 +205,7 @@ static void aminmax_kernel( } static void where_kernel_impl(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_cpu", [&] { cpu_kernel( iter, diff --git a/aten/src/ATen/native/cuda/PowKernel.cu b/aten/src/ATen/native/cuda/PowKernel.cu index f2a170e9705..a1e453455d1 100644 --- a/aten/src/ATen/native/cuda/PowKernel.cu +++ b/aten/src/ATen/native/cuda/PowKernel.cu @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -83,9 +84,69 @@ void pow_scalar_tensor_impl(TensorIteratorBase& iter, c10::complex base }); } +/* complex support impl */ +const char pow_scalar_base_name[] = "pow_scalar_base_kernel"; +template <> +void pow_scalar_tensor_impl(TensorIteratorBase& iter, c10::complex base) { + using scalar_t = c10::complex; + using opmath_t = at::opmath_type; + // For complex, thrust::pow uses the identity + // pow(a, b) = exp(log(a) * b) + const auto fct = std::log(opmath_t{base}); +#if AT_USE_JITERATOR() + static const auto pow_kernel_string = + jiterator_stringify(template T pow_scalar_base_kernel(T exp, T fct) { + return std::exp(fct * exp); + }); + jitted_gpu_kernel( + iter, + pow_kernel_string, + /*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar, + /*scalar_val=*/0, + /*extra_args=*/std::make_tuple(fct)); +#else + gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t exp) -> scalar_t { + return std::exp(fct * opmath_t{exp}); + }); +#endif +} + +namespace { + +#if AT_USE_JITERATOR() +/* complex support impl */ +const char pow_name[] = "pow_kernel"; +static const auto pow_kernel_string = + jiterator_stringify(template T pow_kernel(T base, T exp) { + return std::pow(base, exp); + }); +#endif + +/* complex support impl */ +void pow_chalf_tensor_scalar_impl(TensorIteratorBase& iter, const Scalar& exp_scalar) { + using scalar_t = c10::complex; + using opmath_t = at::opmath_type; + auto exp = exp_scalar.to(); +#if AT_USE_JITERATOR() + jitted_gpu_kernel( + iter, + pow_kernel_string, + /*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar, + /*scalar_val=*/0, + /*extra_args=*/std::make_tuple(exp)); +#else + gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t base) -> scalar_t { + return std::pow(opmath_t{base}, exp); + }); +#endif +} + +} // anonymous namespace + void pow_tensor_tensor_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( - kHalf, kBFloat16, iter.common_dtype(), "pow_cuda", [&] { + auto common_dtype = iter.common_dtype(); + if (common_dtype == kComplexHalf) { + using scalar_t = c10::complex; if (iter.is_cpu_scalar(1)) { const auto base = iter.scalar_value(1); iter.remove_operand(1); @@ -93,13 +154,38 @@ void pow_tensor_tensor_kernel(TensorIteratorBase& iter) { } else if (iter.is_cpu_scalar(2)) { const auto exp = iter.scalar_value(2); iter.remove_operand(2); - pow_tensor_scalar_kernel(iter, exp); + pow_chalf_tensor_scalar_impl(iter, exp); } else { - gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t { - return pow_(base, exp); - }); + using opmath_t = at::opmath_type; + TORCH_INTERNAL_ASSERT(!iter.is_cpu_scalar(1) && !iter.is_cpu_scalar(2)); +#if AT_USE_JITERATOR() + jitted_gpu_kernel( + iter, pow_kernel_string); +#else + gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t { + using opmath_t = at::opmath_type; + return pow_(opmath_t{base}, opmath_t{exp}); + }); +#endif } - }); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + kHalf, kBFloat16, iter.common_dtype(), "pow_cuda", [&] { + if (iter.is_cpu_scalar(1)) { + const auto base = iter.scalar_value(1); + iter.remove_operand(1); + pow_scalar_tensor_impl(iter, base); + } else if (iter.is_cpu_scalar(2)) { + const auto exp = iter.scalar_value(2); + iter.remove_operand(2); + pow_tensor_scalar_kernel(iter, exp); + } else { + gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t { + return pow_(base, exp); + }); + } + }); + } } @@ -140,6 +226,11 @@ void pow_tensor_scalar_kernel(TensorIteratorBase& iter, const Scalar& exp_scalar } } if (isComplexType(iter.common_dtype()) || exp_scalar.isComplex()) { + if (iter.common_dtype() == kComplexHalf) { + using scalar_t = c10::complex; + pow_chalf_tensor_scalar_impl(iter, exp_scalar); + return; + } AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "pow_cuda", [&]() { const auto exp = exp_scalar.to(); gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base) -> scalar_t { diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu index f81c90c5651..88489234dd1 100644 --- a/aten/src/ATen/native/cuda/TensorCompare.cu +++ b/aten/src/ATen/native/cuda/TensorCompare.cu @@ -12,7 +12,7 @@ namespace at { namespace native { namespace { void where_kernel_impl(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.dtype(), "where_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_cuda", [&] { gpu_kernel( iter, [=] GPU_LAMBDA (bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t { diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index 295d1006ff2..19074b004dd 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -36,7 +36,7 @@ class C10_API Scalar { #define DEFINE_IMPLICIT_CTOR(type, name) \ Scalar(type vv) : Scalar(vv, true) {} - AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, DEFINE_IMPLICIT_CTOR) + AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR) AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR) #undef DEFINE_IMPLICIT_CTOR diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 5fefc589951..efeea466e1a 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -1456,13 +1456,15 @@ class TestBinaryUfuncs(TestCase): self._do_pow_for_exponents(m1, exponents + complex_exponents, pow, 10e-4) else: self._do_pow_for_exponents(m1, exponents, math.pow, None) - if dtype != torch.half: - self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4) - else: + will_raise_error = dtype is torch.half and torch.device(device).type == 'cpu' + if will_raise_error: + # On CPU, # Half Tensor with complex exponents leads to computation dtype # of ComplexHalf for which this ops is not supported yet with self.assertRaisesRegex(RuntimeError, "not implemented for 'ComplexHalf'"): self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4) + else: + self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4) # base - number, exponent - tensor # contiguous @@ -1751,11 +1753,14 @@ class TestBinaryUfuncs(TestCase): first_exp[0] = first_exp[10] = first_exp[20] = 0 second_exp[0] = second_exp[10] = second_exp[20] = 0 for base in complexes: + # On CPU, # Half Tensor with complex base leads to computation dtype # of ComplexHalf for which this ops is not supported yet # NOTE: pow has fast-path when base is 1 which supports # ComplexHalf - if dtype is torch.half and base != (1 + 0j): + will_raise_error = torch.device(device).type == 'cpu' and \ + dtype is torch.half and base != (1 + 0j) + if will_raise_error: with self.assertRaisesRegex(RuntimeError, "not implemented for 'ComplexHalf'"): self._test_pow(base, first_exp) self._test_pow(base, second_exp) diff --git a/test/test_ops.py b/test/test_ops.py index 2e001463278..26e2f436c1d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -48,6 +48,7 @@ from torch.testing._internal.common_device_type import ( OpDTypes, skipMeta, ) +from torch.utils._pytree import tree_map import torch._prims as prims from torch._prims.context import TorchRefsMode @@ -1106,6 +1107,16 @@ class TestCommon(TestCase): *transformed_sample.args, **transformed_sample.kwargs, ) + # Since range of chalf is much less compared to cfloat, + # we get `inf`s easily (eg. with `pow`, `exp`), + # so we cast `cfloat` back to `chalf`. + expected = tree_map(lambda x: x.to(torch.complex32) if isinstance( + x, torch.Tensor) and x.dtype is torch.complex64 else x, expected) + + # `exact_dtype` is False because for ops like real, imag + # we get different dtypes for `actual` and `expected` + # `chalf` input -> `half` output + # `cfloat` input -> `float` output self.assertEqual(actual, expected, exact_dtype=False) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 1e642956327..1872c86827e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -12590,15 +12590,7 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_forward_grad=False, - supports_out=False, - skips=( - # RuntimeError: "where_cpu" not implemented for 'ComplexHalf' - # RuntimeError: "where_cuda" not implemented for 'ComplexHalf' - DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive', dtypes=(torch.chalf,)), - # RuntimeError: "where_cpu" not implemented for 'ComplexHalf' - # RuntimeError: "where_cuda" not implemented for 'ComplexHalf' - DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick', dtypes=(torch.chalf,)), - )), + supports_out=False), OpInfo('masked_scatter', dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), sample_inputs_func=sample_inputs_masked_scatter, @@ -14819,12 +14811,13 @@ op_db: List[OpInfo] = [ reference_inputs_func=reference_inputs_permute), BinaryUfuncInfo('pow', dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), ref=np.power, # Due to AVX2 curently not being fully supported for Float16, log_vml_cpu can't be enabled # for Float16, causing this test to fail. pow's autograd for Float16 is thus currently # unsupported on CPU. backward_dtypes=floating_and_complex_types_and(torch.bfloat16), - backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf), supports_inplace_autograd=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -14848,12 +14841,18 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_reference_numerics_large_values', dtypes=[torch.int16, torch.int32, torch.int64]), # FIXME Complex values error with: Greatest absolute difference: nan at index + # Ref: https://github.com/pytorch/pytorch/issues/76853 + # For `chalf`, reference computation in `numpy` is computed in `cfloat`. + # Output of `chalf` saturates to `inf` quicker than reference due to its small range + # which leads to failure of this test. + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics', + dtypes=(torch.complex32,)), DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_small_values', - dtypes=[torch.complex64, torch.complex128]), + dtypes=(torch.complex32, torch.complex64, torch.complex128)), DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_large_values', - dtypes=[torch.complex64, torch.complex128]), + dtypes=(torch.complex32, torch.complex64, torch.complex128)), DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values', - dtypes=[torch.complex64, torch.complex128]), + dtypes=(torch.complex32, torch.complex64, torch.complex128)), )), BinaryUfuncInfo('float_power', ref=np.float_power, @@ -15103,7 +15102,8 @@ op_db: List[OpInfo] = [ UnaryUfuncInfo('sgn', ref=reference_sgn, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), - backward_dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf), supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_sparse=True, @@ -15360,7 +15360,6 @@ op_db: List[OpInfo] = [ ref=np.tan, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), - backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), assert_autodiffed=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -15673,11 +15672,8 @@ op_db: List[OpInfo] = [ dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool), decorators=(precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-2}),), - # TODO: add `torch.chalf` backward dtype support. - # AssertionError: The supported dtypes for angle on device type cuda are incorrect! - # The following dtypes did not work in backward but are listed by the OpInfo: {torch.complex32}. backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), - backward_dtypesIfCUDA=floating_and_complex_types(), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.chalf), supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_sparse_csr=True, @@ -17587,7 +17583,7 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), ), - dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf)), OpInfo('nonzero', dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), sample_inputs_func=sample_inputs_nonzero,