mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Support batching for UsefulSensors Moonshine (#35922)
* Add support for attention masking in moonshine. Tested against Open ASR Leaderboard with batch size 256. * Update comments and ensure attention masks are passed everywhere. Perform attention mask downsampling inside of moonshine forward call. * Hide padding behind conditional. Fix encoder/decoder masking. - Correctly pipe encoder attention mask into decoder - Add correct scaling factor if one is not already provided. - Fix formatting with ruff * Add auto generated modeling_moonshine file. * Update formatting in generated model file. * Address review comments. * Fix typo. * Add `pad_head_dim_to_multiple_of` to moonshine config. * Correct args order for MooonshineConfig. * Update configuration moonshine too. * Update src/transformers/models/moonshine/modular_moonshine.py * Update src/transformers/models/moonshine/configuration_moonshine.py --------- Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
This commit is contained in:
parent
5757681837
commit
693328f2bc
3 changed files with 181 additions and 15 deletions
|
|
@ -64,6 +64,9 @@ class MoonshineConfig(PretrainedConfig):
|
|||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`decoder_num_attention_heads`.
|
||||
pad_head_dim_to_multiple_of (`int`, *optional*):
|
||||
Pad head dimension in encoder and decoder to the next multiple of this value. Necessary for using certain
|
||||
optimized attention implementations.
|
||||
encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder.
|
||||
decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
|
|
@ -164,6 +167,7 @@ class MoonshineConfig(PretrainedConfig):
|
|||
decoder_num_attention_heads=8,
|
||||
encoder_num_key_value_heads=None,
|
||||
decoder_num_key_value_heads=None,
|
||||
pad_head_dim_to_multiple_of=None,
|
||||
encoder_hidden_act="gelu",
|
||||
decoder_hidden_act="silu",
|
||||
max_position_embeddings=512,
|
||||
|
|
@ -196,6 +200,8 @@ class MoonshineConfig(PretrainedConfig):
|
|||
decoder_num_key_value_heads = decoder_num_attention_heads
|
||||
self.decoder_num_key_value_heads = decoder_num_key_value_heads
|
||||
|
||||
self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of
|
||||
|
||||
self.encoder_hidden_act = encoder_hidden_act
|
||||
self.decoder_hidden_act = decoder_hidden_act
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -27,7 +28,11 @@ import torch.nn as nn
|
|||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_attn_mask_utils import (
|
||||
AttentionMaskConverter,
|
||||
_prepare_4d_attention_mask,
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
|
|
@ -270,6 +275,23 @@ class MoonshineAttention(nn.Module):
|
|||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
|
||||
|
||||
# Pad head size dimension to next specified multiple. Q K and V always have equal head sizes.
|
||||
head_dim_padding = 0
|
||||
if self.config.pad_head_dim_to_multiple_of is not None:
|
||||
head_dim = query_states.shape[-1]
|
||||
target_multiple = self.config.pad_head_dim_to_multiple_of
|
||||
target_head_dim = target_multiple * ((head_dim + target_multiple - 1) // target_multiple)
|
||||
head_dim_padding = target_head_dim - head_dim
|
||||
if head_dim_padding > 0:
|
||||
# Ensure scaling is correct even with padding.
|
||||
if self.scaling is None:
|
||||
self.scaling = 1.0 / math.sqrt(query_states.shape[-1])
|
||||
|
||||
query_states = torch.nn.functional.pad(query_states, (0, head_dim_padding))
|
||||
key_states = torch.nn.functional.pad(key_states, (0, head_dim_padding))
|
||||
value_states = torch.nn.functional.pad(value_states, (0, head_dim_padding))
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
|
|
@ -282,6 +304,10 @@ class MoonshineAttention(nn.Module):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
# Remove head size padding.
|
||||
if head_dim_padding > 0:
|
||||
attn_output = attn_output[:, :, :, :-head_dim_padding]
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
|
@ -603,9 +629,11 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
|
|||
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
|
||||
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
|
||||
and conversion into a tensor of type `torch.FloatTensor`.
|
||||
attention_mask (`torch.Tensor`)`, *optional*):
|
||||
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
|
||||
but it is not used.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
|
|
@ -632,6 +660,22 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
|
|||
hidden_states = nn.functional.gelu(self.conv3(hidden_states))
|
||||
hidden_states = hidden_states.permute(0, 2, 1)
|
||||
|
||||
# attention mask downsampling
|
||||
if attention_mask is not None:
|
||||
mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1])
|
||||
downsample_stride = 64 * 3 * 2 # conv strides
|
||||
attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len]
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask if (attention_mask == 0.0).any() else None
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
elif self.config._attn_implementation == "sdpa" and not output_attentions:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, hidden_states.dtype)
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
|
||||
|
||||
position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
|
|
@ -649,7 +693,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
|
|||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
output_attentions,
|
||||
|
|
@ -660,6 +704,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
|
|||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
position_embeddings=position_embeddings,
|
||||
|
|
@ -810,6 +855,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
|
|||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
"""
|
||||
|
|
@ -817,6 +863,11 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
|
|||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||
of the decoder.
|
||||
encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`:
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
|
@ -865,6 +916,26 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
|
|||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
|
||||
# attention mask downsampling
|
||||
if encoder_attention_mask is not None:
|
||||
mask_len = encoder_hidden_states.shape[-2]
|
||||
downsample_stride = 64 * 3 * 2 # conv strides
|
||||
encoder_attention_mask = encoder_attention_mask[..., ::downsample_stride][..., :mask_len]
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
elif self.config._attn_implementation == "sdpa" and not output_attentions:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
|
||||
)
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
|
||||
)
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
|
@ -886,6 +957,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
|
|||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
|
|
@ -1168,9 +1240,11 @@ MOONSHINE_MODEL_INPUTS_DOCSTRING = r"""
|
|||
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
|
||||
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
|
||||
and conversion into a tensor of type `torch.FloatTensor`.
|
||||
attention_mask (`torch.Tensor`)`, *optional*):
|
||||
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
|
||||
but it is not used.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
|
@ -1371,6 +1445,7 @@ class MoonshineModel(MoonshinePreTrainedModel):
|
|||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
|
|
@ -1387,6 +1462,7 @@ class MoonshineModel(MoonshinePreTrainedModel):
|
|||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
encoder_attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
|
|
@ -1517,6 +1593,7 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi
|
|||
|
||||
outputs = self.model(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
encoder_outputs=encoder_outputs,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -21,6 +22,10 @@ from ...activations import ACT2FN
|
|||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_attention_mask,
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
|
|
@ -91,6 +96,9 @@ class MoonshineConfig(PretrainedConfig):
|
|||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`decoder_num_attention_heads`.
|
||||
pad_head_dim_to_multiple_of (`int`, *optional*):
|
||||
Pad head dimension in encoder and decoder to the next multiple of this value. Necessary for using certain
|
||||
optimized attention implementations.
|
||||
encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder.
|
||||
decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
|
|
@ -191,6 +199,7 @@ class MoonshineConfig(PretrainedConfig):
|
|||
decoder_num_attention_heads=8,
|
||||
encoder_num_key_value_heads=None,
|
||||
decoder_num_key_value_heads=None,
|
||||
pad_head_dim_to_multiple_of=None,
|
||||
encoder_hidden_act="gelu",
|
||||
decoder_hidden_act="silu",
|
||||
max_position_embeddings=512,
|
||||
|
|
@ -223,6 +232,8 @@ class MoonshineConfig(PretrainedConfig):
|
|||
decoder_num_key_value_heads = decoder_num_attention_heads
|
||||
self.decoder_num_key_value_heads = decoder_num_key_value_heads
|
||||
|
||||
self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of
|
||||
|
||||
self.encoder_hidden_act = encoder_hidden_act
|
||||
self.decoder_hidden_act = decoder_hidden_act
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
|
@ -360,6 +371,23 @@ class MoonshineAttention(GlmAttention):
|
|||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
|
||||
|
||||
# Pad head size dimension to next specified multiple. Q K and V always have equal head sizes.
|
||||
head_dim_padding = 0
|
||||
if self.config.pad_head_dim_to_multiple_of is not None:
|
||||
head_dim = query_states.shape[-1]
|
||||
target_multiple = self.config.pad_head_dim_to_multiple_of
|
||||
target_head_dim = target_multiple * ((head_dim + target_multiple - 1) // target_multiple)
|
||||
head_dim_padding = target_head_dim - head_dim
|
||||
if head_dim_padding > 0:
|
||||
# Ensure scaling is correct even with padding.
|
||||
if self.scaling is None:
|
||||
self.scaling = 1.0 / math.sqrt(query_states.shape[-1])
|
||||
|
||||
query_states = torch.nn.functional.pad(query_states, (0, head_dim_padding))
|
||||
key_states = torch.nn.functional.pad(key_states, (0, head_dim_padding))
|
||||
value_states = torch.nn.functional.pad(value_states, (0, head_dim_padding))
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
|
|
@ -372,6 +400,10 @@ class MoonshineAttention(GlmAttention):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
# Remove head size padding.
|
||||
if head_dim_padding > 0:
|
||||
attn_output = attn_output[:, :, :, :-head_dim_padding]
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
|
@ -593,9 +625,11 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
|
|||
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
|
||||
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
|
||||
and conversion into a tensor of type `torch.FloatTensor`.
|
||||
attention_mask (`torch.Tensor`)`, *optional*):
|
||||
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
|
||||
but it is not used.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
|
|
@ -622,6 +656,22 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
|
|||
hidden_states = nn.functional.gelu(self.conv3(hidden_states))
|
||||
hidden_states = hidden_states.permute(0, 2, 1)
|
||||
|
||||
# attention mask downsampling
|
||||
if attention_mask is not None:
|
||||
mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1])
|
||||
downsample_stride = 64 * 3 * 2 # conv strides
|
||||
attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len]
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask if (attention_mask == 0.0).any() else None
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
elif self.config._attn_implementation == "sdpa" and not output_attentions:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, hidden_states.dtype)
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
|
||||
|
||||
position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
|
|
@ -639,7 +689,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
|
|||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
output_attentions,
|
||||
|
|
@ -650,6 +700,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
|
|||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
position_embeddings=position_embeddings,
|
||||
|
|
@ -698,6 +749,7 @@ class MoonshineDecoder(LlamaModel):
|
|||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
"""
|
||||
|
|
@ -705,6 +757,11 @@ class MoonshineDecoder(LlamaModel):
|
|||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||
of the decoder.
|
||||
encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`:
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
|
@ -753,6 +810,26 @@ class MoonshineDecoder(LlamaModel):
|
|||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
|
||||
# attention mask downsampling
|
||||
if encoder_attention_mask is not None:
|
||||
mask_len = encoder_hidden_states.shape[-2]
|
||||
downsample_stride = 64 * 3 * 2 # conv strides
|
||||
encoder_attention_mask = encoder_attention_mask[..., ::downsample_stride][..., :mask_len]
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
elif self.config._attn_implementation == "sdpa" and not output_attentions:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
|
||||
)
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
|
||||
)
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
|
@ -774,6 +851,7 @@ class MoonshineDecoder(LlamaModel):
|
|||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
|
|
@ -816,9 +894,11 @@ MOONSHINE_MODEL_INPUTS_DOCSTRING = r"""
|
|||
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
|
||||
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
|
||||
and conversion into a tensor of type `torch.FloatTensor`.
|
||||
attention_mask (`torch.Tensor`)`, *optional*):
|
||||
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
|
||||
but it is not used.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
|
@ -945,6 +1025,7 @@ class MoonshineModel(WhisperModel):
|
|||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
|
|
@ -961,6 +1042,7 @@ class MoonshineModel(WhisperModel):
|
|||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
encoder_attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
|
|
@ -1075,6 +1157,7 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi
|
|||
|
||||
outputs = self.model(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
encoder_outputs=encoder_outputs,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
|
|
|
|||
Loading…
Reference in a new issue