From 4f38808e9e1b7ff9c7b783444c8df8493e44dcda Mon Sep 17 00:00:00 2001 From: Ruihua Fang Date: Wed, 1 Jun 2022 03:16:15 -0700 Subject: [PATCH] Add OnnxConfig for SqueezeBert iss17314 (#17315) * add onnx config for SqueezeBert * add test for onnx config for SqueezeBert * add automatically updated doc for onnx config for SqueezeBert * Update src/transformers/onnx/features.py Co-authored-by: lewtun * Update src/transformers/models/squeezebert/configuration_squeezebert.py Co-authored-by: lewtun Co-authored-by: lewtun --- docs/source/en/serialization.mdx | 1 + .../models/squeezebert/__init__.py | 12 +++++++++-- .../squeezebert/configuration_squeezebert.py | 20 +++++++++++++++++++ src/transformers/onnx/features.py | 10 ++++++++++ tests/onnx/test_onnx_v2.py | 1 + 5 files changed, 42 insertions(+), 2 deletions(-) diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index d12e52627..0cb1ff683 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -73,6 +73,7 @@ Ready-made configurations include the following architectures: - PLBart - RoBERTa - RoFormer +- SqueezeBERT - T5 - ViT - XLM diff --git a/src/transformers/models/squeezebert/__init__.py b/src/transformers/models/squeezebert/__init__.py index 52c001dbd..9f758bebe 100644 --- a/src/transformers/models/squeezebert/__init__.py +++ b/src/transformers/models/squeezebert/__init__.py @@ -22,7 +22,11 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_ _import_structure = { - "configuration_squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig"], + "configuration_squeezebert": [ + "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", + "SqueezeBertConfig", + "SqueezeBertOnnxConfig", + ], "tokenization_squeezebert": ["SqueezeBertTokenizer"], } @@ -54,7 +58,11 @@ else: if TYPE_CHECKING: - from .configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig + from .configuration_squeezebert import ( + SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, + SqueezeBertConfig, + SqueezeBertOnnxConfig, + ) from .tokenization_squeezebert import SqueezeBertTokenizer try: diff --git a/src/transformers/models/squeezebert/configuration_squeezebert.py b/src/transformers/models/squeezebert/configuration_squeezebert.py index b4b707d6c..41b47ff57 100644 --- a/src/transformers/models/squeezebert/configuration_squeezebert.py +++ b/src/transformers/models/squeezebert/configuration_squeezebert.py @@ -13,8 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ SqueezeBERT model configuration""" +from collections import OrderedDict +from typing import Mapping from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -154,3 +157,20 @@ class SqueezeBertConfig(PretrainedConfig): self.post_attention_groups = post_attention_groups self.intermediate_groups = intermediate_groups self.output_groups = output_groups + + +# # Copied from transformers.models.bert.configuration_bert.BertOnxxConfig with Bert->SqueezeBert +class SqueezeBertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), + ] + ) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index dfb46f89e..9013618e0 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -28,6 +28,7 @@ from ..models.mbart import MBartOnnxConfig from ..models.mobilebert import MobileBertOnnxConfig from ..models.roberta import RobertaOnnxConfig from ..models.roformer import RoFormerOnnxConfig +from ..models.squeezebert import SqueezeBertOnnxConfig from ..models.t5 import T5OnnxConfig from ..models.vit import ViTOnnxConfig from ..models.xlm import XLMOnnxConfig @@ -352,6 +353,15 @@ class FeaturesManager: "token-classification", onnx_config_cls=RoFormerOnnxConfig, ), + "squeezebert": supported_features_mapping( + "default", + "masked-lm", + "sequence-classification", + "multiple-choice", + "token-classification", + "question-answering", + onnx_config_cls=SqueezeBertOnnxConfig, + ), "t5": supported_features_mapping( "default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig ), diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index bdf08c445..d5115a9b3 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -180,6 +180,7 @@ PYTORCH_EXPORT_MODELS = { ("electra", "google/electra-base-generator"), ("roberta", "roberta-base"), ("roformer", "junnyu/roformer_chinese_base"), + ("squeezebert", "squeezebert/squeezebert-uncased"), ("mobilebert", "google/mobilebert-uncased"), ("xlm", "xlm-clm-ende-1024"), ("xlm-roberta", "xlm-roberta-base"),