mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Further reduce the number of alls to head for cached objects (#18871)
* Further reduce the number of alls to head for cached models/tokenizers/pipelines * Fix tests * Address review comments
This commit is contained in:
parent
6678350c01
commit
71ff88fa4f
5 changed files with 36 additions and 10 deletions
|
|
@ -120,6 +120,9 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", HUGGINGFACE_CO_R
|
|||
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
|
||||
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"
|
||||
|
||||
# Return value when trying to load a file from cache but the file does not exist in the distant repo.
|
||||
_CACHED_NO_EXIST = object()
|
||||
|
||||
|
||||
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
|
||||
"""
|
||||
|
|
@ -222,6 +225,22 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
|
|||
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_hash=None):
|
||||
"""
|
||||
Explores the cache to return the latest cached file for a given revision.
|
||||
|
||||
Args:
|
||||
cache_dir (`str` or `os.PathLike`): The folder where the cached files lie.
|
||||
repo_id (`str`): The ID of the repo on huggingface.co.
|
||||
filename (`str`): The filename to look for inside `repo_id`.
|
||||
revision (`str`, *optional*):
|
||||
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
|
||||
provided either.
|
||||
commit_hash (`str`, *optional*): The (full) commit hash to look for inside the cache.
|
||||
|
||||
Returns:
|
||||
`Optional[str]` or `_CACHED_NO_EXIST`:
|
||||
Will return `None` if the file was not cached. Otherwise:
|
||||
- The exact path to the cached file if it's found in the cache
|
||||
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
|
||||
cached.
|
||||
"""
|
||||
if commit_hash is not None and revision is not None:
|
||||
raise ValueError("`commit_hash` and `revision` are mutually exclusive, pick one only.")
|
||||
|
|
@ -244,6 +263,9 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h
|
|||
with open(os.path.join(model_cache, "refs", revision)) as f:
|
||||
commit_hash = f.read()
|
||||
|
||||
if os.path.isfile(os.path.join(model_cache, ".no_exist", commit_hash, filename)):
|
||||
return _CACHED_NO_EXIST
|
||||
|
||||
cached_shas = os.listdir(os.path.join(model_cache, "snapshots"))
|
||||
if commit_hash not in cached_shas:
|
||||
# No cache for this revision and we won't try to return a random revision
|
||||
|
|
@ -338,7 +360,10 @@ def cached_file(
|
|||
resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename)
|
||||
if not os.path.isfile(resolved_file):
|
||||
if _raise_exceptions_for_missing_entries:
|
||||
raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")
|
||||
raise EnvironmentError(
|
||||
f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
|
||||
f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return resolved_file
|
||||
|
|
@ -352,7 +377,12 @@ def cached_file(
|
|||
# If the file is cached under that commit hash, we return it directly.
|
||||
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, commit_hash=_commit_hash)
|
||||
if resolved_file is not None:
|
||||
return resolved_file
|
||||
if resolved_file is not _CACHED_NO_EXIST:
|
||||
return resolved_file
|
||||
elif not _raise_exceptions_for_missing_entries:
|
||||
return None
|
||||
else:
|
||||
raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")
|
||||
|
||||
user_agent = http_user_agent(user_agent)
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -370,6 +370,5 @@ class AutoModelTest(unittest.TestCase):
|
|||
with RequestCounter() as counter:
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
# There is no pytorch_model.bin so we still get one call for this one.
|
||||
self.assertEqual(counter.head_request_count, 2)
|
||||
self.assertEqual(counter.head_request_count, 1)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
|
|
|||
|
|
@ -303,6 +303,5 @@ class TFAutoModelTest(unittest.TestCase):
|
|||
with RequestCounter() as counter:
|
||||
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
# There is no pytorch_model.bin so we still get one call for this one.
|
||||
self.assertEqual(counter.head_request_count, 2)
|
||||
self.assertEqual(counter.head_request_count, 1)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
|
|
|||
|
|
@ -349,6 +349,5 @@ class AutoTokenizerTest(unittest.TestCase):
|
|||
with RequestCounter() as counter:
|
||||
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
# We still have one extra call because the model does not have a added_tokens.json file
|
||||
self.assertEqual(counter.head_request_count, 2)
|
||||
self.assertEqual(counter.head_request_count, 1)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
|
|
|||
|
|
@ -884,8 +884,7 @@ class CustomPipelineTest(unittest.TestCase):
|
|||
with RequestCounter() as counter:
|
||||
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
# We still have one extra call because the model does not have a added_tokens.json file
|
||||
self.assertEqual(counter.head_request_count, 2)
|
||||
self.assertEqual(counter.head_request_count, 1)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue