mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
remove old function
This commit is contained in:
parent
33f9e49a5b
commit
98aa2bdad6
8 changed files with 87 additions and 127 deletions
|
|
@ -71,7 +71,6 @@ from .utils import (
|
|||
copy_func,
|
||||
default_cache_path,
|
||||
define_sagemaker_information,
|
||||
get_file_from_repo,
|
||||
get_torch_version,
|
||||
has_file,
|
||||
http_user_agent,
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from typing import Dict, Optional, Union
|
|||
from ...configuration_utils import PretrainedConfig
|
||||
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
from ...feature_extraction_utils import FeatureExtractionMixin
|
||||
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo, logging
|
||||
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, cached_file, logging
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
from .configuration_auto import (
|
||||
CONFIG_MAPPING_NAMES,
|
||||
|
|
@ -220,7 +220,7 @@ def get_feature_extractor_config(
|
|||
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
||||
token = use_auth_token
|
||||
|
||||
resolved_config_file = get_file_from_repo(
|
||||
resolved_config_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
FEATURE_EXTRACTOR_NAME,
|
||||
cache_dir=cache_dir,
|
||||
|
|
@ -230,6 +230,9 @@ def get_feature_extractor_config(
|
|||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
if resolved_config_file is None:
|
||||
logger.info(
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from ...image_processing_utils_fast import BaseImageProcessorFast
|
|||
from ...utils import (
|
||||
CONFIG_NAME,
|
||||
IMAGE_PROCESSOR_NAME,
|
||||
get_file_from_repo,
|
||||
cached_file,
|
||||
is_timm_config_dict,
|
||||
is_timm_local_checkpoint,
|
||||
is_torchvision_available,
|
||||
|
|
@ -284,7 +284,7 @@ def get_image_processor_config(
|
|||
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
||||
token = use_auth_token
|
||||
|
||||
resolved_config_file = get_file_from_repo(
|
||||
resolved_config_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
IMAGE_PROCESSOR_NAME,
|
||||
cache_dir=cache_dir,
|
||||
|
|
@ -294,6 +294,9 @@ def get_image_processor_config(
|
|||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
if resolved_config_file is None:
|
||||
logger.info(
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from ...feature_extraction_utils import FeatureExtractionMixin
|
|||
from ...image_processing_utils import ImageProcessingMixin
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils import TOKENIZER_CONFIG_FILE
|
||||
from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, get_file_from_repo, logging
|
||||
from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, cached_file, logging
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
from .configuration_auto import (
|
||||
CONFIG_MAPPING_NAMES,
|
||||
|
|
@ -251,15 +251,21 @@ class AutoProcessor:
|
|||
processor_auto_map = None
|
||||
|
||||
# First, let's see if we have a processor or preprocessor config.
|
||||
# Filter the kwargs for `get_file_from_repo`.
|
||||
get_file_from_repo_kwargs = {
|
||||
key: kwargs[key] for key in inspect.signature(get_file_from_repo).parameters.keys() if key in kwargs
|
||||
# Filter the kwargs for `cached_file`.
|
||||
cached_file_kwargs = {
|
||||
key: kwargs[key] for key in inspect.signature(cached_file).parameters.keys() if key in kwargs
|
||||
}
|
||||
# We don't want to raise
|
||||
cached_file_kwargs.update(
|
||||
{
|
||||
"_raise_exceptions_for_gated_repo": False,
|
||||
"_raise_exceptions_for_missing_entries": False,
|
||||
"_raise_exceptions_for_connection_errors": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Let's start by checking whether the processor class is saved in a processor config
|
||||
processor_config_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path, PROCESSOR_NAME, **get_file_from_repo_kwargs
|
||||
)
|
||||
processor_config_file = cached_file(pretrained_model_name_or_path, PROCESSOR_NAME, **cached_file_kwargs)
|
||||
if processor_config_file is not None:
|
||||
config_dict, _ = ProcessorMixin.get_processor_dict(pretrained_model_name_or_path, **kwargs)
|
||||
processor_class = config_dict.get("processor_class", None)
|
||||
|
|
@ -268,8 +274,8 @@ class AutoProcessor:
|
|||
|
||||
if processor_class is None:
|
||||
# If not found, let's check whether the processor class is saved in an image processor config
|
||||
preprocessor_config_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **get_file_from_repo_kwargs
|
||||
preprocessor_config_file = cached_file(
|
||||
pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs
|
||||
)
|
||||
if preprocessor_config_file is not None:
|
||||
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
|
@ -288,8 +294,8 @@ class AutoProcessor:
|
|||
|
||||
if processor_class is None:
|
||||
# Next, let's check whether the processor class is saved in a tokenizer
|
||||
tokenizer_config_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **get_file_from_repo_kwargs
|
||||
tokenizer_config_file = cached_file(
|
||||
pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **cached_file_kwargs
|
||||
)
|
||||
if tokenizer_config_file is not None:
|
||||
with open(tokenizer_config_file, encoding="utf-8") as reader:
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ import numpy as np
|
|||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...utils import logging
|
||||
from ...utils.hub import get_file_from_repo
|
||||
from ...utils.hub import cached_file
|
||||
from ..auto import AutoTokenizer
|
||||
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ class BarkProcessor(ProcessorMixin):
|
|||
"""
|
||||
|
||||
if speaker_embeddings_dict_path is not None:
|
||||
speaker_embeddings_path = get_file_from_repo(
|
||||
speaker_embeddings_path = cached_file(
|
||||
pretrained_processor_name_or_path,
|
||||
speaker_embeddings_dict_path,
|
||||
subfolder=kwargs.pop("subfolder", None),
|
||||
|
|
@ -97,6 +97,9 @@ class BarkProcessor(ProcessorMixin):
|
|||
local_files_only=kwargs.pop("local_files_only", False),
|
||||
token=kwargs.pop("use_auth_token", None),
|
||||
revision=kwargs.pop("revision", None),
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
if speaker_embeddings_path is None:
|
||||
logger.warning(
|
||||
|
|
@ -182,7 +185,7 @@ class BarkProcessor(ProcessorMixin):
|
|||
f"Voice preset unrecognized, missing {key} as a key in self.speaker_embeddings[{voice_preset}]."
|
||||
)
|
||||
|
||||
path = get_file_from_repo(
|
||||
path = cached_file(
|
||||
self.speaker_embeddings.get("repo_or_path", "/"),
|
||||
voice_preset_paths[key],
|
||||
subfolder=kwargs.pop("subfolder", None),
|
||||
|
|
@ -193,6 +196,9 @@ class BarkProcessor(ProcessorMixin):
|
|||
local_files_only=kwargs.pop("local_files_only", False),
|
||||
token=kwargs.pop("use_auth_token", None),
|
||||
revision=kwargs.pop("revision", None),
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
if path is None:
|
||||
raise ValueError(
|
||||
|
|
|
|||
|
|
@ -91,7 +91,6 @@ from .hub import (
|
|||
define_sagemaker_information,
|
||||
download_url,
|
||||
extract_commit_hash,
|
||||
get_file_from_repo,
|
||||
has_file,
|
||||
http_user_agent,
|
||||
is_offline_mode,
|
||||
|
|
|
|||
|
|
@ -644,105 +644,6 @@ def cached_files(
|
|||
return resolved_files
|
||||
|
||||
|
||||
# TODO: deprecate `get_file_from_repo` or document it differently?
|
||||
# Docstring is exactly the same as `cached_repo` but behavior is slightly different. If file is missing or if
|
||||
# there is a connection error, `cached_repo` will return None while `get_file_from_repo` will raise an error.
|
||||
# IMO we should keep only 1 method and have a single `raise_error` argument (to be discussed).
|
||||
def get_file_from_repo(
|
||||
path_or_repo: Union[str, os.PathLike],
|
||||
filename: str,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
resume_download: Optional[bool] = None,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
subfolder: str = "",
|
||||
**deprecated_kwargs,
|
||||
):
|
||||
"""
|
||||
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
|
||||
|
||||
Args:
|
||||
path_or_repo (`str` or `os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the *model id* of a model repo on huggingface.co.
|
||||
- a path to a *directory* potentially containing the file.
|
||||
filename (`str`):
|
||||
The name of the file to locate in `path_or_repo`.
|
||||
cache_dir (`str` or `os.PathLike`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
||||
cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
||||
exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible.
|
||||
Will be removed in v5 of Transformers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, will only try to load the tokenizer configuration from local files.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||
specify the folder name here.
|
||||
|
||||
<Tip>
|
||||
|
||||
Passing `token=True` is required when you want to use a private model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
`Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo) or `None` if the
|
||||
file does not exist.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# Download a tokenizer configuration from huggingface.co and cache.
|
||||
tokenizer_config = get_file_from_repo("google-bert/bert-base-uncased", "tokenizer_config.json")
|
||||
# This model does not have a tokenizer config so the result will be None.
|
||||
tokenizer_config = get_file_from_repo("FacebookAI/xlm-roberta-base", "tokenizer_config.json")
|
||||
```
|
||||
"""
|
||||
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
if token is not None:
|
||||
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
||||
token = use_auth_token
|
||||
|
||||
return cached_file(
|
||||
path_or_repo_id=path_or_repo,
|
||||
filename=filename,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
|
||||
|
||||
def download_url(url, proxies=None):
|
||||
"""
|
||||
Downloads a given url in a temporary file. This function is not safe to use in multiple processes. Its only use is
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ from transformers.utils import (
|
|||
TRANSFORMERS_CACHE,
|
||||
WEIGHTS_NAME,
|
||||
cached_file,
|
||||
get_file_from_repo,
|
||||
has_file,
|
||||
)
|
||||
|
||||
|
|
@ -117,18 +116,45 @@ class GetFromCacheTests(unittest.TestCase):
|
|||
assert has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME, local_files_only=True, cache_dir=tmp_dir)
|
||||
|
||||
def test_get_file_from_repo_distant(self):
|
||||
# `get_file_from_repo` returns None if the file does not exist
|
||||
self.assertIsNone(get_file_from_repo("google-bert/bert-base-cased", "ahah.txt"))
|
||||
# should return None if the file does not exist
|
||||
self.assertIsNone(
|
||||
cached_file(
|
||||
"google-bert/bert-base-cased",
|
||||
"ahah.txt",
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
)
|
||||
|
||||
# The function raises if the repository does not exist.
|
||||
with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"):
|
||||
get_file_from_repo("bert-base-case", CONFIG_NAME)
|
||||
cached_file(
|
||||
"bert-base-case",
|
||||
CONFIG_NAME,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
|
||||
# The function raises if the revision does not exist.
|
||||
with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"):
|
||||
get_file_from_repo("google-bert/bert-base-cased", CONFIG_NAME, revision="ahaha")
|
||||
cached_file(
|
||||
"google-bert/bert-base-cased",
|
||||
CONFIG_NAME,
|
||||
revision="ahaha",
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
|
||||
resolved_file = get_file_from_repo("google-bert/bert-base-cased", CONFIG_NAME)
|
||||
resolved_file = cached_file(
|
||||
"google-bert/bert-base-cased",
|
||||
CONFIG_NAME,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
# The name is the cached name which is not very easy to test, so instead we load the content.
|
||||
config = json.loads(open(resolved_file, "r").read())
|
||||
self.assertEqual(config["hidden_size"], 768)
|
||||
|
|
@ -137,9 +163,26 @@ class GetFromCacheTests(unittest.TestCase):
|
|||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
filename = Path(tmp_dir) / "a.txt"
|
||||
filename.touch()
|
||||
self.assertEqual(get_file_from_repo(tmp_dir, "a.txt"), str(filename))
|
||||
self.assertEqual(
|
||||
cached_file(
|
||||
tmp_dir,
|
||||
"a.txt",
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
),
|
||||
str(filename),
|
||||
)
|
||||
|
||||
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt"))
|
||||
self.assertIsNone(
|
||||
cached_file(
|
||||
tmp_dir,
|
||||
"b.txt",
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
)
|
||||
|
||||
def test_get_file_gated_repo(self):
|
||||
"""Test download file from a gated repo fails with correct message when not authenticated."""
|
||||
|
|
|
|||
Loading…
Reference in a new issue