mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Add DeBERTa model (#5929)
* Add DeBERTa model * Remove dependency of deberta * Address comments * Patch DeBERTa Documentation Style * Add final tests * Style * Enable tests + nitpicks * position IDs * BERT -> DeBERTa * Quality * Style * Tokenization * Last updates. * @patrickvonplaten's comments * Not everything can be a copy * Apply most of @sgugger's review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Last reviews * DeBERTa -> Deberta Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
44a93c981f
commit
7a0cf0ec93
16 changed files with 2350 additions and 3 deletions
|
|
@ -187,8 +187,9 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
|||
25. **[LXMERT](https://github.com/airsplay/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
|
||||
26. **[Funnel Transformer](https://github.com/laiguokun/Funnel-Transformer)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le.
|
||||
27. **[LayoutLM](https://github.com/microsoft/unilm/tree/master/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
|
||||
28. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users).
|
||||
29. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
|
||||
28. **[DeBERTa](https://huggingface.co/transformers/model_doc/deberta.html)** (from Microsoft Research) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
29. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users).
|
||||
30. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
|
||||
|
||||
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations. You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html).
|
||||
|
||||
|
|
|
|||
|
|
@ -214,6 +214,7 @@ conversion utilities for the following models:
|
|||
model_doc/bertgeneration
|
||||
model_doc/camembert
|
||||
model_doc/ctrl
|
||||
model_doc/deberta
|
||||
model_doc/dialogpt
|
||||
model_doc/distilbert
|
||||
model_doc/dpr
|
||||
|
|
|
|||
62
docs/source/model_doc/deberta.rst
Normal file
62
docs/source/model_doc/deberta.rst
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
DeBERTa
|
||||
----------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention <https://arxiv.org/abs/2006.03654>`__
|
||||
by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen
|
||||
It is based on Google's BERT model released in 2018 and Facebook's RoBERTa model released in 2019.
|
||||
|
||||
It builds on RoBERTa with disentangled attention and enhanced mask decoder training with half of the data used in RoBERTa.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Recent progress in pre-trained neural language models has significantly improved the performance of many natural language processing (NLP) tasks.
|
||||
In this paper we propose a new model architecture DeBERTa (Decoding-enhanced BERT with disentangled attention) that improves the BERT and RoBERTa
|
||||
models using two novel techniques. The first is the disentangled attention mechanism, where each word is represented using two vectors that encode
|
||||
its content and position, respectively, and the attention weights among words are computed using disentangled matrices on their contents and
|
||||
relative positions. Second, an enhanced mask decoder is used to replace the output softmax layer to predict the masked tokens for model pretraining.
|
||||
We show that these two techniques significantly improve the efficiency of model pre-training and performance of downstream tasks. Compared to
|
||||
RoBERTa-Large, a DeBERTa model trained on half of the training data performs consistently better on a wide range of NLP tasks, achieving improvements
|
||||
on MNLI by +0.9% (90.2% vs. 91.1%), on SQuAD v2.0 by +2.3% (88.4% vs. 90.7%) and RACE by +3.6% (83.2% vs. 86.8%). The DeBERTa code and pre-trained
|
||||
models will be made publicly available at https://github.com/microsoft/DeBERTa.*
|
||||
|
||||
|
||||
The original code can be found `here <https://github.com/microsoft/DeBERTa>`__.
|
||||
|
||||
|
||||
DebertaConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.DebertaConfig
|
||||
:members:
|
||||
|
||||
|
||||
DebertaTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.DebertaTokenizer
|
||||
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
|
||||
create_token_type_ids_from_sequences, save_vocabulary
|
||||
|
||||
|
||||
DebertaModel
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.DebertaModel
|
||||
:members:
|
||||
|
||||
|
||||
DebertaPreTrainedModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.DebertaPreTrainedModel
|
||||
:members:
|
||||
|
||||
|
||||
DebertaForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.DebertaForSequenceClassification
|
||||
:members:
|
||||
|
|
@ -415,4 +415,15 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
|
|||
| | ``microsoft/layoutlm-large-uncased`` | | 24 layers, 1024-hidden, 16-heads, 343M parameters |
|
||||
| | | |
|
||||
| | | (see `details <https://github.com/microsoft/unilm/tree/master/layoutlm>`__) |
|
||||
+--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
+--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| DeBERTa | ``microsoft/deberta-base`` | | 12-layer, 768-hidden, 12-heads, ~125M parameters |
|
||||
| | | | DeBERTa using the BERT-base architecture |
|
||||
| | | |
|
||||
| | | (see `details <https://github.com/microsoft/DeBERTa>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``microsoft/deberta-large`` | | 24-layer, 1024-hidden, 16-heads, ~390M parameters |
|
||||
| | | | DeBERTa using the BERT-large architecture |
|
||||
| | | |
|
||||
| | | (see `details <https://github.com/microsoft/DeBERTa>`__) |
|
||||
+--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
|
||||
|
|
|
|||
36
model_cards/microsoft/DeBERTa-base/README.md
Normal file
36
model_cards/microsoft/DeBERTa-base/README.md
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
---
|
||||
thumbnail: https://huggingface.co/front/thumbnails/microsoft.png
|
||||
license: mit
|
||||
---
|
||||
|
||||
## DeBERTa: Decoding-enhanced BERT with Disentangled Attention
|
||||
|
||||
[DeBERTa](https://arxiv.org/abs/2006.03654) improves the BERT and RoBERTa models using disentangled attention and enhanced mask decoder. With those two improvements, DeBERTa out perform RoBERTa on a majority of NLU tasks with 80GB training data.
|
||||
|
||||
Please check the [official repository](https://github.com/microsoft/DeBERTa) for more details and updates.
|
||||
|
||||
|
||||
#### Fine-tuning on NLU tasks
|
||||
|
||||
We present the dev results on SQuAD 1.1/2.0 and MNLI tasks.
|
||||
|
||||
| Model | SQuAD 1.1 | SQuAD 2.0 | MNLI-m |
|
||||
|-------------------|-----------|-----------|--------|
|
||||
| RoBERTa-base | 91.5/84.6 | 83.7/80.5 | 87.6 |
|
||||
| XLNet-Large | -/- | -/80.2 | 86.8 |
|
||||
| **DeBERTa-base** | 93.1/87.2 | 86.2/83.1 | 88.8 |
|
||||
|
||||
### Citation
|
||||
|
||||
If you find DeBERTa useful for your work, please cite the following paper:
|
||||
|
||||
``` latex
|
||||
@misc{he2020deberta,
|
||||
title={DeBERTa: Decoding-enhanced BERT with Disentangled Attention},
|
||||
author={Pengcheng He and Xiaodong Liu and Jianfeng Gao and Weizhu Chen},
|
||||
year={2020},
|
||||
eprint={2006.03654},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
```
|
||||
37
model_cards/microsoft/DeBERTa-large/README.md
Normal file
37
model_cards/microsoft/DeBERTa-large/README.md
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
---
|
||||
thumbnail: https://huggingface.co/front/thumbnails/microsoft.png
|
||||
license: mit
|
||||
---
|
||||
|
||||
## DeBERTa: Decoding-enhanced BERT with Disentangled Attention
|
||||
|
||||
[DeBERTa](https://arxiv.org/abs/2006.03654) improves the BERT and RoBERTa models using disentangled attention and enhanced mask decoder. With those two improvements, DeBERTa out perform RoBERTa on a majority of NLU tasks with 80GB training data.
|
||||
|
||||
Please check the [official repository](https://github.com/microsoft/DeBERTa) for more details and updates.
|
||||
|
||||
|
||||
#### Fine-tuning on NLU tasks
|
||||
|
||||
We present the dev results on SQuAD 1.1/2.0 and several GLUE benchmark tasks.
|
||||
|
||||
| Model | SQuAD 1.1 | SQuAD 2.0 | MNLI-m | SST-2 | QNLI | CoLA | RTE | MRPC | QQP |STS-B|
|
||||
|-------------------|-----------|-----------|--------|-------|------|------|------|------|------|-----|
|
||||
| BERT-Large | 90.9/84.1 | 81.8/79.0 | 86.6 | 93.2 | 92.3 | 60.6 | 70.4 | 88.0 | 91.3 |90.0 |
|
||||
| RoBERTa-Large | 94.6/88.9 | 89.4/86.5 | 90.2 | 96.4 | 93.9 | 68.0 | 86.6 | 90.9 | 92.2 |92.4 |
|
||||
| XLNet-Large | 95.1/89.7 | 90.6/87.9 | 90.8 | 97.0 | 94.9 | 69.0 | 85.9 | 90.8 | 92.3 |92.5 |
|
||||
| **DeBERTa-Large** | 95.5/90.1 | 90.7/88.0 | 91.1 | 96.5 | 95.3 | 69.5 | 88.1 | 92.5 | 92.3 |92.5 |
|
||||
|
||||
### Citation
|
||||
|
||||
If you find DeBERTa useful for your work, please cite the following paper:
|
||||
|
||||
``` latex
|
||||
@misc{he2020deberta,
|
||||
title={DeBERTa: Decoding-enhanced BERT with Disentangled Attention},
|
||||
author={Pengcheng He and Xiaodong Liu and Jianfeng Gao and Weizhu Chen},
|
||||
year={2020},
|
||||
eprint={2006.03654},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
```
|
||||
|
|
@ -35,6 +35,7 @@ from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
|||
from .configuration_bert_generation import BertGenerationConfig
|
||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
|
||||
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
|
||||
from .configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig
|
||||
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
|
||||
|
|
@ -156,6 +157,7 @@ from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenize
|
|||
from .tokenization_bertweet import BertweetTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
from .tokenization_ctrl import CTRLTokenizer
|
||||
from .tokenization_deberta import DebertaTokenizer
|
||||
from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
|
||||
from .tokenization_dpr import (
|
||||
DPRContextEncoderTokenizer,
|
||||
|
|
@ -310,6 +312,12 @@ if is_torch_available():
|
|||
CamembertModel,
|
||||
)
|
||||
from .modeling_ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLLMHeadModel, CTRLModel, CTRLPreTrainedModel
|
||||
from .modeling_deberta import (
|
||||
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
DebertaForSequenceClassification,
|
||||
DebertaModel,
|
||||
DebertaPreTrainedModel,
|
||||
)
|
||||
from .modeling_distilbert import (
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
DistilBertForMaskedLM,
|
||||
|
|
|
|||
|
|
@ -44,6 +44,10 @@ def mish(x):
|
|||
return x * torch.tanh(torch.nn.functional.softplus(x))
|
||||
|
||||
|
||||
def linear_act(x):
|
||||
return x
|
||||
|
||||
|
||||
ACT2FN = {
|
||||
"relu": F.relu,
|
||||
"swish": swish,
|
||||
|
|
@ -52,6 +56,8 @@ ACT2FN = {
|
|||
"gelu_new": gelu_new,
|
||||
"gelu_fast": gelu_fast,
|
||||
"mish": mish,
|
||||
"linear": linear_act,
|
||||
"sigmoid": torch.sigmoid,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
|||
from .configuration_bert_generation import BertGenerationConfig
|
||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
|
||||
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
|
||||
from .configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig
|
||||
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
|
||||
|
|
@ -78,6 +79,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
|||
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
]
|
||||
for key, value, in pretrained_map.items()
|
||||
)
|
||||
|
|
@ -100,6 +102,7 @@ CONFIG_MAPPING = OrderedDict(
|
|||
("reformer", ReformerConfig),
|
||||
("longformer", LongformerConfig),
|
||||
("roberta", RobertaConfig),
|
||||
("deberta", DebertaConfig),
|
||||
("flaubert", FlaubertConfig),
|
||||
("fsmt", FSMTConfig),
|
||||
("bert", BertConfig),
|
||||
|
|
@ -149,6 +152,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||
("encoder-decoder", "Encoder decoder"),
|
||||
("funnel", "Funnel Transformer"),
|
||||
("lxmert", "LXMERT"),
|
||||
("deberta", "DeBERTa"),
|
||||
("layoutlm", "LayoutLM"),
|
||||
("dpr", "DPR"),
|
||||
("rag", "RAG"),
|
||||
|
|
|
|||
132
src/transformers/configuration_deberta.py
Normal file
132
src/transformers/configuration_deberta.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020, Microsoft and the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" DeBERTa model configuration """
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"microsoft/deberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/microsoft/deberta-base/config.json",
|
||||
"microsoft/deberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/microsoft/deberta-large/config.json",
|
||||
}
|
||||
|
||||
|
||||
class DebertaConfig(PretrainedConfig):
|
||||
r"""
|
||||
:class:`~transformers.DebertaConfig` is the configuration class to store the configuration of a
|
||||
:class:`~transformers.DebertaModel`.
|
||||
|
||||
Arguments:
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 30522):
|
||||
Vocabulary size of the DeBERTa model. Defines the number of different tokens that can be represented by the
|
||||
:obj:`inputs_ids` passed when calling :class:`~transformers.DebertaModel` or
|
||||
:class:`~transformers.TFDebertaModel`.
|
||||
hidden_size (:obj:`int`, `optional`, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler.
|
||||
If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"`, :obj:`"gelu"`, :obj:`"tanh"`, :obj:`"gelu_fast"`,
|
||||
:obj:`"mish"`, :obj:`"linear"`, :obj:`"sigmoid"` and :obj:`"gelu_new"` are supported.
|
||||
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
|
||||
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.DebertaModel` or
|
||||
:class:`~transformers.TFDebertaModel`.
|
||||
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
relative_attention (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether use relative position encoding.
|
||||
max_relative_positions (:obj:`int`, `optional`, defaults to 1):
|
||||
The range of relative positions :obj:`[-max_position_embeddings, max_position_embeddings]`.
|
||||
Use the same value as :obj:`max_position_embeddings`.
|
||||
pad_token_id (:obj:`int`, `optional`, defaults to 0):
|
||||
The value used to pad input_ids.
|
||||
position_biased_input (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether add absolute position embedding to content embedding.
|
||||
pos_att_type (:obj:`List[str]`, `optional`):
|
||||
The type of relative position attention, it can be a combination of :obj:`["p2c", "c2p", "p2p"]`,
|
||||
e.g. :obj:`["p2c"]`, :obj:`["p2c", "c2p"]`, :obj:`["p2c", "c2p", 'p2p"]`.
|
||||
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
"""
|
||||
model_type = "deberta"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50265,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=0,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-7,
|
||||
relative_attention=False,
|
||||
max_relative_positions=-1,
|
||||
pad_token_id=0,
|
||||
position_biased_input=True,
|
||||
pos_att_type=None,
|
||||
pooler_dropout=0,
|
||||
pooler_hidden_act="gelu",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.relative_attention = relative_attention
|
||||
self.max_relative_positions = max_relative_positions
|
||||
self.pad_token_id = pad_token_id
|
||||
self.position_biased_input = position_biased_input
|
||||
|
||||
# Backwards compatibility
|
||||
if type(pos_att_type) == str:
|
||||
pos_att_type = [x.strip() for x in pos_att_type.lower().split("|")]
|
||||
|
||||
self.pos_att_type = pos_att_type
|
||||
self.vocab_size = vocab_size
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
|
||||
self.pooler_hidden_size = kwargs.get("pooler_hidden_size", hidden_size)
|
||||
self.pooler_dropout = pooler_dropout
|
||||
self.pooler_hidden_act = pooler_hidden_act
|
||||
|
|
@ -26,6 +26,7 @@ from .configuration_auto import (
|
|||
BertGenerationConfig,
|
||||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
DebertaConfig,
|
||||
DistilBertConfig,
|
||||
DPRConfig,
|
||||
ElectraConfig,
|
||||
|
|
@ -90,6 +91,7 @@ from .modeling_camembert import (
|
|||
CamembertModel,
|
||||
)
|
||||
from .modeling_ctrl import CTRLLMHeadModel, CTRLModel
|
||||
from .modeling_deberta import DebertaForSequenceClassification, DebertaModel
|
||||
from .modeling_distilbert import (
|
||||
DistilBertForMaskedLM,
|
||||
DistilBertForMultipleChoice,
|
||||
|
|
@ -231,6 +233,7 @@ MODEL_MAPPING = OrderedDict(
|
|||
(FunnelConfig, FunnelModel),
|
||||
(LxmertConfig, LxmertModel),
|
||||
(BertGenerationConfig, BertGenerationEncoder),
|
||||
(DebertaConfig, DebertaModel),
|
||||
(DPRConfig, DPRQuestionEncoder),
|
||||
]
|
||||
)
|
||||
|
|
@ -359,6 +362,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
|||
(XLMConfig, XLMForSequenceClassification),
|
||||
(ElectraConfig, ElectraForSequenceClassification),
|
||||
(FunnelConfig, FunnelForSequenceClassification),
|
||||
(DebertaConfig, DebertaForSequenceClassification),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
|||
1032
src/transformers/modeling_deberta.py
Normal file
1032
src/transformers/modeling_deberta.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -25,6 +25,7 @@ from .configuration_auto import (
|
|||
BertGenerationConfig,
|
||||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
DebertaConfig,
|
||||
DistilBertConfig,
|
||||
DPRConfig,
|
||||
ElectraConfig,
|
||||
|
|
@ -61,6 +62,7 @@ from .tokenization_bert_japanese import BertJapaneseTokenizer
|
|||
from .tokenization_bertweet import BertweetTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
from .tokenization_ctrl import CTRLTokenizer
|
||||
from .tokenization_deberta import DebertaTokenizer
|
||||
from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
|
||||
from .tokenization_dpr import DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast
|
||||
from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
|
||||
|
|
@ -125,6 +127,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
|||
(CTRLConfig, (CTRLTokenizer, None)),
|
||||
(FSMTConfig, (FSMTTokenizer, None)),
|
||||
(BertGenerationConfig, (BertGenerationTokenizer, None)),
|
||||
(DebertaConfig, (DebertaTokenizer, None)),
|
||||
(LayoutLMConfig, (LayoutLMTokenizer, None)),
|
||||
(RagConfig, (RagTokenizer, None)),
|
||||
]
|
||||
|
|
|
|||
663
src/transformers/tokenization_deberta.py
Normal file
663
src/transformers/tokenization_deberta.py
Normal file
|
|
@ -0,0 +1,663 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 Microsoft and the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Tokenization class for model DeBERTa."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import random
|
||||
import unicodedata
|
||||
from functools import lru_cache
|
||||
from zipfile import ZipFile
|
||||
|
||||
import tqdm
|
||||
|
||||
import requests
|
||||
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
|
||||
try:
|
||||
import regex as re
|
||||
except ImportError:
|
||||
raise ImportError("Please install regex with: pip install regex")
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "bpe_encoder.bin"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"microsoft/deberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/microsoft/deberta-base/bpe_encoder.bin",
|
||||
"microsoft/deberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/microsoft/deberta-large/bpe_encoder.bin",
|
||||
}
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"microsoft/deberta-base": 512,
|
||||
"microsoft/deberta-large": 512,
|
||||
}
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
"microsoft/deberta-base": {"do_lower_case": False},
|
||||
"microsoft/deberta-large": {"do_lower_case": False},
|
||||
}
|
||||
|
||||
__all__ = ["DebertaTokenizer"]
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = (
|
||||
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
||||
)
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2 ** 8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2 ** 8 + n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
class Encoder:
|
||||
def __init__(self, encoder, bpe_merges, errors="replace"):
|
||||
self.encoder = encoder
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
self.bpe_ranks = dict(zip([tuple(k) for k in bpe_merges], range(len(bpe_merges))))
|
||||
self.cache = {}
|
||||
self.random = random.Random(0)
|
||||
|
||||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except Exception:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = " ".join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def split_to_words(self, text):
|
||||
return list(re.findall(self.pat, text))
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
for token in self.split_to_words(text):
|
||||
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = "".join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
|
||||
return text
|
||||
|
||||
|
||||
def get_encoder(encoder, vocab):
|
||||
return Encoder(
|
||||
encoder=encoder,
|
||||
bpe_merges=vocab,
|
||||
)
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def download_asset(name, tag=None, no_cache=False, cache_dir=None):
|
||||
_tag = tag
|
||||
if _tag is None:
|
||||
_tag = "latest"
|
||||
if not cache_dir:
|
||||
cache_dir = os.path.join(pathlib.Path.home(), f".~DeBERTa/assets/{_tag}/")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
output = os.path.join(cache_dir, name)
|
||||
if os.path.exists(output) and (not no_cache):
|
||||
return output
|
||||
|
||||
repo = "https://api.github.com/repos/microsoft/DeBERTa/releases"
|
||||
releases = requests.get(repo).json()
|
||||
if tag and tag != "latest":
|
||||
release = [r for r in releases if r["name"].lower() == tag.lower()]
|
||||
if len(release) != 1:
|
||||
raise Exception(f"{tag} can't be found in the repository.")
|
||||
else:
|
||||
release = releases[0]
|
||||
asset = [s for s in release["assets"] if s["name"].lower() == name.lower()]
|
||||
if len(asset) != 1:
|
||||
raise Exception(f"{name} can't be found in the release.")
|
||||
url = asset[0]["url"]
|
||||
headers = {}
|
||||
headers["Accept"] = "application/octet-stream"
|
||||
resp = requests.get(url, stream=True, headers=headers)
|
||||
if resp.status_code != 200:
|
||||
raise Exception(f"Request for {url} return {resp.status_code}, {resp.text}")
|
||||
try:
|
||||
with open(output, "wb") as fs:
|
||||
progress = tqdm(
|
||||
total=int(resp.headers["Content-Length"]) if "Content-Length" in resp.headers else -1,
|
||||
ncols=80,
|
||||
desc=f"Downloading {name}",
|
||||
)
|
||||
for c in resp.iter_content(chunk_size=1024 * 1024):
|
||||
fs.write(c)
|
||||
progress.update(len(c))
|
||||
progress.close()
|
||||
except Exception:
|
||||
os.remove(output)
|
||||
raise
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def load_vocab(name=None, tag=None, no_cache=False, cache_dir=None):
|
||||
import torch
|
||||
|
||||
if name is None:
|
||||
name = "bpe_encoder"
|
||||
|
||||
model_path = name
|
||||
if model_path and (not os.path.exists(model_path)) and not (("/" in model_path) or ("\\" in model_path)):
|
||||
_tag = tag
|
||||
if _tag is None:
|
||||
_tag = "latest"
|
||||
if not cache_dir:
|
||||
cache_dir = os.path.join(pathlib.Path.home(), f".~DeBERTa/assets/{_tag}/")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
out_dir = os.path.join(cache_dir, name)
|
||||
model_path = os.path.join(out_dir, "bpe_encoder.bin")
|
||||
if (not os.path.exists(model_path)) or no_cache:
|
||||
asset = download_asset(name + ".zip", tag=tag, no_cache=no_cache, cache_dir=cache_dir)
|
||||
with ZipFile(asset, "r") as zipf:
|
||||
for zip_info in zipf.infolist():
|
||||
if zip_info.filename[-1] == "/":
|
||||
continue
|
||||
zip_info.filename = os.path.basename(zip_info.filename)
|
||||
zipf.extract(zip_info, out_dir)
|
||||
elif not model_path:
|
||||
return None, None
|
||||
|
||||
encoder_state = torch.load(model_path)
|
||||
return encoder_state
|
||||
|
||||
|
||||
class GPT2Tokenizer(object):
|
||||
""" A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer
|
||||
|
||||
Args:
|
||||
vocab_file (:obj:`str`, optional):
|
||||
The local path of vocabulary package or the release name of vocabulary in `DeBERTa GitHub releases <https://github.com/microsoft/DeBERTa/releases>`_, \
|
||||
e.g. "bpe_encoder", default: `None`.
|
||||
|
||||
If it's `None`, then it will download the vocabulary in the latest release from GitHub. The vocabulary file is a \
|
||||
state dictionary with three items, "dict_map", "vocab", "encoder" which correspond to three files used in `RoBERTa`, i.e. `dict.txt`, `vocab.txt` and `encoder.json`. \
|
||||
The difference between our wrapped GPT2 tokenizer and RoBERTa wrapped tokenizer are,
|
||||
|
||||
- Special tokens, unlike `RoBERTa` which use `<s>`, `</s>` as the `start` token and `end` token of a sentence. We use `[CLS]` and `[SEP]` as the `start` and `end`\
|
||||
token of input sentence which is the same as `BERT`.
|
||||
|
||||
- We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0, `[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264
|
||||
|
||||
do_lower_case (:obj:`bool`, optional):
|
||||
Whether to convert inputs to lower case. **Not used in GPT2 tokenizer**.
|
||||
|
||||
special_tokens (:obj:`list`, optional):
|
||||
List of special tokens to be added to the end of the vocabulary.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_file=None, do_lower_case=True, special_tokens=None):
|
||||
self.pad_token = "[PAD]"
|
||||
self.sep_token = "[SEP]"
|
||||
self.unk_token = "[UNK]"
|
||||
self.cls_token = "[CLS]"
|
||||
|
||||
self.symbols = []
|
||||
self.count = []
|
||||
self.indices = {}
|
||||
self.pad_token_id = self.add_symbol(self.pad_token)
|
||||
self.cls_token_id = self.add_symbol(self.cls_token)
|
||||
self.sep_token_id = self.add_symbol(self.sep_token)
|
||||
self.unk_token_id = self.add_symbol(self.unk_token)
|
||||
|
||||
self.gpt2_encoder = load_vocab(vocab_file)
|
||||
self.bpe = get_encoder(self.gpt2_encoder["encoder"], self.gpt2_encoder["vocab"])
|
||||
for w, n in self.gpt2_encoder["dict_map"]:
|
||||
self.add_symbol(w, n)
|
||||
|
||||
self.mask_token = "[MASK]"
|
||||
self.mask_id = self.add_symbol(self.mask_token)
|
||||
self.special_tokens = ["[MASK]", "[SEP]", "[PAD]", "[UNK]", "[CLS]"]
|
||||
if special_tokens is not None:
|
||||
for t in special_tokens:
|
||||
self.add_special_token(t)
|
||||
|
||||
self.vocab = self.indices
|
||||
self.ids_to_tokens = self.symbols
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Convert an input text to tokens.
|
||||
|
||||
Args:
|
||||
text (:obj:`str`): input text to be tokenized.
|
||||
|
||||
Returns:
|
||||
A list of byte tokens where each token represent the byte id in GPT2 byte dictionary
|
||||
|
||||
Example::
|
||||
>>> tokenizer = GPT2Tokenizer()
|
||||
>>> text = "Hello world!"
|
||||
>>> tokens = tokenizer.tokenize(text)
|
||||
>>> print(tokens)
|
||||
['15496', '995', '0']
|
||||
"""
|
||||
bpe = self._encode(text)
|
||||
|
||||
return [t for t in bpe.split(" ") if t]
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
"""Convert list of tokens to ids.
|
||||
Args:
|
||||
tokens (:obj:`list<str>`): list of tokens
|
||||
|
||||
Returns:
|
||||
List of ids
|
||||
"""
|
||||
|
||||
return [self.vocab[t] for t in tokens]
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
"""Convert list of ids to tokens.
|
||||
Args:
|
||||
ids (:obj:`list<int>`): list of ids
|
||||
|
||||
Returns:
|
||||
List of tokens
|
||||
"""
|
||||
|
||||
tokens = []
|
||||
for i in ids:
|
||||
tokens.append(self.ids_to_tokens[i])
|
||||
return tokens
|
||||
|
||||
def split_to_words(self, text):
|
||||
return self.bpe.split_to_words(text)
|
||||
|
||||
def decode(self, tokens):
|
||||
"""Decode list of tokens to text strings.
|
||||
Args:
|
||||
tokens (:obj:`list<str>`): list of tokens.
|
||||
|
||||
Returns:
|
||||
Text string corresponds to the input tokens.
|
||||
|
||||
Example::
|
||||
>>> tokenizer = GPT2Tokenizer()
|
||||
>>> text = "Hello world!"
|
||||
>>> tokens = tokenizer.tokenize(text)
|
||||
>>> print(tokens)
|
||||
['15496', '995', '0']
|
||||
>>> tokenizer.decode(tokens)
|
||||
'Hello world!'
|
||||
"""
|
||||
return self.bpe.decode([int(t) for t in tokens if t not in self.special_tokens])
|
||||
|
||||
def add_special_token(self, token):
|
||||
"""Adds a special token to the dictionary.
|
||||
Args:
|
||||
token (:obj:`str`): Tthe new token/word to be added to the vocabulary.
|
||||
|
||||
Returns:
|
||||
The id of new token in the vocabulary.
|
||||
|
||||
"""
|
||||
self.special_tokens.append(token)
|
||||
return self.add_symbol(token)
|
||||
|
||||
def part_of_whole_word(self, token, is_bos=False):
|
||||
if is_bos:
|
||||
return True
|
||||
s = self._decode(token)
|
||||
if len(s) == 1 and (_is_whitespace(list(s)[0]) or _is_control(list(s)[0]) or _is_punctuation(list(s)[0])):
|
||||
return False
|
||||
|
||||
return not s.startswith(" ")
|
||||
|
||||
def sym(self, id):
|
||||
return self.ids_to_tokens[id]
|
||||
|
||||
def id(self, sym):
|
||||
return self.vocab[sym]
|
||||
|
||||
def _encode(self, x: str) -> str:
|
||||
return " ".join(map(str, self.bpe.encode(x)))
|
||||
|
||||
def _decode(self, x: str) -> str:
|
||||
return self.bpe.decode(map(int, x.split()))
|
||||
|
||||
def add_symbol(self, word, n=1):
|
||||
"""Adds a word to the dictionary.
|
||||
Args:
|
||||
word (:obj:`str`): Tthe new token/word to be added to the vocabulary.
|
||||
n (int, optional): The frequency of the word.
|
||||
|
||||
Returns:
|
||||
The id of the new word.
|
||||
|
||||
"""
|
||||
if word in self.indices:
|
||||
idx = self.indices[word]
|
||||
self.count[idx] = self.count[idx] + n
|
||||
return idx
|
||||
else:
|
||||
idx = len(self.symbols)
|
||||
self.indices[word] = idx
|
||||
self.symbols.append(word)
|
||||
self.count.append(n)
|
||||
return idx
|
||||
|
||||
def save_pretrained(self, path: str):
|
||||
import torch
|
||||
|
||||
torch.save(self.gpt2_encoder, path)
|
||||
|
||||
|
||||
class DebertaTokenizer(PreTrainedTokenizer):
|
||||
r"""
|
||||
Constructs a DeBERTa tokenizer, which runs end-to-end tokenization: punctuation
|
||||
splitting + wordpiece
|
||||
|
||||
Args:
|
||||
vocab_file (:obj:`str`):
|
||||
File containing the vocabulary.
|
||||
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to lowercase the input when tokenizing.
|
||||
unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
|
||||
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
||||
for sequence classification or for a text and a question for question answering.
|
||||
It is also used as the last token of a sequence built with special tokens.
|
||||
pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
|
||||
The classifier token which is used when doing sequence classification (classification of the whole
|
||||
sequence instead of per-token classification). It is the first token of the sequence when built with
|
||||
special tokens.
|
||||
mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
|
||||
The token used for masking values. This is the token used when training this model with masked language
|
||||
modeling. This is the token which the model will try to predict.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
do_lower_case=False,
|
||||
unk_token="[UNK]",
|
||||
sep_token="[SEP]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
mask_token="[MASK]",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
unk_token=unk_token,
|
||||
sep_token=sep_token,
|
||||
pad_token=pad_token,
|
||||
cls_token=cls_token,
|
||||
mask_token=mask_token,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not os.path.isfile(vocab_file):
|
||||
raise ValueError(
|
||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||
"model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
|
||||
)
|
||||
self.do_lower_case = do_lower_case
|
||||
self.gpt2_tokenizer = GPT2Tokenizer(vocab_file)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.vocab)
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
return self.gpt2_tokenizer.vocab
|
||||
|
||||
def get_vocab(self):
|
||||
vocab = self.vocab.copy()
|
||||
vocab.update(self.get_added_vocab())
|
||||
return vocab
|
||||
|
||||
def _tokenize(self, text):
|
||||
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
|
||||
if self.do_lower_case:
|
||||
text = text.lower()
|
||||
return self.gpt2_tokenizer.tokenize(text)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
return self.gpt2_tokenizer.sym(index) if index < self.vocab_size else self.unk_token
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
return self.gpt2_tokenizer.decode(tokens)
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens.
|
||||
A BERT sequence has the following format:
|
||||
|
||||
- single sequence: [CLS] X [SEP]
|
||||
- pair of sequences: [CLS] A [SEP] B [SEP]
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
|
||||
if token_ids_1 is None:
|
||||
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
sep = [self.sep_token_id]
|
||||
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||||
|
||||
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
|
||||
"""
|
||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
||||
|
||||
Args:
|
||||
token_ids_0: list of ids (must not contain special tokens)
|
||||
token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
|
||||
for sequence pairs
|
||||
already_has_special_tokens: (default False) Set to True if the token list is already formated with
|
||||
special tokens for the model
|
||||
|
||||
Returns:
|
||||
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
|
||||
if already_has_special_tokens:
|
||||
if token_ids_1 is not None:
|
||||
raise ValueError(
|
||||
"You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formated with special tokens for the model."
|
||||
)
|
||||
return list(
|
||||
map(
|
||||
lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0,
|
||||
token_ids_0,
|
||||
)
|
||||
)
|
||||
|
||||
if token_ids_1 is not None:
|
||||
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||
|
||||
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
|
||||
"""
|
||||
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
||||
A BERT sequence pair mask has the following format:
|
||||
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
|
||||
| first sequence | second sequence
|
||||
|
||||
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
||||
~
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
||||
sequence(s).
|
||||
"""
|
||||
sep = [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
if token_ids_1 is None:
|
||||
return len(cls + token_ids_0 + sep) * [0]
|
||||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||
|
||||
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
|
||||
add_prefix_space = kwargs.pop("add_prefix_space", False)
|
||||
if is_split_into_words or add_prefix_space:
|
||||
text = " " + text
|
||||
return (text, kwargs)
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary to a directory or file."""
|
||||
if os.path.isdir(vocab_path):
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
|
||||
else:
|
||||
vocab_file = vocab_path
|
||||
self.gpt2_tokenizer.save_pretrained(vocab_file)
|
||||
return (vocab_file,)
|
||||
273
tests/test_modeling_deberta.py
Normal file
273
tests/test_modeling_deberta.py
Normal file
|
|
@ -0,0 +1,273 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 Microsoft Authors and the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import ( # XxxForMaskedLM,; XxxForQuestionAnswering,; XxxForTokenClassification,
|
||||
DebertaConfig,
|
||||
DebertaForSequenceClassification,
|
||||
DebertaModel,
|
||||
)
|
||||
from transformers.modeling_deberta import DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@require_torch
|
||||
class DebertaModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
DebertaModel,
|
||||
DebertaForSequenceClassification,
|
||||
) # , DebertaForMaskedLM, DebertaForQuestionAnswering, DebertaForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
test_torchscript = False
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
is_encoder_decoder = False
|
||||
|
||||
class DebertaModelTester(object):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
relative_attention=False,
|
||||
position_biased_input=True,
|
||||
pos_att_type="None",
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.relative_attention = relative_attention
|
||||
self.position_biased_input = position_biased_input
|
||||
self.pos_att_type = pos_att_type
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = DebertaConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range,
|
||||
relative_attention=self.relative_attention,
|
||||
position_biased_input=self.position_biased_input,
|
||||
pos_att_type=self.pos_att_type,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
|
||||
def create_and_check_deberta_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = DebertaModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
sequence_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0]
|
||||
sequence_output = model(input_ids, token_type_ids=token_type_ids)[0]
|
||||
sequence_output = model(input_ids)[0]
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
|
||||
def create_and_check_deberta_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = DebertaForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DebertaModelTest.DebertaModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=DebertaConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_deberta_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_deberta_model(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_deberta_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Model not available yet")
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_deberta_for_masked_lm(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Model not available yet")
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_deberta_for_question_answering(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Model not available yet")
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_deberta_for_token_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = DebertaModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_torch
|
||||
class DebertaModelIntegrationTest(unittest.TestCase):
|
||||
@unittest.skip(reason="Model not available yet")
|
||||
def test_inference_masked_lm(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed_all(0)
|
||||
DebertaModel.base_model_prefix = "bert"
|
||||
model = DebertaModel.from_pretrained("microsoft/deberta-base")
|
||||
|
||||
input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
# compare the actual values for a slice.
|
||||
expected_slice = torch.tensor(
|
||||
[[[-0.0218, -0.6641, -0.3665], [-0.3907, -0.4716, -0.6640], [0.7461, 1.2570, -0.9063]]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4), f"{output[:, :3, :3]}")
|
||||
|
||||
@slow
|
||||
def test_inference_classification_head(self):
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed_all(0)
|
||||
model = DebertaForSequenceClassification.from_pretrained("microsoft/deberta-base")
|
||||
input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size((1, 2))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_tensor = torch.tensor([[0.0884, -0.1047]])
|
||||
self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4), f"{output}")
|
||||
74
tests/test_tokenization_deberta.py
Normal file
74
tests/test_tokenization_deberta.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 Microsoft.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import re
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.tokenization_deberta import DebertaTokenizer
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
@require_torch
|
||||
class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = DebertaTokenizer
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
def get_tokenizer(self, name="microsoft/deberta-base", **kwargs):
|
||||
return DebertaTokenizer.from_pretrained(name, **kwargs)
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "lower newer"
|
||||
output_text = "lower newer"
|
||||
return input_text, output_text
|
||||
|
||||
def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20) -> Tuple[str, list]:
|
||||
toks = [
|
||||
(i, tokenizer.decode([i], clean_up_tokenization_spaces=False))
|
||||
for i in range(5, min(len(tokenizer), 50260))
|
||||
]
|
||||
toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks))
|
||||
toks = list(filter(lambda t: [t[0]] == tokenizer.encode(t[1], add_special_tokens=False), toks))
|
||||
if max_length is not None and len(toks) > max_length:
|
||||
toks = toks[:max_length]
|
||||
# toks_str = [t[1] for t in toks]
|
||||
toks_ids = [t[0] for t in toks]
|
||||
|
||||
# Ensure consistency
|
||||
output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False)
|
||||
if " " not in output_txt and len(toks_ids) > 1:
|
||||
output_txt = (
|
||||
tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False)
|
||||
+ " "
|
||||
+ tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False)
|
||||
)
|
||||
if with_prefix_space and not output_txt.startswith(" "):
|
||||
output_txt = " " + output_txt
|
||||
output_ids = tokenizer.encode(output_txt, add_special_tokens=False)
|
||||
return output_txt, output_ids
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.get_tokenizer("microsoft/deberta-base")
|
||||
input_str = "UNwant\u00E9d,running"
|
||||
tokens = tokenizer.tokenize(input_str)
|
||||
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
self.assertEqual(tokenizer.decode(token_ids), input_str)
|
||||
Loading…
Reference in a new issue