From 8a170fbacd72359b51828c2cb07dabc7e8b6a3df Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Wed, 31 Mar 2021 00:56:11 -0700 Subject: [PATCH] [package] fix mangling issues with TorchScript (#54915) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54915 TorchScript and torch.package have different mangling schemes. To avoid them interfering with each other, we should undo the torch.package mangling before processing anything with TorchScript (since TS independently makes sure that no names collide). Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D27410472 Pulled By: suo fbshipit-source-id: d1cc013c532d9abb7fb9615122bc465ded4785bb --- test/package/package_a/fake_interface.py | 48 +++++++ test/package/package_a/fake_script_class.py | 16 +++ test/package/test_model.py | 13 +- test/package/test_package_script.py | 145 ++++++++++++++++++++ torch/_jit_internal.py | 6 + torch/csrc/jit/python/pybind_utils.h | 32 ++--- torch/csrc/jit/python/python_ir.cpp | 5 +- torch/csrc/jit/python/script_init.cpp | 18 ++- torch/jit/_script.py | 11 +- torch/jit/_state.py | 31 +++-- torch/jit/annotations.py | 23 ++-- torch/testing/_internal/jit_utils.py | 2 +- 12 files changed, 297 insertions(+), 53 deletions(-) create mode 100644 test/package/package_a/fake_interface.py create mode 100644 test/package/package_a/fake_script_class.py create mode 100644 test/package/test_package_script.py diff --git a/test/package/package_a/fake_interface.py b/test/package/package_a/fake_interface.py new file mode 100644 index 00000000000..111febf0e9b --- /dev/null +++ b/test/package/package_a/fake_interface.py @@ -0,0 +1,48 @@ +import torch +from torch import Tensor + + +@torch.jit.interface +class ModuleInterface(torch.nn.Module): + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: + pass + + +class OrigModule(torch.nn.Module): + """A module that implements ModuleInterface.""" + + def __init__(self): + super(OrigModule, self).__init__() + + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: + return inp1 + inp2 + 1 + + def two(self, input: Tensor) -> Tensor: + return input + 2 + + def forward(self, input: Tensor) -> Tensor: + return input + self.one(input, input) + 1 + + +class NewModule(torch.nn.Module): + """A *different* module that implements ModuleInterface.""" + + def __init__(self): + super(NewModule, self).__init__() + + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: + return inp1 * inp2 + 1 + + def forward(self, input: Tensor) -> Tensor: + return self.one(input, input + 1) + + +class UsesInterface(torch.nn.Module): + proxy_mod: ModuleInterface + + def __init__(self): + super().__init__() + self.proxy_mod = OrigModule() # type: ignore + + def forward(self, input: Tensor) -> Tensor: + return self.proxy_mod.one(input, input) # type: ignore diff --git a/test/package/package_a/fake_script_class.py b/test/package/package_a/fake_script_class.py new file mode 100644 index 00000000000..ea14480263b --- /dev/null +++ b/test/package/package_a/fake_script_class.py @@ -0,0 +1,16 @@ +import torch + +@torch.jit.script +class MyScriptClass: # flake8: noqa + """Intended to be scripted.""" + def __init__(self, x): + self.foo = x + + def set_foo(self, x): + self.foo = x + +@torch.jit.script +def uses_script_class(x): + """Intended to be scripted.""" + foo = MyScriptClass(x) + return foo.foo diff --git a/test/package/test_model.py b/test/package/test_model.py index 800e113b966..12f9a231763 100644 --- a/test/package/test_model.py +++ b/test/package/test_model.py @@ -187,7 +187,18 @@ class ModelTest(PackageTestCase): i = PackageImporter(f1) loaded = i.load_pickle("model", "pickled") - torch.jit.script(loaded) + + # Model should script successfully. + scripted = torch.jit.script(loaded) + + # Scripted model should save and load successfully. + f2 = BytesIO() + torch.jit.save(scripted, f2) + f2.seek(0) + loaded = torch.jit.load(f2) + + input = torch.rand(1, 3, 224, 224) + self.assertTrue(torch.allclose((loaded(input)), resnet(input))) if __name__ == "__main__": diff --git a/test/package/test_package_script.py b/test/package/test_package_script.py new file mode 100644 index 00000000000..9ce09310219 --- /dev/null +++ b/test/package/test_package_script.py @@ -0,0 +1,145 @@ +from io import BytesIO +from textwrap import dedent + +import torch +from torch.package import ( + PackageExporter, + PackageImporter, +) +from torch.testing._internal.common_utils import run_tests + +try: + from .common import PackageTestCase +except ImportError: + # Support the case where we run this file directly. + from common import PackageTestCase # type: ignore + + +class TestPackageScript(PackageTestCase): + """Tests for compatibility with TorchScript.""" + + def test_package_interface(self): + """Packaging an interface class should work correctly.""" + + import package_a.fake_interface as fake + + uses_interface = fake.UsesInterface() + scripted = torch.jit.script(uses_interface) + scripted.proxy_mod = torch.jit.script(fake.NewModule()) + + buffer = BytesIO() + with PackageExporter(buffer, verbose=False) as pe: + pe.save_pickle("model", "model.pkl", uses_interface) + buffer.seek(0) + + package_importer = PackageImporter(buffer) + loaded = package_importer.load_pickle("model", "model.pkl") + + scripted_loaded = torch.jit.script(loaded) + scripted_loaded.proxy_mod = torch.jit.script(fake.NewModule()) + + input = torch.tensor(1) + + self.assertTrue(torch.allclose(scripted(input), scripted_loaded(input))) + + def test_different_package_interface(self): + """Test a case where the interface defined in the package is + different than the one defined in the loading environment, to make + sure TorchScript can distinguish between the two. + """ + # Import one version of the interface + import package_a.fake_interface as fake + + # Simulate a package that contains a different version of the + # interface, with the exact same name. + buffer = BytesIO() + with PackageExporter(buffer, verbose=False) as pe: + pe.save_source_string( + fake.__name__, + dedent( + """\ + import torch + from torch import Tensor + + @torch.jit.interface + class ModuleInterface(torch.nn.Module): + def one(self, inp1: Tensor) -> Tensor: + pass + + class ImplementsInterface(torch.nn.Module): + def one(self, inp1: Tensor) -> Tensor: + return inp1 + 1 + + class UsesInterface(torch.nn.Module): + proxy_mod: ModuleInterface + + def __init__(self): + super().__init__() + self.proxy_mod = ImplementsInterface() + + def forward(self, input: Tensor) -> Tensor: + return self.proxy_mod.one(input) + """ + ), + ) + buffer.seek(0) + + package_importer = PackageImporter(buffer) + diff_fake = package_importer.import_module(fake.__name__) + # We should be able to script successfully. + torch.jit.script(diff_fake.UsesInterface()) + + def test_package_script_class(self): + import package_a.fake_script_class as fake + + buffer = BytesIO() + with PackageExporter(buffer, verbose=False) as pe: + pe.save_module(fake.__name__) + buffer.seek(0) + + package_importer = PackageImporter(buffer) + loaded = package_importer.import_module(fake.__name__) + + input = torch.tensor(1) + self.assertTrue( + torch.allclose( + fake.uses_script_class(input), loaded.uses_script_class(input) + ) + ) + + def test_different_package_script_class(self): + """Test a case where the script class defined in the package is + different than the one defined in the loading environment, to make + sure TorchScript can distinguish between the two. + """ + import package_a.fake_script_class as fake + + # Simulate a package that contains a different version of the + # script class ,with the attribute `bar` instead of `foo` + buffer = BytesIO() + with PackageExporter(buffer, verbose=False) as pe2: + pe2.save_source_string( + fake.__name__, + dedent( + """\ + import torch + + @torch.jit.script + class MyScriptClass: + def __init__(self, x): + self.bar = x + """ + ), + ) + buffer.seek(0) + + package_importer = PackageImporter(buffer) + diff_fake = package_importer.import_module(fake.__name__) + input = torch.rand(2, 3) + loaded_script_class = diff_fake.MyScriptClass(input) + orig_script_class = fake.MyScriptClass(input) + self.assertTrue(torch.allclose(loaded_script_class.bar, orig_script_class.foo)) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index d7f1259d67f..6a1c250970e 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -21,6 +21,7 @@ import builtins import torch.distributed.rpc from torch._utils_internal import get_source_lines_and_file from torch.futures import Future +import torch.package._mangling as package_mangling from typing import Tuple, List, Dict, Optional, Union, Any, TypeVar, Generic, Callable # noqa: F401 if sys.version_info[:2] > (3, 7): @@ -926,6 +927,11 @@ def _qualified_name(obj): # raise RuntimeError(f"Could not get qualified name for class '{name}': " # f"the attr {name} on module {module_name} is not the the class") + # torch.package and TorchScript have separate mangling schemes to avoid + # name collisions from multiple packages. To avoid them interfering with + # each other, remove the package mangling here. + module_name = package_mangling.demangle(module_name) + # __main__ is a builtin module, so rewrite it to "__torch__". if module_name == "__main__": module_name = "__torch__" diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index fd9ae6c5be9..e2492897148 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -334,17 +334,12 @@ inline InferredType tryToInferType(py::handle input) { py::bool_ isClass = py::module::import("inspect").attr("isclass")(input.get_type()); if (py::cast(isClass)) { - py::str qualifiedName = py::module::import("torch._jit_internal") - .attr("_qualified_name")(input.get_type()); - auto pyClass = py::module::import("torch.jit._state") - .attr("_get_script_class")(qualifiedName); - if (!pyClass.is_none()) { - auto cu = get_python_cu(); - const auto classname = - c10::QualifiedName(py::cast(qualifiedName)); - auto class_type = cu->get_class(classname); - TORCH_INTERNAL_ASSERT(class_type); - return InferredType(class_type); + auto scriptClass = py::module::import("torch.jit._state") + .attr("_get_script_class")(input.get_type()); + if (!scriptClass.is_none()) { + auto classType = py::cast(scriptClass); + TORCH_INTERNAL_ASSERT(classType); + return InferredType(classType); } } @@ -630,13 +625,14 @@ inline IValue returnToIValue(const TypePtr& type, py::handle object) { } } -inline py::object getScriptedClassOrError(const std::string& name) { - auto py_class = py::module::import("torch.jit._state") - .attr("_get_script_class")(name.c_str()); +inline py::object getScriptedClassOrError(const c10::NamedTypePtr& classType) { + auto py_class = + py::module::import("torch.jit._state") + .attr("_get_python_class")(classType->name()->qualifiedName()); if (py_class.is_none()) { std::stringstream err; err << "Unknown reference to ScriptClass "; - err << name; + err << classType->name()->qualifiedName(); err << ". (Did you forget to import it?)"; throw std::runtime_error(err.str()); } @@ -724,7 +720,7 @@ inline py::object toPyObject(IValue ivalue) { } const auto classType = pyCu->get_class(c10::QualifiedName(obj->name())); AT_ASSERT(classType); - auto pyClass = getScriptedClassOrError(obj->name()); + auto pyClass = getScriptedClassOrError(obj->type()); auto pyObj = pyClass.attr("__new__")(pyClass); const auto numAttrs = classType->numAttributes(); @@ -745,9 +741,7 @@ inline py::object toPyObject(IValue ivalue) { return py::cast(std::make_shared(ivalue.toFuture())); } else if (ivalue.isEnum()) { auto enum_holder = ivalue.toEnumHolder(); - auto qualified_class_name = enum_holder->qualifiedClassName(); - - auto py_class = getScriptedClassOrError(qualified_class_name); + auto py_class = getScriptedClassOrError(enum_holder->type()); return py_class.attr(enum_holder->name().c_str()); } else if (ivalue.isRRef()) { #ifdef USE_RPC diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index 0601926297e..012c97c6571 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -868,7 +868,10 @@ void initPythonIRBindings(PyObject* module_) { .def(py::init([](const std::string& qualified_name) { return get_python_cu()->get_class(c10::QualifiedName(qualified_name)); })) - .def("name", [](ClassType& self) { return self.name()->name(); }); + .def("name", [](ClassType& self) { return self.name()->name(); }) + .def("qualified_name", [](ClassType& self) { + return self.name()->qualifiedName(); + }); py::class_>(m, "EnumType") .def(py::init([](const std::string& qualified_name, TypePtr value_type, diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index ba19c9ac5f9..0cad40ebb15 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -1497,7 +1497,10 @@ void initJitScriptBindings(PyObject* module) { << "Torchscript does not support class inheritance."; } auto cu = get_python_cu(); - const auto classname = c10::QualifiedName(qualifiedName); + auto classname = c10::QualifiedName(qualifiedName); + if (cu->get_type(classname) != nullptr) { + classname = cu->mangle(classname); + } auto classType = ClassType::create(classname, cu); cu->register_type(classType); std::vector methodRcbs, propRcbs; @@ -1552,6 +1555,7 @@ void initJitScriptBindings(PyObject* module) { default_it->second)); ++defs_it; } + return classType; }); m.def( "_jit_script_interface_compile", @@ -1559,11 +1563,15 @@ void initJitScriptBindings(PyObject* module) { const ClassDef& classDef, const ResolutionCallback& rcb, bool is_module) { + auto cu = get_python_cu(); + auto className = c10::QualifiedName(qualifiedName); + if (cu->get_type(className) != nullptr) { + className = cu->mangle(className); + } + get_python_cu()->define_interface( - c10::QualifiedName(qualifiedName), - classDef, - pythonResolver(rcb), - is_module); + className, classDef, pythonResolver(rcb), is_module); + return className.qualifiedName(); }); py::class_( diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 56f870aad5a..4a5506ebc38 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -122,8 +122,9 @@ def _is_new_style_class(cls): def _compile_and_register_class(obj, rcb, qualified_name): ast = get_jit_class_def(obj, obj.__name__) defaults = torch.jit.frontend.get_default_args_for_class(obj) - torch._C._jit_script_class_compile(qualified_name, ast, defaults, rcb) - torch.jit._state._add_script_class(obj, qualified_name) + script_class = torch._C._jit_script_class_compile(qualified_name, ast, defaults, rcb) + torch.jit._state._add_script_class(obj, script_class) + return script_class # These OrderedDictWrapper classes replace the actual OrderedDicts in @@ -1148,10 +1149,10 @@ def interface(obj): # instead of a class interface type, an module interface type only compile # the user provided methods as part of the interface ast = get_jit_class_def(obj, obj.__name__) - torch._C._jit_script_interface_compile( + mangled_classname = torch._C._jit_script_interface_compile( qualified_name, ast, rcb, is_module_interface ) - obj.__torch_script_interface__ = True + obj.__torch_script_interface__ = mangled_classname return obj @@ -1161,7 +1162,7 @@ def _recursive_compile_class(obj, loc): # case it fails error_stack = torch._C.CallStack(_qual_name, loc) rcb = _jit_internal.createResolutionCallbackForClassMethods(obj) - _compile_and_register_class(obj, rcb, _qual_name) + return _compile_and_register_class(obj, rcb, _qual_name) CompilationUnit = torch._C.CompilationUnit set_module(CompilationUnit, "torch.jit") diff --git a/torch/jit/_state.py b/torch/jit/_state.py index eb81c1e463a..b87ba5dd60d 100644 --- a/torch/jit/_state.py +++ b/torch/jit/_state.py @@ -57,19 +57,30 @@ def enable(): _python_cu = torch._C.CompilationUnit() -# qualified_name => ScriptClass mapping +# python class => ScriptClass mapping _script_classes = {} - -def _add_script_class(cls, name): - global _script_classes - _script_classes[name] = cls +_name_to_pyclass = {} -def _get_script_class(name): - global _script_classes - if name not in _script_classes: - return None - return _script_classes[name] +def _add_script_class(python_class, script_class): + _script_classes[python_class] = script_class + _name_to_pyclass[script_class.qualified_name()] = python_class + + +def _get_script_class(python_class): + override = getattr(python_class, "_jit_override_qualname", None) + if override is not None: + python_class = _get_python_class(override) + return _script_classes.get(python_class, None) + + +def _get_python_class(qualified_name): + return _name_to_pyclass.get(qualified_name, None) + + +def _clear_class_state(): + _script_classes.clear() + _name_to_pyclass.clear() # Caching: we currently cache compilation of free functions and overloaded functions. diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index 2c8fd77f979..91f78c7bbfe 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -10,7 +10,7 @@ from .._jit_internal import BroadcastingList1, BroadcastingList2, BroadcastingLi from ._state import _get_script_class from torch._C import TensorType, TupleType, FloatType, IntType, ComplexType, \ - ListType, StringType, DictType, BoolType, OptionalType, ClassType, InterfaceType, AnyType, NoneType, \ + ListType, StringType, DictType, BoolType, OptionalType, InterfaceType, AnyType, NoneType, \ DeviceObjType, StreamObjType, FutureType, EnumType @@ -324,7 +324,7 @@ def try_ann_to_type(ann, loc): if ann is type(None): return NoneType.get() if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"): - return InterfaceType(_qualified_name(ann)) + return InterfaceType(ann.__torch_script_interface__) if ann is torch.device: return DeviceObjType.get() if ann is torch.Stream: @@ -332,18 +332,19 @@ def try_ann_to_type(ann, loc): if ann is torch.dtype: return IntType.get() # dtype not yet bound in as its own type if inspect.isclass(ann) and issubclass(ann, enum.Enum): - qualified_name = _qualified_name(ann) - if _get_script_class(qualified_name) is None: - torch.jit._script._recursive_compile_class(ann, loc) - return EnumType(_qualified_name(ann), get_enum_value_type(ann, loc), list(ann)) + if _get_script_class(ann) is None: + scripted_class = torch.jit._script._recursive_compile_class(ann, loc) + name = scripted_class.qualified_name() + else: + name = _qualified_name(ann) + return EnumType(name, get_enum_value_type(ann, loc), list(ann)) if inspect.isclass(ann): - qualified_name = _qualified_name(ann) - if _get_script_class(qualified_name) is not None: - return ClassType(qualified_name) + maybe_script_class = _get_script_class(ann) + if maybe_script_class is not None: + return maybe_script_class ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception) if torch._jit_internal.can_compile_class(ann) and not issubclass(ann, ignored_builtin_classes): - torch.jit._script._recursive_compile_class(ann, loc) - return ClassType(qualified_name) + return torch.jit._script._recursive_compile_class(ann, loc) # Maybe resolve a NamedTuple to a Tuple Type def fake_rcb(key): diff --git a/torch/testing/_internal/jit_utils.py b/torch/testing/_internal/jit_utils.py index b3ac61055bb..7c277567411 100644 --- a/torch/testing/_internal/jit_utils.py +++ b/torch/testing/_internal/jit_utils.py @@ -58,7 +58,7 @@ def do_input_map(fn, input): def clear_class_registry(): torch._C._jit_clear_class_registry() torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() - torch.jit._state._script_classes.clear() + torch.jit._state._clear_class_state() def get_execution_plan(graph_executor_state): execution_plans = list(graph_executor_state.execution_plans.values())