Add accelerate support for BART-like models (#19927)

* forward contrib credits from suggestion

* add `accelerate` support for BART-like models

Co-authored-by: sgugger <sgugger@users.noreply.github.com>
This commit is contained in:
Younes Belkada 2022-10-27 23:14:53 +02:00 committed by GitHub
parent ebfd7229d2
commit 4cef546ffc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 21 deletions

View file

@ -500,6 +500,7 @@ class BartPretrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"]
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
def _init_weights(self, module):
std = self.config.init_std
@ -712,10 +713,10 @@ class BartEncoder(BartPretrainedModel):
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
@ -801,6 +802,7 @@ class BartEncoder(BartPretrainedModel):
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
@ -884,10 +886,10 @@ class BartDecoder(BartPretrainedModel):
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
@ -1043,6 +1045,7 @@ class BartDecoder(BartPretrainedModel):
# embed positions
positions = self.embed_positions(input, past_key_values_length)
positions = positions.to(inputs_embeds.device)
hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
@ -1373,7 +1376,9 @@ class BartForConditionalGeneration(BartPretrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
lm_logits = self.lm_head(outputs[0])
lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
masked_lm_loss = None
if labels is not None:

View file

@ -1595,6 +1595,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
config_class = BigBirdPegasusConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"]
def _init_weights(self, module):
std = self.config.init_std
@ -1788,10 +1789,10 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
config.max_position_embeddings,
@ -2082,10 +2083,10 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
config.max_position_embeddings,
@ -2240,6 +2241,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
positions = positions.to(inputs_embeds.device)
hidden_states = inputs_embeds + positions
@ -2573,7 +2575,9 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
lm_logits = self.lm_head(outputs[0])
lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
masked_lm_loss = None
if labels is not None:

View file

@ -506,6 +506,7 @@ class PLBartPreTrainedModel(PreTrainedModel):
config_class = PLBartConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
def _init_weights(self, module):
std = self.config.init_std
@ -683,10 +684,10 @@ class PLBartEncoder(PLBartPreTrainedModel):
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = PLBartLearnedPositionalEmbedding(
config.max_position_embeddings,
@ -772,6 +773,7 @@ class PLBartEncoder(PLBartPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
@ -856,10 +858,10 @@ class PLBartDecoder(PLBartPreTrainedModel):
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = PLBartLearnedPositionalEmbedding(
config.max_position_embeddings,
@ -1015,6 +1017,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
# embed positions
positions = self.embed_positions(input, past_key_values_length)
positions = positions.to(inputs_embeds.device)
hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
@ -1334,7 +1337,8 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
lm_logits = self.lm_head(outputs[0])
lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
masked_lm_loss = None
if labels is not None: