mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
fix 2
This commit is contained in:
parent
6ba13f577b
commit
421bf8611a
3 changed files with 24 additions and 7 deletions
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -62,3 +62,17 @@ def sdpa_attention_forward(
|
|||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class SdpaAttentionKwargs(TypedDict, total=False):
|
||||
"""
|
||||
Keyword arguments for sdpa Attention.
|
||||
|
||||
Attributes:
|
||||
is_causal (`bool`, *optional*)
|
||||
The value for the argument `is_causal` that is passed to `torch.nn.functional.scaled_dot_product_attention`.
|
||||
An error is thrown if both attention_mask and is_causal are set. If `None`, it is inferred in
|
||||
`sdpa_attention_forward`.
|
||||
"""
|
||||
|
||||
is_causal: Optional[bool]
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from contextlib import contextmanager
|
|||
from dataclasses import dataclass
|
||||
from functools import partial, wraps
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypedDict, TypeVar, Union
|
||||
from zipfile import is_zipfile
|
||||
|
||||
import torch
|
||||
|
|
@ -48,8 +48,9 @@ from .generation import CompileConfig, GenerationConfig, GenerationMixin
|
|||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from .integrations.flash_attention import flash_attention_forward
|
||||
from .integrations.flex_attention import flex_attention_forward
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward, SdpaAttentionKwargs
|
||||
from .loss.loss_utils import LOSS_MAPPING
|
||||
from .modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
Conv1D,
|
||||
apply_chunking_to_forward,
|
||||
|
|
@ -5702,3 +5703,6 @@ ALL_ATTENTION_FUNCTIONS.update(
|
|||
"sdpa": sdpa_attention_forward,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
AttentionKwargs = Union[FlashAttentionKwargs, SdpaAttentionKwargs]
|
||||
|
|
|
|||
|
|
@ -27,7 +27,6 @@ from ...activations import ACT2FN
|
|||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
|
|
@ -36,7 +35,7 @@ from ...modeling_outputs import (
|
|||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionKwargs, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import (
|
||||
|
|
@ -262,7 +261,7 @@ class LlamaAttention(nn.Module):
|
|||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[AttentionKwargs],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
|
@ -528,7 +527,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
**flash_attn_kwargs: Unpack[AttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
|
|
|||
Loading…
Reference in a new issue