From fbe04423b6fc5ca2b7d28e423264e50505dbdf45 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 4 Jul 2019 00:25:30 +0200 Subject: [PATCH] Common SequenceSummary class --- pytorch_pretrained_bert/__init__.py | 2 +- pytorch_pretrained_bert/model_utils.py | 108 +++++++++++++++++---- pytorch_pretrained_bert/modeling_gpt2.py | 48 +++------ pytorch_pretrained_bert/modeling_openai.py | 50 +++------- pytorch_pretrained_bert/modeling_xlnet.py | 47 ++------- 5 files changed, 130 insertions(+), 125 deletions(-) diff --git a/pytorch_pretrained_bert/__init__.py b/pytorch_pretrained_bert/__init__.py index e14b8b27a..23346967b 100644 --- a/pytorch_pretrained_bert/__init__.py +++ b/pytorch_pretrained_bert/__init__.py @@ -17,7 +17,7 @@ from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl) from .modeling_gpt2 import (GPT2Config, GPT2Model, - GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2MultipleChoiceHead, + GPT2LMHeadModel, GPT2DoubleHeadsModel, load_tf_weights_in_gpt2) from .modeling_xlnet import (XLNetConfig, XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, diff --git a/pytorch_pretrained_bert/model_utils.py b/pytorch_pretrained_bert/model_utils.py index ec735c3e0..0496e41bb 100644 --- a/pytorch_pretrained_bert/model_utils.py +++ b/pytorch_pretrained_bert/model_utils.py @@ -282,6 +282,95 @@ class PreTrainedModel(nn.Module): return model +class Conv1D(nn.Module): + def __init__(self, nf, nx): + """ Conv1D layer as defined by Alec for GPT (and also used in GPT-2) + Basically works like a Linear layer but the weights are transposed + """ + super(Conv1D, self).__init__() + self.nf = nf + w = torch.empty(nx, nf) + nn.init.normal_(w, std=0.02) + self.weight = nn.Parameter(w) + self.bias = nn.Parameter(torch.zeros(nf)) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(*size_out) + return x + + +class SequenceSummary(nn.Module): + def __init__(self, config): + """ Compute a single vector summary of a sequence hidden states according to various possibilities: + Args of the config class: + summary_type: + - 'last' => [default] take the last token hidden state (like XLNet) + - 'first' => take the first token hidden state (like Bert) + - 'mean' => take the mean of all tokens hidden states + - 'token_ids' => supply a Tensor of classification token indices (GPT/GPT-2) + - 'attn' => Not implemented now, use multi-head attention + summary_use_proj: Add a projection after the vector extraction + summary_num_classes: If > 0: the projection outputs to n classes (otherwise to hidden_size) + summary_activation: + 'tanh' => add a tanh activation to the output + None => no activation + """ + super(SequenceSummary, self).__init__() + + self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' + if config.summary_type == 'attn': + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = nn.Identity() + if hasattr(config, 'summary_use_proj') and config.summary_use_proj: + if hasattr(config, 'summary_num_classes') and config.summary_num_classes > 0: + num_classes = config.summary_num_classes + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + self.activation = nn.Identity() + if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': + self.activation = nn.Tanh() + + self.dropout = nn.Dropout(config.summary_dropout) + + def forward(self, hidden_states, token_ids=None): + """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer. + token_ids: [optional] index of the classification token if summary_type == 'token_ids', + shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states. + if summary_type == 'token_ids' and token_ids is None: + we take the last token of the sequence as classification token + """ + if self.summary_type == 'last': + output = hidden_states[:, -1] + elif self.summary_type == 'first': + output = hidden_states[:, 0] + elif self.summary_type == 'mean': + output = hidden_states.mean(dim=1) + elif self.summary_type == 'token_ids': + if token_ids is None: + token_ids = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long) + else: + token_ids = token_ids.unsqueeze(-1).unsqueeze(-1) + token_ids = token_ids.expand((-1,) * (token_ids.dim()-1) + (hidden_states.size(-1),)) + # shape of token_ids: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, token_ids).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == 'attn': + raise NotImplementedError + + output = self.summary(output) + output = self.activation(output) + output = self.dropout(output) + + return output + + def prune_linear_layer(layer, index, dim=0): """ Prune a linear layer (a model parameters) to keep only entries in index. Return the pruned layer as a new layer with requires_grad=True. @@ -307,25 +396,6 @@ def prune_linear_layer(layer, index, dim=0): return new_layer -class Conv1D(nn.Module): - """ Conv1D layer as defined by Alec Radford for GPT (and also used in GPT-2) - Basically works like a Linear layer but the weights are transposed - """ - def __init__(self, nf, nx): - super(Conv1D, self).__init__() - self.nf = nf - w = torch.empty(nx, nf) - nn.init.normal_(w, std=0.02) - self.weight = nn.Parameter(w) - self.bias = nn.Parameter(torch.zeros(nf)) - - def forward(self, x): - size_out = x.size()[:-1] + (self.nf,) - x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) - x = x.view(*size_out) - return x - - def prune_conv1d_layer(layer, index, dim=1): """ Prune a Conv1D layer (a model parameters) to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed. diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index fa5766f4e..c16ad2f76 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -31,7 +31,8 @@ from torch.nn import CrossEntropyLoss from torch.nn.parameter import Parameter from .file_utils import cached_path -from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer +from .model_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, + PreTrainedModel, prune_conv1d_layer, SequenceSummary) from .modeling_bert import BertLayerNorm as LayerNorm logger = logging.getLogger(__name__) @@ -119,6 +120,11 @@ class GPT2Config(PretrainedConfig): layer_norm_epsilon=1e-5, initializer_range=0.02, predict_special_tokens=True, + summary_type='token_ids', + summary_use_proj=True, + summary_num_classes=1, + summary_activation=None, + summary_dropout=0.1, **kwargs ): """Constructs GPT2Config. @@ -164,6 +170,11 @@ class GPT2Config(PretrainedConfig): self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range self.predict_special_tokens = predict_special_tokens + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_num_classes = summary_num_classes + self.summary_activation = summary_activation + self.summary_dropout = summary_dropout else: raise ValueError( "First argument must be either a vocabulary size (int)" @@ -342,37 +353,6 @@ class GPT2LMHead(nn.Module): return lm_logits -class GPT2MultipleChoiceHead(nn.Module): - """ Classifier Head for the transformer """ - - def __init__(self, config): - super(GPT2MultipleChoiceHead, self).__init__() - self.n_embd = config.n_embd - self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation - self.linear = nn.Linear(config.n_embd, 1) - - nn.init.normal_(self.linear.weight, std=0.02) - nn.init.normal_(self.linear.bias, 0) - - def forward(self, hidden_states, mc_token_ids=None): - """ Extract classification token hidden state and project it using self.linear - hidden_state: shape (bsz, num_choices, seq_length, hidden_size) - mc_token_ids: [optional] index of the classification token, shape (bsz, num_choices) - if mc_token_ids=None we take the last token of the sequence as classification token - """ - if mc_token_ids is None: - mc_token_ids = torch.full_like(hidden_states[:, :, :1, :], hidden_states.shape[2] - 1, dtype=torch.long) - else: - mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1)) - # mc_token_ids has shape (bsz, num_choices, 1, hidden_size) - multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2) - # (bsz, num_choices, hidden_size) - multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2) - multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1) - # (bsz, num_choices) - return multiple_choice_logits - - class GPT2PreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for dowloading and loading pretrained models. @@ -735,7 +715,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): super(GPT2DoubleHeadsModel, self).__init__(config) self.transformer = GPT2Model(config) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) - self.multiple_choice_head = GPT2MultipleChoiceHead(config) + self.multiple_choice_head = SequenceSummary(config) self.apply(self.init_weights) @@ -753,7 +733,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) - mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) outputs = (lm_logits, mc_logits) + transformer_outputs[1:] if mc_labels is not None: diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 6a182526e..1a3e7fbbb 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -31,7 +31,8 @@ from torch.nn import CrossEntropyLoss from torch.nn.parameter import Parameter from .file_utils import cached_path -from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer +from .model_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, + PreTrainedModel, prune_conv1d_layer, SequenceSummary) from .modeling_bert import BertLayerNorm as LayerNorm logger = logging.getLogger(__name__) @@ -147,6 +148,11 @@ class OpenAIGPTConfig(PretrainedConfig): layer_norm_epsilon=1e-5, initializer_range=0.02, predict_special_tokens=True, + summary_type='token_ids', + summary_use_proj=True, + summary_num_classes=1, + summary_activation=None, + summary_dropout=0.1, **kwargs ): """Constructs OpenAIGPTConfig. @@ -195,6 +201,11 @@ class OpenAIGPTConfig(PretrainedConfig): self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range self.predict_special_tokens = predict_special_tokens + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_num_classes = summary_num_classes + self.summary_activation = summary_activation + self.summary_dropout = summary_dropout else: raise ValueError( "First argument must be either a vocabulary size (int)" @@ -368,37 +379,6 @@ class OpenAIGPTLMHead(nn.Module): return lm_logits -class OpenAIGPTMultipleChoiceHead(nn.Module): - """ Classifier Head for the transformer """ - - def __init__(self, config): - super(OpenAIGPTMultipleChoiceHead, self).__init__() - self.n_embd = config.n_embd - self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation - self.linear = nn.Linear(config.n_embd, 1) - - nn.init.normal_(self.linear.weight, std=0.02) - nn.init.normal_(self.linear.bias, 0) - - def forward(self, hidden_states, mc_token_ids=None): - """ Extract classification token hidden state and project it using self.linear - hidden_state: hidden state of shape (bsz, num_choices, seq_length, hidden_size) - mc_token_ids: [optional] index of the classification token, shape (bsz, num_choices) - if mc_token_ids=None we take the last token of the sequence as classification token - """ - if mc_token_ids is None: - mc_token_ids = torch.full_like(hidden_states[:, :, :1, :], hidden_states.shape[2] - 1, dtype=torch.long) - else: - mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1)) - # (bsz, num_choices, 1, hidden_size) - multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2) - # (bsz, num_choices, hidden_size) - multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2) - multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1) - # (bsz, num_choices) - return multiple_choice_logits - - class OpenAIGPTPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for dowloading and loading pretrained models. @@ -768,9 +748,11 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): def __init__(self, config): super(OpenAIGPTDoubleHeadsModel, self).__init__(config) + self.transformer = OpenAIGPTModel(config) self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config) - self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config) + self.multiple_choice_head = SequenceSummary(config) + self.apply(self.init_weights) def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True): @@ -787,7 +769,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) - mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) outputs = (lm_logits, mc_logits) + transformer_outputs[1:] if mc_labels is not None: diff --git a/pytorch_pretrained_bert/modeling_xlnet.py b/pytorch_pretrained_bert/modeling_xlnet.py index 2771ba7ca..fb3d72954 100644 --- a/pytorch_pretrained_bert/modeling_xlnet.py +++ b/pytorch_pretrained_bert/modeling_xlnet.py @@ -32,7 +32,8 @@ from torch.nn import functional as F from torch.nn import CrossEntropyLoss, MSELoss from .file_utils import cached_path -from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel +from .model_utils import (CONFIG_NAME, WEIGHTS_NAME, + PretrainedConfig, PreTrainedModel, SequenceSummary) logger = logging.getLogger(__name__) @@ -223,8 +224,10 @@ class XLNetConfig(PretrainedConfig): finetuning_task=None, num_labels=2, - summary_type="last", - use_proj=True, + summary_type='last', + summary_use_proj=True, + summary_activation='tanh', + summary_dropout=0.1, **kwargs): """Constructs XLNetConfig. @@ -307,7 +310,9 @@ class XLNetConfig(PretrainedConfig): self.finetuning_task = finetuning_task self.num_labels = num_labels self.summary_type = summary_type - self.use_proj = use_proj + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_dropout = summary_dropout else: raise ValueError("First argument must be either a vocabulary size (int)" "or the path to a pretrained model config file (str)") @@ -1042,38 +1047,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): return outputs # return (loss), logits, (mems), (hidden states), (attentions) -class XLNetSequenceSummary(nn.Module): - def __init__(self, config): - super(XLNetSequenceSummary, self).__init__() - self.summary_type = config.summary_type - if config.use_proj: - self.summary = nn.Linear(config.d_model, config.d_model) - else: - self.summary = None - if config.summary_type == 'attn': - # We should use a standard multi-head attention module with absolute positional embedding for that. - # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 - # We can probably just use the multi-head attention module of PyTorch >=1.1.0 - raise NotImplementedError - self.dropout = nn.Dropout(config.dropout) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - """ hidden_states: float Tensor in shape [bsz, seq_len, d_model], the hidden-states of the last layer.""" - if self.summary_type == 'last': - output = hidden_states[:, -1] - elif self.summary_type == 'first': - output = hidden_states[:, 0] - elif self.summary_type == 'mean': - output = hidden_states.mean(dim=1) - elif self.summary_type == 'attn': - raise NotImplementedError - - output = self.summary(output) - output = self.activation(output) - output = self.dropout(output) - return output - class XLNetForSequenceClassification(XLNetPreTrainedModel): """XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding"). @@ -1143,7 +1116,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): super(XLNetForSequenceClassification, self).__init__(config) self.transformer = XLNetModel(config) - self.sequence_summary = XLNetSequenceSummary(config) + self.sequence_summary = SequenceSummary(config) self.logits_proj = nn.Linear(config.d_model, config.num_labels) self.apply(self.init_weights)