mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Specify dtype=torch.bool to avoid xla error (#31191)
The StoppingCriteriaList allocates is_done without specifying dtype=torch.bool. On XLA this allocates a float tensor and causes a failure on the following line: is_done = is_done | criteria(input_ids, scores, **kwargs) by attempting to OR float with bool.
This commit is contained in:
parent
8685b3c5d2
commit
66875ac070
1 changed files with 1 additions and 1 deletions
|
|
@ -502,7 +502,7 @@ class EosTokenCriteria(StoppingCriteria):
|
|||
class StoppingCriteriaList(list):
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||
is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device)
|
||||
is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device, dtype=torch.bool)
|
||||
for criteria in self:
|
||||
is_done = is_done | criteria(input_ids, scores, **kwargs)
|
||||
return is_done
|
||||
|
|
|
|||
Loading…
Reference in a new issue