mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Model parallelism: Moving labels to same devices as the logits are (#22691)
Model parallelism correct labels device
This commit is contained in:
parent
6daa9cb515
commit
151425ddb2
4 changed files with 26 additions and 0 deletions
|
|
@ -999,6 +999,8 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
|
|||
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
|
||||
labels = labels.to(shifted_prediction_scores.device)
|
||||
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
|
|
@ -1114,6 +1116,8 @@ class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel):
|
|||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
|
||||
labels = labels.to(prediction_scores.device)
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
|
|
@ -1224,6 +1228,8 @@ class Data2VecTextForSequenceClassification(Data2VecTextPreTrainedModel):
|
|||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
|
|
@ -1337,6 +1343,8 @@ class Data2VecTextForMultipleChoice(Data2VecTextPreTrainedModel):
|
|||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
|
||||
labels = labels.to(reshaped_logits.device)
|
||||
loss = loss_fct(reshaped_logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
|
|
@ -1421,6 +1429,8 @@ class Data2VecTextForTokenClassification(Data2VecTextPreTrainedModel):
|
|||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
|
||||
labels = labels.to(logits.device)
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
|
|
|
|||
|
|
@ -1032,6 +1032,8 @@ class EsmForMaskedLM(EsmPreTrainedModel):
|
|||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
|
||||
labels = labels.to(prediction_scores.device)
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
|
|
@ -1131,6 +1133,8 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
|
|||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
|
|
@ -1228,6 +1232,8 @@ class EsmForTokenClassification(EsmPreTrainedModel):
|
|||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
|
||||
labels = labels.to(logits.device)
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
|
|
|
|||
|
|
@ -1863,6 +1863,8 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
|
|||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
|
||||
labels = labels.to(prediction_scores.device)
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
|
|
@ -1952,6 +1954,8 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
|
|||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
|
|
@ -2217,6 +2221,8 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
|
|||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
|
||||
labels = labels.to(logits.device)
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
|
|
@ -2329,6 +2335,8 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel):
|
|||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
|
||||
labels = labels.to(reshaped_logits.device)
|
||||
loss = loss_fct(reshaped_logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
|
|
|
|||
|
|
@ -2074,6 +2074,8 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
|
|||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||
|
||||
labels = labels.to(lm_logits.device)
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
||||
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue