From e5ae0e652d77828af6780507ec01d0f6e763db8d Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Thu, 1 Jul 2021 09:25:24 -0700 Subject: [PATCH] [jit] Allow instance overrides of ignored methods (#61076) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61076 Previously we would always retrieve ignored methods from the type, which doesn't work when the user has overriden the ignored method for a specific instance. This PR changes things up so we retrieve the ignored method as a bound method from the object being scripted, unwrap it, then re-bind it to the scriptmodule. Test Plan: Imported from OSS Differential Revision: D29504421 Pulled By: suo fbshipit-source-id: 14649863ea69a8d2180dd2c4341ec9a826039de1 --- test/jit/test_recursive_script.py | 21 +++++++++++++++++++++ torch/jit/_recursive.py | 2 +- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/test/jit/test_recursive_script.py b/test/jit/test_recursive_script.py index 7baf9555e2a..0f04b0b1daf 100644 --- a/test/jit/test_recursive_script.py +++ b/test/jit/test_recursive_script.py @@ -1,5 +1,6 @@ import os import sys +import types import typing import typing_extensions from typing import List, Dict, Optional, Tuple @@ -729,3 +730,23 @@ class TestRecursiveScript(JitTestCase): self.checkModule(mod, (torch.rand(2, 2),)) mod.foo = None self.checkModule(mod, (torch.rand(2, 2),)) + + def test_override_instance_method_ignore(self): + class M(torch.nn.Module): + @torch.jit.ignore + def i_am_ignored(self): + return "old" + + m = M() + + # Override the ignored method by binding a new method to this instance. + @torch.jit.ignore + def i_am_ignored(self): + return "new" + + m.i_am_ignored = types.MethodType(i_am_ignored, m) + self.assertEqual(m.i_am_ignored(), "new") + + # ScriptModule should correctly reflect the override. + s = torch.jit.script(m) + self.assertEqual(s.i_am_ignored(), "new") diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index d88a985edbe..2df0fbf4d96 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -499,7 +499,7 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn): continue item = getattr(nn_module, name, None) if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item): - unbound_function = getattr(type(nn_module), name) + unbound_function = getattr(nn_module, name).__func__ bound_method = unbound_function.__get__(script_module) setattr(script_module, name, bound_method) elif concrete_type.is_ignored_attribute(name):