From f3f305ef3e6b5ec8ed405cb74658a8bbec58bede Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 18 Nov 2024 14:03:09 -0800 Subject: [PATCH] Fix condition for weights_only unpickler for DTensor (#140740) Same as #140739 but for DTensor (move safe globals for DTensor to `torch.distributed.tensor.__init__` and update error message to let user know `torch.distributed.tensor` must be imported to load DTensor) Differential Revision: [D65961690](https://our.internmc.facebook.com/intern/diff/D65961690) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140740 Approved by: https://github.com/malfet ghstack dependencies: #140739 --- test/distributed/_tensor/test_dtensor.py | 40 +++++++++++++++++++++--- torch/_weights_only_unpickler.py | 27 ++++++++-------- torch/distributed/tensor/__init__.py | 16 ++++++++++ 3 files changed, 66 insertions(+), 17 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py index 668a8041714..212c0d84b1c 100644 --- a/test/distributed/_tensor/test_dtensor.py +++ b/test/distributed/_tensor/test_dtensor.py @@ -2,6 +2,9 @@ # Owner(s): ["oncall: distributed"] import os +import pathlib +import tempfile +import unittest from numpy.testing import assert_array_equal @@ -28,7 +31,7 @@ from torch.distributed.tensor.parallel import ( parallelize_module, RowwiseParallel, ) -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import IS_FBCODE, run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, @@ -542,6 +545,33 @@ class DTensorTest(DTensorTestBase): reloaded_st = torch.load(buffer, weights_only=True) self.assertEqual(sharded_tensor, reloaded_st) + @with_comms + @unittest.skipIf( + IS_FBCODE, + "subprocess import torch fails with ModuleNotFoundError: No module named 'torch' in fbcode", + ) + def test_dtensor_save_load_import(self): + for should_import in [True, False]: + device_mesh = self.build_device_mesh() + placements = [Shard(0)] + local_tensor = torch.randn(3, 3) + sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) + with tempfile.NamedTemporaryFile() as f: + torch.save(sharded_tensor, f) + import_string = ( + "import torch.distributed.tensor;" if should_import else "" + ) + filename = pathlib.Path(f.name) + err_msg = ( + ( + "_pickle.UnpicklingError: Weights only load failed. " + "``torch.distributed.tensor`` must be imported to load DTensors" + ) + if not should_import + else None + ) + self._attempt_load_from_subprocess(filename, import_string, err_msg) + class DTensorMeshTest(DTensorTestBase): @property @@ -943,9 +973,11 @@ class TestDTensorPlacementTypes(DTensorTestBase): from torch.distributed.tensor._collective_utils import unpad_tensor unpadded_list = [ - unpad_tensor(tensor, shard_placement.dim, pad_sizes[i]) - if pad_sizes[i] > 0 - else tensor + ( + unpad_tensor(tensor, shard_placement.dim, pad_sizes[i]) + if pad_sizes[i] > 0 + else tensor + ) for i, tensor in enumerate(splitted_tensor_list) ] expected_is_tensor_empty = [ diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 06828d05d6d..9a146632ded 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -169,19 +169,6 @@ def _get_allowed_globals(): "builtins.bytearray": bytearray, # for bytearray "builtins.set": set, # for set } - # Only add the dtensor related classes if the dtensor module is available - if hasattr(torch.distributed, "tensor"): - dtensor_rc: Dict[str, Any] = { - # DTensor related - "torch.distributed.device_mesh.DeviceMesh": torch.distributed.device_mesh.DeviceMesh, - "torch.distributed.tensor._dtensor_spec.DTensorSpec": torch.distributed.tensor._dtensor_spec.DTensorSpec, - "torch.distributed.tensor._dtensor_spec.TensorMeta": torch.distributed.tensor._dtensor_spec.TensorMeta, - "torch.distributed.tensor.DTensor": torch.distributed.tensor.DTensor, - "torch.distributed.tensor.placement_types.Partial": torch.distributed.tensor.placement_types.Partial, - "torch.distributed.tensor.placement_types.Replicate": torch.distributed.tensor.placement_types.Replicate, - "torch.distributed.tensor.placement_types.Shard": torch.distributed.tensor.placement_types.Shard, - } - rc.update(dtensor_rc) # dtype for t in torch.storage._dtype_to_storage_type_map().keys(): @@ -341,6 +328,20 @@ class Unpickler: raise UnpicklingError( "``torch.nested`` and ``torch._dynamo`` must be imported to load nested jagged tensors (NJTs)" ) + elif full_path in ( + [ + "torch.distributed.device_mesh.DeviceMesh", + "torch.distributed.tensor._dtensor_spec.DTensorSpec", + "torch.distributed.tensor._dtensor_spec.TensorMeta", + "torch.distributed.tensor.DTensor", + "torch.distributed.tensor.placement_types.Partial", + "torch.distributed.tensor.placement_types.Replicate", + "torch.distributed.tensor.placement_types.Shard", + ] + ): + raise UnpicklingError( + "``torch.distributed.tensor`` must be imported to load DTensors" + ) else: raise UnpicklingError( f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " diff --git a/torch/distributed/tensor/__init__.py b/torch/distributed/tensor/__init__.py index f2746f60ba8..8bfd708f9dd 100644 --- a/torch/distributed/tensor/__init__.py +++ b/torch/distributed/tensor/__init__.py @@ -45,6 +45,22 @@ __all__ = [ "zeros", ] +# For weights_only torch.load +from ._dtensor_spec import DTensorSpec as _DTensorSpec, TensorMeta as _TensorMeta + + +torch.serialization.add_safe_globals( + [ + DeviceMesh, + _DTensorSpec, + _TensorMeta, + DTensor, + Partial, + Replicate, + Shard, + ] +) + # Append DTensor to the list of supported types for foreach implementation for optimizer # and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA.