mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Fix accelerate failing tests (#30836)
* Fix accelerate tests * fix clip * skip dbrx tests * fix GPTSan * fix M2M100Model * same fix as jamba * fix mt5 * Fix T5Model * Fix umt5 model * fix switch_transformers * fix whisper * fix gptsan again * fix siglip recent test * skip siglip tests * wrong place fixed
This commit is contained in:
parent
5a74ae6dbe
commit
8366b57241
14 changed files with 36 additions and 32 deletions
|
|
@ -964,7 +964,7 @@ class BertModel(BertPreTrainedModel):
|
|||
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
||||
"""
|
||||
|
||||
_no_split_modules = ["BertEmbeddings"]
|
||||
_no_split_modules = ["BertEmbeddings", "BertLayer"]
|
||||
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
|
|
|
|||
|
|
@ -933,7 +933,7 @@ class CLIPVisionModel(CLIPPreTrainedModel):
|
|||
@add_start_docstrings(CLIP_START_DOCSTRING)
|
||||
class CLIPModel(CLIPPreTrainedModel):
|
||||
config_class = CLIPConfig
|
||||
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
||||
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer", "CLIPVisionEmbeddings"]
|
||||
|
||||
def __init__(self, config: CLIPConfig):
|
||||
super().__init__(config)
|
||||
|
|
@ -1135,7 +1135,9 @@ class CLIPModel(CLIPPreTrainedModel):
|
|||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
||||
logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * logit_scale.to(
|
||||
text_embeds.device
|
||||
)
|
||||
logits_per_image = logits_per_text.t()
|
||||
|
||||
loss = None
|
||||
|
|
|
|||
|
|
@ -739,7 +739,7 @@ class M2M100PreTrainedModel(PreTrainedModel):
|
|||
config_class = M2M100Config
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["M2M100Attention"]
|
||||
_no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"]
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
|
|
|||
|
|
@ -244,6 +244,7 @@ class MambaMixer(nn.Module):
|
|||
# 2. Convolution sequence transformation
|
||||
if cache_params is not None:
|
||||
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
|
||||
ssm_state = ssm_state.to(hidden_states.device)
|
||||
if cache_params.seqlen_offset > 0:
|
||||
conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
|
||||
conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
|
||||
|
|
|
|||
|
|
@ -959,7 +959,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
|||
class SiglipVisionModel(SiglipPreTrainedModel):
|
||||
config_class = SiglipVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["SiglipVisionTransformer"]
|
||||
_no_split_modules = ["SiglipVisionEmbeddings", "SiglipEncoderLayer", "SiglipMultiheadAttentionPoolingHead"]
|
||||
|
||||
def __init__(self, config: SiglipVisionConfig):
|
||||
super().__init__(config)
|
||||
|
|
|
|||
|
|
@ -1393,7 +1393,7 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||
inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
|
|
|||
|
|
@ -465,6 +465,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||
else {}
|
||||
)
|
||||
fx_compatible = True
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
|
|
|||
|
|
@ -354,6 +354,20 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||
def test_tied_weights_keys(self):
|
||||
pass
|
||||
|
||||
# Offload does not work with Dbrx models because of the forward of DbrxExperts where we chunk the experts.
|
||||
# The issue is that the offloaded weights of the mlp layer are still on meta device (w1_chunked, v1_chunked, w2_chunked)
|
||||
@unittest.skip("Dbrx models do not work with offload")
|
||||
def test_cpu_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Dbrx models do not work with offload")
|
||||
def test_disk_offload_safetensors(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Dbrx models do not work with offload")
|
||||
def test_disk_offload_bin(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class DbrxModelIntegrationTest(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -145,12 +145,10 @@ class GPTSanJapaneseTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
|||
is_encoder_decoder = False
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_cpu_offload = False
|
||||
test_disk_offload = False
|
||||
test_save_load_fast_init_to_base = False
|
||||
test_training = False
|
||||
# The small GPTSAN_JAPANESE model needs higher percentages for CPU/MP tests
|
||||
model_split_percents = [0.8, 0.9]
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
# TODO: Fix the failed tests when this model gets more usage
|
||||
def is_pipeline_test_to_skip(
|
||||
|
|
@ -198,10 +196,8 @@ class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTes
|
|||
is_encoder_decoder = False
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_cpu_offload = False
|
||||
test_disk_offload = False
|
||||
# The small GPTSAN_JAPANESE model needs higher percentages for CPU/MP tests
|
||||
model_split_percents = [0.8, 0.9]
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = GPTSanJapaneseTester(self)
|
||||
|
|
|
|||
|
|
@ -574,7 +574,7 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||
test_model_parallel = True
|
||||
is_encoder_decoder = True
|
||||
# The small MT5 model needs higher percentages for CPU/MP tests
|
||||
model_split_percents = [0.8, 0.9]
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MT5ModelTester(self)
|
||||
|
|
@ -886,10 +886,6 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
||||
|
||||
@unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")
|
||||
def test_disk_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -146,6 +146,12 @@ class SiglipVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
# MP works but offload doesn't work when the MultiheadAttention is offloaded
|
||||
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
|
||||
# in the dispatch_model function
|
||||
test_cpu_offload = False
|
||||
test_disk_offload_safetensors = False
|
||||
test_disk_offload_bin = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SiglipVisionModelTester(self)
|
||||
|
|
|
|||
|
|
@ -575,7 +575,7 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
|||
is_encoder_decoder = True
|
||||
test_torchscript = False
|
||||
# The small SWITCH_TRANSFORMERS model needs higher percentages for CPU/MP tests
|
||||
model_split_percents = [0.8, 0.9]
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SwitchTransformersModelTester(self)
|
||||
|
|
@ -721,10 +721,6 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
|||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
||||
|
||||
@unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")
|
||||
def test_disk_offload(self):
|
||||
pass
|
||||
|
||||
|
||||
class SwitchTransformersEncoderOnlyModelTester:
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -577,7 +577,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||
test_model_parallel = True
|
||||
is_encoder_decoder = True
|
||||
# The small T5 model needs higher percentages for CPU/MP tests
|
||||
model_split_percents = [0.8, 0.9]
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = T5ModelTester(self)
|
||||
|
|
@ -889,10 +889,6 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
||||
|
||||
@unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")
|
||||
def test_disk_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -315,7 +315,7 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||
test_missing_keys = True
|
||||
test_torchscript = True
|
||||
# The small UMT5 model needs higher percentages for CPU/MP tests
|
||||
model_split_percents = [0.8, 0.9]
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = UMT5ModelTester(self)
|
||||
|
|
@ -536,10 +536,6 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
||||
|
||||
@unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")
|
||||
def test_disk_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue