diff --git a/aten/src/ATen/native/Lerp.cpp b/aten/src/ATen/native/Lerp.cpp index ecfdd75e559..c11838a8007 100644 --- a/aten/src/ATen/native/Lerp.cpp +++ b/aten/src/ATen/native/Lerp.cpp @@ -16,10 +16,16 @@ TORCH_META_FUNC(lerp_Tensor)( const Tensor& self, const Tensor& end, const Tensor& weight) { TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(), " for `end` but got dtype ", end.dtype()); - TORCH_CHECK(self.dtype() == weight.dtype(), "expected dtype ", self.dtype(), - " for `weight` but got dtype ", weight.dtype()); + bool promote_weight = weight.dim() == 0; + if (!promote_weight) { + TORCH_CHECK(self.dtype() == weight.dtype(), "expected dtype ", self.dtype(), + " for `weight` but got dtype ", weight.dtype()); + } build(at::TensorIteratorConfig() .allow_cpu_scalars(true) + .promote_inputs_to_common_dtype(promote_weight) + .enforce_safe_casting_to_output(promote_weight) + .cast_common_dtype_to_outputs(promote_weight) .add_output(maybe_get_output()) .add_const_input(self) .add_const_input(end) diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 0035ab30ad0..bdc0d7329df 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -3519,6 +3519,24 @@ class TestBinaryUfuncs(TestCase): expected = torch.lerp(xref, yref, wref).to(dtype) self.assertEqual(actual, expected, atol=0.0, rtol=0.0) + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_lerp_weight_scalar_tensor_promotion(self, device, dtype): + start = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100) + end = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100) + weight = torch.rand((), dtype=torch.float, device=device) + + actual = torch.lerp(start, end, weight) + expected = start + weight.to(dtype) * (end - start) + self.assertEqual(expected, actual) + + @dtypes(torch.double, torch.cfloat, torch.cdouble) + def test_lerp_weight_tensor_promotion_error(self, device, dtype): + start = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100) + end = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100) + weight = torch.rand((5, 5), dtype=torch.float, device=device) + with self.assertRaisesRegex(RuntimeError, "expected dtype"): + torch.lerp(start, end, weight) + def _test_logaddexp(self, device, dtype, base2): if base2: ref_func = np.logaddexp2 diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 173c6c95b2b..15d39fdae27 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -6972,10 +6972,11 @@ def lerp(start, end, weight): ) args = [start, end] if isinstance(weight, TensorLike): - torch._check( - start.dtype == weight.dtype, - lambda: f"expected dtype {start.dtype} for `weight`, but got dtype {weight.dtype}", - ) + if weight.ndim != 0: + torch._check( + start.dtype == weight.dtype, + lambda: f"expected dtype {start.dtype} for `weight`, but got dtype {weight.dtype}", + ) args.append(weight) return elementwise_meta( *args, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT