mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[jit] Allow instance overrides of ignored methods (#61076)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61076 Previously we would always retrieve ignored methods from the type, which doesn't work when the user has overriden the ignored method for a specific instance. This PR changes things up so we retrieve the ignored method as a bound method from the object being scripted, unwrap it, then re-bind it to the scriptmodule. Test Plan: Imported from OSS Differential Revision: D29504421 Pulled By: suo fbshipit-source-id: 14649863ea69a8d2180dd2c4341ec9a826039de1
This commit is contained in:
parent
ccfdb30644
commit
e5ae0e652d
2 changed files with 22 additions and 1 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
import typing_extensions
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
|
|
@ -729,3 +730,23 @@ class TestRecursiveScript(JitTestCase):
|
|||
self.checkModule(mod, (torch.rand(2, 2),))
|
||||
mod.foo = None
|
||||
self.checkModule(mod, (torch.rand(2, 2),))
|
||||
|
||||
def test_override_instance_method_ignore(self):
|
||||
class M(torch.nn.Module):
|
||||
@torch.jit.ignore
|
||||
def i_am_ignored(self):
|
||||
return "old"
|
||||
|
||||
m = M()
|
||||
|
||||
# Override the ignored method by binding a new method to this instance.
|
||||
@torch.jit.ignore
|
||||
def i_am_ignored(self):
|
||||
return "new"
|
||||
|
||||
m.i_am_ignored = types.MethodType(i_am_ignored, m)
|
||||
self.assertEqual(m.i_am_ignored(), "new")
|
||||
|
||||
# ScriptModule should correctly reflect the override.
|
||||
s = torch.jit.script(m)
|
||||
self.assertEqual(s.i_am_ignored(), "new")
|
||||
|
|
|
|||
|
|
@ -499,7 +499,7 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn):
|
|||
continue
|
||||
item = getattr(nn_module, name, None)
|
||||
if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item):
|
||||
unbound_function = getattr(type(nn_module), name)
|
||||
unbound_function = getattr(nn_module, name).__func__
|
||||
bound_method = unbound_function.__get__(script_module)
|
||||
setattr(script_module, name, bound_method)
|
||||
elif concrete_type.is_ignored_attribute(name):
|
||||
|
|
|
|||
Loading…
Reference in a new issue