diff --git a/test/jit/test_recursive_script.py b/test/jit/test_recursive_script.py index 7baf9555e2a..0f04b0b1daf 100644 --- a/test/jit/test_recursive_script.py +++ b/test/jit/test_recursive_script.py @@ -1,5 +1,6 @@ import os import sys +import types import typing import typing_extensions from typing import List, Dict, Optional, Tuple @@ -729,3 +730,23 @@ class TestRecursiveScript(JitTestCase): self.checkModule(mod, (torch.rand(2, 2),)) mod.foo = None self.checkModule(mod, (torch.rand(2, 2),)) + + def test_override_instance_method_ignore(self): + class M(torch.nn.Module): + @torch.jit.ignore + def i_am_ignored(self): + return "old" + + m = M() + + # Override the ignored method by binding a new method to this instance. + @torch.jit.ignore + def i_am_ignored(self): + return "new" + + m.i_am_ignored = types.MethodType(i_am_ignored, m) + self.assertEqual(m.i_am_ignored(), "new") + + # ScriptModule should correctly reflect the override. + s = torch.jit.script(m) + self.assertEqual(s.i_am_ignored(), "new") diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index d88a985edbe..2df0fbf4d96 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -499,7 +499,7 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn): continue item = getattr(nn_module, name, None) if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item): - unbound_function = getattr(type(nn_module), name) + unbound_function = getattr(nn_module, name).__func__ bound_method = unbound_function.__get__(script_module) setattr(script_module, name, bound_method) elif concrete_type.is_ignored_attribute(name):