mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
avoid explicitly casting low precision inputs to fp32 in norm (#59134)
Summary: Per title. Now `norm` with fp16/bfloat16 inputs and fp32 outputs on cuda won't do explicit cast Pull Request resolved: https://github.com/pytorch/pytorch/pull/59134 Reviewed By: mruberry Differential Revision: D28775729 Pulled By: ngimel fbshipit-source-id: 896daa4f02e8a817cb7cb99ae8a93c02fa8dd5e9
This commit is contained in:
parent
d68df54269
commit
1871d4e604
5 changed files with 45 additions and 30 deletions
|
|
@ -2325,7 +2325,11 @@ static Tensor& linalg_vector_norm_impl(const Tensor& self, const Scalar& scalar_
|
|||
TORCH_CHECK(!result.defined() || out_dtype == result.scalar_type(),
|
||||
"linalg.vector_norm expected out tensor dtype ", out_dtype,
|
||||
" but got: ", result.scalar_type());
|
||||
auto iter = make_reduction("vector_norm", result, self_, dim, keepdim, in_dtype, out_dtype);
|
||||
// omit in_dtype in the following call, to avoid make_reduction explicitly casting input to out_dtype
|
||||
auto iter = isComplexType(self.scalar_type()) ?
|
||||
make_reduction("vector_norm", result, self_, dim, keepdim, in_dtype, out_dtype) :
|
||||
make_reduction("vector_norm", result, self_, dim, keepdim, out_dtype);
|
||||
|
||||
linalg_vector_norm_stub(iter.device_type(), iter, ord);
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1030,7 +1030,10 @@ static Tensor& norm_out(Tensor &result, const Tensor &self, const optional<Scala
|
|||
|
||||
ScalarType out_dtype = result.defined() ? result.scalar_type() : (opt_dtype.has_value() ? opt_dtype.value() : toValueType(self.scalar_type()));
|
||||
|
||||
auto iter = make_reduction("norm", result, self, dim, keepdim, in_dtype, out_dtype);
|
||||
// omit in_dtype in the following call, to avoid make_reduction explicitly casting input to out_dtype
|
||||
auto iter = isComplexType(self.scalar_type()) ?
|
||||
make_reduction("norm", result, self, dim, keepdim, in_dtype, out_dtype) :
|
||||
make_reduction("norm", result, self, dim, keepdim, out_dtype);
|
||||
|
||||
if (iter.numel() == 0) {
|
||||
result.zero_();
|
||||
|
|
|
|||
|
|
@ -210,9 +210,9 @@ static TensorIterator make_reduction(
|
|||
// efficiency.
|
||||
// not generalize this to common mismatched input/output types to avoid cross
|
||||
// product of templated kernel launches.
|
||||
const bool gpu_f16_to_f32 = (
|
||||
self.is_cuda() && self.scalar_type() == kHalf && out_dtype == kFloat);
|
||||
auto in_dtype = gpu_f16_to_f32 ? self.scalar_type() : out_dtype;
|
||||
const bool gpu_lowp_to_f32 = (
|
||||
self.is_cuda() && (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) && out_dtype == kFloat);
|
||||
auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype;
|
||||
return make_reduction(name, result, self, dim, keepdim, in_dtype, out_dtype);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -45,41 +45,31 @@ void norm_kernel_cuda_impl(TensorIterator& iter, const Scalar& val) {
|
|||
|
||||
}
|
||||
|
||||
static void norm_kernel_cuda(TensorIterator& iter, const Scalar& p) {
|
||||
if (iter.input_dtype() == kHalf) {
|
||||
return norm_kernel_cuda_impl<at::Half, float>(iter, p);
|
||||
} else if (iter.dtype(1) == kHalf && iter.input_dtype() == kFloat) {
|
||||
static void norm_dispatch(TensorIterator& iter, const Scalar& ord){
|
||||
if (iter.dtype(0) == kHalf) {
|
||||
return norm_kernel_cuda_impl<at::Half, float>(iter, ord);
|
||||
} else if (iter.input_dtype() == kHalf && iter.dtype(0) == kFloat) {
|
||||
// type promotion that does cast and reduction in a single kernel
|
||||
return norm_kernel_cuda_impl<at::Half, float, float>(iter, p);
|
||||
return norm_kernel_cuda_impl<at::Half, float, float>(iter, ord);
|
||||
}
|
||||
else if(iter.input_dtype() == kBFloat16) {
|
||||
return norm_kernel_cuda_impl<at::BFloat16, float>(iter, p);
|
||||
} else if (iter.dtype(1) == kBFloat16 && iter.input_dtype() == kFloat) {
|
||||
else if(iter.dtype(0) == kBFloat16) {
|
||||
return norm_kernel_cuda_impl<at::BFloat16, float>(iter, ord);
|
||||
} else if (iter.input_dtype() == kBFloat16 && iter.dtype(0) == kFloat) {
|
||||
// type promotion that does cast and reduction in a single kernel
|
||||
return norm_kernel_cuda_impl<at::BFloat16, float, float>(iter, p);
|
||||
return norm_kernel_cuda_impl<at::BFloat16, float, float>(iter, ord);
|
||||
}
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.input_dtype(), "norm_cuda", [&] {
|
||||
norm_kernel_cuda_impl<scalar_t>(iter, p);
|
||||
norm_kernel_cuda_impl<scalar_t>(iter, ord);
|
||||
});
|
||||
}
|
||||
|
||||
static void norm_kernel_cuda(TensorIterator& iter, const Scalar& ord) {
|
||||
norm_dispatch(iter, ord);
|
||||
}
|
||||
|
||||
static void linalg_vector_norm_kernel_cuda(TensorIterator& iter, Scalar ord) {
|
||||
TORCH_CHECK(ord.isFloatingPoint(), "linalg.vector_norm expects ord to be float");
|
||||
if (iter.output().scalar_type() == kHalf) {
|
||||
return norm_kernel_cuda_impl<at::Half, float>(iter, ord);
|
||||
} else if (iter.input_dtype() == kHalf && iter.output().scalar_type() == kFloat) {
|
||||
// type promotion that does cast and reduction in a single kernel
|
||||
return norm_kernel_cuda_impl<at::Half, float, float>(iter, ord);
|
||||
}
|
||||
else if(iter.output().scalar_type() == kBFloat16) {
|
||||
return norm_kernel_cuda_impl<at::BFloat16, float>(iter, ord);
|
||||
} else if (iter.input_dtype() == kBFloat16 && iter.output().scalar_type() == kFloat) {
|
||||
// type promotion that does cast and reduction in a single kernel
|
||||
return norm_kernel_cuda_impl<at::BFloat16, float, float>(iter, ord);
|
||||
}
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.input_dtype(), "linalg_vector_norm_cuda", [&] {
|
||||
norm_kernel_cuda_impl<scalar_t>(iter, ord);
|
||||
});
|
||||
norm_dispatch(iter, ord);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1562,6 +1562,24 @@ class TestLinalg(TestCase):
|
|||
for ord in ord_settings:
|
||||
run_test_case(input, ord, dim, keepdim)
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.bfloat16, torch.float16)
|
||||
def test_norm_fused_type_promotion(self, device, dtype):
|
||||
x = torch.randn(10, device=device, dtype=dtype)
|
||||
|
||||
def profile_and_check(fn, x, kwargs, fn_name):
|
||||
with torch.profiler.profile(activities=(torch.profiler.ProfilerActivity.CPU,)) as p:
|
||||
fn(x, **kwargs, dtype=torch.float)
|
||||
# smoke check that profiler returned some events
|
||||
self.assertTrue(fn_name in map(lambda e: e.name, p.events()))
|
||||
# test that there was no explicit copy
|
||||
self.assertFalse("aten::to" in map(lambda e: e.name, p.events()))
|
||||
|
||||
for f, kwargs, fn_name in zip((torch.norm, torch.linalg.vector_norm), ({"p" : 2}, {}),
|
||||
("aten::norm", "aten::linalg_vector_norm")):
|
||||
profile_and_check(f, x, kwargs, fn_name)
|
||||
|
||||
@skipMeta # https://github.com/pytorch/pytorch/issues/53739
|
||||
@skipCPUIfNoLapack
|
||||
@skipCUDAIfNoMagma
|
||||
|
|
|
|||
Loading…
Reference in a new issue