diff --git a/test/test_meta.py b/test/test_meta.py index 4532bb35677..f6487ae6619 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -440,6 +440,9 @@ def run_meta_crossref( if func is torch.tensor_split: # Use original indices_or_sections, this argument is data dependent meta_args = (meta_args[0], args[1]) + meta_args[2:] + elif func is torch.ops.aten.repeat_interleave.Tensor: + if kwargs.get("output_size", None) is None: + meta_args = args try: # Suppress warnings, this doesn't matter for test_meta.py # but it does matter if you want to use this decorator @@ -840,7 +843,6 @@ meta_dispatch_expected_failures = { aten.polar.default: {f64, f32}, aten.prelu.default: {bf16, f64, f32}, aten.relu.default: {i64, bf16, u8, f32, i8, f64, i16, i32}, - aten.repeat_interleave.Tensor: {c64, i64, c128, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, aten.roll.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, aten.rrelu_with_noise.default: {bf16, f64, f32}, aten.searchsorted.Tensor: {i64, bf16, f16, u8, f32, i8, f64, i16, i32}, diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index cc6b8f986a6..6a3edb53fde 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -135,3 +135,11 @@ def meta_adaptive_avg_pool2d(self, output_size): def meta_adaptive_avg_pool3d(self, output_size): check(self.ndim == 4 or self.ndim == 5, f"Expected 4D or 5D tensor, but got {self.shape}") return self.new_empty(self.shape[:-3] + tuple(output_size)) + +@torch.library.impl(meta_lib, "repeat_interleave.Tensor") +def meta_repeat_interleave_Tensor(repeats, output_size=None): + if output_size is None: + raise RuntimeError( + "cannot repeat_interleave a meta tensor without output_size" + ) + return repeats.new_empty(output_size)