tgi update

This commit is contained in:
Cyril Vallez 2024-12-12 18:29:26 +00:00
parent 95cb944ee6
commit caaa5e5508
3 changed files with 12 additions and 4 deletions

View file

@ -4,15 +4,20 @@ from ..modeling_flash_attention_utils import _flash_attention_forward
def flash_attention_forward(
config, query, key, value, attentions_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs
config, query, key, value, attention_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs
):
if attentions_mask is not None:
seq_len = attentions_mask.shape[1]
if attention_mask is not None:
seq_len = attention_mask.shape[1]
query = query[:, :, :seq_len]
value = value[:, :, :seq_len]
else:
seq_len = query.shape[1]
# Re-transpose them
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
dropout_rate = config.attention_dropout if training else 0.0
input_dtype = query.dtype
@ -25,7 +30,7 @@ def flash_attention_forward(
query,
key,
value,
attentions_mask,
attention_mask,
seq_len,
config=config,
dropout=dropout_rate,

View file

@ -25,6 +25,8 @@ class AutoForCausalLM(PreTrainedModel, GenerationMixin):
_no_split_modules = []
_supports_cache_class = True
config_class = AutoConfig
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, config):
super().__init__(config)

View file

@ -269,6 +269,7 @@ class LlamaAttention(nn.Module):
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
kwargs["layer_idx"] = self.layer_idx
attn_output, attn_weights = attention_interface(
self,
query_states,