mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
repeat_interleaves meta function
Taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/78602 Approved by: https://github.com/mruberry
This commit is contained in:
parent
cc6a51c9f3
commit
9446f9678a
2 changed files with 11 additions and 1 deletions
|
|
@ -440,6 +440,9 @@ def run_meta_crossref(
|
|||
if func is torch.tensor_split:
|
||||
# Use original indices_or_sections, this argument is data dependent
|
||||
meta_args = (meta_args[0], args[1]) + meta_args[2:]
|
||||
elif func is torch.ops.aten.repeat_interleave.Tensor:
|
||||
if kwargs.get("output_size", None) is None:
|
||||
meta_args = args
|
||||
try:
|
||||
# Suppress warnings, this doesn't matter for test_meta.py
|
||||
# but it does matter if you want to use this decorator
|
||||
|
|
@ -840,7 +843,6 @@ meta_dispatch_expected_failures = {
|
|||
aten.polar.default: {f64, f32},
|
||||
aten.prelu.default: {bf16, f64, f32},
|
||||
aten.relu.default: {i64, bf16, u8, f32, i8, f64, i16, i32},
|
||||
aten.repeat_interleave.Tensor: {c64, i64, c128, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
||||
aten.roll.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
||||
aten.rrelu_with_noise.default: {bf16, f64, f32},
|
||||
aten.searchsorted.Tensor: {i64, bf16, f16, u8, f32, i8, f64, i16, i32},
|
||||
|
|
|
|||
|
|
@ -135,3 +135,11 @@ def meta_adaptive_avg_pool2d(self, output_size):
|
|||
def meta_adaptive_avg_pool3d(self, output_size):
|
||||
check(self.ndim == 4 or self.ndim == 5, f"Expected 4D or 5D tensor, but got {self.shape}")
|
||||
return self.new_empty(self.shape[:-3] + tuple(output_size))
|
||||
|
||||
@torch.library.impl(meta_lib, "repeat_interleave.Tensor")
|
||||
def meta_repeat_interleave_Tensor(repeats, output_size=None):
|
||||
if output_size is None:
|
||||
raise RuntimeError(
|
||||
"cannot repeat_interleave a meta tensor without output_size"
|
||||
)
|
||||
return repeats.new_empty(output_size)
|
||||
|
|
|
|||
Loading…
Reference in a new issue