repeat_interleaves meta function

Taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78602

Approved by: https://github.com/mruberry
This commit is contained in:
Edward Z. Yang 2022-06-02 06:32:07 -07:00 committed by PyTorch MergeBot
parent cc6a51c9f3
commit 9446f9678a
2 changed files with 11 additions and 1 deletions

View file

@ -440,6 +440,9 @@ def run_meta_crossref(
if func is torch.tensor_split:
# Use original indices_or_sections, this argument is data dependent
meta_args = (meta_args[0], args[1]) + meta_args[2:]
elif func is torch.ops.aten.repeat_interleave.Tensor:
if kwargs.get("output_size", None) is None:
meta_args = args
try:
# Suppress warnings, this doesn't matter for test_meta.py
# but it does matter if you want to use this decorator
@ -840,7 +843,6 @@ meta_dispatch_expected_failures = {
aten.polar.default: {f64, f32},
aten.prelu.default: {bf16, f64, f32},
aten.relu.default: {i64, bf16, u8, f32, i8, f64, i16, i32},
aten.repeat_interleave.Tensor: {c64, i64, c128, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
aten.roll.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
aten.rrelu_with_noise.default: {bf16, f64, f32},
aten.searchsorted.Tensor: {i64, bf16, f16, u8, f32, i8, f64, i16, i32},

View file

@ -135,3 +135,11 @@ def meta_adaptive_avg_pool2d(self, output_size):
def meta_adaptive_avg_pool3d(self, output_size):
check(self.ndim == 4 or self.ndim == 5, f"Expected 4D or 5D tensor, but got {self.shape}")
return self.new_empty(self.shape[:-3] + tuple(output_size))
@torch.library.impl(meta_lib, "repeat_interleave.Tensor")
def meta_repeat_interleave_Tensor(repeats, output_size=None):
if output_size is None:
raise RuntimeError(
"cannot repeat_interleave a meta tensor without output_size"
)
return repeats.new_empty(output_size)