diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index 65f718c3d..4dd0db178 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -283,4 +283,7 @@ For a list that includes community-uploaded models, refer to `https://huggingfac | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | | ``bart-large-cnn`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters (same as base) | | | | | bart-large base architecture finetuned on cnn summarization task | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``mbart-large-en-ro`` | | 12-layer, 1024-hidden, 16-heads, 880M parameters | +| | | | bart-large architecture pretrained on cc25 multilingual data , finetuned on WMT english romanian translation. | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b28cd3f61..206020ddf 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -122,7 +122,7 @@ from .pipelines import ( ) from .tokenization_albert import AlbertTokenizer from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer -from .tokenization_bart import BartTokenizer +from .tokenization_bart import BartTokenizer, MBartTokenizer from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer from .tokenization_camembert import CamembertTokenizer diff --git a/src/transformers/configuration_bart.py b/src/transformers/configuration_bart.py index d13aacc87..0281d473b 100644 --- a/src/transformers/configuration_bart.py +++ b/src/transformers/configuration_bart.py @@ -27,6 +27,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { "bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json", "bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json", "bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json", + "mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json", } @@ -61,6 +62,9 @@ class BartConfig(PretrainedConfig): pad_token_id=1, bos_token_id=0, eos_token_id=2, + normalize_before=False, + add_final_layer_norm=False, + scale_embedding=False, **common_kwargs ): r""" @@ -90,6 +94,11 @@ class BartConfig(PretrainedConfig): self.max_position_embeddings = max_position_embeddings self.init_std = init_std # Normal(0, this parameter) self.activation_function = activation_function + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + # True for mbart, False otherwise + self.normalize_before = normalize_before # combo of fairseq's encoder_ and decoder_normalize_before + self.add_final_layer_norm = add_final_layer_norm # 3 Types of Dropout self.attention_dropout = attention_dropout @@ -100,9 +109,17 @@ class BartConfig(PretrainedConfig): self.classif_dropout = classifier_dropout @property - def num_attention_heads(self): + def num_attention_heads(self) -> int: return self.encoder_attention_heads @property - def hidden_size(self): + def hidden_size(self) -> int: return self.d_model + + def is_valid_mbart(self) -> bool: + """Is the configuration aligned with the MBART paper.""" + if self.normalize_before and self.add_final_layer_norm and self.scale_embedding: + return True + if self.normalize_before or self.add_final_layer_norm or self.scale_embedding: + logger.info("This configuration is a mixture of MBART and BART settings") + return False diff --git a/src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py index 22fb047db..4873631b5 100644 --- a/src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py @@ -45,13 +45,24 @@ logger = logging.getLogger(__name__) SAMPLE_TEXT = " Hello world! cécé herlolip" -rename_keys = [ +mnli_rename_keys = [ ("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"), ("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"), ("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"), ("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"), ] -IGNORE_KEYS = ["encoder.version", "decoder.version", "model.encoder.version", "model.decoder.version", "_float_tensor"] + + +def remove_ignore_keys_(state_dict): + ignore_keys = [ + "encoder.version", + "decoder.version", + "model.encoder.version", + "model.decoder.version", + "_float_tensor", + ] + for k in ignore_keys: + state_dict.pop(k, None) def rename_key(dct, old, new): @@ -67,6 +78,19 @@ def load_xsum_checkpoint(checkpoint_path): return hub_interface +def convert_checkpoint_from_disk(checkpoint_path, **config_kwargs): + state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] + remove_ignore_keys_(state_dict) + vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0] + state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] + mbart_config = BartConfig(vocab_size=vocab_size, **config_kwargs) + model = BartForConditionalGeneration(mbart_config) + model.model.load_state_dict(state_dict) + if hasattr(model, "lm_head"): + model.lm_head = _make_linear_from_emb(model.model.shared) + return model + + @torch.no_grad() def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None): """ @@ -89,7 +113,7 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkp state_dict = bart.state_dict() remove_ignore_keys_(state_dict) state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"] - for src, dest in rename_keys: + for src, dest in mnli_rename_keys: rename_key(state_dict, src, dest) model = BartForSequenceClassification(config).eval() model.load_state_dict(state_dict) @@ -118,11 +142,6 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkp model.save_pretrained(pytorch_dump_folder_path) -def remove_ignore_keys_(state_dict): - for k in IGNORE_KEYS: - state_dict.pop(k, None) - - if __name__ == "__main__": parser = argparse.ArgumentParser() # Required parameters diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 73b6d9422..249e2d171 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -14,6 +14,7 @@ # limitations under the License. """PyTorch BART model, ported from the fairseq repo.""" import logging +import math import random from typing import Dict, List, Optional, Tuple @@ -35,6 +36,7 @@ BART_PRETRAINED_MODEL_ARCHIVE_MAP = { "bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin", "bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin", "bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/pytorch_model.bin", + "mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/pytorch_model.bin", } BART_START_DOCSTRING = r""" @@ -180,6 +182,7 @@ class EncoderLayer(nn.Module): self.self_attn = SelfAttention( self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, ) + self.normalize_before = config.normalize_before self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -201,20 +204,26 @@ class EncoderLayer(nn.Module): encoded output of shape `(seq_len, batch, embed_dim)` """ residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) x, attn_weights = self.self_attn( query=x, key=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x - x = self.self_attn_layer_norm(x) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) residual = x + if self.normalize_before: + x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x - x = self.final_layer_norm(x) + if not self.normalize_before: + x = self.final_layer_norm(x) return x, attn_weights @@ -236,6 +245,7 @@ class BartEncoder(nn.Module): self.output_hidden_states = config.output_hidden_states embed_dim = embed_tokens.embedding_dim + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 self.padding_idx = embed_tokens.padding_idx self.max_source_positions = config.max_position_embeddings @@ -244,6 +254,8 @@ class BartEncoder(nn.Module): self.embed_positions = LearnedPositionalEmbedding(config.max_position_embeddings, embed_dim, self.padding_idx,) self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) self.layernorm_embedding = LayerNorm(embed_dim) + # mbart has one extra layer_norm + self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None def forward( self, input_ids, attention_mask=None, @@ -267,7 +279,7 @@ class BartEncoder(nn.Module): if attention_mask is not None: attention_mask = invert_mask(attention_mask) - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale embed_pos = self.embed_positions(input_ids) x = inputs_embeds + embed_pos x = self.layernorm_embedding(x) @@ -290,6 +302,8 @@ class BartEncoder(nn.Module): if self.output_attentions: all_attentions.append(attn) + if self.layer_norm: + x = self.layer_norm(x) if self.output_hidden_states: encoder_states.append(x) @@ -311,6 +325,7 @@ class DecoderLayer(nn.Module): self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout + self.normalize_before = config.normalize_before self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.encoder_attn = SelfAttention( @@ -337,21 +352,28 @@ class DecoderLayer(nn.Module): if layer_state is None: layer_state = {} - # next line mutates layer state + if self.normalize_before: + x = self.self_attn_layer_norm(x) + # Self Attention + x, self_attn_weights = self.self_attn( query=x, key=x, - layer_state=layer_state, + layer_state=layer_state, # adds keys to layer state key_padding_mask=decoder_padding_mask, attn_mask=causal_mask, need_weights=self.output_attentions, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x - x = self.self_attn_layer_norm(x) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + # Cross attention residual = x assert self.encoder_attn.cache_key != self.self_attn.cache_key - + if self.normalize_before: + x = self.encoder_attn_layer_norm(x) x, _ = self.encoder_attn( query=x, key=encoder_hidden_states, @@ -360,16 +382,20 @@ class DecoderLayer(nn.Module): ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x + if not self.normalize_before: + x = self.encoder_attn_layer_norm(x) - x = self.encoder_attn_layer_norm(x) - + # Fully Connected residual = x + if self.normalize_before: + x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x - x = self.final_layer_norm(x) + if not self.normalize_before: + x = self.final_layer_norm(x) return ( x, self_attn_weights, @@ -394,6 +420,7 @@ class BartDecoder(nn.Module): self.layerdrop = config.decoder_layerdrop self.padding_idx = embed_tokens.padding_idx self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.embed_tokens = embed_tokens self.embed_positions = LearnedPositionalEmbedding( config.max_position_embeddings, config.d_model, self.padding_idx, @@ -402,6 +429,7 @@ class BartDecoder(nn.Module): [DecoderLayer(config) for _ in range(config.decoder_layers)] ) # type: List[DecoderLayer] self.layernorm_embedding = LayerNorm(config.d_model) + self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None def forward( self, @@ -444,9 +472,8 @@ class BartDecoder(nn.Module): positions = positions[:, -1:] # happens after we embed them assert input_ids.ne(self.padding_idx).any() - x = self.embed_tokens(input_ids) + x = self.embed_tokens(input_ids) * self.embed_scale x += positions - x = self.layernorm_embedding(x) x = F.dropout(x, p=self.dropout, training=self.training) @@ -458,14 +485,16 @@ class BartDecoder(nn.Module): all_hidden_states = () all_self_attns = () next_decoder_cache = [] - for i, decoder_layer in enumerate(self.layers): - decoder_layer # type: DecoderLayer + for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if self.output_hidden_states: + all_hidden_states += (x,) dropout_probability = random.uniform(0, 1) if self.training and (dropout_probability < self.layerdrop): continue - layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None + layer_state = decoder_cached_states[idx] if decoder_cached_states is not None else None + x, layer_self_attn, layer_past = decoder_layer( x, encoder_hidden_states, @@ -477,12 +506,13 @@ class BartDecoder(nn.Module): if use_cache: next_decoder_cache.append(layer_past.copy()) - if self.output_hidden_states: - all_hidden_states += (x,) + + if self.layer_norm and (idx == len(self.layers) - 1): # last layer of mbart + x = self.layer_norm(x) if self.output_attentions: all_self_attns += (layer_self_attn,) - # Convert to standart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) + # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states] x = x.transpose(0, 1) encoder_hidden_states = encoder_hidden_states.transpose(0, 1) diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index 76f184f50..de3981587 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -13,7 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from .tokenization_roberta import RobertaTokenizer +from .tokenization_xlm_roberta import XLMRobertaTokenizer + + +logger = logging.getLogger(__name__) # vocab and merges same as roberta @@ -21,6 +27,8 @@ vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-v merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt" _all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn", "bart-large-xsum"] +VOCAB_FILES_NAMES = {"vocab_file": "sentence.bpe.model"} + class BartTokenizer(RobertaTokenizer): # merges and vocab same as Roberta @@ -29,3 +37,13 @@ class BartTokenizer(RobertaTokenizer): "vocab_file": {m: vocab_url for m in _all_bart_models}, "merges_file": {m: merges_url for m in _all_bart_models}, } + + +_all_mbart_models = ["mbart-large-en-ro"] +SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model" + + +class MBartTokenizer(XLMRobertaTokenizer): + vocab_files_names = VOCAB_FILES_NAMES + max_model_input_sizes = {m: 1024 for m in _all_mbart_models} + pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}} diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index bc17d08a5..4d4507528 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -34,6 +34,8 @@ if is_torch_available(): BartForConditionalGeneration, BartForSequenceClassification, BartConfig, + BartTokenizer, + MBartTokenizer, ) from transformers.modeling_bart import ( BART_PRETRAINED_MODEL_ARCHIVE_MAP, @@ -41,7 +43,6 @@ if is_torch_available(): invert_mask, _prepare_bart_decoder_inputs, ) - from transformers.tokenization_bart import BartTokenizer @require_torch @@ -55,10 +56,10 @@ class ModelTester: self.is_training = True self.use_labels = False self.vocab_size = 99 - self.hidden_size = 32 - self.num_hidden_layers = 5 + self.hidden_size = 16 + self.num_hidden_layers = 2 self.num_attention_heads = 4 - self.intermediate_size = 37 + self.intermediate_size = 4 self.hidden_act = "gelu" self.hidden_dropout_prob = 0.1 self.attention_probs_dropout_prob = 0.1 @@ -105,7 +106,6 @@ def prepare_bart_inputs_dict( @require_torch class BARTModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = ( (BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else () ) @@ -196,8 +196,114 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase): @require_torch -class BartHeadTests(unittest.TestCase): +class BartTranslationTests(unittest.TestCase): + _model = None + @classmethod + def setUpClass(cls): + checkpoint_name = "mbart-large-en-ro" + cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name) + cls.pad_token_id = 1 + net_input = { + "input_ids": _long_tensor( + [ + [3493, 3060, 621, 104064, 1810, 100, 142, 566, 13158, 6889, 5, 2, 250004], + [64511, 7, 765, 2837, 45188, 297, 4049, 237, 10, 122122, 5, 2, 250004], + ] + ), + "decoder_input_ids": _long_tensor( + [ + [250020, 31952, 144, 9019, 242307, 21980, 55749, 11, 5, 2, 1, 1], + [250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2], + ] + ), + "generation_mode": False, + } + net_input["attention_mask"] = net_input["input_ids"].ne(cls.pad_token_id) + cls.net_input = net_input + + return cls + + @property + def model(self): + """Only load the model if needed.""" + if self._model is None: + model = BartForConditionalGeneration.from_pretrained("mbart-large-en-ro") + self._model = model + return self._model + + @slow + def test_enro_forward(self): + model = self.model + with torch.no_grad(): + logits, *other_stuff = model(**self.net_input) + + expected_slice = torch.tensor([9.0078, 10.1113, 14.4787]) + result_slice = logits[0][0][:3] + self.assertTrue(torch.allclose(expected_slice, result_slice, atol=TOLERANCE)) + + @slow + def test_enro_generate(self): + model = self.model + # example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" + # inputs: dict = tokenizer.batch_encode_plus([example_english_phrase], return_tensors="pt",) + expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" + + inputs = { + "input_ids": torch.LongTensor( + [[8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]] # 250004 + ) + } + translated_tokens = model.generate(input_ids=inputs["input_ids"].to(torch_device), num_beams=5,) + decoded = [ + self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) + for g in translated_tokens + ] + self.assertEqual(expected_translation_romanian, decoded[0]) + + def test_mbart_enro_config(self): + mbart_models = ["mbart-large-en-ro"] + expected = {"scale_embedding": True, "output_past": True} + for name in mbart_models: + config = BartConfig.from_pretrained(name) + self.assertTrue(config.is_valid_mbart()) + for k, v in expected.items(): + try: + self.assertEqual(v, getattr(config, k)) + except AssertionError as e: + e.args += (name, k) + raise + + def test_enro_tokenizer(self): + raw = "UN Chief Says There Is No Military Solution in Syria" + ids = self.tokenizer.batch_encode_plus([raw])["input_ids"][0] + expected_result = [0, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2] + # TODO(SS): should be [8274, ..., 2, 250020] + self.assertListEqual(expected_result, ids) + + def test_mbart_fast_forward(self): + config = BartConfig( + vocab_size=99, + d_model=24, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=2, + decoder_attention_heads=2, + encoder_ffn_dim=32, + decoder_ffn_dim=32, + max_position_embeddings=48, + add_final_layer_norm=True, + ) + lm_model = BartForConditionalGeneration(config).to(torch_device) + context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device) + summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device) + loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, lm_labels=summary) + expected_shape = (*summary.shape, config.vocab_size) + self.assertEqual(logits.shape, expected_shape) + + +@require_torch +class BartHeadTests(unittest.TestCase): vocab_size = 99 def _get_config_and_data(self): @@ -263,13 +369,13 @@ class BartHeadTests(unittest.TestCase): def test_lm_uneven_forward(self): config = BartConfig( vocab_size=self.vocab_size, - d_model=24, + d_model=14, encoder_layers=2, decoder_layers=2, encoder_attention_heads=2, decoder_attention_heads=2, - encoder_ffn_dim=32, - decoder_ffn_dim=32, + encoder_ffn_dim=8, + decoder_ffn_dim=8, max_position_embeddings=48, ) lm_model = BartForConditionalGeneration(config).to(torch_device) @@ -462,6 +568,7 @@ class BartModelIntegrationTests(unittest.TestCase): @slow def test_xsum_summarization_same_as_fairseq(self): model = BartForConditionalGeneration.from_pretrained("bart-large-xsum").to(torch_device) + self.assertFalse(model.config.is_valid_mbart()) tok = BartTokenizer.from_pretrained("bart-large") PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""