avoid explicitly casting low precision inputs to fp32 in norm (#59134)

Summary:
Per title. Now `norm` with fp16/bfloat16 inputs and fp32 outputs on cuda won't do explicit cast

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59134

Reviewed By: mruberry

Differential Revision: D28775729

Pulled By: ngimel

fbshipit-source-id: 896daa4f02e8a817cb7cb99ae8a93c02fa8dd5e9
This commit is contained in:
Natalia Gimelshein 2021-05-29 00:47:00 -07:00 committed by Facebook GitHub Bot
parent d68df54269
commit 1871d4e604
5 changed files with 45 additions and 30 deletions

View file

@ -2325,7 +2325,11 @@ static Tensor& linalg_vector_norm_impl(const Tensor& self, const Scalar& scalar_
TORCH_CHECK(!result.defined() || out_dtype == result.scalar_type(),
"linalg.vector_norm expected out tensor dtype ", out_dtype,
" but got: ", result.scalar_type());
auto iter = make_reduction("vector_norm", result, self_, dim, keepdim, in_dtype, out_dtype);
// omit in_dtype in the following call, to avoid make_reduction explicitly casting input to out_dtype
auto iter = isComplexType(self.scalar_type()) ?
make_reduction("vector_norm", result, self_, dim, keepdim, in_dtype, out_dtype) :
make_reduction("vector_norm", result, self_, dim, keepdim, out_dtype);
linalg_vector_norm_stub(iter.device_type(), iter, ord);
return result;
}

View file

@ -1030,7 +1030,10 @@ static Tensor& norm_out(Tensor &result, const Tensor &self, const optional<Scala
ScalarType out_dtype = result.defined() ? result.scalar_type() : (opt_dtype.has_value() ? opt_dtype.value() : toValueType(self.scalar_type()));
auto iter = make_reduction("norm", result, self, dim, keepdim, in_dtype, out_dtype);
// omit in_dtype in the following call, to avoid make_reduction explicitly casting input to out_dtype
auto iter = isComplexType(self.scalar_type()) ?
make_reduction("norm", result, self, dim, keepdim, in_dtype, out_dtype) :
make_reduction("norm", result, self, dim, keepdim, out_dtype);
if (iter.numel() == 0) {
result.zero_();

View file

@ -210,9 +210,9 @@ static TensorIterator make_reduction(
// efficiency.
// not generalize this to common mismatched input/output types to avoid cross
// product of templated kernel launches.
const bool gpu_f16_to_f32 = (
self.is_cuda() && self.scalar_type() == kHalf && out_dtype == kFloat);
auto in_dtype = gpu_f16_to_f32 ? self.scalar_type() : out_dtype;
const bool gpu_lowp_to_f32 = (
self.is_cuda() && (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) && out_dtype == kFloat);
auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype;
return make_reduction(name, result, self, dim, keepdim, in_dtype, out_dtype);
}

View file

@ -45,41 +45,31 @@ void norm_kernel_cuda_impl(TensorIterator& iter, const Scalar& val) {
}
static void norm_kernel_cuda(TensorIterator& iter, const Scalar& p) {
if (iter.input_dtype() == kHalf) {
return norm_kernel_cuda_impl<at::Half, float>(iter, p);
} else if (iter.dtype(1) == kHalf && iter.input_dtype() == kFloat) {
static void norm_dispatch(TensorIterator& iter, const Scalar& ord){
if (iter.dtype(0) == kHalf) {
return norm_kernel_cuda_impl<at::Half, float>(iter, ord);
} else if (iter.input_dtype() == kHalf && iter.dtype(0) == kFloat) {
// type promotion that does cast and reduction in a single kernel
return norm_kernel_cuda_impl<at::Half, float, float>(iter, p);
return norm_kernel_cuda_impl<at::Half, float, float>(iter, ord);
}
else if(iter.input_dtype() == kBFloat16) {
return norm_kernel_cuda_impl<at::BFloat16, float>(iter, p);
} else if (iter.dtype(1) == kBFloat16 && iter.input_dtype() == kFloat) {
else if(iter.dtype(0) == kBFloat16) {
return norm_kernel_cuda_impl<at::BFloat16, float>(iter, ord);
} else if (iter.input_dtype() == kBFloat16 && iter.dtype(0) == kFloat) {
// type promotion that does cast and reduction in a single kernel
return norm_kernel_cuda_impl<at::BFloat16, float, float>(iter, p);
return norm_kernel_cuda_impl<at::BFloat16, float, float>(iter, ord);
}
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.input_dtype(), "norm_cuda", [&] {
norm_kernel_cuda_impl<scalar_t>(iter, p);
norm_kernel_cuda_impl<scalar_t>(iter, ord);
});
}
static void norm_kernel_cuda(TensorIterator& iter, const Scalar& ord) {
norm_dispatch(iter, ord);
}
static void linalg_vector_norm_kernel_cuda(TensorIterator& iter, Scalar ord) {
TORCH_CHECK(ord.isFloatingPoint(), "linalg.vector_norm expects ord to be float");
if (iter.output().scalar_type() == kHalf) {
return norm_kernel_cuda_impl<at::Half, float>(iter, ord);
} else if (iter.input_dtype() == kHalf && iter.output().scalar_type() == kFloat) {
// type promotion that does cast and reduction in a single kernel
return norm_kernel_cuda_impl<at::Half, float, float>(iter, ord);
}
else if(iter.output().scalar_type() == kBFloat16) {
return norm_kernel_cuda_impl<at::BFloat16, float>(iter, ord);
} else if (iter.input_dtype() == kBFloat16 && iter.output().scalar_type() == kFloat) {
// type promotion that does cast and reduction in a single kernel
return norm_kernel_cuda_impl<at::BFloat16, float, float>(iter, ord);
}
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.input_dtype(), "linalg_vector_norm_cuda", [&] {
norm_kernel_cuda_impl<scalar_t>(iter, ord);
});
norm_dispatch(iter, ord);
}

View file

@ -1562,6 +1562,24 @@ class TestLinalg(TestCase):
for ord in ord_settings:
run_test_case(input, ord, dim, keepdim)
@onlyCUDA
@dtypes(torch.bfloat16, torch.float16)
def test_norm_fused_type_promotion(self, device, dtype):
x = torch.randn(10, device=device, dtype=dtype)
def profile_and_check(fn, x, kwargs, fn_name):
with torch.profiler.profile(activities=(torch.profiler.ProfilerActivity.CPU,)) as p:
fn(x, **kwargs, dtype=torch.float)
# smoke check that profiler returned some events
self.assertTrue(fn_name in map(lambda e: e.name, p.events()))
# test that there was no explicit copy
self.assertFalse("aten::to" in map(lambda e: e.name, p.events()))
for f, kwargs, fn_name in zip((torch.norm, torch.linalg.vector_norm), ({"p" : 2}, {}),
("aten::norm", "aten::linalg_vector_norm")):
profile_and_check(f, x, kwargs, fn_name)
@skipMeta # https://github.com/pytorch/pytorch/issues/53739
@skipCPUIfNoLapack
@skipCUDAIfNoMagma