diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index df6e96131..b29d56cd9 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional, Tuple, TypedDict import torch @@ -62,3 +62,17 @@ def sdpa_attention_forward( attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None + + +class SdpaAttentionKwargs(TypedDict, total=False): + """ + Keyword arguments for sdpa Attention. + + Attributes: + is_causal (`bool`, *optional*) + The value for the argument `is_causal` that is passed to `torch.nn.functional.scaled_dot_product_attention`. + An error is thrown if both attention_mask and is_causal are set. If `None`, it is inferred in + `sdpa_attention_forward`. + """ + + is_causal: Optional[bool] diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1c67ee1f8..1861a99cb 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -30,7 +30,7 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial, wraps from threading import Thread -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypedDict, TypeVar, Union from zipfile import is_zipfile import torch @@ -48,8 +48,9 @@ from .generation import CompileConfig, GenerationConfig, GenerationMixin from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled from .integrations.flash_attention import flash_attention_forward from .integrations.flex_attention import flex_attention_forward -from .integrations.sdpa_attention import sdpa_attention_forward +from .integrations.sdpa_attention import sdpa_attention_forward, SdpaAttentionKwargs from .loss.loss_utils import LOSS_MAPPING +from .modeling_flash_attention_utils import FlashAttentionKwargs from .pytorch_utils import ( # noqa: F401 Conv1D, apply_chunking_to_forward, @@ -5702,3 +5703,6 @@ ALL_ATTENTION_FUNCTIONS.update( "sdpa": sdpa_attention_forward, } ) + + +AttentionKwargs = Union[FlashAttentionKwargs, SdpaAttentionKwargs] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 195a078ad..0b2cf053d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -27,7 +27,6 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -36,7 +35,7 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionKwargs, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( @@ -262,7 +261,7 @@ class LlamaAttention(nn.Module): attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + **kwargs: Unpack[AttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -528,7 +527,7 @@ class LlamaModel(LlamaPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + **flash_attn_kwargs: Unpack[AttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (