mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[package] fix mangling issues with TorchScript (#54915)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54915 TorchScript and torch.package have different mangling schemes. To avoid them interfering with each other, we should undo the torch.package mangling before processing anything with TorchScript (since TS independently makes sure that no names collide). Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D27410472 Pulled By: suo fbshipit-source-id: d1cc013c532d9abb7fb9615122bc465ded4785bb
This commit is contained in:
parent
444e5f0b60
commit
8a170fbacd
12 changed files with 297 additions and 53 deletions
48
test/package/package_a/fake_interface.py
Normal file
48
test/package/package_a/fake_interface.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@torch.jit.interface
|
||||
class ModuleInterface(torch.nn.Module):
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
pass
|
||||
|
||||
|
||||
class OrigModule(torch.nn.Module):
|
||||
"""A module that implements ModuleInterface."""
|
||||
|
||||
def __init__(self):
|
||||
super(OrigModule, self).__init__()
|
||||
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
return inp1 + inp2 + 1
|
||||
|
||||
def two(self, input: Tensor) -> Tensor:
|
||||
return input + 2
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return input + self.one(input, input) + 1
|
||||
|
||||
|
||||
class NewModule(torch.nn.Module):
|
||||
"""A *different* module that implements ModuleInterface."""
|
||||
|
||||
def __init__(self):
|
||||
super(NewModule, self).__init__()
|
||||
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
return inp1 * inp2 + 1
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return self.one(input, input + 1)
|
||||
|
||||
|
||||
class UsesInterface(torch.nn.Module):
|
||||
proxy_mod: ModuleInterface
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.proxy_mod = OrigModule() # type: ignore
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return self.proxy_mod.one(input, input) # type: ignore
|
||||
16
test/package/package_a/fake_script_class.py
Normal file
16
test/package/package_a/fake_script_class.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
import torch
|
||||
|
||||
@torch.jit.script
|
||||
class MyScriptClass: # flake8: noqa
|
||||
"""Intended to be scripted."""
|
||||
def __init__(self, x):
|
||||
self.foo = x
|
||||
|
||||
def set_foo(self, x):
|
||||
self.foo = x
|
||||
|
||||
@torch.jit.script
|
||||
def uses_script_class(x):
|
||||
"""Intended to be scripted."""
|
||||
foo = MyScriptClass(x)
|
||||
return foo.foo
|
||||
|
|
@ -187,7 +187,18 @@ class ModelTest(PackageTestCase):
|
|||
|
||||
i = PackageImporter(f1)
|
||||
loaded = i.load_pickle("model", "pickled")
|
||||
torch.jit.script(loaded)
|
||||
|
||||
# Model should script successfully.
|
||||
scripted = torch.jit.script(loaded)
|
||||
|
||||
# Scripted model should save and load successfully.
|
||||
f2 = BytesIO()
|
||||
torch.jit.save(scripted, f2)
|
||||
f2.seek(0)
|
||||
loaded = torch.jit.load(f2)
|
||||
|
||||
input = torch.rand(1, 3, 224, 224)
|
||||
self.assertTrue(torch.allclose((loaded(input)), resnet(input)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
145
test/package/test_package_script.py
Normal file
145
test/package/test_package_script.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
from io import BytesIO
|
||||
from textwrap import dedent
|
||||
|
||||
import torch
|
||||
from torch.package import (
|
||||
PackageExporter,
|
||||
PackageImporter,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
try:
|
||||
from .common import PackageTestCase
|
||||
except ImportError:
|
||||
# Support the case where we run this file directly.
|
||||
from common import PackageTestCase # type: ignore
|
||||
|
||||
|
||||
class TestPackageScript(PackageTestCase):
|
||||
"""Tests for compatibility with TorchScript."""
|
||||
|
||||
def test_package_interface(self):
|
||||
"""Packaging an interface class should work correctly."""
|
||||
|
||||
import package_a.fake_interface as fake
|
||||
|
||||
uses_interface = fake.UsesInterface()
|
||||
scripted = torch.jit.script(uses_interface)
|
||||
scripted.proxy_mod = torch.jit.script(fake.NewModule())
|
||||
|
||||
buffer = BytesIO()
|
||||
with PackageExporter(buffer, verbose=False) as pe:
|
||||
pe.save_pickle("model", "model.pkl", uses_interface)
|
||||
buffer.seek(0)
|
||||
|
||||
package_importer = PackageImporter(buffer)
|
||||
loaded = package_importer.load_pickle("model", "model.pkl")
|
||||
|
||||
scripted_loaded = torch.jit.script(loaded)
|
||||
scripted_loaded.proxy_mod = torch.jit.script(fake.NewModule())
|
||||
|
||||
input = torch.tensor(1)
|
||||
|
||||
self.assertTrue(torch.allclose(scripted(input), scripted_loaded(input)))
|
||||
|
||||
def test_different_package_interface(self):
|
||||
"""Test a case where the interface defined in the package is
|
||||
different than the one defined in the loading environment, to make
|
||||
sure TorchScript can distinguish between the two.
|
||||
"""
|
||||
# Import one version of the interface
|
||||
import package_a.fake_interface as fake
|
||||
|
||||
# Simulate a package that contains a different version of the
|
||||
# interface, with the exact same name.
|
||||
buffer = BytesIO()
|
||||
with PackageExporter(buffer, verbose=False) as pe:
|
||||
pe.save_source_string(
|
||||
fake.__name__,
|
||||
dedent(
|
||||
"""\
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
@torch.jit.interface
|
||||
class ModuleInterface(torch.nn.Module):
|
||||
def one(self, inp1: Tensor) -> Tensor:
|
||||
pass
|
||||
|
||||
class ImplementsInterface(torch.nn.Module):
|
||||
def one(self, inp1: Tensor) -> Tensor:
|
||||
return inp1 + 1
|
||||
|
||||
class UsesInterface(torch.nn.Module):
|
||||
proxy_mod: ModuleInterface
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.proxy_mod = ImplementsInterface()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return self.proxy_mod.one(input)
|
||||
"""
|
||||
),
|
||||
)
|
||||
buffer.seek(0)
|
||||
|
||||
package_importer = PackageImporter(buffer)
|
||||
diff_fake = package_importer.import_module(fake.__name__)
|
||||
# We should be able to script successfully.
|
||||
torch.jit.script(diff_fake.UsesInterface())
|
||||
|
||||
def test_package_script_class(self):
|
||||
import package_a.fake_script_class as fake
|
||||
|
||||
buffer = BytesIO()
|
||||
with PackageExporter(buffer, verbose=False) as pe:
|
||||
pe.save_module(fake.__name__)
|
||||
buffer.seek(0)
|
||||
|
||||
package_importer = PackageImporter(buffer)
|
||||
loaded = package_importer.import_module(fake.__name__)
|
||||
|
||||
input = torch.tensor(1)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
fake.uses_script_class(input), loaded.uses_script_class(input)
|
||||
)
|
||||
)
|
||||
|
||||
def test_different_package_script_class(self):
|
||||
"""Test a case where the script class defined in the package is
|
||||
different than the one defined in the loading environment, to make
|
||||
sure TorchScript can distinguish between the two.
|
||||
"""
|
||||
import package_a.fake_script_class as fake
|
||||
|
||||
# Simulate a package that contains a different version of the
|
||||
# script class ,with the attribute `bar` instead of `foo`
|
||||
buffer = BytesIO()
|
||||
with PackageExporter(buffer, verbose=False) as pe2:
|
||||
pe2.save_source_string(
|
||||
fake.__name__,
|
||||
dedent(
|
||||
"""\
|
||||
import torch
|
||||
|
||||
@torch.jit.script
|
||||
class MyScriptClass:
|
||||
def __init__(self, x):
|
||||
self.bar = x
|
||||
"""
|
||||
),
|
||||
)
|
||||
buffer.seek(0)
|
||||
|
||||
package_importer = PackageImporter(buffer)
|
||||
diff_fake = package_importer.import_module(fake.__name__)
|
||||
input = torch.rand(2, 3)
|
||||
loaded_script_class = diff_fake.MyScriptClass(input)
|
||||
orig_script_class = fake.MyScriptClass(input)
|
||||
self.assertTrue(torch.allclose(loaded_script_class.bar, orig_script_class.foo))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -21,6 +21,7 @@ import builtins
|
|||
import torch.distributed.rpc
|
||||
from torch._utils_internal import get_source_lines_and_file
|
||||
from torch.futures import Future
|
||||
import torch.package._mangling as package_mangling
|
||||
from typing import Tuple, List, Dict, Optional, Union, Any, TypeVar, Generic, Callable # noqa: F401
|
||||
|
||||
if sys.version_info[:2] > (3, 7):
|
||||
|
|
@ -926,6 +927,11 @@ def _qualified_name(obj):
|
|||
# raise RuntimeError(f"Could not get qualified name for class '{name}': "
|
||||
# f"the attr {name} on module {module_name} is not the the class")
|
||||
|
||||
# torch.package and TorchScript have separate mangling schemes to avoid
|
||||
# name collisions from multiple packages. To avoid them interfering with
|
||||
# each other, remove the package mangling here.
|
||||
module_name = package_mangling.demangle(module_name)
|
||||
|
||||
# __main__ is a builtin module, so rewrite it to "__torch__".
|
||||
if module_name == "__main__":
|
||||
module_name = "__torch__"
|
||||
|
|
|
|||
|
|
@ -334,17 +334,12 @@ inline InferredType tryToInferType(py::handle input) {
|
|||
py::bool_ isClass =
|
||||
py::module::import("inspect").attr("isclass")(input.get_type());
|
||||
if (py::cast<bool>(isClass)) {
|
||||
py::str qualifiedName = py::module::import("torch._jit_internal")
|
||||
.attr("_qualified_name")(input.get_type());
|
||||
auto pyClass = py::module::import("torch.jit._state")
|
||||
.attr("_get_script_class")(qualifiedName);
|
||||
if (!pyClass.is_none()) {
|
||||
auto cu = get_python_cu();
|
||||
const auto classname =
|
||||
c10::QualifiedName(py::cast<std::string>(qualifiedName));
|
||||
auto class_type = cu->get_class(classname);
|
||||
TORCH_INTERNAL_ASSERT(class_type);
|
||||
return InferredType(class_type);
|
||||
auto scriptClass = py::module::import("torch.jit._state")
|
||||
.attr("_get_script_class")(input.get_type());
|
||||
if (!scriptClass.is_none()) {
|
||||
auto classType = py::cast<ClassTypePtr>(scriptClass);
|
||||
TORCH_INTERNAL_ASSERT(classType);
|
||||
return InferredType(classType);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -630,13 +625,14 @@ inline IValue returnToIValue(const TypePtr& type, py::handle object) {
|
|||
}
|
||||
}
|
||||
|
||||
inline py::object getScriptedClassOrError(const std::string& name) {
|
||||
auto py_class = py::module::import("torch.jit._state")
|
||||
.attr("_get_script_class")(name.c_str());
|
||||
inline py::object getScriptedClassOrError(const c10::NamedTypePtr& classType) {
|
||||
auto py_class =
|
||||
py::module::import("torch.jit._state")
|
||||
.attr("_get_python_class")(classType->name()->qualifiedName());
|
||||
if (py_class.is_none()) {
|
||||
std::stringstream err;
|
||||
err << "Unknown reference to ScriptClass ";
|
||||
err << name;
|
||||
err << classType->name()->qualifiedName();
|
||||
err << ". (Did you forget to import it?)";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
|
@ -724,7 +720,7 @@ inline py::object toPyObject(IValue ivalue) {
|
|||
}
|
||||
const auto classType = pyCu->get_class(c10::QualifiedName(obj->name()));
|
||||
AT_ASSERT(classType);
|
||||
auto pyClass = getScriptedClassOrError(obj->name());
|
||||
auto pyClass = getScriptedClassOrError(obj->type());
|
||||
auto pyObj = pyClass.attr("__new__")(pyClass);
|
||||
|
||||
const auto numAttrs = classType->numAttributes();
|
||||
|
|
@ -745,9 +741,7 @@ inline py::object toPyObject(IValue ivalue) {
|
|||
return py::cast(std::make_shared<PythonFutureWrapper>(ivalue.toFuture()));
|
||||
} else if (ivalue.isEnum()) {
|
||||
auto enum_holder = ivalue.toEnumHolder();
|
||||
auto qualified_class_name = enum_holder->qualifiedClassName();
|
||||
|
||||
auto py_class = getScriptedClassOrError(qualified_class_name);
|
||||
auto py_class = getScriptedClassOrError(enum_holder->type());
|
||||
return py_class.attr(enum_holder->name().c_str());
|
||||
} else if (ivalue.isRRef()) {
|
||||
#ifdef USE_RPC
|
||||
|
|
|
|||
|
|
@ -868,7 +868,10 @@ void initPythonIRBindings(PyObject* module_) {
|
|||
.def(py::init([](const std::string& qualified_name) {
|
||||
return get_python_cu()->get_class(c10::QualifiedName(qualified_name));
|
||||
}))
|
||||
.def("name", [](ClassType& self) { return self.name()->name(); });
|
||||
.def("name", [](ClassType& self) { return self.name()->name(); })
|
||||
.def("qualified_name", [](ClassType& self) {
|
||||
return self.name()->qualifiedName();
|
||||
});
|
||||
py::class_<EnumType, Type, std::shared_ptr<EnumType>>(m, "EnumType")
|
||||
.def(py::init([](const std::string& qualified_name,
|
||||
TypePtr value_type,
|
||||
|
|
|
|||
|
|
@ -1497,7 +1497,10 @@ void initJitScriptBindings(PyObject* module) {
|
|||
<< "Torchscript does not support class inheritance.";
|
||||
}
|
||||
auto cu = get_python_cu();
|
||||
const auto classname = c10::QualifiedName(qualifiedName);
|
||||
auto classname = c10::QualifiedName(qualifiedName);
|
||||
if (cu->get_type(classname) != nullptr) {
|
||||
classname = cu->mangle(classname);
|
||||
}
|
||||
auto classType = ClassType::create(classname, cu);
|
||||
cu->register_type(classType);
|
||||
std::vector<ResolverPtr> methodRcbs, propRcbs;
|
||||
|
|
@ -1552,6 +1555,7 @@ void initJitScriptBindings(PyObject* module) {
|
|||
default_it->second));
|
||||
++defs_it;
|
||||
}
|
||||
return classType;
|
||||
});
|
||||
m.def(
|
||||
"_jit_script_interface_compile",
|
||||
|
|
@ -1559,11 +1563,15 @@ void initJitScriptBindings(PyObject* module) {
|
|||
const ClassDef& classDef,
|
||||
const ResolutionCallback& rcb,
|
||||
bool is_module) {
|
||||
auto cu = get_python_cu();
|
||||
auto className = c10::QualifiedName(qualifiedName);
|
||||
if (cu->get_type(className) != nullptr) {
|
||||
className = cu->mangle(className);
|
||||
}
|
||||
|
||||
get_python_cu()->define_interface(
|
||||
c10::QualifiedName(qualifiedName),
|
||||
classDef,
|
||||
pythonResolver(rcb),
|
||||
is_module);
|
||||
className, classDef, pythonResolver(rcb), is_module);
|
||||
return className.qualifiedName();
|
||||
});
|
||||
|
||||
py::class_<torch::jit::ErrorReport::CallStack>(
|
||||
|
|
|
|||
|
|
@ -122,8 +122,9 @@ def _is_new_style_class(cls):
|
|||
def _compile_and_register_class(obj, rcb, qualified_name):
|
||||
ast = get_jit_class_def(obj, obj.__name__)
|
||||
defaults = torch.jit.frontend.get_default_args_for_class(obj)
|
||||
torch._C._jit_script_class_compile(qualified_name, ast, defaults, rcb)
|
||||
torch.jit._state._add_script_class(obj, qualified_name)
|
||||
script_class = torch._C._jit_script_class_compile(qualified_name, ast, defaults, rcb)
|
||||
torch.jit._state._add_script_class(obj, script_class)
|
||||
return script_class
|
||||
|
||||
|
||||
# These OrderedDictWrapper classes replace the actual OrderedDicts in
|
||||
|
|
@ -1148,10 +1149,10 @@ def interface(obj):
|
|||
# instead of a class interface type, an module interface type only compile
|
||||
# the user provided methods as part of the interface
|
||||
ast = get_jit_class_def(obj, obj.__name__)
|
||||
torch._C._jit_script_interface_compile(
|
||||
mangled_classname = torch._C._jit_script_interface_compile(
|
||||
qualified_name, ast, rcb, is_module_interface
|
||||
)
|
||||
obj.__torch_script_interface__ = True
|
||||
obj.__torch_script_interface__ = mangled_classname
|
||||
return obj
|
||||
|
||||
|
||||
|
|
@ -1161,7 +1162,7 @@ def _recursive_compile_class(obj, loc):
|
|||
# case it fails
|
||||
error_stack = torch._C.CallStack(_qual_name, loc)
|
||||
rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
|
||||
_compile_and_register_class(obj, rcb, _qual_name)
|
||||
return _compile_and_register_class(obj, rcb, _qual_name)
|
||||
|
||||
CompilationUnit = torch._C.CompilationUnit
|
||||
set_module(CompilationUnit, "torch.jit")
|
||||
|
|
|
|||
|
|
@ -57,19 +57,30 @@ def enable():
|
|||
_python_cu = torch._C.CompilationUnit()
|
||||
|
||||
|
||||
# qualified_name => ScriptClass mapping
|
||||
# python class => ScriptClass mapping
|
||||
_script_classes = {}
|
||||
|
||||
def _add_script_class(cls, name):
|
||||
global _script_classes
|
||||
_script_classes[name] = cls
|
||||
_name_to_pyclass = {}
|
||||
|
||||
|
||||
def _get_script_class(name):
|
||||
global _script_classes
|
||||
if name not in _script_classes:
|
||||
return None
|
||||
return _script_classes[name]
|
||||
def _add_script_class(python_class, script_class):
|
||||
_script_classes[python_class] = script_class
|
||||
_name_to_pyclass[script_class.qualified_name()] = python_class
|
||||
|
||||
|
||||
def _get_script_class(python_class):
|
||||
override = getattr(python_class, "_jit_override_qualname", None)
|
||||
if override is not None:
|
||||
python_class = _get_python_class(override)
|
||||
return _script_classes.get(python_class, None)
|
||||
|
||||
|
||||
def _get_python_class(qualified_name):
|
||||
return _name_to_pyclass.get(qualified_name, None)
|
||||
|
||||
|
||||
def _clear_class_state():
|
||||
_script_classes.clear()
|
||||
_name_to_pyclass.clear()
|
||||
|
||||
|
||||
# Caching: we currently cache compilation of free functions and overloaded functions.
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from .._jit_internal import BroadcastingList1, BroadcastingList2, BroadcastingLi
|
|||
from ._state import _get_script_class
|
||||
|
||||
from torch._C import TensorType, TupleType, FloatType, IntType, ComplexType, \
|
||||
ListType, StringType, DictType, BoolType, OptionalType, ClassType, InterfaceType, AnyType, NoneType, \
|
||||
ListType, StringType, DictType, BoolType, OptionalType, InterfaceType, AnyType, NoneType, \
|
||||
DeviceObjType, StreamObjType, FutureType, EnumType
|
||||
|
||||
|
||||
|
|
@ -324,7 +324,7 @@ def try_ann_to_type(ann, loc):
|
|||
if ann is type(None):
|
||||
return NoneType.get()
|
||||
if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
|
||||
return InterfaceType(_qualified_name(ann))
|
||||
return InterfaceType(ann.__torch_script_interface__)
|
||||
if ann is torch.device:
|
||||
return DeviceObjType.get()
|
||||
if ann is torch.Stream:
|
||||
|
|
@ -332,18 +332,19 @@ def try_ann_to_type(ann, loc):
|
|||
if ann is torch.dtype:
|
||||
return IntType.get() # dtype not yet bound in as its own type
|
||||
if inspect.isclass(ann) and issubclass(ann, enum.Enum):
|
||||
qualified_name = _qualified_name(ann)
|
||||
if _get_script_class(qualified_name) is None:
|
||||
torch.jit._script._recursive_compile_class(ann, loc)
|
||||
return EnumType(_qualified_name(ann), get_enum_value_type(ann, loc), list(ann))
|
||||
if _get_script_class(ann) is None:
|
||||
scripted_class = torch.jit._script._recursive_compile_class(ann, loc)
|
||||
name = scripted_class.qualified_name()
|
||||
else:
|
||||
name = _qualified_name(ann)
|
||||
return EnumType(name, get_enum_value_type(ann, loc), list(ann))
|
||||
if inspect.isclass(ann):
|
||||
qualified_name = _qualified_name(ann)
|
||||
if _get_script_class(qualified_name) is not None:
|
||||
return ClassType(qualified_name)
|
||||
maybe_script_class = _get_script_class(ann)
|
||||
if maybe_script_class is not None:
|
||||
return maybe_script_class
|
||||
ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception)
|
||||
if torch._jit_internal.can_compile_class(ann) and not issubclass(ann, ignored_builtin_classes):
|
||||
torch.jit._script._recursive_compile_class(ann, loc)
|
||||
return ClassType(qualified_name)
|
||||
return torch.jit._script._recursive_compile_class(ann, loc)
|
||||
|
||||
# Maybe resolve a NamedTuple to a Tuple Type
|
||||
def fake_rcb(key):
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ def do_input_map(fn, input):
|
|||
def clear_class_registry():
|
||||
torch._C._jit_clear_class_registry()
|
||||
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
|
||||
torch.jit._state._script_classes.clear()
|
||||
torch.jit._state._clear_class_state()
|
||||
|
||||
def get_execution_plan(graph_executor_state):
|
||||
execution_plans = list(graph_executor_state.execution_plans.values())
|
||||
|
|
|
|||
Loading…
Reference in a new issue