[jit] Allow instance overrides of ignored methods (#61076)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61076

Previously we would always retrieve ignored methods from the
type, which doesn't work when the user has overriden the ignored method
for a specific instance.

This PR changes things up so we retrieve the ignored method as a bound
method from the object being scripted, unwrap it, then re-bind it to the
scriptmodule.

Test Plan: Imported from OSS

Differential Revision: D29504421

Pulled By: suo

fbshipit-source-id: 14649863ea69a8d2180dd2c4341ec9a826039de1
This commit is contained in:
Michael Suo 2021-07-01 09:25:24 -07:00 committed by Facebook GitHub Bot
parent ccfdb30644
commit e5ae0e652d
2 changed files with 22 additions and 1 deletions

View file

@ -1,5 +1,6 @@
import os
import sys
import types
import typing
import typing_extensions
from typing import List, Dict, Optional, Tuple
@ -729,3 +730,23 @@ class TestRecursiveScript(JitTestCase):
self.checkModule(mod, (torch.rand(2, 2),))
mod.foo = None
self.checkModule(mod, (torch.rand(2, 2),))
def test_override_instance_method_ignore(self):
class M(torch.nn.Module):
@torch.jit.ignore
def i_am_ignored(self):
return "old"
m = M()
# Override the ignored method by binding a new method to this instance.
@torch.jit.ignore
def i_am_ignored(self):
return "new"
m.i_am_ignored = types.MethodType(i_am_ignored, m)
self.assertEqual(m.i_am_ignored(), "new")
# ScriptModule should correctly reflect the override.
s = torch.jit.script(m)
self.assertEqual(s.i_am_ignored(), "new")

View file

@ -499,7 +499,7 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn):
continue
item = getattr(nn_module, name, None)
if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item):
unbound_function = getattr(type(nn_module), name)
unbound_function = getattr(nn_module, name).__func__
bound_method = unbound_function.__get__(script_module)
setattr(script_module, name, bound_method)
elif concrete_type.is_ignored_attribute(name):