diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index aa36b35ebbc..3c5bba96053 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -2325,7 +2325,11 @@ static Tensor& linalg_vector_norm_impl(const Tensor& self, const Scalar& scalar_ TORCH_CHECK(!result.defined() || out_dtype == result.scalar_type(), "linalg.vector_norm expected out tensor dtype ", out_dtype, " but got: ", result.scalar_type()); - auto iter = make_reduction("vector_norm", result, self_, dim, keepdim, in_dtype, out_dtype); + // omit in_dtype in the following call, to avoid make_reduction explicitly casting input to out_dtype + auto iter = isComplexType(self.scalar_type()) ? + make_reduction("vector_norm", result, self_, dim, keepdim, in_dtype, out_dtype) : + make_reduction("vector_norm", result, self_, dim, keepdim, out_dtype); + linalg_vector_norm_stub(iter.device_type(), iter, ord); return result; } diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 33046fc5497..93e91842e4a 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -1030,7 +1030,10 @@ static Tensor& norm_out(Tensor &result, const Tensor &self, const optional(iter, p); - } else if (iter.dtype(1) == kHalf && iter.input_dtype() == kFloat) { +static void norm_dispatch(TensorIterator& iter, const Scalar& ord){ + if (iter.dtype(0) == kHalf) { + return norm_kernel_cuda_impl(iter, ord); + } else if (iter.input_dtype() == kHalf && iter.dtype(0) == kFloat) { // type promotion that does cast and reduction in a single kernel - return norm_kernel_cuda_impl(iter, p); + return norm_kernel_cuda_impl(iter, ord); } - else if(iter.input_dtype() == kBFloat16) { - return norm_kernel_cuda_impl(iter, p); - } else if (iter.dtype(1) == kBFloat16 && iter.input_dtype() == kFloat) { + else if(iter.dtype(0) == kBFloat16) { + return norm_kernel_cuda_impl(iter, ord); + } else if (iter.input_dtype() == kBFloat16 && iter.dtype(0) == kFloat) { // type promotion that does cast and reduction in a single kernel - return norm_kernel_cuda_impl(iter, p); + return norm_kernel_cuda_impl(iter, ord); } AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.input_dtype(), "norm_cuda", [&] { - norm_kernel_cuda_impl(iter, p); + norm_kernel_cuda_impl(iter, ord); }); } +static void norm_kernel_cuda(TensorIterator& iter, const Scalar& ord) { + norm_dispatch(iter, ord); +} + static void linalg_vector_norm_kernel_cuda(TensorIterator& iter, Scalar ord) { TORCH_CHECK(ord.isFloatingPoint(), "linalg.vector_norm expects ord to be float"); - if (iter.output().scalar_type() == kHalf) { - return norm_kernel_cuda_impl(iter, ord); - } else if (iter.input_dtype() == kHalf && iter.output().scalar_type() == kFloat) { - // type promotion that does cast and reduction in a single kernel - return norm_kernel_cuda_impl(iter, ord); - } - else if(iter.output().scalar_type() == kBFloat16) { - return norm_kernel_cuda_impl(iter, ord); - } else if (iter.input_dtype() == kBFloat16 && iter.output().scalar_type() == kFloat) { - // type promotion that does cast and reduction in a single kernel - return norm_kernel_cuda_impl(iter, ord); - } - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.input_dtype(), "linalg_vector_norm_cuda", [&] { - norm_kernel_cuda_impl(iter, ord); - }); + norm_dispatch(iter, ord); } diff --git a/test/test_linalg.py b/test/test_linalg.py index 0fdd82218be..ed8d7145f00 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1562,6 +1562,24 @@ class TestLinalg(TestCase): for ord in ord_settings: run_test_case(input, ord, dim, keepdim) + + @onlyCUDA + @dtypes(torch.bfloat16, torch.float16) + def test_norm_fused_type_promotion(self, device, dtype): + x = torch.randn(10, device=device, dtype=dtype) + + def profile_and_check(fn, x, kwargs, fn_name): + with torch.profiler.profile(activities=(torch.profiler.ProfilerActivity.CPU,)) as p: + fn(x, **kwargs, dtype=torch.float) + # smoke check that profiler returned some events + self.assertTrue(fn_name in map(lambda e: e.name, p.events())) + # test that there was no explicit copy + self.assertFalse("aten::to" in map(lambda e: e.name, p.events())) + + for f, kwargs, fn_name in zip((torch.norm, torch.linalg.vector_norm), ({"p" : 2}, {}), + ("aten::norm", "aten::linalg_vector_norm")): + profile_and_check(f, x, kwargs, fn_name) + @skipMeta # https://github.com/pytorch/pytorch/issues/53739 @skipCPUIfNoLapack @skipCUDAIfNoMagma