mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
more refactoring
This commit is contained in:
parent
f446bd4c00
commit
0384db9c0c
6 changed files with 407 additions and 505 deletions
37
src/transformers/integrations/flash_attention.py
Normal file
37
src/transformers/integrations/flash_attention.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import torch
|
||||
|
||||
|
||||
def flash_attention_forward(
|
||||
config, query, key, value, mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs
|
||||
):
|
||||
if mask is not None:
|
||||
seq_len = mask.shape[1]
|
||||
query = query[:, :, :seq_len]
|
||||
value = value[:, :, :seq_len]
|
||||
else:
|
||||
seq_len = query.shape[1]
|
||||
|
||||
dropout_rate = config.attention_dropout if training else 0.0
|
||||
|
||||
input_dtype = query.dtype
|
||||
if input_dtype == torch.float32:
|
||||
query = query.to(target_dtype)
|
||||
key = key.to(target_dtype)
|
||||
value = value.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
mask,
|
||||
seq_len,
|
||||
dropout=dropout_rate,
|
||||
softmax_scale=getattr(config, "scaling", 1.0),
|
||||
is_causal=getattr(config, "is_causal", False),
|
||||
sliding_window=getattr(config, "sliding_window", None),
|
||||
use_top_left_mask=getattr(config, "_flash_attn_uses_top_left_mask", False),
|
||||
layer_idx=layer_idx,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return attn_output, None
|
||||
23
src/transformers/integrations/flex_attention.py
Normal file
23
src/transformers/integrations/flex_attention.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import torch
|
||||
|
||||
def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs):
|
||||
def tanh_softcap(score, b, h, q_idx, kv_idx):
|
||||
soft_cap = config.attn_logit_softcapping
|
||||
score = soft_cap * torch.tanh(score / soft_cap)
|
||||
if mask is not None:
|
||||
return score + mask[b][0][q_idx][kv_idx]
|
||||
return score
|
||||
|
||||
attn_output = flex_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
score_mod=tanh_softcap,
|
||||
enable_gqa=True,
|
||||
scale=config.scaling,
|
||||
return_lse=output_attentions,
|
||||
)
|
||||
if not output_attentions:
|
||||
return attn_output, None
|
||||
else:
|
||||
return attn_output[0], attn_output[1]
|
||||
32
src/transformers/integrations/sdpa_attention.py
Normal file
32
src/transformers/integrations/sdpa_attention.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import torch
|
||||
|
||||
|
||||
def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
|
||||
key = repeat_kv(key, config.num_key_value_groups)
|
||||
value = repeat_kv(value, config.num_key_value_groups)
|
||||
|
||||
causal_mask = mask
|
||||
if mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query.device.type == "cuda" and causal_mask is not None:
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and query.shape[1] > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=config.attention_dropout if config.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
scale=config.scaling,
|
||||
)
|
||||
return attn_output, None
|
||||
|
|
@ -45,6 +45,9 @@ from .configuration_utils import PretrainedConfig
|
|||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import GenerationConfig, GenerationMixin
|
||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from .integrations.flash_attention import *
|
||||
from .integrations.flex_attention import *
|
||||
from .integrations.sdpa_attention import *
|
||||
from .loss.loss_utils import LOSS_MAPPING
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
Conv1D,
|
||||
|
|
@ -5484,6 +5487,19 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
|
|||
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
|
||||
|
||||
|
||||
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS: Dict[str, Dict[str, function]] = {}
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS.update(
|
||||
{
|
||||
"flash_attention_2": flash_attention_forward,
|
||||
"flex_attention": flex_attention_forward,
|
||||
"sdpa": sdpa_attention_forward,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class GradientCheckpointLayer(torch.nn.Module):
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
|
|
@ -5599,6 +5615,3 @@ class GradientCheckpointLayer(torch.nn.Module):
|
|||
|
||||
if getattr(self, "_hf_peft_config_loaded", False):
|
||||
self.disable_input_require_grads()
|
||||
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS: Dict[str, Dict[str, function]] = {}
|
||||
|
|
|
|||
296
src/transformers/models/auto/modeling_task.py
Normal file
296
src/transformers/models/auto/modeling_task.py
Normal file
|
|
@ -0,0 +1,296 @@
|
|||
import torch
|
||||
|
||||
class AutoForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_embeding_layer = "model.embed_tokens"
|
||||
_output_embedding = "lm_head"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = AutoModel.from_config(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class AutoForSequenceClassification(LlamaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = LlamaModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
else:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class AutoForQuestionAnswering(LlamaPreTrainedModel):
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
# Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.transformer = LlamaModel(config)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.transformer.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.transformer.embed_tokens = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1).contiguous()
|
||||
end_logits = end_logits.squeeze(-1).contiguous()
|
||||
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class AutoForTokenClassification(LlamaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = LlamaModel(config)
|
||||
if getattr(config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif getattr(config, "hidden_dropout", None) is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@validate_config_kwargs
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, TokenClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.score(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
|
@ -37,7 +37,7 @@ from ...modeling_outputs import (
|
|||
)
|
||||
from ...utils.generic import validate_config_kwargs
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, GradientCheckpointLayer
|
||||
from ...processing_utils import Unpack
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import (
|
||||
|
|
@ -53,7 +53,6 @@ from .configuration_llama import LlamaConfig
|
|||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf"
|
||||
_CONFIG_FOR_DOC = "LlamaConfig"
|
||||
|
||||
|
||||
|
|
@ -229,105 +228,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def flash_attention_forward(
|
||||
config, query, key, value, mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs
|
||||
):
|
||||
if mask is not None:
|
||||
seq_len = mask.shape[1]
|
||||
query = query[:, :, :seq_len]
|
||||
value = value[:, :, :seq_len]
|
||||
else:
|
||||
seq_len = query.shape[1]
|
||||
|
||||
dropout_rate = config.attention_dropout if training else 0.0
|
||||
|
||||
input_dtype = query.dtype
|
||||
if input_dtype == torch.float32:
|
||||
query = query.to(target_dtype)
|
||||
key = key.to(target_dtype)
|
||||
value = value.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
mask,
|
||||
seq_len,
|
||||
dropout=dropout_rate,
|
||||
softmax_scale=getattr(config, "scaling", 1.0),
|
||||
is_causal=getattr(config, "is_causal", False),
|
||||
sliding_window=getattr(config, "sliding_window", None),
|
||||
use_top_left_mask=getattr(config, "_flash_attn_uses_top_left_mask", False),
|
||||
layer_idx=layer_idx,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs):
|
||||
def tanh_softcap(score, b, h, q_idx, kv_idx):
|
||||
soft_cap = config.attn_logit_softcapping
|
||||
score = soft_cap * torch.tanh(score / soft_cap)
|
||||
if mask is not None:
|
||||
return score + mask[b][0][q_idx][kv_idx]
|
||||
return score
|
||||
|
||||
attn_output = flex_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
score_mod=tanh_softcap,
|
||||
enable_gqa=True,
|
||||
scale=config.scaling,
|
||||
return_lse=output_attentions,
|
||||
)
|
||||
if not output_attentions:
|
||||
return attn_output, None
|
||||
else:
|
||||
return attn_output[0], attn_output[1]
|
||||
|
||||
|
||||
def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
|
||||
key = repeat_kv(key, config.num_key_value_groups)
|
||||
value = repeat_kv(value, config.num_key_value_groups)
|
||||
|
||||
causal_mask = mask
|
||||
if mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query.device.type == "cuda" and causal_mask is not None:
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and query.shape[1] > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=config.attention_dropout if config.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
scale=config.scaling,
|
||||
)
|
||||
return attn_output, None
|
||||
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS.update(
|
||||
{
|
||||
"llama.flash_attention_2": flash_attention_forward,
|
||||
"llama.flex_attention": flex_attention_forward,
|
||||
"llama.sdpa": sdpa_attention_forward,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def eager_attention_forward(attention_class:nn.Module, query, key, value, mask, **_kwargs):
|
||||
config = attention_class.config
|
||||
key_states = repeat_kv(key, config.num_key_value_groups)
|
||||
|
|
@ -585,6 +485,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||
Args:
|
||||
config: LlamaConfig
|
||||
"""
|
||||
_input_embedding = "embed_tokens" # no need for set and get, take then from PreTrainedModel
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
|
|
@ -598,16 +499,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@validate_config_kwargs
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -623,8 +517,6 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
|
|
@ -693,10 +585,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return output.to_tuple()
|
||||
return output
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
|
|
@ -819,391 +708,3 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = LlamaModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The LLaMa Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
||||
(e.g. GPT-2) do.
|
||||
|
||||
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
||||
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
||||
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
||||
each row of the batch).
|
||||
""",
|
||||
LLAMA_START_DOCSTRING,
|
||||
)
|
||||
class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = LlamaModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
else:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The Llama Model transformer with a span classification head on top for extractive question-answering tasks like
|
||||
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||
""",
|
||||
LLAMA_START_DOCSTRING,
|
||||
)
|
||||
class LlamaForQuestionAnswering(LlamaPreTrainedModel):
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
# Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.transformer = LlamaModel(config)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.transformer.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.transformer.embed_tokens = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||||
r"""
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1).contiguous()
|
||||
end_logits = end_logits.squeeze(-1).contiguous()
|
||||
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||
""",
|
||||
LLAMA_START_DOCSTRING,
|
||||
)
|
||||
class LlamaForTokenClassification(LlamaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = LlamaModel(config)
|
||||
if getattr(config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif getattr(config, "hidden_dropout", None) is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@validate_config_kwargs
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, TokenClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.score(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue