[dynamo] avoid graph break on repeat_interleave.self_int (#99528)

Address convit_base failure: https://github.com/pytorch/torchdynamo/issues/1886 mentioned in https://github.com/pytorch/pytorch/issues/93777
Also for models like EleutherAI/gpt-j-6B.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99528
Approved by: https://github.com/ezyang
This commit is contained in:
Jiong Gong 2023-04-21 18:30:36 +08:00 committed by PyTorch MergeBot
parent ecd2c71871
commit e5c9a0fcf5
9 changed files with 59 additions and 9 deletions

View file

@ -102,8 +102,17 @@ Tensor repeat_interleave(
int64_t repeats,
c10::optional<int64_t> dim,
c10::optional<int64_t> output_size) {
Tensor input = self;
at::Tensor repeats_ = at::empty(1, self.options().dtype(at::kLong)).fill_(repeats);
return at::native::repeat_interleave(self, repeats_, dim, output_size);
if (!output_size) {
if (!dim) {
input = input.flatten();
dim = 0;
}
auto input_size = input.sym_size(dim.value()).guard_int(__FILE__, __LINE__);
output_size = input_size * repeats;
}
return at::native::repeat_interleave(input, repeats_, dim, output_size);
}
Tensor repeat_interleave_symint(

View file

@ -3,7 +3,7 @@ adv_inception_v3,pass,0
beit_base_patch16_224,pass,0
botnet26t_256,pass,0
coat_lite_mini,pass,0
convit_base,fail_to_run,4
convit_base,pass,0
convmixer_768_32,pass,0
convnext_base,pass,0
crossvit_9_240,pass,0

1 name accuracy graph_breaks
3 beit_base_patch16_224 pass 0
4 botnet26t_256 pass 0
5 coat_lite_mini pass 0
6 convit_base fail_to_run pass 4 0
7 convmixer_768_32 pass 0
8 convnext_base pass 0
9 crossvit_9_240 pass 0

View file

@ -3,7 +3,7 @@ adv_inception_v3,pass,9
beit_base_patch16_224,pass,9
botnet26t_256,pass,11
coat_lite_mini,pass,9
convit_base,fail_to_run,8
convit_base,pass,12
convmixer_768_32,pass,6
convnext_base,pass,9
crossvit_9_240,pass,9

1 name accuracy graph_breaks
3 beit_base_patch16_224 pass 9
4 botnet26t_256 pass 11
5 coat_lite_mini pass 9
6 convit_base fail_to_run pass 8 12
7 convmixer_768_32 pass 6
8 convnext_base pass 9
9 crossvit_9_240 pass 9

View file

@ -3,7 +3,7 @@ adv_inception_v3,pass,0
beit_base_patch16_224,pass,0
botnet26t_256,pass,0
coat_lite_mini,pass,0
convit_base,pass,15
convit_base,pass,0
convmixer_768_32,pass,0
convnext_base,pass,0
crossvit_9_240,pass,0

1 name accuracy graph_breaks
3 beit_base_patch16_224 pass 0
4 botnet26t_256 pass 0
5 coat_lite_mini pass 0
6 convit_base pass 15 0
7 convmixer_768_32 pass 0
8 convnext_base pass 0
9 crossvit_9_240 pass 0

View file

@ -3,7 +3,7 @@ adv_inception_v3,pass,9
beit_base_patch16_224,pass,9
botnet26t_256,pass,11
coat_lite_mini,pass,9
convit_base,pass,25
convit_base,pass,12
convmixer_768_32,pass,6
convnext_base,pass,9
crossvit_9_240,pass,9

1 name accuracy graph_breaks
3 beit_base_patch16_224 pass 9
4 botnet26t_256 pass 11
5 coat_lite_mini pass 9
6 convit_base pass 25 12
7 convmixer_768_32 pass 6
8 convnext_base pass 9
9 crossvit_9_240 pass 9

View file

@ -2213,6 +2213,35 @@ def fn():
opt_fn(x, Foo.BAR)
self.assertEqual(cnts.op_count, 1)
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
def test_repeat_interleave_graphbreaks(self):
def fn_no_breaks(x):
# no breaks on self_int
x += 1
x = torch.repeat_interleave(x, 2, 3)
x += 1
return x
def fn_has_breaks(x):
# breaks on self_Tensor
x += 1
x = torch.repeat_interleave(x, torch.tensor(2), 3)
x += 1
return x
x = torch.randn([4, 16, 1, 64])
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn_no_breaks)
opt_fn(x)
self.assertEqual(cnts.frame_count, 1)
torch._dynamo.reset()
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn_has_breaks)
opt_fn(x)
self.assertEqual(cnts.frame_count, 2)
def test_id_of_nn_module(self):
class M(torch.nn.Module):
def forward(self, x, ref_id):

