mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Add accelerate support for BART-like models (#19927)
* forward contrib credits from suggestion * add `accelerate` support for BART-like models Co-authored-by: sgugger <sgugger@users.noreply.github.com>
This commit is contained in:
parent
ebfd7229d2
commit
4cef546ffc
3 changed files with 34 additions and 21 deletions
|
|
@ -500,6 +500,7 @@ class BartPretrainedModel(PreTrainedModel):
|
|||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"]
|
||||
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
|
|
@ -712,10 +713,10 @@ class BartEncoder(BartPretrainedModel):
|
|||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
|
|
@ -801,6 +802,7 @@ class BartEncoder(BartPretrainedModel):
|
|||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
embed_pos = self.embed_positions(input)
|
||||
embed_pos = embed_pos.to(inputs_embeds.device)
|
||||
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = self.layernorm_embedding(hidden_states)
|
||||
|
|
@ -884,10 +886,10 @@ class BartDecoder(BartPretrainedModel):
|
|||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
|
|
@ -1043,6 +1045,7 @@ class BartDecoder(BartPretrainedModel):
|
|||
|
||||
# embed positions
|
||||
positions = self.embed_positions(input, past_key_values_length)
|
||||
positions = positions.to(inputs_embeds.device)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
hidden_states = self.layernorm_embedding(hidden_states)
|
||||
|
|
@ -1373,7 +1376,9 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
|||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
||||
|
||||
lm_logits = self.lm_head(outputs[0])
|
||||
lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
|
||||
|
||||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
|
|
|
|||
|
|
@ -1595,6 +1595,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
|
|||
config_class = BigBirdPegasusConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
|
|
@ -1788,10 +1789,10 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
|
|||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
|
|
@ -2082,10 +2083,10 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
|||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
|
|
@ -2240,6 +2241,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
|||
|
||||
# embed positions
|
||||
positions = self.embed_positions(input_shape, past_key_values_length)
|
||||
positions = positions.to(inputs_embeds.device)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
|
||||
|
|
@ -2573,7 +2575,9 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
|
|||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
||||
|
||||
lm_logits = self.lm_head(outputs[0])
|
||||
lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
|
||||
|
||||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
|
|
|
|||
|
|
@ -506,6 +506,7 @@ class PLBartPreTrainedModel(PreTrainedModel):
|
|||
config_class = PLBartConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
|
|
@ -683,10 +684,10 @@ class PLBartEncoder(PLBartPreTrainedModel):
|
|||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = PLBartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
|
|
@ -772,6 +773,7 @@ class PLBartEncoder(PLBartPreTrainedModel):
|
|||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
embed_pos = self.embed_positions(input)
|
||||
embed_pos = embed_pos.to(inputs_embeds.device)
|
||||
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = self.layernorm_embedding(hidden_states)
|
||||
|
|
@ -856,10 +858,10 @@ class PLBartDecoder(PLBartPreTrainedModel):
|
|||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = PLBartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
|
|
@ -1015,6 +1017,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
|
|||
|
||||
# embed positions
|
||||
positions = self.embed_positions(input, past_key_values_length)
|
||||
positions = positions.to(inputs_embeds.device)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
hidden_states = self.layernorm_embedding(hidden_states)
|
||||
|
|
@ -1334,7 +1337,8 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
|
|||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
||||
lm_logits = self.lm_head(outputs[0])
|
||||
lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
|
||||
|
||||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue