diff --git a/setup.py b/setup.py index c07920520..25f503f8d 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,12 @@ To create the package for pypi. from io import open from setuptools import find_packages, setup + +extras = { + 'serving': ['uvicorn', 'fastapi'] +} +extras['all'] = [package for package in extras.values()] + setup( name="transformers", version="2.2.1", @@ -61,6 +67,10 @@ setup( "transformers=transformers.__main__:main", ] }, + extras_require=extras, + scripts=[ + 'transformers-cli' + ], # python_requires='>=3.5.0', tests_require=['pytest'], classifiers=[ diff --git a/transformers-cli b/transformers-cli new file mode 100644 index 000000000..ef00d15aa --- /dev/null +++ b/transformers-cli @@ -0,0 +1,23 @@ +#!/usr/bin/env python +from argparse import ArgumentParser + +from transformers.commands.user import UserCommands + + +if __name__ == '__main__': + parser = ArgumentParser(description='Transformers CLI tool', usage='transformers-cli []') + commands_parser = parser.add_subparsers(help='transformers-cli command helpers') + + # Register commands + UserCommands.register_subcommand(commands_parser) + + # Let's go + args = parser.parse_args() + + if not hasattr(args, 'func'): + parser.print_help() + exit(1) + + # Run + service = args.func(args) + service.run() diff --git a/transformers/commands/__init__.py b/transformers/commands/__init__.py new file mode 100644 index 000000000..bbdd5655f --- /dev/null +++ b/transformers/commands/__init__.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod +from argparse import ArgumentParser + +class BaseTransformersCLICommand(ABC): + @staticmethod + @abstractmethod + def register_subcommand(parser: ArgumentParser): + raise NotImplementedError() + + @abstractmethod + def run(self): + raise NotImplementedError() diff --git a/transformers/commands/user.py b/transformers/commands/user.py new file mode 100644 index 000000000..4b826f4dc --- /dev/null +++ b/transformers/commands/user.py @@ -0,0 +1,122 @@ +from argparse import ArgumentParser +from getpass import getpass +import os + +from transformers.commands import BaseTransformersCLICommand +from transformers.hf_api import HfApi, HfFolder, HTTPError + + +class UserCommands(BaseTransformersCLICommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + login_parser = parser.add_parser('login') + login_parser.set_defaults(func=lambda args: LoginCommand(args)) + whoami_parser = parser.add_parser('whoami') + whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args)) + logout_parser = parser.add_parser('logout') + logout_parser.set_defaults(func=lambda args: LogoutCommand(args)) + list_parser = parser.add_parser('ls') + list_parser.set_defaults(func=lambda args: ListObjsCommand(args)) + # upload + upload_parser = parser.add_parser('upload') + upload_parser.add_argument('file', type=str, help='Local filepath of the file to upload.') + upload_parser.add_argument('--filename', type=str, default=None, help='Optional: override object filename on S3.') + upload_parser.set_defaults(func=lambda args: UploadCommand(args)) + + + +class BaseUserCommand: + def __init__(self, args): + self.args = args + self._api = HfApi() + + +class LoginCommand(BaseUserCommand): + def run(self): + print(""" + _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_| + _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| + _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_| + _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| + _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_| + + """) + username = input("Username: ") + password = getpass() + try: + token = self._api.login(username, password) + except HTTPError as e: + # probably invalid credentials, display error message. + print(e) + exit(1) + HfFolder.save_token(token) + print("Login successful") + print("Your token:", token, "\n") + print("Your token has been saved to", HfFolder.path_token) + + +class WhoamiCommand(BaseUserCommand): + def run(self): + token = HfFolder.get_token() + if token is None: + print("Not logged in") + exit() + try: + user = self._api.whoami(token) + print(user) + except HTTPError as e: + print(e) + + +class LogoutCommand(BaseUserCommand): + def run(self): + token = HfFolder.get_token() + if token is None: + print("Not logged in") + exit() + HfFolder.delete_token() + self._api.logout(token) + print("Successfully logged out.") + + +class ListObjsCommand(BaseUserCommand): + def run(self): + token = HfFolder.get_token() + if token is None: + print("Not logged in") + exit(1) + try: + objs = self._api.list_objs(token) + except HTTPError as e: + print(e) + exit(1) + if len(objs) == 0: + print("No shared file yet") + for obj in objs: + print( + obj.filename, + obj.LastModified, + obj.ETag, + obj.Size + ) + + +class UploadCommand(BaseUserCommand): + def run(self): + token = HfFolder.get_token() + if token is None: + print("Not logged in") + exit(1) + filepath = os.path.join(os.getcwd(), self.args.file) + filename = self.args.filename if self.args.filename is not None else os.path.basename(filepath) + print("About to upload file {} to S3 under filename {}".format(filepath, filename)) + choice = input("Proceed? [Y/n] ").lower() + if not(choice == "" or choice == "y" or choice == "yes"): + print("Abort") + exit() + print("Uploading...") + access_url = self._api.presign_and_upload( + token=token, filename=filename, filepath=filepath + ) + print("Your file now lives at:") + print(access_url) diff --git a/transformers/hf_api.py b/transformers/hf_api.py new file mode 100644 index 000000000..238762ebf --- /dev/null +++ b/transformers/hf_api.py @@ -0,0 +1,152 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import, division, print_function + +from typing import List, NamedTuple +import os +from os.path import expanduser + +import requests +from requests.exceptions import HTTPError + +ENDPOINT = "https://huggingface.co" + +class S3Obj: + def __init__(self, filename: str, LastModified: str, ETag: str, Size: int): + self.filename = filename + self.LastModified = LastModified + self.ETag = ETag + self.Size = Size + + +class PresignedUrl(NamedTuple): + write: str + access: str + + +class HfApi: + def __init__(self, endpoint=None): + self.endpoint = endpoint if endpoint is not None else ENDPOINT + + def login(self, username: str, password: str) -> str: + """ + Call HF API to sign in a user and get a token if credentials are valid. + + Outputs: + token if credentials are valid + + Throws: + requests.exceptions.HTTPError if credentials are invalid + """ + path = "{}/api/login".format(self.endpoint) + r = requests.post(path, json={"username": username, "password": password}) + r.raise_for_status() + d = r.json() + return d["token"] + + def whoami(self, token: str) -> str: + """ + Call HF API to know "whoami" + """ + path = "{}/api/whoami".format(self.endpoint) + r = requests.get(path, headers={"authorization": "Bearer {}".format(token)}) + r.raise_for_status() + d = r.json() + return d["user"] + + def logout(self, token: str): + """ + Call HF API to log out. + """ + path = "{}/api/logout".format(self.endpoint) + r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}) + r.raise_for_status() + + def presign(self, token: str, filename: str) -> PresignedUrl: + """ + Call HF API to get a presigned url to upload `filename` to S3. + """ + path = "{}/api/presign".format(self.endpoint) + r = requests.post( + path, + headers={"authorization": "Bearer {}".format(token)}, + json={"filename": filename}, + ) + r.raise_for_status() + d = r.json() + return PresignedUrl(**d) + + def presign_and_upload(self, token: str, filename: str, filepath: str) -> str: + """ + Get a presigned url, then upload file to S3. + + Outputs: + url: Read-only url for the stored file on S3. + """ + urls = self.presign(token, filename=filename) + # streaming upload: + # https://2.python-requests.org/en/master/user/advanced/#streaming-uploads + with open(filepath, "rb") as f: + r = requests.put(urls.write, data=f) + r.raise_for_status() + return urls.access + + def list_objs(self, token: str) -> List[S3Obj]: + """ + Call HF API to list all stored files for user. + """ + path = "{}/api/listObjs".format(self.endpoint) + r = requests.get(path, headers={"authorization": "Bearer {}".format(token)}) + r.raise_for_status() + d = r.json() + return [S3Obj(**x) for x in d] + + + + + +class HfFolder: + path_token = expanduser("~/.huggingface/token") + + @classmethod + def save_token(cls, token: str): + """ + Save token, creating folder as needed. + """ + os.makedirs(os.path.dirname(cls.path_token), exist_ok=True) + with open(cls.path_token, 'w+') as f: + f.write(token) + + @classmethod + def get_token(cls): + """ + Get token or None if not existent. + """ + try: + with open(cls.path_token, 'r') as f: + return f.read() + except FileNotFoundError: + return None + + @classmethod + def delete_token(cls): + """ + Delete token. + Do not fail if token does not exist. + """ + try: + os.remove(cls.path_token) + except: + return diff --git a/transformers/tests/hf_api_test.py b/transformers/tests/hf_api_test.py new file mode 100644 index 000000000..59822344b --- /dev/null +++ b/transformers/tests/hf_api_test.py @@ -0,0 +1,94 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import, division, print_function + +import os +import time +import unittest + +from transformers.hf_api import HfApi, S3Obj, PresignedUrl, HfFolder, HTTPError + +USER = "__DUMMY_TRANSFORMERS_USER__" +PASS = "__DUMMY_TRANSFORMERS_PASS__" +FILE_KEY = "Test-{}.txt".format(int(time.time())) +FILE_PATH = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "fixtures/input.txt" +) + + + +class HfApiCommonTest(unittest.TestCase): + _api = HfApi(endpoint="https://moon-staging.huggingface.co") + + +class HfApiLoginTest(HfApiCommonTest): + def test_login_invalid(self): + with self.assertRaises(HTTPError): + self._api.login(username=USER, password="fake") + + def test_login_valid(self): + token = self._api.login(username=USER, password=PASS) + self.assertIsInstance(token, str) + + +class HfApiEndpointsTest(HfApiCommonTest): + @classmethod + def setUpClass(cls): + """ + Share this valid token in all tests below. + """ + cls._token = cls._api.login(username=USER, password=PASS) + + def test_whoami(self): + user = self._api.whoami(token=self._token) + self.assertEqual(user, USER) + + def test_presign(self): + url = self._api.presign(token=self._token, filename=FILE_KEY) + self.assertIsInstance(url, PresignedUrl) + + def test_presign_and_upload(self): + access_url = self._api.presign_and_upload( + token=self._token, filename=FILE_KEY, filepath=FILE_PATH + ) + self.assertIsInstance(access_url, str) + + def test_list_objs(self): + objs = self._api.list_objs(token=self._token) + o = objs[-1] + self.assertIsInstance(o, S3Obj) + + + +class HfFolderTest(unittest.TestCase): + def test_token_workflow(self): + """ + Test the whole token save/get/delete workflow, + with the desired behavior with respect to non-existent tokens. + """ + token = "token-{}".format(int(time.time())) + HfFolder.save_token(token) + self.assertEqual( + HfFolder.get_token(), + token + ) + HfFolder.delete_token() + HfFolder.delete_token() + # ^^ not an error, we test that the + # second call does not fail. + self.assertEqual( + HfFolder.get_token(), + None + )