Specify dtype=torch.bool to avoid xla error (#31191)

The StoppingCriteriaList allocates is_done without specifying dtype=torch.bool. On XLA this allocates a float tensor and causes a failure on the following line:

is_done = is_done | criteria(input_ids, scores, **kwargs)

by attempting to OR float with bool.
This commit is contained in:
Yury Sulsky 2024-06-05 06:50:54 +01:00 committed by GitHub
parent 8685b3c5d2
commit 66875ac070
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -502,7 +502,7 @@ class EosTokenCriteria(StoppingCriteria):
class StoppingCriteriaList(list):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device)
is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device, dtype=torch.bool)
for criteria in self:
is_done = is_done | criteria(input_ids, scores, **kwargs)
return is_done