Add OnnxConfig for SqueezeBert iss17314 (#17315)

* add onnx config for SqueezeBert

* add test for onnx config for SqueezeBert

* add automatically updated doc for onnx config for SqueezeBert

* Update src/transformers/onnx/features.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update src/transformers/models/squeezebert/configuration_squeezebert.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
Ruihua Fang 2022-06-01 03:16:15 -07:00 committed by GitHub
parent ba286fe7d5
commit 4f38808e9e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 42 additions and 2 deletions

View file

@ -73,6 +73,7 @@ Ready-made configurations include the following architectures:
- PLBart
- RoBERTa
- RoFormer
- SqueezeBERT
- T5
- ViT
- XLM

View file

@ -22,7 +22,11 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_
_import_structure = {
"configuration_squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig"],
"configuration_squeezebert": [
"SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"SqueezeBertConfig",
"SqueezeBertOnnxConfig",
],
"tokenization_squeezebert": ["SqueezeBertTokenizer"],
}
@ -54,7 +58,11 @@ else:
if TYPE_CHECKING:
from .configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig
from .configuration_squeezebert import (
SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
SqueezeBertConfig,
SqueezeBertOnnxConfig,
)
from .tokenization_squeezebert import SqueezeBertTokenizer
try:

View file

@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" SqueezeBERT model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
@ -154,3 +157,20 @@ class SqueezeBertConfig(PretrainedConfig):
self.post_attention_groups = post_attention_groups
self.intermediate_groups = intermediate_groups
self.output_groups = output_groups
# # Copied from transformers.models.bert.configuration_bert.BertOnxxConfig with Bert->SqueezeBert
class SqueezeBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)

View file

@ -28,6 +28,7 @@ from ..models.mbart import MBartOnnxConfig
from ..models.mobilebert import MobileBertOnnxConfig
from ..models.roberta import RobertaOnnxConfig
from ..models.roformer import RoFormerOnnxConfig
from ..models.squeezebert import SqueezeBertOnnxConfig
from ..models.t5 import T5OnnxConfig
from ..models.vit import ViTOnnxConfig
from ..models.xlm import XLMOnnxConfig
@ -352,6 +353,15 @@ class FeaturesManager:
"token-classification",
onnx_config_cls=RoFormerOnnxConfig,
),
"squeezebert": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=SqueezeBertOnnxConfig,
),
"t5": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
),

View file

@ -180,6 +180,7 @@ PYTORCH_EXPORT_MODELS = {
("electra", "google/electra-base-generator"),
("roberta", "roberta-base"),
("roformer", "junnyu/roformer_chinese_base"),
("squeezebert", "squeezebert/squeezebert-uncased"),
("mobilebert", "google/mobilebert-uncased"),
("xlm", "xlm-clm-ende-1024"),
("xlm-roberta", "xlm-roberta-base"),