pytorch/test/package/test_package_script.py
Michael Suo 8a170fbacd [package] fix mangling issues with TorchScript (#54915)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54915

TorchScript and torch.package have different mangling schemes. To avoid
them interfering with each other, we should undo the torch.package
mangling before processing anything with TorchScript (since TS
independently makes sure that no names collide).

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Differential Revision: D27410472

Pulled By: suo

fbshipit-source-id: d1cc013c532d9abb7fb9615122bc465ded4785bb
2021-03-31 00:58:05 -07:00

145 lines
4.9 KiB
Python

from io import BytesIO
from textwrap import dedent
import torch
from torch.package import (
PackageExporter,
PackageImporter,
)
from torch.testing._internal.common_utils import run_tests
try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
from common import PackageTestCase # type: ignore
class TestPackageScript(PackageTestCase):
"""Tests for compatibility with TorchScript."""
def test_package_interface(self):
"""Packaging an interface class should work correctly."""
import package_a.fake_interface as fake
uses_interface = fake.UsesInterface()
scripted = torch.jit.script(uses_interface)
scripted.proxy_mod = torch.jit.script(fake.NewModule())
buffer = BytesIO()
with PackageExporter(buffer, verbose=False) as pe:
pe.save_pickle("model", "model.pkl", uses_interface)
buffer.seek(0)
package_importer = PackageImporter(buffer)
loaded = package_importer.load_pickle("model", "model.pkl")
scripted_loaded = torch.jit.script(loaded)
scripted_loaded.proxy_mod = torch.jit.script(fake.NewModule())
input = torch.tensor(1)
self.assertTrue(torch.allclose(scripted(input), scripted_loaded(input)))
def test_different_package_interface(self):
"""Test a case where the interface defined in the package is
different than the one defined in the loading environment, to make
sure TorchScript can distinguish between the two.
"""
# Import one version of the interface
import package_a.fake_interface as fake
# Simulate a package that contains a different version of the
# interface, with the exact same name.
buffer = BytesIO()
with PackageExporter(buffer, verbose=False) as pe:
pe.save_source_string(
fake.__name__,
dedent(
"""\
import torch
from torch import Tensor
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def one(self, inp1: Tensor) -> Tensor:
pass
class ImplementsInterface(torch.nn.Module):
def one(self, inp1: Tensor) -> Tensor:
return inp1 + 1
class UsesInterface(torch.nn.Module):
proxy_mod: ModuleInterface
def __init__(self):
super().__init__()
self.proxy_mod = ImplementsInterface()
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod.one(input)
"""
),
)
buffer.seek(0)
package_importer = PackageImporter(buffer)
diff_fake = package_importer.import_module(fake.__name__)
# We should be able to script successfully.
torch.jit.script(diff_fake.UsesInterface())
def test_package_script_class(self):
import package_a.fake_script_class as fake
buffer = BytesIO()
with PackageExporter(buffer, verbose=False) as pe:
pe.save_module(fake.__name__)
buffer.seek(0)
package_importer = PackageImporter(buffer)
loaded = package_importer.import_module(fake.__name__)
input = torch.tensor(1)
self.assertTrue(
torch.allclose(
fake.uses_script_class(input), loaded.uses_script_class(input)
)
)
def test_different_package_script_class(self):
"""Test a case where the script class defined in the package is
different than the one defined in the loading environment, to make
sure TorchScript can distinguish between the two.
"""
import package_a.fake_script_class as fake
# Simulate a package that contains a different version of the
# script class ,with the attribute `bar` instead of `foo`
buffer = BytesIO()
with PackageExporter(buffer, verbose=False) as pe2:
pe2.save_source_string(
fake.__name__,
dedent(
"""\
import torch
@torch.jit.script
class MyScriptClass:
def __init__(self, x):
self.bar = x
"""
),
)
buffer.seek(0)
package_importer = PackageImporter(buffer)
diff_fake = package_importer.import_module(fake.__name__)
input = torch.rand(2, 3)
loaded_script_class = diff_fake.MyScriptClass(input)
orig_script_class = fake.MyScriptClass(input)
self.assertTrue(torch.allclose(loaded_script_class.bar, orig_script_class.foo))
if __name__ == "__main__":
run_tests()