mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Default permissions for torch.hub downloads (#82869)
### Description The `download_url_to_file` function in torch.hub uses a temporary file to prevent overriding a local working checkpoint with a broken download.This temporary file is created using `NamedTemporaryFile`. However, since `NamedTemporaryFile` creates files with overly restrictive permissions (0600), the resulting download will not have default permissions and will not respect umask on Linux (since moving the file will retain the restrictive permissions of the temporary file). This is especially problematic when trying to share model checkpoints between multiple users as other users will not even have read access to the file. The change in this PR fixes the issue by using custom code to create the temporary file without changing the permissions to 0600 (unfortunately there is no way to override the permissions behaviour of existing Python standard library code). This ensures that the downloaded checkpoint file correctly have the default permissions applied. If a user wants to apply more restrictive permissions, they can do so via usual means (i.e. by setting umask). See these similar issues in other projects for even more context: * https://github.com/borgbackup/borg/issues/6400 * https://github.com/borgbackup/borg/issues/6933 * https://github.com/zarr-developers/zarr-python/issues/325 ### Issue https://github.com/pytorch/pytorch/issues/81297 ### Testing Extended the unit test `test_download_url_to_file` to also check permissions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/82869 Approved by: https://github.com/vmoens
This commit is contained in:
parent
64d5851b1f
commit
85b0e03df8
2 changed files with 18 additions and 2 deletions
|
|
@ -96,6 +96,12 @@ class TestHub(TestCase):
|
|||
hub.download_url_to_file(TORCHHUB_EXAMPLE_RELEASE_URL, f, progress=False)
|
||||
loaded_state = torch.load(f)
|
||||
self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
|
||||
# Check that the downloaded file has default file permissions
|
||||
f_ref = os.path.join(tmpdir, 'reference')
|
||||
open(f_ref, 'w').close()
|
||||
expected_permissions = oct(os.stat(f_ref).st_mode & 0o777)
|
||||
actual_permissions = oct(os.stat(f).st_mode & 0o777)
|
||||
assert actual_permissions == expected_permissions
|
||||
|
||||
@retry(Exception, tries=3)
|
||||
def test_load_state_dict_from_url(self):
|
||||
|
|
|
|||
14
torch/hub.py
14
torch/hub.py
|
|
@ -8,6 +8,7 @@ import shutil
|
|||
import sys
|
||||
import tempfile
|
||||
import torch
|
||||
import uuid
|
||||
import warnings
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
|
@ -628,9 +629,18 @@ def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
|
|||
# We deliberately save it in a temp file and move it after
|
||||
# download is complete. This prevents a local working checkpoint
|
||||
# being overridden by a broken download.
|
||||
# We deliberately do not use NamedTemporaryFile to avoid restrictive
|
||||
# file permissions being applied to the downloaded file.
|
||||
dst = os.path.expanduser(dst)
|
||||
dst_dir = os.path.dirname(dst)
|
||||
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
|
||||
for seq in range(tempfile.TMP_MAX):
|
||||
tmp_dst = dst + '.' + uuid.uuid4().hex + '.partial'
|
||||
try:
|
||||
f = open(tmp_dst, 'w+b')
|
||||
except FileExistsError:
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise FileExistsError(errno.EEXIST, 'No usable temporary file name found')
|
||||
|
||||
try:
|
||||
if hash_prefix is not None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue