From 73fcebf7ec122e68b93f50fc770f0515502eb025 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 20 Dec 2019 13:47:35 +0100 Subject: [PATCH] update serving command --- setup.py | 6 +++--- transformers-cli | 2 +- transformers/commands/serving.py | 35 +++++++++++++++++++++----------- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/setup.py b/setup.py index 6560cc496..4bfb77415 100644 --- a/setup.py +++ b/setup.py @@ -38,9 +38,9 @@ from setuptools import find_packages, setup extras = { - 'serving': ['uvicorn', 'fastapi'], - 'serving-tf': ['uvicorn', 'fastapi', 'tensorflow'], - 'serving-torch': ['uvicorn', 'fastapi', 'torch'] + 'serving': ['pydantic', 'uvicorn', 'fastapi'], + 'serving-tf': ['pydantic', 'uvicorn', 'fastapi', 'tensorflow'], + 'serving-torch': ['pydantic', 'uvicorn', 'fastapi', 'torch'] } extras['all'] = [package for package in extras.values()] diff --git a/transformers-cli b/transformers-cli index db2bd0e2a..0a980a357 100755 --- a/transformers-cli +++ b/transformers-cli @@ -3,9 +3,9 @@ from argparse import ArgumentParser from transformers.commands.download import DownloadCommand from transformers.commands.run import RunCommand -from transformers.commands.serving import ServeCommand from transformers.commands.user import UserCommands from transformers.commands.convert import ConvertCommand +from transformers.commands.serving import ServeCommand if __name__ == '__main__': parser = ArgumentParser('Transformers CLI tool', usage='transformers-cli []') diff --git a/transformers/commands/serving.py b/transformers/commands/serving.py index a7321470c..3c3f85280 100644 --- a/transformers/commands/serving.py +++ b/transformers/commands/serving.py @@ -1,16 +1,23 @@ from argparse import ArgumentParser, Namespace from typing import List, Optional, Union, Any -from fastapi import FastAPI, HTTPException, Body -from logging import getLogger +import logging -from pydantic import BaseModel -from uvicorn import run +try: + from uvicorn import run + from fastapi import FastAPI, HTTPException, Body + from pydantic import BaseModel + _serve_dependancies_installed = True +except (ImportError, AttributeError): + BaseModel = object + Body = lambda *x, **y: None + _serve_dependancies_installed = False from transformers import Pipeline from transformers.commands import BaseTransformersCLICommand from transformers.pipelines import SUPPORTED_TASKS, pipeline +logger = logging.getLogger('transformers-cli/serving') def serve_command_factory(args: Namespace): """ @@ -70,20 +77,24 @@ class ServeCommand(BaseTransformersCLICommand): serve_parser.set_defaults(func=serve_command_factory) def __init__(self, pipeline: Pipeline, host: str, port: int): - self._logger = getLogger('transformers-cli/serving') self._pipeline = pipeline - self._logger.info('Serving model over {}:{}'.format(host, port)) self._host = host self._port = port - self._app = FastAPI() + if not _serve_dependancies_installed: + raise ImportError("Using serve command requires FastAPI and unicorn. " + "Please install transformers with [serving]: pip install transformers[serving]." + "Or install FastAPI and unicorn separatly.") + else: + logger.info('Serving model over {}:{}'.format(host, port)) + self._app = FastAPI() - # Register routes - self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET']) - self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST']) - self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST']) - self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST']) + # Register routes + self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET']) + self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST']) + self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST']) + self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST']) def run(self): run(self._app, host=self._host, port=self._port)