mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
tgi update
This commit is contained in:
parent
95cb944ee6
commit
caaa5e5508
3 changed files with 12 additions and 4 deletions
|
|
@ -4,15 +4,20 @@ from ..modeling_flash_attention_utils import _flash_attention_forward
|
|||
|
||||
|
||||
def flash_attention_forward(
|
||||
config, query, key, value, attentions_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs
|
||||
config, query, key, value, attention_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs
|
||||
):
|
||||
if attentions_mask is not None:
|
||||
seq_len = attentions_mask.shape[1]
|
||||
if attention_mask is not None:
|
||||
seq_len = attention_mask.shape[1]
|
||||
query = query[:, :, :seq_len]
|
||||
value = value[:, :, :seq_len]
|
||||
else:
|
||||
seq_len = query.shape[1]
|
||||
|
||||
# Re-transpose them
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
dropout_rate = config.attention_dropout if training else 0.0
|
||||
|
||||
input_dtype = query.dtype
|
||||
|
|
@ -25,7 +30,7 @@ def flash_attention_forward(
|
|||
query,
|
||||
key,
|
||||
value,
|
||||
attentions_mask,
|
||||
attention_mask,
|
||||
seq_len,
|
||||
config=config,
|
||||
dropout=dropout_rate,
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ class AutoForCausalLM(PreTrainedModel, GenerationMixin):
|
|||
_no_split_modules = []
|
||||
_supports_cache_class = True
|
||||
config_class = AutoConfig
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
|
|
|||
|
|
@ -269,6 +269,7 @@ class LlamaAttention(nn.Module):
|
|||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
kwargs["layer_idx"] = self.layer_idx
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
|
|
|
|||
Loading…
Reference in a new issue