View file

@ -579,7 +579,7 @@ meta disagrees with real impl:
else:
seen_succeeded.setdefault(func, set()).add(dtype)
if test_expect is TestExpect.XFAILURE and not COLLECT_EXPECT:
raise RuntimeError(f"unexpected success {resolve_name(func)}")
raise RuntimeError(f"unexpected success {resolve_name(func)} {meta_args} {meta_kwargs}")
return rs
@ -603,7 +603,6 @@ meta_function_expected_failures = {
torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
torch.Tensor.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
torch.ormqr : {f64, c64, c128, f32},
torch.repeat_interleave : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
torch.bincount : {i32, i64, u8, i16, i8},
torch.frexp : {f64, f16, bf16, f32},
@ -642,6 +641,10 @@ meta_function_expected_failures_only_outplace = {
torch.nn.functional.rrelu : {f64, bf16, f32},
}
meta_function_expected_failures_conditional = {
torch.repeat_interleave : (lambda dtype, *args, **kwargs: not isinstance(kwargs.get("repeats", None), int)),
}
"""
# This is some sample code for how we could dump these dicts into YAML
# file for easier reading/writing
@ -799,6 +802,8 @@ class MetaCrossRefFunctionMode(torch.overrides.TorchFunctionMode):
test_expect = TestExpect.XFAILURE
elif self.dtype in meta_function_device_expected_failures[self.device_type].get(func, set()):
test_expect = TestExpect.XFAILURE
elif meta_function_expected_failures_conditional.get(func, lambda *_, **__: False)(self.dtype, *args, **kwargs):
test_expect = TestExpect.XFAILURE
elif not self.inplace and \
self.dtype in meta_function_device_expected_failures_only_outplace[self.device_type].get(func, set()):
test_expect = TestExpect.XFAILURE

View file

@ -1888,7 +1888,6 @@ else:
lambda: x.nonzero(),
lambda: _cond_fn(y),
lambda: torch.nn.functional.one_hot(ind),
lambda: torch.repeat_interleave(x, 2),
lambda: torch.repeat_interleave(x, repeats))
for f, level in product(expect_no_sync, (1, 2)):
_no_sync_helper(f, level)

View file

@ -445,12 +445,20 @@ def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
# index.Tensor data-dependent in only some conditions
@register_op_impl(
lambda func: torch.Tag.dynamic_output_shape in func.tags # type: ignore[attr-defined]
and func not in [aten.index.Tensor, aten.nonzero.default]
and func
not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor]
)
def dyn_shape(fake_mode, func, *args, **kwargs):
raise DynamicOutputShapeException(func)
@register_op_impl(lambda func: func is aten.repeat_interleave.Tensor)
def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
if output_size is None:
raise DynamicOutputShapeException(func)
return repeats.new_empty(output_size)
@register_op_impl(lambda func: func is torch.ops.aten._local_scalar_dense.default)
def local_scalar_dense(fake_mode, func, arg):
if fake_mode.shape_env is None or not fake_mode.shape_env.allow_scalar_outputs: