diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 510f3fe1a..95dcaea95 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2042,16 +2042,10 @@ class GenerationTesterMixin: with self.assertRaises(ValueError): model.generate(**generation_kwargs, **inputs_dict) - @parameterized.expand( - [ - ("forward_only", False), # TODO (@joao): a few models failing. After fixed, this should not be "@slow" - ("end_to_end", True), # TODO (@joao): end-to-end compilation is broken with torch 2.5+, explore and fix - ] - ) @pytest.mark.generate @require_torch_gpu @slow - def test_generate_compile(self, _, end_to_end): + def test_generate_compile_model_forward(self): """ Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests end-to-end compilation and forward pass compilation only. @@ -2061,14 +2055,7 @@ class GenerationTesterMixin: if not model_class._supports_static_cache: self.skipTest("This model doesn't support static cache") - # TODO (joao) -- fix and enable me :) - if end_to_end and any(model_name in model_class.__name__.lower() for model_name in ["whisper"]): - self.skipTest("whisper model end-to-end generate compile not yet supported") - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - # TODO (joao) -- fix and enable me :) - if end_to_end and config.is_encoder_decoder: - self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported") model = model_class(config).to(torch_device) model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time @@ -2084,10 +2071,8 @@ class GenerationTesterMixin: "max_new_tokens": 10, "return_dict_in_generate": True, "output_scores": True, + "cache_implementation": "static", } - # end-to-end works best with dynamic cache, forward compilation works best with static cache - if not end_to_end: - generation_kwargs["cache_implementation"] = "static" # get eager + dynamic cache results for future comparison dynamic_outputs = [] @@ -2098,10 +2083,8 @@ class GenerationTesterMixin: generation_config = copy.deepcopy(model.generation_config) generation_config.update(**generation_kwargs) torch.compiler.reset() - if end_to_end: - model.generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead") - else: - model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead") + + model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead") compiled_outputs = [] for model_inputs in input_ids_sets: diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py index bb2ba8b34..01d4ef720 100644 --- a/tests/models/chameleon/test_modeling_chameleon.py +++ b/tests/models/chameleon/test_modeling_chameleon.py @@ -333,7 +333,7 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester # TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow @unittest.skip("Chameleon is not compatible with end-to-end generation compilation") - def test_generate_compile_fullgraph(self): + def test_generate_compile_model_forward(self): pass diff --git a/tests/models/dbrx/test_modeling_dbrx.py b/tests/models/dbrx/test_modeling_dbrx.py index d38a479ab..dee93109d 100644 --- a/tests/models/dbrx/test_modeling_dbrx.py +++ b/tests/models/dbrx/test_modeling_dbrx.py @@ -369,7 +369,7 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin pass @unittest.skip("Dbrx does not support `torch.compile` with `fullgraph=True`.") - def test_generate_compile_fullgraph(self): + def test_generate_compile_model_forward(self): pass diff --git a/tests/models/emu3/test_modeling_emu3.py b/tests/models/emu3/test_modeling_emu3.py index d1c4501c5..007207c06 100644 --- a/tests/models/emu3/test_modeling_emu3.py +++ b/tests/models/emu3/test_modeling_emu3.py @@ -176,10 +176,6 @@ class Emu3Text2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTe def test_custom_4d_attention_mask(self): pass - @unittest.skip("Fails with unknown error only on end-to-end compile") # TODO raushan fixme - def test_generate_compile_1_end_to_end(self): - pass - class Emu3Vision2TextModelTester: def __init__( @@ -398,10 +394,6 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline def test_initialization(self): pass - @unittest.skip("End-to-end compilation is not supported due to dynamic control in `prepare_inputs_for_generation`") - def test_generate_compile_1_end_to_end(self): - pass - @require_torch class Emu3IntegrationTest(unittest.TestCase): diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index 94229b13d..a8f1304b6 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -781,7 +781,7 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni pass @unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs") - def test_generate_compile_fullgraph(self): + def test_generate_compile_model_forward(self): pass @unittest.skip(reason="We only test the model that takes in multiple images") diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index f973e1211..587d46064 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -348,7 +348,7 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes # TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow @unittest.skip("PaliGemma is not compatible with end-to-end generation compilation") - def test_generate_compile_fullgraph(self): + def test_generate_compile_model_forward(self): pass diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index 1b2891fe6..fc9adcebf 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -333,7 +333,7 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas pass @unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`") - def test_generate_compile_fullgraph(self): + def test_generate_compile_model_forward(self): pass