mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57341 Require that users be explicit about what they are going to be interning. There are a lot of changes that are enabled by this. The new overall scheme is: PackageExporter maintains a dependency graph. Users can add to it, either explicitly (by issuing a `save_*` call) or explicitly (through dependency resolution). Users can also specify what action to take when PackageExporter encounters a module (deny, intern, mock, extern). Nothing (except pickles, tho that can be changed with a small amount of work) is written to the zip archive until we are finalizing the package. At that point, we consult the dependency graph and write out the package exactly as it tells us to. This accomplishes two things: 1. We can gather up *all* packaging errors instead of showing them one at a time. 2. We require that users be explicit about what's going in packages, which is a common request. Differential Revision: D28114185 Test Plan: Imported from OSS Reviewed By: SplitInfinity Pulled By: suo fbshipit-source-id: fa1abf1c26be42b14c7e7cf3403ecf336ad4fc12
209 lines
6.6 KiB
Python
209 lines
6.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
import inspect
|
|
from io import BytesIO
|
|
from textwrap import dedent
|
|
|
|
from torch.package import PackageExporter, PackageImporter, is_from_package
|
|
from torch.testing._internal.common_utils import run_tests
|
|
|
|
try:
|
|
from .common import PackageTestCase
|
|
except ImportError:
|
|
# Support the case where we run this file directly.
|
|
from common import PackageTestCase
|
|
|
|
|
|
class TestMisc(PackageTestCase):
|
|
"""Tests for one-off or random functionality. Try not to add to this!"""
|
|
|
|
def test_file_structure(self):
|
|
"""
|
|
Tests package's Directory structure representation of a zip file. Ensures
|
|
that the returned Directory prints what is expected and filters
|
|
inputs/outputs correctly.
|
|
"""
|
|
buffer = BytesIO()
|
|
|
|
export_plain = dedent(
|
|
"""\
|
|
├── .data
|
|
│ ├── extern_modules
|
|
│ └── version
|
|
├── main
|
|
│ └── main
|
|
├── obj
|
|
│ └── obj.pkl
|
|
├── package_a
|
|
│ ├── __init__.py
|
|
│ └── subpackage.py
|
|
└── module_a.py
|
|
"""
|
|
)
|
|
export_include = dedent(
|
|
"""\
|
|
├── obj
|
|
│ └── obj.pkl
|
|
└── package_a
|
|
└── subpackage.py
|
|
"""
|
|
)
|
|
import_exclude = dedent(
|
|
"""\
|
|
├── .data
|
|
│ ├── extern_modules
|
|
│ └── version
|
|
├── main
|
|
│ └── main
|
|
├── obj
|
|
│ └── obj.pkl
|
|
├── package_a
|
|
│ ├── __init__.py
|
|
│ └── subpackage.py
|
|
└── module_a.py
|
|
"""
|
|
)
|
|
|
|
with PackageExporter(buffer, verbose=False) as he:
|
|
import module_a
|
|
import package_a
|
|
import package_a.subpackage
|
|
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
he.intern("**")
|
|
he.save_module(module_a.__name__)
|
|
he.save_module(package_a.__name__)
|
|
he.save_pickle("obj", "obj.pkl", obj)
|
|
he.save_text("main", "main", "my string")
|
|
|
|
|
|
buffer.seek(0)
|
|
hi = PackageImporter(buffer)
|
|
|
|
file_structure = hi.file_structure()
|
|
# remove first line from testing because WINDOW/iOS/Unix treat the buffer differently
|
|
self.assertEqual(
|
|
dedent("\n".join(str(file_structure).split("\n")[1:])),
|
|
export_plain,
|
|
)
|
|
file_structure = hi.file_structure(
|
|
include=["**/subpackage.py", "**/*.pkl"]
|
|
)
|
|
self.assertEqual(
|
|
dedent("\n".join(str(file_structure).split("\n")[1:])),
|
|
export_include,
|
|
)
|
|
|
|
file_structure = hi.file_structure(exclude="**/*.storage")
|
|
self.assertEqual(
|
|
dedent("\n".join(str(file_structure).split("\n")[1:])),
|
|
import_exclude,
|
|
)
|
|
|
|
def test_file_structure_has_file(self):
|
|
"""
|
|
Test Directory's has_file() method.
|
|
"""
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer, verbose=False) as he:
|
|
import package_a.subpackage
|
|
|
|
he.intern("**")
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
he.save_pickle("obj", "obj.pkl", obj)
|
|
|
|
buffer.seek(0)
|
|
|
|
importer = PackageImporter(buffer)
|
|
file_structure = importer.file_structure()
|
|
self.assertTrue(file_structure.has_file("package_a/subpackage.py"))
|
|
self.assertFalse(file_structure.has_file("package_a/subpackage"))
|
|
|
|
def test_is_from_package(self):
|
|
"""is_from_package should work for objects and modules"""
|
|
import package_a.subpackage
|
|
|
|
buffer = BytesIO()
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
|
|
with PackageExporter(buffer, verbose=False) as pe:
|
|
pe.intern("**")
|
|
pe.save_pickle("obj", "obj.pkl", obj)
|
|
|
|
buffer.seek(0)
|
|
pi = PackageImporter(buffer)
|
|
mod = pi.import_module("package_a.subpackage")
|
|
loaded_obj = pi.load_pickle("obj", "obj.pkl")
|
|
|
|
self.assertFalse(is_from_package(package_a.subpackage))
|
|
self.assertTrue(is_from_package(mod))
|
|
|
|
self.assertFalse(is_from_package(obj))
|
|
self.assertTrue(is_from_package(loaded_obj))
|
|
|
|
def test_inspect_class(self):
|
|
"""Should be able to retrieve source for a packaged class."""
|
|
import package_a.subpackage
|
|
|
|
buffer = BytesIO()
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
|
|
with PackageExporter(buffer, verbose=False) as pe:
|
|
pe.intern("**")
|
|
pe.save_pickle("obj", "obj.pkl", obj)
|
|
|
|
buffer.seek(0)
|
|
pi = PackageImporter(buffer)
|
|
packaged_class = pi.import_module(
|
|
"package_a.subpackage"
|
|
).PackageASubpackageObject
|
|
regular_class = package_a.subpackage.PackageASubpackageObject
|
|
|
|
packaged_src = inspect.getsourcelines(packaged_class)
|
|
regular_src = inspect.getsourcelines(regular_class)
|
|
self.assertEqual(packaged_src, regular_src)
|
|
|
|
def test_dunder_package_present(self):
|
|
"""
|
|
The attribute '__torch_package__' should be populated on imported modules.
|
|
"""
|
|
import package_a.subpackage
|
|
|
|
buffer = BytesIO()
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
|
|
with PackageExporter(buffer, verbose=False) as pe:
|
|
pe.intern("**")
|
|
pe.save_pickle("obj", "obj.pkl", obj)
|
|
|
|
buffer.seek(0)
|
|
pi = PackageImporter(buffer)
|
|
mod = pi.import_module(
|
|
"package_a.subpackage"
|
|
)
|
|
self.assertTrue(hasattr(mod, "__torch_package__"))
|
|
|
|
def test_dunder_package_works_from_package(self):
|
|
"""
|
|
The attribute '__torch_package__' should be accessible from within
|
|
the module itself, so that packaged code can detect whether it's
|
|
being used in a packaged context or not.
|
|
"""
|
|
import package_a.use_dunder_package as mod
|
|
|
|
buffer = BytesIO()
|
|
|
|
with PackageExporter(buffer, verbose=False) as pe:
|
|
pe.intern("**")
|
|
pe.save_module(mod.__name__)
|
|
|
|
buffer.seek(0)
|
|
pi = PackageImporter(buffer)
|
|
imported_mod = pi.import_module(
|
|
mod.__name__
|
|
)
|
|
self.assertTrue(imported_mod.is_from_package())
|
|
self.assertFalse(mod.is_from_package())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|