Default permissions for torch.hub downloads (#82869)

### Description
The `download_url_to_file` function in torch.hub uses a temporary file to prevent overriding a local working checkpoint with a broken download.This temporary file is created using `NamedTemporaryFile`. However, since `NamedTemporaryFile` creates files with overly restrictive permissions (0600), the resulting download will not have default permissions and will not respect umask on Linux (since moving the file will retain the restrictive permissions of the temporary file). This is especially problematic when trying to share model checkpoints between multiple users as other users will not even have read access to the file.

The change in this PR fixes the issue by using custom code to create the temporary file without changing the permissions to 0600 (unfortunately there is no way to override the permissions behaviour of existing Python standard library code). This ensures that the downloaded checkpoint file correctly have the default permissions applied. If a user wants to apply more restrictive permissions, they can do so via usual means (i.e. by setting umask).

See these similar issues in other projects for even more context:
* https://github.com/borgbackup/borg/issues/6400
* https://github.com/borgbackup/borg/issues/6933
* https://github.com/zarr-developers/zarr-python/issues/325

### Issue
https://github.com/pytorch/pytorch/issues/81297

### Testing
Extended the unit test `test_download_url_to_file` to also check permissions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82869
Approved by: https://github.com/vmoens
This commit is contained in:
Aiden Nibali 2023-08-24 15:48:20 +00:00 committed by PyTorch MergeBot
parent 64d5851b1f
commit 85b0e03df8
2 changed files with 18 additions and 2 deletions

View file

@ -96,6 +96,12 @@ class TestHub(TestCase):
hub.download_url_to_file(TORCHHUB_EXAMPLE_RELEASE_URL, f, progress=False)
loaded_state = torch.load(f)
self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
# Check that the downloaded file has default file permissions
f_ref = os.path.join(tmpdir, 'reference')
open(f_ref, 'w').close()
expected_permissions = oct(os.stat(f_ref).st_mode & 0o777)
actual_permissions = oct(os.stat(f).st_mode & 0o777)
assert actual_permissions == expected_permissions
@retry(Exception, tries=3)
def test_load_state_dict_from_url(self):

View file

@ -8,6 +8,7 @@ import shutil
import sys
import tempfile
import torch
import uuid
import warnings
import zipfile
from pathlib import Path
@ -628,9 +629,18 @@ def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
# We deliberately save it in a temp file and move it after
# download is complete. This prevents a local working checkpoint
# being overridden by a broken download.
# We deliberately do not use NamedTemporaryFile to avoid restrictive
# file permissions being applied to the downloaded file.
dst = os.path.expanduser(dst)
dst_dir = os.path.dirname(dst)
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
for seq in range(tempfile.TMP_MAX):
tmp_dst = dst + '.' + uuid.uuid4().hex + '.partial'
try:
f = open(tmp_dst, 'w+b')
except FileExistsError:
continue
break
else:
raise FileExistsError(errno.EEXIST, 'No usable temporary file name found')
try:
if hash_prefix is not None: