diff --git a/src/transformers/hf_api.py b/src/transformers/hf_api.py index c8da5615e..c62cb6f05 100644 --- a/src/transformers/hf_api.py +++ b/src/transformers/hf_api.py @@ -17,7 +17,7 @@ import io import os from os.path import expanduser -from typing import List +from typing import Dict, List, Optional import requests from tqdm import tqdm @@ -27,6 +27,10 @@ ENDPOINT = "https://huggingface.co" class S3Obj: + """ + Data structure that represents a file belonging to the current user. + """ + def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs): self.filename = filename self.LastModified = LastModified @@ -41,6 +45,50 @@ class PresignedUrl: self.type = type # mime-type to send to S3. +class S3Object: + """ + Data structure that represents a public file accessible on our S3. + """ + + def __init__( + self, + key: str, # S3 object key + etag: str, + lastModified: str, + size: int, + rfilename: str, # filename relative to config.json + **kwargs + ): + self.key = key + self.etag = etag + self.lastModified = lastModified + self.size = size + self.rfilename = rfilename + + +class ModelInfo: + """ + Info about a public model accessible from our S3. + """ + + def __init__( + self, + modelId: str, # id of model + key: str, # S3 object key of config.json + author: Optional[str] = None, + downloads: Optional[int] = None, + tags: List[str] = [], + siblings: List[Dict] = [], + **kwargs + ): + self.modelId = modelId + self.key = key + self.author = author + self.downloads = downloads + self.tags = tags + self.siblings = [S3Object(**x) for x in siblings] + + class HfApi: def __init__(self, endpoint=None): self.endpoint = endpoint if endpoint is not None else ENDPOINT @@ -129,6 +177,16 @@ class HfApi: r = requests.delete(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename}) r.raise_for_status() + def model_list(self) -> List[ModelInfo]: + """ + Get the public list of all the models on huggingface, including the community models + """ + path = "{}/api/models".format(self.endpoint) + r = requests.get(path) + r.raise_for_status() + d = r.json() + return [ModelInfo(**x) for x in d] + class TqdmProgressFileReader: """ diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index c79139095..e1537bbfd 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -21,7 +21,7 @@ import unittest import requests from requests.exceptions import HTTPError -from transformers.hf_api import HfApi, HfFolder, PresignedUrl, S3Obj +from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, S3Obj USER = "__DUMMY_TRANSFORMERS_USER__" @@ -36,10 +36,11 @@ FILES = [ os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"), ), ] +ENDPOINT_STAGING = "https://moon-staging.huggingface.co" class HfApiCommonTest(unittest.TestCase): - _api = HfApi(endpoint="https://moon-staging.huggingface.co") + _api = HfApi(endpoint=ENDPOINT_STAGING) class HfApiLoginTest(HfApiCommonTest): @@ -92,6 +93,18 @@ class HfApiEndpointsTest(HfApiCommonTest): self.assertIsInstance(o, S3Obj) +class HfApiPublicTest(unittest.TestCase): + def test_staging_model_list(self): + _api = HfApi(endpoint=ENDPOINT_STAGING) + _ = _api.model_list() + + def test_model_list(self): + _api = HfApi() + models = _api.model_list() + self.assertGreater(len(models), 100) + self.assertIsInstance(models[0], ModelInfo) + + class HfFolderTest(unittest.TestCase): def test_token_workflow(self): """