diff --git a/examples/README.md b/examples/README.md index 85c4c568f..fb5de20a2 100644 --- a/examples/README.md +++ b/examples/README.md @@ -97,20 +97,20 @@ Fine-tuning the library models for sequence classification on the GLUE benchmark Evaluation](https://gluebenchmark.com/). This script can fine-tune the following models: BERT, XLM, XLNet and RoBERTa. GLUE is made up of a total of 9 different tasks. We get the following results on the dev set of the benchmark with an -uncased BERT base model (the checkpoint `bert-base-uncased`). All experiments ran on 8 V100 GPUs with a total train +uncased BERT base model (the checkpoint `bert-base-uncased`). All experiments ran on 8 V100 GPUs with a total train batch size of 24. Some of these tasks have a small dataset and training can lead to high variance in the results between different runs. We report the median on 5 runs (with different seeds) for each of the metrics. | Task | Metric | Result | |-------|------------------------------|-------------| -| CoLA | Matthew's corr | 55.75 | -| SST-2 | Accuracy | 92.09 | -| MRPC | F1/Accuracy | 90.48/86.27 | -| STS-B | Person/Spearman corr. | 89.03/88.64 | -| QQP | Accuracy/F1 | 90.92/87.72 | -| MNLI | Matched acc./Mismatched acc. | 83.74/84.06 | -| QNLI | Accuracy | 91.07 | -| RTE | Accuracy | 68.59 | +| CoLA | Matthew's corr | 48.87 | +| SST-2 | Accuracy | 91.74 | +| MRPC | F1/Accuracy | 90.70/86.27 | +| STS-B | Person/Spearman corr. | 91.39/91.04 | +| QQP | Accuracy/F1 | 90.79/87.66 | +| MNLI | Matched acc./Mismatched acc. | 83.70/84.83 | +| QNLI | Accuracy | 89.31 | +| RTE | Accuracy | 71.43 | | WNLI | Accuracy | 43.66 | Some of these results are significantly different from the ones reported on the test set diff --git a/examples/distillation/README.md b/examples/distillation/README.md index ad4f8989b..4cddbd3a2 100644 --- a/examples/distillation/README.md +++ b/examples/distillation/README.md @@ -2,12 +2,21 @@ This folder contains the original code used to train DistilBERT as well as examples showcasing how to use DistilBERT. +**2019, September 19th - Update:** We fixed bugs in the code and released an upadted version of the weights trained with a modification of the distillation loss. DistilBERT now reaches 97% of `BERT-base`'s performance on GLUE, and 86.9 F1 score on SQuAD v1.1 dev set (compared to 88.5 for `BERT-base`). We will publish a formal write-up of our approach in the near future! + ## What is DistilBERT -DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving over 95% of Bert's performances as measured on the GLUE language understanding benchmark. DistilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DistilBERT is thus an interesting option to put large-scaled trained Transformer model into production. +DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving 97% of BERT's performances as measured on the GLUE language understanding benchmark. DistilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DistilBERT is thus an interesting option to put large-scaled trained Transformer model into production. For more information on DistilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5 -). +). *Please note that we will publish a formal write-up with updated and more complete results in the near future (September 19th).* + +Here's the updated results on the dev sets of GLUE: + +| Model | Macro-score | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST-2 | STS-B | WNLI | +| :---: | :---: | :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---:| +| BERT-base | **77.6** | 48.9 | 84.3 | 88.6 | 89.3 | 89.5 | 71.3 | 91.7 | 91.2 | 43.7 | +| DistilBERT | **75.2** | 49.1 | 81.8 | 90.2 | 87.0 | 89.2 | 62.9 | 92.7 | 90.7 | 44.4 | ## Setup @@ -20,7 +29,7 @@ This part of the library has only be tested with Python3.6+. There are few speci Transformers includes two pre-trained DistilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT): - `distilbert-base-uncased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-uncased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters. -- `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knwoledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.2 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score). +- `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knwoledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.9 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score). Using DistilBERT is very similar to using BERT. DistilBERT share the same tokenizer as BERT's `bert-base-uncased` even though we provide a link to this tokenizer under the `DistilBertTokenizer` name to have a consistent naming between the library models. diff --git a/examples/distillation/dataset.py b/examples/distillation/dataset.py index 89e3f1187..4babf73ea 100644 --- a/examples/distillation/dataset.py +++ b/examples/distillation/dataset.py @@ -92,11 +92,11 @@ class Dataset: Too short sequences are simply removed. This could be tunedd. """ init_size = len(self) - indices = self.lengths > 5 + indices = self.lengths > 11 self.token_ids = self.token_ids[indices] self.lengths = self.lengths[indices] new_size = len(self) - logger.info(f'Remove {init_size - new_size} too short (<=5 tokens) sequences.') + logger.info(f'Remove {init_size - new_size} too short (<=11 tokens) sequences.') def print_statistics(self): """ diff --git a/examples/distillation/distiller.py b/examples/distillation/distiller.py index 1bfda325d..79755b81e 100644 --- a/examples/distillation/distiller.py +++ b/examples/distillation/distiller.py @@ -18,15 +18,18 @@ import os import math import psutil +import time from tensorboardX import SummaryWriter from tqdm import trange, tqdm import numpy as np +import psutil import torch import torch.nn as nn import torch.nn.functional as F +from torch.optim import AdamW -from transformers import AdamW, WarmupLinearSchedule +from transformers import WarmupLinearSchedule from utils import logger from dataset import Dataset @@ -58,10 +61,12 @@ class Distiller: self.alpha_ce = params.alpha_ce self.alpha_mlm = params.alpha_mlm self.alpha_mse = params.alpha_mse + self.alpha_cos = params.alpha_cos assert self.alpha_ce >= 0. assert self.alpha_mlm >= 0. assert self.alpha_mse >= 0. - assert self.alpha_ce + self.alpha_mlm + self.alpha_mse > 0. + assert self.alpha_cos >= 0. + assert self.alpha_ce + self.alpha_mlm + self.alpha_mse + self.alpha_cos > 0. self.mlm_mask_prop = params.mlm_mask_prop assert 0.0 <= self.mlm_mask_prop <= 1.0 @@ -81,17 +86,21 @@ class Distiller: self.last_loss = 0 self.last_loss_ce = 0 self.last_loss_mlm = 0 - self.last_loss_mse = 0 + if self.alpha_mse > 0.: self.last_loss_mse = 0 + if self.alpha_cos > 0.: self.last_loss_cos = 0 + self.last_log = 0 self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean') self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) - self.mse_loss_fct = nn.MSELoss(reduction='sum') + if self.alpha_mse > 0.: + self.mse_loss_fct = nn.MSELoss(reduction='sum') + if self.alpha_cos > 0.: + self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean') logger.info('--- Initializing model optimizer') assert params.gradient_accumulation_steps >= 1 self.num_steps_epoch = int(len(self.dataloader) / params.batch_size) + 1 num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1 - warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ @@ -104,9 +113,11 @@ class Distiller: lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)) + + warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop) self.scheduler = WarmupLinearSchedule(self.optimizer, - warmup_steps=warmup_steps, - t_total=num_train_optimization_steps) + warmup_steps=warmup_steps, + t_total=num_train_optimization_steps) if self.fp16: try: @@ -272,11 +283,14 @@ class Distiller: The real training loop. """ if self.is_master: logger.info('Starting training') + self.last_log = time.time() self.student.train() self.teacher.eval() for _ in range(self.params.n_epoch): if self.is_master: logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}') + if self.multi_gpu: + torch.distributed.barrier() iter_bar = trange(self.num_steps_epoch, desc="-Iter", disable=self.params.local_rank not in [-1, 0]) for __ in range(self.num_steps_epoch): @@ -314,9 +328,9 @@ class Distiller: attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention. mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels. """ - s_logits = self.student(input_ids=input_ids, attention_mask=attention_mask)[0] # (bs, seq_length, voc_size) + s_logits, s_hidden_states = self.student(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size) with torch.no_grad(): - t_logits = self.teacher(input_ids=input_ids, attention_mask=attention_mask)[0] # (bs, seq_length, voc_size) + t_logits, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size) assert s_logits.size() == t_logits.size() #https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 @@ -340,6 +354,22 @@ class Distiller: if self.alpha_mse > 0.: loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction loss += self.alpha_mse * loss_mse + + if self.alpha_cos > 0.: + s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim) + t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim) + mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim) + assert s_hidden_states.size() == t_hidden_states.size() + dim = s_hidden_states.size(-1) + + s_hidden_states_slct = torch.masked_select(s_hidden_states, mask) # (bs * seq_length * dim) + s_hidden_states_slct = s_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim) + t_hidden_states_slct = torch.masked_select(t_hidden_states, mask) # (bs * seq_length * dim) + t_hidden_states_slct = t_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim) + + target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,) + loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target) + loss += self.alpha_cos * loss_cos self.total_loss_epoch += loss.item() self.last_loss = loss.item() @@ -348,6 +378,8 @@ class Distiller: self.last_loss_mlm = loss_mlm.item() if self.alpha_mse > 0.: self.last_loss_mse = loss_mse.item() + if self.alpha_cos > 0.: + self.last_loss_cos = loss_cos.item() self.optimize(loss) @@ -396,6 +428,7 @@ class Distiller: if self.n_total_iter % self.params.log_interval == 0: self.log_tensorboard() + self.last_log = time.time() if self.n_total_iter % self.params.checkpoint_interval == 0: self.save_checkpoint() @@ -421,9 +454,12 @@ class Distiller: self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter) if self.alpha_mse > 0.: self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter) + if self.alpha_cos > 0.: + self.tensorboard.add_scalar(tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter) + self.tensorboard.add_scalar(tag="global/speed", scalar_value=time.time()-self.last_log, global_step=self.n_total_iter) def end_epoch(self): """ diff --git a/examples/distillation/requirements.txt b/examples/distillation/requirements.txt index 18146239e..2cf6ee2d8 100644 --- a/examples/distillation/requirements.txt +++ b/examples/distillation/requirements.txt @@ -2,3 +2,5 @@ gitpython==3.0.2 tensorboard>=1.14.0 tensorboardX==1.8 psutil==5.6.3 +scipy==1.3.1 +pytorch_transformers==1.2.0 diff --git a/examples/distillation/scripts/binarized_data.py b/examples/distillation/scripts/binarized_data.py index e98662fdc..eb4af08b0 100644 --- a/examples/distillation/scripts/binarized_data.py +++ b/examples/distillation/scripts/binarized_data.py @@ -20,7 +20,7 @@ import pickle import random import time import numpy as np -from transformers import BertTokenizer +from transformers import BertTokenizer, RobertaTokenizer import logging logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', @@ -32,16 +32,21 @@ def main(): parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).") parser.add_argument('--file_path', type=str, default='data/dump.txt', help='The path to the data.') - parser.add_argument('--bert_tokenizer', type=str, default='bert-base-uncased', + parser.add_argument('--tokenizer_type', type=str, default='bert', choices=['bert', 'roberta']) + parser.add_argument('--tokenizer_name', type=str, default='bert-base-uncased', help="The tokenizer to use.") parser.add_argument('--dump_file', type=str, default='data/dump', help='The dump file prefix.') args = parser.parse_args() - logger.info(f'Loading Tokenizer ({args.bert_tokenizer})') - bert_tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer) - + logger.info(f'Loading Tokenizer ({args.tokenizer_name})') + if args.tokenizer_type == 'bert': + tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name) + elif args.tokenizer_type == 'roberta': + tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) + bos = tokenizer.special_tokens_map['bos_token'] # `[CLS]` for bert, `` for roberta + sep = tokenizer.special_tokens_map['sep_token'] # `[SEP]` for bert, `` for roberta logger.info(f'Loading text from {args.file_path}') with open(args.file_path, 'r', encoding='utf8') as fp: @@ -56,8 +61,8 @@ def main(): interval = 10000 start = time.time() for text in data: - text = f'[CLS] {text.strip()} [SEP]' - token_ids = bert_tokenizer.encode(text) + text = f'{bos} {text.strip()} {sep}' + token_ids = tokenizer.encode(text) rslt.append(token_ids) iter += 1 @@ -69,7 +74,7 @@ def main(): logger.info(f'{len(data)} examples processed.') - dp_file = f'{args.dump_file}.{args.bert_tokenizer}.pickle' + dp_file = f'{args.dump_file}.{args.tokenizer_name}.pickle' rslt_ = [np.uint16(d) for d in rslt] random.shuffle(rslt_) logger.info(f'Dump to {dp_file}') diff --git a/examples/distillation/scripts/extract_for_distil.py b/examples/distillation/scripts/extract_for_distil.py index 1b9e20c38..2e7e5c73d 100644 --- a/examples/distillation/scripts/extract_for_distil.py +++ b/examples/distillation/scripts/extract_for_distil.py @@ -15,59 +15,73 @@ """ Preprocessing script before training DistilBERT. """ -from transformers import BertForPreTraining +from transformers import BertForMaskedLM, RobertaForMaskedLM import torch import argparse if __name__ == '__main__': - parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForPreTraining for Transfer Learned Distillation") - parser.add_argument("--bert_model", default='bert-base-uncased', type=str) - parser.add_argument("--dump_checkpoint", default='serialization_dir/transfer_learning_checkpoint_0247911.pth', type=str) + parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation") + parser.add_argument("--model_type", default="bert", choices=["bert", "roberta"]) + parser.add_argument("--model_name", default='bert-base-uncased', type=str) + parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str) parser.add_argument("--vocab_transform", action='store_true') args = parser.parse_args() - model = BertForPreTraining.from_pretrained(args.bert_model) + if args.model_type == 'bert': + model = BertForMaskedLM.from_pretrained(args.model_name) + prefix = 'bert' + elif args.model_type == 'roberta': + model = RobertaForMaskedLM.from_pretrained(args.model_name) + prefix = 'roberta' state_dict = model.state_dict() compressed_sd = {} for w in ['word_embeddings', 'position_embeddings']: compressed_sd[f'distilbert.embeddings.{w}.weight'] = \ - state_dict[f'bert.embeddings.{w}.weight'] + state_dict[f'{prefix}.embeddings.{w}.weight'] for w in ['weight', 'bias']: compressed_sd[f'distilbert.embeddings.LayerNorm.{w}'] = \ - state_dict[f'bert.embeddings.LayerNorm.{w}'] + state_dict[f'{prefix}.embeddings.LayerNorm.{w}'] std_idx = 0 for teacher_idx in [0, 2, 4, 7, 9, 11]: for w in ['weight', 'bias']: compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \ - state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.query.{w}'] + state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}'] compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \ - state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.key.{w}'] + state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}'] compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \ - state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.value.{w}'] + state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}'] compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \ - state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.dense.{w}'] + state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}'] compressed_sd[f'distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \ - state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}'] + state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}'] compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \ - state_dict[f'bert.encoder.layer.{teacher_idx}.intermediate.dense.{w}'] + state_dict[f'{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}'] compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \ - state_dict[f'bert.encoder.layer.{teacher_idx}.output.dense.{w}'] + state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}'] compressed_sd[f'distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \ - state_dict[f'bert.encoder.layer.{teacher_idx}.output.LayerNorm.{w}'] + state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}'] std_idx += 1 - compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight'] - compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias'] - if args.vocab_transform: - for w in ['weight', 'bias']: - compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}'] - compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}'] + if args.model_type == 'bert': + compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight'] + compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias'] + if args.vocab_transform: + for w in ['weight', 'bias']: + compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}'] + compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}'] + elif args.model_type == 'roberta': + compressed_sd[f'vocab_projector.weight'] = state_dict[f'lm_head.decoder.weight'] + compressed_sd[f'vocab_projector.bias'] = state_dict[f'lm_head.bias'] + if args.vocab_transform: + for w in ['weight', 'bias']: + compressed_sd[f'vocab_transform.{w}'] = state_dict[f'lm_head.dense.{w}'] + compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}'] print(f'N layers selected for distillation: {std_idx}') print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}') diff --git a/examples/distillation/train.py b/examples/distillation/train.py index 6060cd9eb..f0255d08f 100644 --- a/examples/distillation/train.py +++ b/examples/distillation/train.py @@ -23,7 +23,7 @@ import shutil import numpy as np import torch -from transformers import BertTokenizer, BertForMaskedLM +from transformers import BertTokenizer, BertForMaskedLM, RobertaTokenizer, RobertaForMaskedLM from transformers import DistilBertForMaskedLM, DistilBertConfig from distiller import Distiller @@ -70,8 +70,10 @@ def main(): help="Load student initialization checkpoint.") parser.add_argument("--from_pretrained_config", default=None, type=str, help="Load student initialization architecture config.") - parser.add_argument("--bert_model", default='bert-base-uncased', type=str, - help="The teacher BERT model.") + parser.add_argument("--teacher_type", default="bert", choices=["bert", "roberta"], + help="Teacher type (BERT, RoBERTa).") + parser.add_argument("--teacher_name", default="bert-base-uncased", type=str, + help="The teacher model.") parser.add_argument("--temperature", default=2., type=float, help="Temperature for the softmax temperature.") @@ -81,6 +83,8 @@ def main(): help="Linear weight for the MLM loss. Must be >=0.") parser.add_argument("--alpha_mse", default=0.0, type=float, help="Linear weight of the MSE loss. Must be >=0.") + parser.add_argument("--alpha_cos", default=0.0, type=float, + help="Linear weight of the cosine embedding loss. Must be >=0.") parser.add_argument("--mlm_mask_prop", default=0.15, type=float, help="Proportion of tokens for which we need to make a prediction.") parser.add_argument("--word_mask", default=0.8, type=float, @@ -165,11 +169,14 @@ def main(): ### TOKENIZER ### - bert_tokenizer = BertTokenizer.from_pretrained(args.bert_model) + if args.teacher_type == 'bert': + tokenizer = BertTokenizer.from_pretrained(args.teacher_name) + elif args.teacher_type == 'roberta': + tokenizer = RobertaTokenizer.from_pretrained(args.teacher_name) special_tok_ids = {} - for tok_name, tok_symbol in bert_tokenizer.special_tokens_map.items(): - idx = bert_tokenizer.all_special_tokens.index(tok_symbol) - special_tok_ids[tok_name] = bert_tokenizer.all_special_ids[idx] + for tok_name, tok_symbol in tokenizer.special_tokens_map.items(): + idx = tokenizer.all_special_tokens.index(tok_symbol) + special_tok_ids[tok_name] = tokenizer.all_special_ids[idx] logger.info(f'Special tokens {special_tok_ids}') args.special_tok_ids = special_tok_ids @@ -197,16 +204,17 @@ def main(): ## STUDENT ## if args.from_pretrained_weights is not None: - assert os.path.isfile(os.path.join(args.from_pretrained_weights)) - assert os.path.isfile(os.path.join(args.from_pretrained_config)) + assert os.path.isfile(args.from_pretrained_weights) + assert os.path.isfile(args.from_pretrained_config) logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}') logger.info(f'Loading pretrained config from {args.from_pretrained_config}') stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config) + stu_architecture_config.output_hidden_states = True student = DistilBertForMaskedLM.from_pretrained(args.from_pretrained_weights, - config=stu_architecture_config) + config=stu_architecture_config) else: args.vocab_size_or_config_json_file = args.vocab_size - stu_architecture_config = DistilBertConfig(**vars(args)) + stu_architecture_config = DistilBertConfig(**vars(args), output_hidden_states=True) student = DistilBertForMaskedLM(stu_architecture_config) @@ -216,10 +224,13 @@ def main(): ## TEACHER ## - teacher = BertForMaskedLM.from_pretrained(args.bert_model) + if args.teacher_type == 'bert': + teacher = BertForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True) + elif args.teacher_type == 'roberta': + teacher = RobertaForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True) if args.n_gpu > 0: teacher.to(f'cuda:{args.local_rank}') - logger.info(f'Teacher loaded from {args.bert_model}.') + logger.info(f'Teacher loaded from {args.teacher_name}.') ## DISTILLER ## torch.cuda.empty_cache() diff --git a/transformers/configuration_openai.py b/transformers/configuration_openai.py index b27df5689..886b7f5bc 100644 --- a/transformers/configuration_openai.py +++ b/transformers/configuration_openai.py @@ -36,7 +36,6 @@ class OpenAIGPTConfig(PretrainedConfig): Args: vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file. - n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...) n_positions: Number of positional embeddings. n_ctx: Size of the causal mask (usually same as n_positions). n_embd: Dimensionality of the embeddings and hidden states. diff --git a/transformers/tokenization_xlnet.py b/transformers/tokenization_xlnet.py index 941c6c5bc..ad9efdf04 100644 --- a/transformers/tokenization_xlnet.py +++ b/transformers/tokenization_xlnet.py @@ -183,8 +183,8 @@ class XLNetTokenizer(PreTrainedTokenizer): def add_special_tokens_single_sequence(self, token_ids): """ - Adds special tokens to a sequence pair for sequence classification tasks. - An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS] + Adds special tokens to a sequence for sequence classification tasks. + An XLNet sequence has the following format: X [SEP][CLS] """ sep = [self.sep_token_id] cls = [self.cls_token_id] @@ -192,8 +192,8 @@ class XLNetTokenizer(PreTrainedTokenizer): def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1): """ - Adds special tokens to a sequence for sequence classification tasks. - An XLNet sequence has the following format: X [SEP][CLS] + Adds special tokens to a sequence pair for sequence classification tasks. + An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS] """ sep = [self.sep_token_id]