From 54e2f4b2013decbf764d0d6b22c561509cadc42c Mon Sep 17 00:00:00 2001 From: zeshengzong Date: Fri, 24 Jan 2025 01:18:18 +0000 Subject: [PATCH] Fix lerp weight type promotion (#141117) Fixes #140601 Enable `promote_inputs_to_common_dtype` when tensors not same dtype when invoke `lerp` function. For `lerp_Tensor` - Check whether same `dtype` of tensors, enable promote if not - Remove type check assert For `lerp_Scalar` - Seems already enable `promote_inputs_to_common_dtype` by default, just remove the type check. Make sure promote behavior consistent with `lerp_Tensor` `lerp_Scalar` get TensorIteratorConfig from here https://github.com/pytorch/pytorch/blob/c37185c76ae4068899869e48a8388e78437508e8/aten/src/ATen/TensorIterator.cpp#L979-L985 **Test Result** Test case in issue passed ```python >>> import torch >>> >>> x = torch.ones(2, 2, dtype=torch.float64) >>> w = torch.ones(2, 2, dtype=torch.float64) >>> s = torch.tensor(2.2) >>> x.lerp_(w, s) tensor([[1., 1.], [1., 1.]], dtype=torch.float64) >>> x = torch.ones(2, 2, dtype=torch.float16) >>> w = torch.ones(2, 2, dtype=torch.float16) >>> s = torch.tensor(2.2) >>> x.lerp_(w, s) tensor([[1., 1.], [1., 1.]], dtype=torch.float16) ``` ```bash $ pytest test/test_binary_ufuncs.py -k 'test_lerp_tensor_type_promotion or test_lerp_scalar_type_promotion' ``` ![image](https://github.com/user-attachments/assets/288a5294-a9ee-47f3-bbf7-d4ff986f3ba8) ```bash $ lintrunner ``` ![image](https://github.com/user-attachments/assets/d469836f-5c49-4d89-a2fd-379cad4db3af) Pull Request resolved: https://github.com/pytorch/pytorch/pull/141117 Approved by: https://github.com/janeyx99 Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com> --- aten/src/ATen/native/Lerp.cpp | 10 ++++++++-- test/test_binary_ufuncs.py | 18 ++++++++++++++++++ torch/_meta_registrations.py | 9 +++++---- 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/Lerp.cpp b/aten/src/ATen/native/Lerp.cpp index ecfdd75e559..c11838a8007 100644 --- a/aten/src/ATen/native/Lerp.cpp +++ b/aten/src/ATen/native/Lerp.cpp @@ -16,10 +16,16 @@ TORCH_META_FUNC(lerp_Tensor)( const Tensor& self, const Tensor& end, const Tensor& weight) { TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(), " for `end` but got dtype ", end.dtype()); - TORCH_CHECK(self.dtype() == weight.dtype(), "expected dtype ", self.dtype(), - " for `weight` but got dtype ", weight.dtype()); + bool promote_weight = weight.dim() == 0; + if (!promote_weight) { + TORCH_CHECK(self.dtype() == weight.dtype(), "expected dtype ", self.dtype(), + " for `weight` but got dtype ", weight.dtype()); + } build(at::TensorIteratorConfig() .allow_cpu_scalars(true) + .promote_inputs_to_common_dtype(promote_weight) + .enforce_safe_casting_to_output(promote_weight) + .cast_common_dtype_to_outputs(promote_weight) .add_output(maybe_get_output()) .add_const_input(self) .add_const_input(end) diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 0035ab30ad0..bdc0d7329df 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -3519,6 +3519,24 @@ class TestBinaryUfuncs(TestCase): expected = torch.lerp(xref, yref, wref).to(dtype) self.assertEqual(actual, expected, atol=0.0, rtol=0.0) + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_lerp_weight_scalar_tensor_promotion(self, device, dtype): + start = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100) + end = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100) + weight = torch.rand((), dtype=torch.float, device=device) + + actual = torch.lerp(start, end, weight) + expected = start + weight.to(dtype) * (end - start) + self.assertEqual(expected, actual) + + @dtypes(torch.double, torch.cfloat, torch.cdouble) + def test_lerp_weight_tensor_promotion_error(self, device, dtype): + start = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100) + end = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100) + weight = torch.rand((5, 5), dtype=torch.float, device=device) + with self.assertRaisesRegex(RuntimeError, "expected dtype"): + torch.lerp(start, end, weight) + def _test_logaddexp(self, device, dtype, base2): if base2: ref_func = np.logaddexp2 diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 173c6c95b2b..15d39fdae27 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -6972,10 +6972,11 @@ def lerp(start, end, weight): ) args = [start, end] if isinstance(weight, TensorLike): - torch._check( - start.dtype == weight.dtype, - lambda: f"expected dtype {start.dtype} for `weight`, but got dtype {weight.dtype}", - ) + if weight.ndim != 0: + torch._check( + start.dtype == weight.dtype, + lambda: f"expected dtype {start.dtype} for `weight`, but got dtype {weight.dtype}", + ) args.append(weight) return elementwise_meta( *args, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT