mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dynamo] avoid graph break on repeat_interleave.self_int (#99528)
Address convit_base failure: https://github.com/pytorch/torchdynamo/issues/1886 mentioned in https://github.com/pytorch/pytorch/issues/93777 Also for models like EleutherAI/gpt-j-6B. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99528 Approved by: https://github.com/ezyang
This commit is contained in:
parent
ecd2c71871
commit
e5c9a0fcf5
9 changed files with 59 additions and 9 deletions
|
|
@ -102,8 +102,17 @@ Tensor repeat_interleave(
|
|||
int64_t repeats,
|
||||
c10::optional<int64_t> dim,
|
||||
c10::optional<int64_t> output_size) {
|
||||
Tensor input = self;
|
||||
at::Tensor repeats_ = at::empty(1, self.options().dtype(at::kLong)).fill_(repeats);
|
||||
return at::native::repeat_interleave(self, repeats_, dim, output_size);
|
||||
if (!output_size) {
|
||||
if (!dim) {
|
||||
input = input.flatten();
|
||||
dim = 0;
|
||||
}
|
||||
auto input_size = input.sym_size(dim.value()).guard_int(__FILE__, __LINE__);
|
||||
output_size = input_size * repeats;
|
||||
}
|
||||
return at::native::repeat_interleave(input, repeats_, dim, output_size);
|
||||
}
|
||||
|
||||
Tensor repeat_interleave_symint(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ adv_inception_v3,pass,0
|
|||
beit_base_patch16_224,pass,0
|
||||
botnet26t_256,pass,0
|
||||
coat_lite_mini,pass,0
|
||||
convit_base,fail_to_run,4
|
||||
convit_base,pass,0
|
||||
convmixer_768_32,pass,0
|
||||
convnext_base,pass,0
|
||||
crossvit_9_240,pass,0
|
||||
|
|
|
|||
|
|
|
@ -3,7 +3,7 @@ adv_inception_v3,pass,9
|
|||
beit_base_patch16_224,pass,9
|
||||
botnet26t_256,pass,11
|
||||
coat_lite_mini,pass,9
|
||||
convit_base,fail_to_run,8
|
||||
convit_base,pass,12
|
||||
convmixer_768_32,pass,6
|
||||
convnext_base,pass,9
|
||||
crossvit_9_240,pass,9
|
||||
|
|
|
|||
|
|
|
@ -3,7 +3,7 @@ adv_inception_v3,pass,0
|
|||
beit_base_patch16_224,pass,0
|
||||
botnet26t_256,pass,0
|
||||
coat_lite_mini,pass,0
|
||||
convit_base,pass,15
|
||||
convit_base,pass,0
|
||||
convmixer_768_32,pass,0
|
||||
convnext_base,pass,0
|
||||
crossvit_9_240,pass,0
|
||||
|
|
|
|||
|
|
|
@ -3,7 +3,7 @@ adv_inception_v3,pass,9
|
|||
beit_base_patch16_224,pass,9
|
||||
botnet26t_256,pass,11
|
||||
coat_lite_mini,pass,9
|
||||
convit_base,pass,25
|
||||
convit_base,pass,12
|
||||
convmixer_768_32,pass,6
|
||||
convnext_base,pass,9
|
||||
crossvit_9_240,pass,9
|
||||
|
|
|
|||
|
|
|
@ -2213,6 +2213,35 @@ def fn():
|
|||
opt_fn(x, Foo.BAR)
|
||||
self.assertEqual(cnts.op_count, 1)
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
def test_repeat_interleave_graphbreaks(self):
|
||||
def fn_no_breaks(x):
|
||||
# no breaks on self_int
|
||||
x += 1
|
||||
x = torch.repeat_interleave(x, 2, 3)
|
||||
x += 1
|
||||
return x
|
||||
|
||||
def fn_has_breaks(x):
|
||||
# breaks on self_Tensor
|
||||
x += 1
|
||||
x = torch.repeat_interleave(x, torch.tensor(2), 3)
|
||||
x += 1
|
||||
return x
|
||||
|
||||
x = torch.randn([4, 16, 1, 64])
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn_no_breaks)
|
||||
opt_fn(x)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
torch._dynamo.reset()
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn_has_breaks)
|
||||
opt_fn(x)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
def test_id_of_nn_module(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, ref_id):
|
||||
|
|
|
|||
|
|
@ -579,7 +579,7 @@ meta disagrees with real impl:
|
|||
else:
|
||||
seen_succeeded.setdefault(func, set()).add(dtype)
|
||||
if test_expect is TestExpect.XFAILURE and not COLLECT_EXPECT:
|
||||
raise RuntimeError(f"unexpected success {resolve_name(func)}")
|
||||
raise RuntimeError(f"unexpected success {resolve_name(func)} {meta_args} {meta_kwargs}")
|
||||
|
||||
return rs
|
||||
|
||||
|
|
@ -603,7 +603,6 @@ meta_function_expected_failures = {
|
|||
torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
|
||||
torch.Tensor.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
|
||||
torch.ormqr : {f64, c64, c128, f32},
|
||||
torch.repeat_interleave : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
|
||||
torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
|
||||
torch.bincount : {i32, i64, u8, i16, i8},
|
||||
torch.frexp : {f64, f16, bf16, f32},
|
||||
|
|
@ -642,6 +641,10 @@ meta_function_expected_failures_only_outplace = {
|
|||
torch.nn.functional.rrelu : {f64, bf16, f32},
|
||||
}
|
||||
|
||||
meta_function_expected_failures_conditional = {
|
||||
torch.repeat_interleave : (lambda dtype, *args, **kwargs: not isinstance(kwargs.get("repeats", None), int)),
|
||||
}
|
||||
|
||||
"""
|
||||
# This is some sample code for how we could dump these dicts into YAML
|
||||
# file for easier reading/writing
|
||||
|
|
@ -799,6 +802,8 @@ class MetaCrossRefFunctionMode(torch.overrides.TorchFunctionMode):
|
|||
test_expect = TestExpect.XFAILURE
|
||||
elif self.dtype in meta_function_device_expected_failures[self.device_type].get(func, set()):
|
||||
test_expect = TestExpect.XFAILURE
|
||||
elif meta_function_expected_failures_conditional.get(func, lambda *_, **__: False)(self.dtype, *args, **kwargs):
|
||||
test_expect = TestExpect.XFAILURE
|
||||
elif not self.inplace and \
|
||||
self.dtype in meta_function_device_expected_failures_only_outplace[self.device_type].get(func, set()):
|
||||
test_expect = TestExpect.XFAILURE
|
||||
|
|
|
|||
|
|
@ -1888,7 +1888,6 @@ else:
|
|||
lambda: x.nonzero(),
|
||||
lambda: _cond_fn(y),
|
||||
lambda: torch.nn.functional.one_hot(ind),
|
||||
lambda: torch.repeat_interleave(x, 2),
|
||||
lambda: torch.repeat_interleave(x, repeats))
|
||||
for f, level in product(expect_no_sync, (1, 2)):
|
||||
_no_sync_helper(f, level)
|
||||
|
|
|
|||
|
|
@ -445,12 +445,20 @@ def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
|
|||
# index.Tensor data-dependent in only some conditions
|
||||
@register_op_impl(
|
||||
lambda func: torch.Tag.dynamic_output_shape in func.tags # type: ignore[attr-defined]
|
||||
and func not in [aten.index.Tensor, aten.nonzero.default]
|
||||
and func
|
||||
not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor]
|
||||
)
|
||||
def dyn_shape(fake_mode, func, *args, **kwargs):
|
||||
raise DynamicOutputShapeException(func)
|
||||
|
||||
|
||||
@register_op_impl(lambda func: func is aten.repeat_interleave.Tensor)
|
||||
def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
|
||||
if output_size is None:
|
||||
raise DynamicOutputShapeException(func)
|
||||
return repeats.new_empty(output_size)
|
||||
|
||||
|
||||
@register_op_impl(lambda func: func is torch.ops.aten._local_scalar_dense.default)
|
||||
def local_scalar_dense(fake_mode, func, arg):
|
||||
if fake_mode.shape_env is None or not fake_mode.shape_env.allow_scalar_outputs:
|
||||
|
|
|
|||
Loading…
Reference in a new issue