Support batching for UsefulSensors Moonshine (#35922)

* Add support for attention masking in moonshine.

Tested against Open ASR Leaderboard with batch size 256.

* Update comments and ensure attention masks are passed everywhere.

Perform attention mask downsampling inside of moonshine forward call.

* Hide padding behind conditional. Fix encoder/decoder masking.

- Correctly pipe encoder attention mask into decoder
- Add correct scaling factor if one is not already provided.
- Fix formatting with ruff

* Add auto generated modeling_moonshine file.

* Update formatting in generated model file.

* Address review comments.

* Fix typo.

* Add `pad_head_dim_to_multiple_of` to moonshine config.

* Correct args order for MooonshineConfig.

* Update configuration moonshine too.

* Update src/transformers/models/moonshine/modular_moonshine.py

* Update src/transformers/models/moonshine/configuration_moonshine.py

---------

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
This commit is contained in:
Nat Jeffries 2025-01-30 08:08:07 -08:00 committed by GitHub
parent 5757681837
commit 693328f2bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 181 additions and 15 deletions

View file

@ -64,6 +64,9 @@ class MoonshineConfig(PretrainedConfig):
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`decoder_num_attention_heads`.
pad_head_dim_to_multiple_of (`int`, *optional*):
Pad head dimension in encoder and decoder to the next multiple of this value. Necessary for using certain
optimized attention implementations.
encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder.
decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
@ -164,6 +167,7 @@ class MoonshineConfig(PretrainedConfig):
decoder_num_attention_heads=8,
encoder_num_key_value_heads=None,
decoder_num_key_value_heads=None,
pad_head_dim_to_multiple_of=None,
encoder_hidden_act="gelu",
decoder_hidden_act="silu",
max_position_embeddings=512,
@ -196,6 +200,8 @@ class MoonshineConfig(PretrainedConfig):
decoder_num_key_value_heads = decoder_num_attention_heads
self.decoder_num_key_value_heads = decoder_num_key_value_heads
self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of
self.encoder_hidden_act = encoder_hidden_act
self.decoder_hidden_act = decoder_hidden_act
self.max_position_embeddings = max_position_embeddings

View file

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Callable, Optional, Tuple, Union
import numpy as np
@ -27,7 +28,11 @@ import torch.nn as nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
@ -270,6 +275,23 @@ class MoonshineAttention(nn.Module):
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
# Pad head size dimension to next specified multiple. Q K and V always have equal head sizes.
head_dim_padding = 0
if self.config.pad_head_dim_to_multiple_of is not None:
head_dim = query_states.shape[-1]
target_multiple = self.config.pad_head_dim_to_multiple_of
target_head_dim = target_multiple * ((head_dim + target_multiple - 1) // target_multiple)
head_dim_padding = target_head_dim - head_dim
if head_dim_padding > 0:
# Ensure scaling is correct even with padding.
if self.scaling is None:
self.scaling = 1.0 / math.sqrt(query_states.shape[-1])
query_states = torch.nn.functional.pad(query_states, (0, head_dim_padding))
key_states = torch.nn.functional.pad(key_states, (0, head_dim_padding))
value_states = torch.nn.functional.pad(value_states, (0, head_dim_padding))
attn_output, attn_weights = attention_interface(
self,
query_states,
@ -282,6 +304,10 @@ class MoonshineAttention(nn.Module):
**kwargs,
)
# Remove head size padding.
if head_dim_padding > 0:
attn_output = attn_output[:, :, :, :-head_dim_padding]
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
@ -603,9 +629,11 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.
attention_mask (`torch.Tensor`)`, *optional*):
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
but it is not used.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
@ -632,6 +660,22 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
hidden_states = nn.functional.gelu(self.conv3(hidden_states))
hidden_states = hidden_states.permute(0, 2, 1)
# attention mask downsampling
if attention_mask is not None:
mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1])
downsample_stride = 64 * 3 * 2 # conv strides
attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len]
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask == 0.0).any() else None
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
elif self.config._attn_implementation == "sdpa" and not output_attentions:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, hidden_states.dtype)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
# create position embeddings to be shared across the decoder layers
@ -649,7 +693,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
None,
attention_mask,
position_ids,
None,
output_attentions,
@ -660,6 +704,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
position_embeddings=position_embeddings,
@ -810,6 +855,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
@ -817,6 +863,11 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -865,6 +916,26 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# attention mask downsampling
if encoder_attention_mask is not None:
mask_len = encoder_hidden_states.shape[-2]
downsample_stride = 64 * 3 * 2 # conv strides
encoder_attention_mask = encoder_attention_mask[..., ::downsample_stride][..., :mask_len]
if self.config._attn_implementation == "flash_attention_2":
encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
elif self.config._attn_implementation == "sdpa" and not output_attentions:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
)
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
@ -886,6 +957,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
position_ids=position_ids,
past_key_value=past_key_values,
@ -1168,9 +1240,11 @@ MOONSHINE_MODEL_INPUTS_DOCSTRING = r"""
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.
attention_mask (`torch.Tensor`)`, *optional*):
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
but it is not used.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
@ -1371,6 +1445,7 @@ class MoonshineModel(MoonshinePreTrainedModel):
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
@ -1387,6 +1462,7 @@ class MoonshineModel(MoonshinePreTrainedModel):
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_attention_mask=attention_mask,
encoder_hidden_states=encoder_outputs[0],
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
@ -1517,6 +1593,7 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi
outputs = self.model(
input_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Callable, Optional, Tuple, Union
import torch
@ -21,6 +22,10 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...configuration_utils import PretrainedConfig
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
@ -91,6 +96,9 @@ class MoonshineConfig(PretrainedConfig):
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`decoder_num_attention_heads`.
pad_head_dim_to_multiple_of (`int`, *optional*):
Pad head dimension in encoder and decoder to the next multiple of this value. Necessary for using certain
optimized attention implementations.
encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder.
decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
@ -191,6 +199,7 @@ class MoonshineConfig(PretrainedConfig):
decoder_num_attention_heads=8,
encoder_num_key_value_heads=None,
decoder_num_key_value_heads=None,
pad_head_dim_to_multiple_of=None,
encoder_hidden_act="gelu",
decoder_hidden_act="silu",
max_position_embeddings=512,
@ -223,6 +232,8 @@ class MoonshineConfig(PretrainedConfig):
decoder_num_key_value_heads = decoder_num_attention_heads
self.decoder_num_key_value_heads = decoder_num_key_value_heads
self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of
self.encoder_hidden_act = encoder_hidden_act
self.decoder_hidden_act = decoder_hidden_act
self.max_position_embeddings = max_position_embeddings
@ -360,6 +371,23 @@ class MoonshineAttention(GlmAttention):
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
# Pad head size dimension to next specified multiple. Q K and V always have equal head sizes.
head_dim_padding = 0
if self.config.pad_head_dim_to_multiple_of is not None:
head_dim = query_states.shape[-1]
target_multiple = self.config.pad_head_dim_to_multiple_of
target_head_dim = target_multiple * ((head_dim + target_multiple - 1) // target_multiple)
head_dim_padding = target_head_dim - head_dim
if head_dim_padding > 0:
# Ensure scaling is correct even with padding.
if self.scaling is None:
self.scaling = 1.0 / math.sqrt(query_states.shape[-1])
query_states = torch.nn.functional.pad(query_states, (0, head_dim_padding))
key_states = torch.nn.functional.pad(key_states, (0, head_dim_padding))
value_states = torch.nn.functional.pad(value_states, (0, head_dim_padding))
attn_output, attn_weights = attention_interface(
self,
query_states,
@ -372,6 +400,10 @@ class MoonshineAttention(GlmAttention):
**kwargs,
)
# Remove head size padding.
if head_dim_padding > 0:
attn_output = attn_output[:, :, :, :-head_dim_padding]
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
@ -593,9 +625,11 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.
attention_mask (`torch.Tensor`)`, *optional*):
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
but it is not used.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
@ -622,6 +656,22 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
hidden_states = nn.functional.gelu(self.conv3(hidden_states))
hidden_states = hidden_states.permute(0, 2, 1)
# attention mask downsampling
if attention_mask is not None:
mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1])
downsample_stride = 64 * 3 * 2 # conv strides
attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len]
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask == 0.0).any() else None
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
elif self.config._attn_implementation == "sdpa" and not output_attentions:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, hidden_states.dtype)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
# create position embeddings to be shared across the decoder layers
@ -639,7 +689,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
None,
attention_mask,
position_ids,
None,
output_attentions,
@ -650,6 +700,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
position_embeddings=position_embeddings,
@ -698,6 +749,7 @@ class MoonshineDecoder(LlamaModel):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
@ -705,6 +757,11 @@ class MoonshineDecoder(LlamaModel):
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -753,6 +810,26 @@ class MoonshineDecoder(LlamaModel):
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# attention mask downsampling
if encoder_attention_mask is not None:
mask_len = encoder_hidden_states.shape[-2]
downsample_stride = 64 * 3 * 2 # conv strides
encoder_attention_mask = encoder_attention_mask[..., ::downsample_stride][..., :mask_len]
if self.config._attn_implementation == "flash_attention_2":
encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
elif self.config._attn_implementation == "sdpa" and not output_attentions:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
)
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
@ -774,6 +851,7 @@ class MoonshineDecoder(LlamaModel):
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
position_ids=position_ids,
past_key_value=past_key_values,
@ -816,9 +894,11 @@ MOONSHINE_MODEL_INPUTS_DOCSTRING = r"""
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.
attention_mask (`torch.Tensor`)`, *optional*):
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
but it is not used.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
@ -945,6 +1025,7 @@ class MoonshineModel(WhisperModel):
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
@ -961,6 +1042,7 @@ class MoonshineModel(WhisperModel):
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_attention_mask=attention_mask,
encoder_hidden_states=encoder_outputs[0],
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
@ -1075,6 +1157,7 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi
outputs = self.model(
input_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,