Fix mistral ONNX export (#31696)

* use bitwise or

* why is the CI not triggered?
This commit is contained in:
fxmarty 2024-07-02 13:54:10 +02:00 committed by GitHub
parent 93cd94b79d
commit 57d7594a79
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1089,8 +1089,9 @@ class MistralModel(MistralPreTrainedModel):
exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if self.config.sliding_window is not None:
if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
exclude_mask |= torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - self.config.sliding_window
exclude_mask.bitwise_or_(
torch.arange(target_length, device=device)
<= (cache_position.reshape(-1, 1) - self.config.sliding_window)
)
causal_mask *= exclude_mask
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)