mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Applies new import merging and sorting from µsort v1.0. When merging imports, µsort will make a best-effort to move associated comments to match merged elements, but there are known limitations due to the diynamic nature of Python and developer tooling. These changes should not produce any dangerous runtime changes, but may require touch-ups to satisfy linters and other tooling. Note that µsort uses case-insensitive, lexicographical sorting, which results in a different ordering compared to isort. This provides a more consistent sorting order, matching the case-insensitive order used when sorting import statements by module name, and ensures that "frog", "FROG", and "Frog" always sort next to each other. For details on µsort's sorting and merging semantics, see the user guide: https://usort.readthedocs.io/en/stable/guide.html#sorting Test Plan: S271899 Reviewed By: lisroach Differential Revision: D36402110 Pull Request resolved: https://github.com/pytorch/pytorch/pull/78973 Approved by: https://github.com/osalpekar
288 lines
9.5 KiB
Python
288 lines
9.5 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Owner(s): ["oncall: package/deploy"]
|
|
|
|
import inspect
|
|
import platform
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
from textwrap import dedent
|
|
from unittest import skipIf
|
|
|
|
from torch.package import is_from_package, PackageExporter, PackageImporter
|
|
from torch.package.package_exporter import PackagingError
|
|
from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests
|
|
|
|
try:
|
|
from .common import PackageTestCase
|
|
except ImportError:
|
|
# Support the case where we run this file directly.
|
|
from common import PackageTestCase
|
|
|
|
|
|
class TestMisc(PackageTestCase):
|
|
"""Tests for one-off or random functionality. Try not to add to this!"""
|
|
|
|
def test_file_structure(self):
|
|
"""
|
|
Tests package's Directory structure representation of a zip file. Ensures
|
|
that the returned Directory prints what is expected and filters
|
|
inputs/outputs correctly.
|
|
"""
|
|
buffer = BytesIO()
|
|
|
|
export_plain = dedent(
|
|
"""\
|
|
├── .data
|
|
│ ├── extern_modules
|
|
│ ├── python_version
|
|
│ └── version
|
|
├── main
|
|
│ └── main
|
|
├── obj
|
|
│ └── obj.pkl
|
|
├── package_a
|
|
│ ├── __init__.py
|
|
│ └── subpackage.py
|
|
└── module_a.py
|
|
"""
|
|
)
|
|
export_include = dedent(
|
|
"""\
|
|
├── obj
|
|
│ └── obj.pkl
|
|
└── package_a
|
|
└── subpackage.py
|
|
"""
|
|
)
|
|
import_exclude = dedent(
|
|
"""\
|
|
├── .data
|
|
│ ├── extern_modules
|
|
│ ├── python_version
|
|
│ └── version
|
|
├── main
|
|
│ └── main
|
|
├── obj
|
|
│ └── obj.pkl
|
|
├── package_a
|
|
│ ├── __init__.py
|
|
│ └── subpackage.py
|
|
└── module_a.py
|
|
"""
|
|
)
|
|
|
|
with PackageExporter(buffer) as he:
|
|
import module_a
|
|
import package_a
|
|
import package_a.subpackage
|
|
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
he.intern("**")
|
|
he.save_module(module_a.__name__)
|
|
he.save_module(package_a.__name__)
|
|
he.save_pickle("obj", "obj.pkl", obj)
|
|
he.save_text("main", "main", "my string")
|
|
|
|
buffer.seek(0)
|
|
hi = PackageImporter(buffer)
|
|
|
|
file_structure = hi.file_structure()
|
|
# remove first line from testing because WINDOW/iOS/Unix treat the buffer differently
|
|
self.assertEqual(
|
|
dedent("\n".join(str(file_structure).split("\n")[1:])),
|
|
export_plain,
|
|
)
|
|
file_structure = hi.file_structure(include=["**/subpackage.py", "**/*.pkl"])
|
|
self.assertEqual(
|
|
dedent("\n".join(str(file_structure).split("\n")[1:])),
|
|
export_include,
|
|
)
|
|
|
|
file_structure = hi.file_structure(exclude="**/*.storage")
|
|
self.assertEqual(
|
|
dedent("\n".join(str(file_structure).split("\n")[1:])),
|
|
import_exclude,
|
|
)
|
|
|
|
def test_python_version(self):
|
|
"""
|
|
Tests that the current python version is stored in the package and is available
|
|
via PackageImporter's python_version() method.
|
|
"""
|
|
buffer = BytesIO()
|
|
|
|
with PackageExporter(buffer) as he:
|
|
from package_a.test_module import SimpleTest
|
|
|
|
he.intern("**")
|
|
obj = SimpleTest()
|
|
he.save_pickle("obj", "obj.pkl", obj)
|
|
|
|
buffer.seek(0)
|
|
hi = PackageImporter(buffer)
|
|
|
|
self.assertEqual(hi.python_version(), platform.python_version())
|
|
|
|
@skipIf(
|
|
IS_FBCODE or IS_SANDCASTLE,
|
|
"Tests that use temporary files are disabled in fbcode",
|
|
)
|
|
def test_load_python_version_from_package(self):
|
|
"""Tests loading a package with a python version embdded"""
|
|
importer1 = PackageImporter(
|
|
f"{Path(__file__).parent}/package_e/test_nn_module.pt"
|
|
)
|
|
self.assertEqual(importer1.python_version(), "3.9.7")
|
|
|
|
def test_file_structure_has_file(self):
|
|
"""
|
|
Test Directory's has_file() method.
|
|
"""
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as he:
|
|
import package_a.subpackage
|
|
|
|
he.intern("**")
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
he.save_pickle("obj", "obj.pkl", obj)
|
|
|
|
buffer.seek(0)
|
|
|
|
importer = PackageImporter(buffer)
|
|
file_structure = importer.file_structure()
|
|
self.assertTrue(file_structure.has_file("package_a/subpackage.py"))
|
|
self.assertFalse(file_structure.has_file("package_a/subpackage"))
|
|
|
|
def test_exporter_content_lists(self):
|
|
"""
|
|
Test content list API for PackageExporter's contained modules.
|
|
"""
|
|
|
|
with PackageExporter(BytesIO()) as he:
|
|
import package_b
|
|
|
|
he.extern("package_b.subpackage_1")
|
|
he.mock("package_b.subpackage_2")
|
|
he.intern("**")
|
|
he.save_pickle("obj", "obj.pkl", package_b.PackageBObject(["a"]))
|
|
self.assertEqual(he.externed_modules(), ["package_b.subpackage_1"])
|
|
self.assertEqual(he.mocked_modules(), ["package_b.subpackage_2"])
|
|
self.assertEqual(
|
|
he.interned_modules(),
|
|
["package_b", "package_b.subpackage_0.subsubpackage_0"],
|
|
)
|
|
self.assertEqual(he.get_rdeps("package_b.subpackage_2"), ["package_b"])
|
|
|
|
with self.assertRaises(PackagingError) as e:
|
|
with PackageExporter(BytesIO()) as he:
|
|
import package_b
|
|
|
|
he.deny("package_b")
|
|
he.save_pickle("obj", "obj.pkl", package_b.PackageBObject(["a"]))
|
|
self.assertEqual(he.denied_modules(), ["package_b"])
|
|
|
|
def test_is_from_package(self):
|
|
"""is_from_package should work for objects and modules"""
|
|
import package_a.subpackage
|
|
|
|
buffer = BytesIO()
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
|
|
with PackageExporter(buffer) as pe:
|
|
pe.intern("**")
|
|
pe.save_pickle("obj", "obj.pkl", obj)
|
|
|
|
buffer.seek(0)
|
|
pi = PackageImporter(buffer)
|
|
mod = pi.import_module("package_a.subpackage")
|
|
loaded_obj = pi.load_pickle("obj", "obj.pkl")
|
|
|
|
self.assertFalse(is_from_package(package_a.subpackage))
|
|
self.assertTrue(is_from_package(mod))
|
|
|
|
self.assertFalse(is_from_package(obj))
|
|
self.assertTrue(is_from_package(loaded_obj))
|
|
|
|
def test_inspect_class(self):
|
|
"""Should be able to retrieve source for a packaged class."""
|
|
import package_a.subpackage
|
|
|
|
buffer = BytesIO()
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
|
|
with PackageExporter(buffer) as pe:
|
|
pe.intern("**")
|
|
pe.save_pickle("obj", "obj.pkl", obj)
|
|
|
|
buffer.seek(0)
|
|
pi = PackageImporter(buffer)
|
|
packaged_class = pi.import_module(
|
|
"package_a.subpackage"
|
|
).PackageASubpackageObject
|
|
regular_class = package_a.subpackage.PackageASubpackageObject
|
|
|
|
packaged_src = inspect.getsourcelines(packaged_class)
|
|
regular_src = inspect.getsourcelines(regular_class)
|
|
self.assertEqual(packaged_src, regular_src)
|
|
|
|
def test_dunder_package_present(self):
|
|
"""
|
|
The attribute '__torch_package__' should be populated on imported modules.
|
|
"""
|
|
import package_a.subpackage
|
|
|
|
buffer = BytesIO()
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
|
|
with PackageExporter(buffer) as pe:
|
|
pe.intern("**")
|
|
pe.save_pickle("obj", "obj.pkl", obj)
|
|
|
|
buffer.seek(0)
|
|
pi = PackageImporter(buffer)
|
|
mod = pi.import_module("package_a.subpackage")
|
|
self.assertTrue(hasattr(mod, "__torch_package__"))
|
|
|
|
def test_dunder_package_works_from_package(self):
|
|
"""
|
|
The attribute '__torch_package__' should be accessible from within
|
|
the module itself, so that packaged code can detect whether it's
|
|
being used in a packaged context or not.
|
|
"""
|
|
import package_a.use_dunder_package as mod
|
|
|
|
buffer = BytesIO()
|
|
|
|
with PackageExporter(buffer) as pe:
|
|
pe.intern("**")
|
|
pe.save_module(mod.__name__)
|
|
|
|
buffer.seek(0)
|
|
pi = PackageImporter(buffer)
|
|
imported_mod = pi.import_module(mod.__name__)
|
|
self.assertTrue(imported_mod.is_from_package())
|
|
self.assertFalse(mod.is_from_package())
|
|
|
|
def test_std_lib_sys_hackery_checks(self):
|
|
"""
|
|
The standard library performs sys.module assignment hackery which
|
|
causes modules who do this hackery to fail on import. See
|
|
https://github.com/pytorch/pytorch/issues/57490 for more information.
|
|
"""
|
|
import package_a.std_sys_module_hacks
|
|
|
|
buffer = BytesIO()
|
|
mod = package_a.std_sys_module_hacks.Module()
|
|
|
|
with PackageExporter(buffer) as pe:
|
|
pe.intern("**")
|
|
pe.save_pickle("obj", "obj.pkl", mod)
|
|
|
|
buffer.seek(0)
|
|
pi = PackageImporter(buffer)
|
|
mod = pi.load_pickle("obj", "obj.pkl")
|
|
mod()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|