fix labels (#6213)

This commit is contained in:
Suraj Patil 2020-08-03 19:49:35 +05:30 committed by GitHub
parent cedc547e7e
commit 0b41867357
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -87,7 +87,8 @@ class DataCollatorForLanguageModeling:
return {"input_ids": inputs, "labels": labels}
else:
labels = batch.clone().detach()
labels[labels == self.tokenizer.pad_token_id] = -100
if self.tokenizer.pad_token_id is not None:
labels[labels == self.tokenizer.pad_token_id] = -100
return {"input_ids": batch, "labels": labels}
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor: