[package] fix mangling issues with TorchScript (#54915)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54915

TorchScript and torch.package have different mangling schemes. To avoid
them interfering with each other, we should undo the torch.package
mangling before processing anything with TorchScript (since TS
independently makes sure that no names collide).

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Differential Revision: D27410472

Pulled By: suo

fbshipit-source-id: d1cc013c532d9abb7fb9615122bc465ded4785bb
This commit is contained in:
Michael Suo 2021-03-31 00:56:11 -07:00 committed by Facebook GitHub Bot
parent 444e5f0b60
commit 8a170fbacd
12 changed files with 297 additions and 53 deletions

View file

@ -0,0 +1,48 @@
import torch
from torch import Tensor
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
pass
class OrigModule(torch.nn.Module):
"""A module that implements ModuleInterface."""
def __init__(self):
super(OrigModule, self).__init__()
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
return inp1 + inp2 + 1
def two(self, input: Tensor) -> Tensor:
return input + 2
def forward(self, input: Tensor) -> Tensor:
return input + self.one(input, input) + 1
class NewModule(torch.nn.Module):
"""A *different* module that implements ModuleInterface."""
def __init__(self):
super(NewModule, self).__init__()
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
return inp1 * inp2 + 1
def forward(self, input: Tensor) -> Tensor:
return self.one(input, input + 1)
class UsesInterface(torch.nn.Module):
proxy_mod: ModuleInterface
def __init__(self):
super().__init__()
self.proxy_mod = OrigModule() # type: ignore
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod.one(input, input) # type: ignore

View file

@ -0,0 +1,16 @@
import torch
@torch.jit.script
class MyScriptClass: # flake8: noqa
"""Intended to be scripted."""
def __init__(self, x):
self.foo = x
def set_foo(self, x):
self.foo = x
@torch.jit.script
def uses_script_class(x):
"""Intended to be scripted."""
foo = MyScriptClass(x)
return foo.foo

View file

@ -187,7 +187,18 @@ class ModelTest(PackageTestCase):
i = PackageImporter(f1)
loaded = i.load_pickle("model", "pickled")
torch.jit.script(loaded)
# Model should script successfully.
scripted = torch.jit.script(loaded)
# Scripted model should save and load successfully.
f2 = BytesIO()
torch.jit.save(scripted, f2)
f2.seek(0)
loaded = torch.jit.load(f2)
input = torch.rand(1, 3, 224, 224)
self.assertTrue(torch.allclose((loaded(input)), resnet(input)))
if __name__ == "__main__":

View file

@ -0,0 +1,145 @@
from io import BytesIO
from textwrap import dedent
import torch
from torch.package import (
PackageExporter,
PackageImporter,
)
from torch.testing._internal.common_utils import run_tests
try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
from common import PackageTestCase # type: ignore
class TestPackageScript(PackageTestCase):
"""Tests for compatibility with TorchScript."""
def test_package_interface(self):
"""Packaging an interface class should work correctly."""
import package_a.fake_interface as fake
uses_interface = fake.UsesInterface()
scripted = torch.jit.script(uses_interface)
scripted.proxy_mod = torch.jit.script(fake.NewModule())
buffer = BytesIO()
with PackageExporter(buffer, verbose=False) as pe:
pe.save_pickle("model", "model.pkl", uses_interface)
buffer.seek(0)
package_importer = PackageImporter(buffer)
loaded = package_importer.load_pickle("model", "model.pkl")
scripted_loaded = torch.jit.script(loaded)
scripted_loaded.proxy_mod = torch.jit.script(fake.NewModule())
input = torch.tensor(1)
self.assertTrue(torch.allclose(scripted(input), scripted_loaded(input)))
def test_different_package_interface(self):
"""Test a case where the interface defined in the package is
different than the one defined in the loading environment, to make
sure TorchScript can distinguish between the two.
"""
# Import one version of the interface
import package_a.fake_interface as fake
# Simulate a package that contains a different version of the
# interface, with the exact same name.
buffer = BytesIO()
with PackageExporter(buffer, verbose=False) as pe:
pe.save_source_string(
fake.__name__,
dedent(
"""\
import torch
from torch import Tensor
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def one(self, inp1: Tensor) -> Tensor:
pass
class ImplementsInterface(torch.nn.Module):
def one(self, inp1: Tensor) -> Tensor:
return inp1 + 1
class UsesInterface(torch.nn.Module):
proxy_mod: ModuleInterface
def __init__(self):
super().__init__()
self.proxy_mod = ImplementsInterface()
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod.one(input)
"""
),
)
buffer.seek(0)
package_importer = PackageImporter(buffer)
diff_fake = package_importer.import_module(fake.__name__)
# We should be able to script successfully.
torch.jit.script(diff_fake.UsesInterface())
def test_package_script_class(self):
import package_a.fake_script_class as fake
buffer = BytesIO()
with PackageExporter(buffer, verbose=False) as pe:
pe.save_module(fake.__name__)
buffer.seek(0)
package_importer = PackageImporter(buffer)
loaded = package_importer.import_module(fake.__name__)
input = torch.tensor(1)
self.assertTrue(
torch.allclose(
fake.uses_script_class(input), loaded.uses_script_class(input)
)
)
def test_different_package_script_class(self):
"""Test a case where the script class defined in the package is
different than the one defined in the loading environment, to make
sure TorchScript can distinguish between the two.
"""
import package_a.fake_script_class as fake
# Simulate a package that contains a different version of the
# script class ,with the attribute `bar` instead of `foo`
buffer = BytesIO()
with PackageExporter(buffer, verbose=False) as pe2:
pe2.save_source_string(
fake.__name__,
dedent(
"""\
import torch
@torch.jit.script
class MyScriptClass:
def __init__(self, x):
self.bar = x
"""
),
)
buffer.seek(0)
package_importer = PackageImporter(buffer)
diff_fake = package_importer.import_module(fake.__name__)
input = torch.rand(2, 3)
loaded_script_class = diff_fake.MyScriptClass(input)
orig_script_class = fake.MyScriptClass(input)
self.assertTrue(torch.allclose(loaded_script_class.bar, orig_script_class.foo))
if __name__ == "__main__":
run_tests()

View file

@ -21,6 +21,7 @@ import builtins
import torch.distributed.rpc
from torch._utils_internal import get_source_lines_and_file
from torch.futures import Future
import torch.package._mangling as package_mangling
from typing import Tuple, List, Dict, Optional, Union, Any, TypeVar, Generic, Callable # noqa: F401
if sys.version_info[:2] > (3, 7):
@ -926,6 +927,11 @@ def _qualified_name(obj):
# raise RuntimeError(f"Could not get qualified name for class '{name}': "
# f"the attr {name} on module {module_name} is not the the class")
# torch.package and TorchScript have separate mangling schemes to avoid
# name collisions from multiple packages. To avoid them interfering with
# each other, remove the package mangling here.
module_name = package_mangling.demangle(module_name)
# __main__ is a builtin module, so rewrite it to "__torch__".
if module_name == "__main__":
module_name = "__torch__"

View file

@ -334,17 +334,12 @@ inline InferredType tryToInferType(py::handle input) {
py::bool_ isClass =
py::module::import("inspect").attr("isclass")(input.get_type());
if (py::cast<bool>(isClass)) {
py::str qualifiedName = py::module::import("torch._jit_internal")
.attr("_qualified_name")(input.get_type());
auto pyClass = py::module::import("torch.jit._state")
.attr("_get_script_class")(qualifiedName);
if (!pyClass.is_none()) {
auto cu = get_python_cu();
const auto classname =
c10::QualifiedName(py::cast<std::string>(qualifiedName));
auto class_type = cu->get_class(classname);
TORCH_INTERNAL_ASSERT(class_type);
return InferredType(class_type);
auto scriptClass = py::module::import("torch.jit._state")
.attr("_get_script_class")(input.get_type());
if (!scriptClass.is_none()) {
auto classType = py::cast<ClassTypePtr>(scriptClass);
TORCH_INTERNAL_ASSERT(classType);
return InferredType(classType);
}
}
@ -630,13 +625,14 @@ inline IValue returnToIValue(const TypePtr& type, py::handle object) {
}
}
inline py::object getScriptedClassOrError(const std::string& name) {
auto py_class = py::module::import("torch.jit._state")
.attr("_get_script_class")(name.c_str());
inline py::object getScriptedClassOrError(const c10::NamedTypePtr& classType) {
auto py_class =
py::module::import("torch.jit._state")
.attr("_get_python_class")(classType->name()->qualifiedName());
if (py_class.is_none()) {
std::stringstream err;
err << "Unknown reference to ScriptClass ";
err << name;
err << classType->name()->qualifiedName();
err << ". (Did you forget to import it?)";
throw std::runtime_error(err.str());
}
@ -724,7 +720,7 @@ inline py::object toPyObject(IValue ivalue) {
}
const auto classType = pyCu->get_class(c10::QualifiedName(obj->name()));
AT_ASSERT(classType);
auto pyClass = getScriptedClassOrError(obj->name());
auto pyClass = getScriptedClassOrError(obj->type());
auto pyObj = pyClass.attr("__new__")(pyClass);
const auto numAttrs = classType->numAttributes();
@ -745,9 +741,7 @@ inline py::object toPyObject(IValue ivalue) {
return py::cast(std::make_shared<PythonFutureWrapper>(ivalue.toFuture()));
} else if (ivalue.isEnum()) {
auto enum_holder = ivalue.toEnumHolder();
auto qualified_class_name = enum_holder->qualifiedClassName();
auto py_class = getScriptedClassOrError(qualified_class_name);
auto py_class = getScriptedClassOrError(enum_holder->type());
return py_class.attr(enum_holder->name().c_str());
} else if (ivalue.isRRef()) {
#ifdef USE_RPC

View file

@ -868,7 +868,10 @@ void initPythonIRBindings(PyObject* module_) {
.def(py::init([](const std::string& qualified_name) {
return get_python_cu()->get_class(c10::QualifiedName(qualified_name));
}))
.def("name", [](ClassType& self) { return self.name()->name(); });
.def("name", [](ClassType& self) { return self.name()->name(); })
.def("qualified_name", [](ClassType& self) {
return self.name()->qualifiedName();
});
py::class_<EnumType, Type, std::shared_ptr<EnumType>>(m, "EnumType")
.def(py::init([](const std::string& qualified_name,
TypePtr value_type,

View file

@ -1497,7 +1497,10 @@ void initJitScriptBindings(PyObject* module) {
<< "Torchscript does not support class inheritance.";
}
auto cu = get_python_cu();
const auto classname = c10::QualifiedName(qualifiedName);
auto classname = c10::QualifiedName(qualifiedName);
if (cu->get_type(classname) != nullptr) {
classname = cu->mangle(classname);
}
auto classType = ClassType::create(classname, cu);
cu->register_type(classType);
std::vector<ResolverPtr> methodRcbs, propRcbs;
@ -1552,6 +1555,7 @@ void initJitScriptBindings(PyObject* module) {
default_it->second));
++defs_it;
}
return classType;
});
m.def(
"_jit_script_interface_compile",
@ -1559,11 +1563,15 @@ void initJitScriptBindings(PyObject* module) {
const ClassDef& classDef,
const ResolutionCallback& rcb,
bool is_module) {
auto cu = get_python_cu();
auto className = c10::QualifiedName(qualifiedName);
if (cu->get_type(className) != nullptr) {
className = cu->mangle(className);
}
get_python_cu()->define_interface(
c10::QualifiedName(qualifiedName),
classDef,
pythonResolver(rcb),
is_module);
className, classDef, pythonResolver(rcb), is_module);
return className.qualifiedName();
});
py::class_<torch::jit::ErrorReport::CallStack>(

View file

@ -122,8 +122,9 @@ def _is_new_style_class(cls):
def _compile_and_register_class(obj, rcb, qualified_name):
ast = get_jit_class_def(obj, obj.__name__)
defaults = torch.jit.frontend.get_default_args_for_class(obj)
torch._C._jit_script_class_compile(qualified_name, ast, defaults, rcb)
torch.jit._state._add_script_class(obj, qualified_name)
script_class = torch._C._jit_script_class_compile(qualified_name, ast, defaults, rcb)
torch.jit._state._add_script_class(obj, script_class)
return script_class
# These OrderedDictWrapper classes replace the actual OrderedDicts in
@ -1148,10 +1149,10 @@ def interface(obj):
# instead of a class interface type, an module interface type only compile
# the user provided methods as part of the interface
ast = get_jit_class_def(obj, obj.__name__)
torch._C._jit_script_interface_compile(
mangled_classname = torch._C._jit_script_interface_compile(
qualified_name, ast, rcb, is_module_interface
)
obj.__torch_script_interface__ = True
obj.__torch_script_interface__ = mangled_classname
return obj
@ -1161,7 +1162,7 @@ def _recursive_compile_class(obj, loc):
# case it fails
error_stack = torch._C.CallStack(_qual_name, loc)
rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
_compile_and_register_class(obj, rcb, _qual_name)
return _compile_and_register_class(obj, rcb, _qual_name)
CompilationUnit = torch._C.CompilationUnit
set_module(CompilationUnit, "torch.jit")

View file

@ -57,19 +57,30 @@ def enable():
_python_cu = torch._C.CompilationUnit()
# qualified_name => ScriptClass mapping
# python class => ScriptClass mapping
_script_classes = {}
def _add_script_class(cls, name):
global _script_classes
_script_classes[name] = cls
_name_to_pyclass = {}
def _get_script_class(name):
global _script_classes
if name not in _script_classes:
return None
return _script_classes[name]
def _add_script_class(python_class, script_class):
_script_classes[python_class] = script_class
_name_to_pyclass[script_class.qualified_name()] = python_class
def _get_script_class(python_class):
override = getattr(python_class, "_jit_override_qualname", None)
if override is not None:
python_class = _get_python_class(override)
return _script_classes.get(python_class, None)
def _get_python_class(qualified_name):
return _name_to_pyclass.get(qualified_name, None)
def _clear_class_state():
_script_classes.clear()
_name_to_pyclass.clear()
# Caching: we currently cache compilation of free functions and overloaded functions.

View file

@ -10,7 +10,7 @@ from .._jit_internal import BroadcastingList1, BroadcastingList2, BroadcastingLi
from ._state import _get_script_class
from torch._C import TensorType, TupleType, FloatType, IntType, ComplexType, \
ListType, StringType, DictType, BoolType, OptionalType, ClassType, InterfaceType, AnyType, NoneType, \
ListType, StringType, DictType, BoolType, OptionalType, InterfaceType, AnyType, NoneType, \
DeviceObjType, StreamObjType, FutureType, EnumType
@ -324,7 +324,7 @@ def try_ann_to_type(ann, loc):
if ann is type(None):
return NoneType.get()
if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
return InterfaceType(_qualified_name(ann))
return InterfaceType(ann.__torch_script_interface__)
if ann is torch.device:
return DeviceObjType.get()
if ann is torch.Stream:
@ -332,18 +332,19 @@ def try_ann_to_type(ann, loc):
if ann is torch.dtype:
return IntType.get() # dtype not yet bound in as its own type
if inspect.isclass(ann) and issubclass(ann, enum.Enum):
qualified_name = _qualified_name(ann)
if _get_script_class(qualified_name) is None:
torch.jit._script._recursive_compile_class(ann, loc)
return EnumType(_qualified_name(ann), get_enum_value_type(ann, loc), list(ann))
if _get_script_class(ann) is None:
scripted_class = torch.jit._script._recursive_compile_class(ann, loc)
name = scripted_class.qualified_name()
else:
name = _qualified_name(ann)
return EnumType(name, get_enum_value_type(ann, loc), list(ann))
if inspect.isclass(ann):
qualified_name = _qualified_name(ann)
if _get_script_class(qualified_name) is not None:
return ClassType(qualified_name)
maybe_script_class = _get_script_class(ann)
if maybe_script_class is not None:
return maybe_script_class
ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception)
if torch._jit_internal.can_compile_class(ann) and not issubclass(ann, ignored_builtin_classes):
torch.jit._script._recursive_compile_class(ann, loc)
return ClassType(qualified_name)
return torch.jit._script._recursive_compile_class(ann, loc)
# Maybe resolve a NamedTuple to a Tuple Type
def fake_rcb(key):

View file

@ -58,7 +58,7 @@ def do_input_map(fn, input):
def clear_class_registry():
torch._C._jit_clear_class_registry()
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
torch.jit._state._script_classes.clear()
torch.jit._state._clear_class_state()
def get_execution_plan(graph_executor_state):
execution_plans = list(graph_executor_state.execution_plans.values())