mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Multilingual BART - (#3602)
- support mbart-en-ro weights - add MBartTokenizer
This commit is contained in:
parent
f98d0ef2a2
commit
7a7fdf71f8
7 changed files with 232 additions and 38 deletions
|
|
@ -283,4 +283,7 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
|
|||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bart-large-cnn`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters (same as base) |
|
||||
| | | | bart-large base architecture finetuned on cnn summarization task |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``mbart-large-en-ro`` | | 12-layer, 1024-hidden, 16-heads, 880M parameters |
|
||||
| | | | bart-large architecture pretrained on cc25 multilingual data , finetuned on WMT english romanian translation. |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ from .pipelines import (
|
|||
)
|
||||
from .tokenization_albert import AlbertTokenizer
|
||||
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||
from .tokenization_bart import BartTokenizer
|
||||
from .tokenization_bart import BartTokenizer, MBartTokenizer
|
||||
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
|
||||
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
|
||||
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
|
||||
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
|
||||
"mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -61,6 +62,9 @@ class BartConfig(PretrainedConfig):
|
|||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
normalize_before=False,
|
||||
add_final_layer_norm=False,
|
||||
scale_embedding=False,
|
||||
**common_kwargs
|
||||
):
|
||||
r"""
|
||||
|
|
@ -90,6 +94,11 @@ class BartConfig(PretrainedConfig):
|
|||
self.max_position_embeddings = max_position_embeddings
|
||||
self.init_std = init_std # Normal(0, this parameter)
|
||||
self.activation_function = activation_function
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
|
||||
# True for mbart, False otherwise
|
||||
self.normalize_before = normalize_before # combo of fairseq's encoder_ and decoder_normalize_before
|
||||
self.add_final_layer_norm = add_final_layer_norm
|
||||
|
||||
# 3 Types of Dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
|
|
@ -100,9 +109,17 @@ class BartConfig(PretrainedConfig):
|
|||
self.classif_dropout = classifier_dropout
|
||||
|
||||
@property
|
||||
def num_attention_heads(self):
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
def is_valid_mbart(self) -> bool:
|
||||
"""Is the configuration aligned with the MBART paper."""
|
||||
if self.normalize_before and self.add_final_layer_norm and self.scale_embedding:
|
||||
return True
|
||||
if self.normalize_before or self.add_final_layer_norm or self.scale_embedding:
|
||||
logger.info("This configuration is a mixture of MBART and BART settings")
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -45,13 +45,24 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
SAMPLE_TEXT = " Hello world! cécé herlolip"
|
||||
|
||||
rename_keys = [
|
||||
mnli_rename_keys = [
|
||||
("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
|
||||
("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"),
|
||||
("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
|
||||
("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
|
||||
]
|
||||
IGNORE_KEYS = ["encoder.version", "decoder.version", "model.encoder.version", "model.decoder.version", "_float_tensor"]
|
||||
|
||||
|
||||
def remove_ignore_keys_(state_dict):
|
||||
ignore_keys = [
|
||||
"encoder.version",
|
||||
"decoder.version",
|
||||
"model.encoder.version",
|
||||
"model.decoder.version",
|
||||
"_float_tensor",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
|
|
@ -67,6 +78,19 @@ def load_xsum_checkpoint(checkpoint_path):
|
|||
return hub_interface
|
||||
|
||||
|
||||
def convert_checkpoint_from_disk(checkpoint_path, **config_kwargs):
|
||||
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
|
||||
remove_ignore_keys_(state_dict)
|
||||
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
|
||||
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
||||
mbart_config = BartConfig(vocab_size=vocab_size, **config_kwargs)
|
||||
model = BartForConditionalGeneration(mbart_config)
|
||||
model.model.load_state_dict(state_dict)
|
||||
if hasattr(model, "lm_head"):
|
||||
model.lm_head = _make_linear_from_emb(model.model.shared)
|
||||
return model
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
|
||||
"""
|
||||
|
|
@ -89,7 +113,7 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkp
|
|||
state_dict = bart.state_dict()
|
||||
remove_ignore_keys_(state_dict)
|
||||
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
|
||||
for src, dest in rename_keys:
|
||||
for src, dest in mnli_rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
model = BartForSequenceClassification(config).eval()
|
||||
model.load_state_dict(state_dict)
|
||||
|
|
@ -118,11 +142,6 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkp
|
|||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
def remove_ignore_keys_(state_dict):
|
||||
for k in IGNORE_KEYS:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
"""PyTorch BART model, ported from the fairseq repo."""
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
|
@ -35,6 +36,7 @@ BART_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin",
|
||||
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin",
|
||||
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/pytorch_model.bin",
|
||||
"mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/pytorch_model.bin",
|
||||
}
|
||||
|
||||
BART_START_DOCSTRING = r"""
|
||||
|
|
@ -180,6 +182,7 @@ class EncoderLayer(nn.Module):
|
|||
self.self_attn = SelfAttention(
|
||||
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout,
|
||||
)
|
||||
self.normalize_before = config.normalize_before
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
|
|
@ -201,20 +204,26 @@ class EncoderLayer(nn.Module):
|
|||
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||
"""
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x, attn_weights = self.self_attn(
|
||||
query=x, key=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
if not self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
||||
x = self.fc2(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.final_layer_norm(x)
|
||||
if not self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
return x, attn_weights
|
||||
|
||||
|
||||
|
|
@ -236,6 +245,7 @@ class BartEncoder(nn.Module):
|
|||
self.output_hidden_states = config.output_hidden_states
|
||||
|
||||
embed_dim = embed_tokens.embedding_dim
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
self.padding_idx = embed_tokens.padding_idx
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
|
||||
|
|
@ -244,6 +254,8 @@ class BartEncoder(nn.Module):
|
|||
self.embed_positions = LearnedPositionalEmbedding(config.max_position_embeddings, embed_dim, self.padding_idx,)
|
||||
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layernorm_embedding = LayerNorm(embed_dim)
|
||||
# mbart has one extra layer_norm
|
||||
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask=None,
|
||||
|
|
@ -267,7 +279,7 @@ class BartEncoder(nn.Module):
|
|||
if attention_mask is not None:
|
||||
attention_mask = invert_mask(attention_mask)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
embed_pos = self.embed_positions(input_ids)
|
||||
x = inputs_embeds + embed_pos
|
||||
x = self.layernorm_embedding(x)
|
||||
|
|
@ -290,6 +302,8 @@ class BartEncoder(nn.Module):
|
|||
if self.output_attentions:
|
||||
all_attentions.append(attn)
|
||||
|
||||
if self.layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
if self.output_hidden_states:
|
||||
encoder_states.append(x)
|
||||
|
||||
|
|
@ -311,6 +325,7 @@ class DecoderLayer(nn.Module):
|
|||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
self.activation_dropout = config.activation_dropout
|
||||
self.normalize_before = config.normalize_before
|
||||
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = SelfAttention(
|
||||
|
|
@ -337,21 +352,28 @@ class DecoderLayer(nn.Module):
|
|||
|
||||
if layer_state is None:
|
||||
layer_state = {}
|
||||
# next line mutates layer state
|
||||
if self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
# Self Attention
|
||||
|
||||
x, self_attn_weights = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
layer_state=layer_state,
|
||||
layer_state=layer_state, # adds keys to layer state
|
||||
key_padding_mask=decoder_padding_mask,
|
||||
attn_mask=causal_mask,
|
||||
need_weights=self.output_attentions,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
if not self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
# Cross attention
|
||||
residual = x
|
||||
assert self.encoder_attn.cache_key != self.self_attn.cache_key
|
||||
|
||||
if self.normalize_before:
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
x, _ = self.encoder_attn(
|
||||
query=x,
|
||||
key=encoder_hidden_states,
|
||||
|
|
@ -360,16 +382,20 @@ class DecoderLayer(nn.Module):
|
|||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
if not self.normalize_before:
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
|
||||
# Fully Connected
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
||||
x = self.fc2(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.final_layer_norm(x)
|
||||
if not self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
return (
|
||||
x,
|
||||
self_attn_weights,
|
||||
|
|
@ -394,6 +420,7 @@ class BartDecoder(nn.Module):
|
|||
self.layerdrop = config.decoder_layerdrop
|
||||
self.padding_idx = embed_tokens.padding_idx
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.embed_tokens = embed_tokens
|
||||
self.embed_positions = LearnedPositionalEmbedding(
|
||||
config.max_position_embeddings, config.d_model, self.padding_idx,
|
||||
|
|
@ -402,6 +429,7 @@ class BartDecoder(nn.Module):
|
|||
[DecoderLayer(config) for _ in range(config.decoder_layers)]
|
||||
) # type: List[DecoderLayer]
|
||||
self.layernorm_embedding = LayerNorm(config.d_model)
|
||||
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -444,9 +472,8 @@ class BartDecoder(nn.Module):
|
|||
positions = positions[:, -1:] # happens after we embed them
|
||||
assert input_ids.ne(self.padding_idx).any()
|
||||
|
||||
x = self.embed_tokens(input_ids)
|
||||
x = self.embed_tokens(input_ids) * self.embed_scale
|
||||
x += positions
|
||||
|
||||
x = self.layernorm_embedding(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
|
|
@ -458,14 +485,16 @@ class BartDecoder(nn.Module):
|
|||
all_hidden_states = ()
|
||||
all_self_attns = ()
|
||||
next_decoder_cache = []
|
||||
for i, decoder_layer in enumerate(self.layers):
|
||||
decoder_layer # type: DecoderLayer
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states += (x,)
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
if self.training and (dropout_probability < self.layerdrop):
|
||||
continue
|
||||
|
||||
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
|
||||
layer_state = decoder_cached_states[idx] if decoder_cached_states is not None else None
|
||||
|
||||
x, layer_self_attn, layer_past = decoder_layer(
|
||||
x,
|
||||
encoder_hidden_states,
|
||||
|
|
@ -477,12 +506,13 @@ class BartDecoder(nn.Module):
|
|||
|
||||
if use_cache:
|
||||
next_decoder_cache.append(layer_past.copy())
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states += (x,)
|
||||
|
||||
if self.layer_norm and (idx == len(self.layers) - 1): # last layer of mbart
|
||||
x = self.layer_norm(x)
|
||||
if self.output_attentions:
|
||||
all_self_attns += (layer_self_attn,)
|
||||
|
||||
# Convert to standart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
||||
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
||||
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
|
||||
x = x.transpose(0, 1)
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from .tokenization_roberta import RobertaTokenizer
|
||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# vocab and merges same as roberta
|
||||
|
|
@ -21,6 +27,8 @@ vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-v
|
|||
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
|
||||
_all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn", "bart-large-xsum"]
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "sentence.bpe.model"}
|
||||
|
||||
|
||||
class BartTokenizer(RobertaTokenizer):
|
||||
# merges and vocab same as Roberta
|
||||
|
|
@ -29,3 +37,13 @@ class BartTokenizer(RobertaTokenizer):
|
|||
"vocab_file": {m: vocab_url for m in _all_bart_models},
|
||||
"merges_file": {m: merges_url for m in _all_bart_models},
|
||||
}
|
||||
|
||||
|
||||
_all_mbart_models = ["mbart-large-en-ro"]
|
||||
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
|
||||
|
||||
|
||||
class MBartTokenizer(XLMRobertaTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
|
||||
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
|
||||
|
|
|
|||
|
|
@ -34,6 +34,8 @@ if is_torch_available():
|
|||
BartForConditionalGeneration,
|
||||
BartForSequenceClassification,
|
||||
BartConfig,
|
||||
BartTokenizer,
|
||||
MBartTokenizer,
|
||||
)
|
||||
from transformers.modeling_bart import (
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
|
|
@ -41,7 +43,6 @@ if is_torch_available():
|
|||
invert_mask,
|
||||
_prepare_bart_decoder_inputs,
|
||||
)
|
||||
from transformers.tokenization_bart import BartTokenizer
|
||||
|
||||
|
||||
@require_torch
|
||||
|
|
@ -55,10 +56,10 @@ class ModelTester:
|
|||
self.is_training = True
|
||||
self.use_labels = False
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.hidden_size = 16
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.intermediate_size = 4
|
||||
self.hidden_act = "gelu"
|
||||
self.hidden_dropout_prob = 0.1
|
||||
self.attention_probs_dropout_prob = 0.1
|
||||
|
|
@ -105,7 +106,6 @@ def prepare_bart_inputs_dict(
|
|||
|
||||
@require_torch
|
||||
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
|
||||
)
|
||||
|
|
@ -196,8 +196,114 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
|
||||
@require_torch
|
||||
class BartHeadTests(unittest.TestCase):
|
||||
class BartTranslationTests(unittest.TestCase):
|
||||
_model = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
checkpoint_name = "mbart-large-en-ro"
|
||||
cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name)
|
||||
cls.pad_token_id = 1
|
||||
net_input = {
|
||||
"input_ids": _long_tensor(
|
||||
[
|
||||
[3493, 3060, 621, 104064, 1810, 100, 142, 566, 13158, 6889, 5, 2, 250004],
|
||||
[64511, 7, 765, 2837, 45188, 297, 4049, 237, 10, 122122, 5, 2, 250004],
|
||||
]
|
||||
),
|
||||
"decoder_input_ids": _long_tensor(
|
||||
[
|
||||
[250020, 31952, 144, 9019, 242307, 21980, 55749, 11, 5, 2, 1, 1],
|
||||
[250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2],
|
||||
]
|
||||
),
|
||||
"generation_mode": False,
|
||||
}
|
||||
net_input["attention_mask"] = net_input["input_ids"].ne(cls.pad_token_id)
|
||||
cls.net_input = net_input
|
||||
|
||||
return cls
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""Only load the model if needed."""
|
||||
if self._model is None:
|
||||
model = BartForConditionalGeneration.from_pretrained("mbart-large-en-ro")
|
||||
self._model = model
|
||||
return self._model
|
||||
|
||||
@slow
|
||||
def test_enro_forward(self):
|
||||
model = self.model
|
||||
with torch.no_grad():
|
||||
logits, *other_stuff = model(**self.net_input)
|
||||
|
||||
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787])
|
||||
result_slice = logits[0][0][:3]
|
||||
self.assertTrue(torch.allclose(expected_slice, result_slice, atol=TOLERANCE))
|
||||
|
||||
@slow
|
||||
def test_enro_generate(self):
|
||||
model = self.model
|
||||
# example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
# inputs: dict = tokenizer.batch_encode_plus([example_english_phrase], return_tensors="pt",)
|
||||
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
|
||||
inputs = {
|
||||
"input_ids": torch.LongTensor(
|
||||
[[8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]] # 250004
|
||||
)
|
||||
}
|
||||
translated_tokens = model.generate(input_ids=inputs["input_ids"].to(torch_device), num_beams=5,)
|
||||
decoded = [
|
||||
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
for g in translated_tokens
|
||||
]
|
||||
self.assertEqual(expected_translation_romanian, decoded[0])
|
||||
|
||||
def test_mbart_enro_config(self):
|
||||
mbart_models = ["mbart-large-en-ro"]
|
||||
expected = {"scale_embedding": True, "output_past": True}
|
||||
for name in mbart_models:
|
||||
config = BartConfig.from_pretrained(name)
|
||||
self.assertTrue(config.is_valid_mbart())
|
||||
for k, v in expected.items():
|
||||
try:
|
||||
self.assertEqual(v, getattr(config, k))
|
||||
except AssertionError as e:
|
||||
e.args += (name, k)
|
||||
raise
|
||||
|
||||
def test_enro_tokenizer(self):
|
||||
raw = "UN Chief Says There Is No Military Solution in Syria"
|
||||
ids = self.tokenizer.batch_encode_plus([raw])["input_ids"][0]
|
||||
expected_result = [0, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]
|
||||
# TODO(SS): should be [8274, ..., 2, 250020]
|
||||
self.assertListEqual(expected_result, ids)
|
||||
|
||||
def test_mbart_fast_forward(self):
|
||||
config = BartConfig(
|
||||
vocab_size=99,
|
||||
d_model=24,
|
||||
encoder_layers=2,
|
||||
decoder_layers=2,
|
||||
encoder_attention_heads=2,
|
||||
decoder_attention_heads=2,
|
||||
encoder_ffn_dim=32,
|
||||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
add_final_layer_norm=True,
|
||||
)
|
||||
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
||||
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
|
||||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
|
||||
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, lm_labels=summary)
|
||||
expected_shape = (*summary.shape, config.vocab_size)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BartHeadTests(unittest.TestCase):
|
||||
vocab_size = 99
|
||||
|
||||
def _get_config_and_data(self):
|
||||
|
|
@ -263,13 +369,13 @@ class BartHeadTests(unittest.TestCase):
|
|||
def test_lm_uneven_forward(self):
|
||||
config = BartConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=24,
|
||||
d_model=14,
|
||||
encoder_layers=2,
|
||||
decoder_layers=2,
|
||||
encoder_attention_heads=2,
|
||||
decoder_attention_heads=2,
|
||||
encoder_ffn_dim=32,
|
||||
decoder_ffn_dim=32,
|
||||
encoder_ffn_dim=8,
|
||||
decoder_ffn_dim=8,
|
||||
max_position_embeddings=48,
|
||||
)
|
||||
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
||||
|
|
@ -462,6 +568,7 @@ class BartModelIntegrationTests(unittest.TestCase):
|
|||
@slow
|
||||
def test_xsum_summarization_same_as_fairseq(self):
|
||||
model = BartForConditionalGeneration.from_pretrained("bart-large-xsum").to(torch_device)
|
||||
self.assertFalse(model.config.is_valid_mbart())
|
||||
tok = BartTokenizer.from_pretrained("bart-large")
|
||||
|
||||
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
|
|
|
|||
Loading…
Reference in a new issue