From bdaa7bbd7d29096c4f4402ab6edbfb744c7d404c Mon Sep 17 00:00:00 2001 From: William Wen Date: Fri, 3 May 2024 13:29:27 -0700 Subject: [PATCH] [dynamo] fix potentially missing _torchdynamo_inline from ScriptFunction (#125447) Fix https://github.com/pytorch/pytorch/issues/119747 Pull Request resolved: https://github.com/pytorch/pytorch/pull/125447 Approved by: https://github.com/jansel --- test/dynamo_expected_failures/TestScript.test_nested_breaks | 0 torch/jit/_script.py | 1 + 2 files changed, 1 insertion(+) delete mode 100644 test/dynamo_expected_failures/TestScript.test_nested_breaks diff --git a/test/dynamo_expected_failures/TestScript.test_nested_breaks b/test/dynamo_expected_failures/TestScript.test_nested_breaks deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/torch/jit/_script.py b/torch/jit/_script.py index a0adf60284f..c18843f7469 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -1391,6 +1391,7 @@ def script( _check_directly_compile_overloaded(obj) maybe_already_compiled_fn = _try_get_jit_cached_function(obj) if maybe_already_compiled_fn: + maybe_already_compiled_fn._torchdynamo_inline = obj # type: ignore[attr-defined] return maybe_already_compiled_fn ast = get_jit_def(obj, obj.__name__) if _rcb is None: