mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Support segformer fx (#19924)
* Support segformer fx * Add fx_compatible attribute to test_modeling_segformer.py * Update glpn model (fx support) glpn model was copied from segformer. * Update utils/fx.py | add semantic-segmentation for SegformerForSemanticSegmentation model * Fix minor import order(isort) * Add random input generation for segformer fx Co-authored-by: noelbird <lduldu00228@gmail.com>
This commit is contained in:
parent
dcca71be61
commit
347ba38cb4
4 changed files with 9 additions and 4 deletions
|
|
@ -149,7 +149,7 @@ class GLPNEfficientSelfAttention(nn.Module):
|
|||
|
||||
def transpose_for_scores(self, hidden_states):
|
||||
new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
hidden_states = hidden_states.view(*new_shape)
|
||||
hidden_states = hidden_states.view(new_shape)
|
||||
return hidden_states.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
|
|
@ -190,7 +190,7 @@ class GLPNEfficientSelfAttention(nn.Module):
|
|||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
|
|
|
|||
|
|
@ -179,7 +179,7 @@ class SegformerEfficientSelfAttention(nn.Module):
|
|||
|
||||
def transpose_for_scores(self, hidden_states):
|
||||
new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
hidden_states = hidden_states.view(*new_shape)
|
||||
hidden_states = hidden_states.view(new_shape)
|
||||
return hidden_states.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
|
|
@ -220,7 +220,7 @@ class SegformerEfficientSelfAttention(nn.Module):
|
|||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ from ..models.auto.modeling_auto import (
|
|||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
|
||||
MODEL_FOR_PRETRAINING_MAPPING_NAMES,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
||||
|
|
@ -80,6 +81,7 @@ def _generate_supported_model_class_names(
|
|||
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
"ctc": MODEL_FOR_CTC_MAPPING_NAMES,
|
||||
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||||
"semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
|
||||
}
|
||||
|
||||
if supported_tasks is None:
|
||||
|
|
@ -128,6 +130,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
|||
"plbart",
|
||||
"resnet",
|
||||
"roberta",
|
||||
"segformer",
|
||||
"speech_to_text",
|
||||
"speech_to_text_2",
|
||||
"swin",
|
||||
|
|
@ -730,6 +733,7 @@ class HFTracer(Tracer):
|
|||
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES),
|
||||
"GPT2DoubleHeadsModel",
|
||||
]:
|
||||
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
|
||||
|
|
|
|||
|
|
@ -161,6 +161,7 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
else ()
|
||||
)
|
||||
|
||||
fx_compatible = True
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
|
|
|
|||
Loading…
Reference in a new issue