From 84a6789145c3d728f2e405d31e9a35df5d74f05c Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Mon, 13 Jan 2025 13:42:08 +0100 Subject: [PATCH] Enable different torch dtype in sub models (#34873) * fix * fix test * add tests * add more tests * fix tests * supposed to be a torch.dtype test * handle BC and make fp32 default --- src/transformers/configuration_utils.py | 7 +- src/transformers/modeling_utils.py | 46 +++++-- .../models/chameleon/modeling_chameleon.py | 114 +++++++++--------- .../models/qwen2_vl/test_modeling_qwen2_vl.py | 1 + tests/utils/test_modeling_utils.py | 55 +++++++++ 5 files changed, 155 insertions(+), 68 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 648877c8d..dfb64fcd0 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -994,8 +994,11 @@ class PretrainedConfig(PushToHubMixin): converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"* string, which can then be stored in the json format. """ - if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str): - d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] + if d.get("torch_dtype", None) is not None: + if isinstance(d["torch_dtype"], dict): + d["torch_dtype"] = {k: str(v).split(".")[-1] for k, v in d["torch_dtype"].items()} + elif not isinstance(d["torch_dtype"], str): + d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] for value in d.values(): if isinstance(value, dict): self.dict_torch_dtype_to_str(value) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8eb2d7439..c09c11050 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1312,11 +1312,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix "`PretrainedConfig`. To create a model from a pretrained model use " f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" ) - # Save config and origin of the pretrained weights if given in model if not getattr(config, "_attn_implementation_autoset", False): - config = self._autoset_attn_implementation( - config, torch_dtype=torch.get_default_dtype(), check_device_map=False - ) + # config usually has a `torch_dtype` but we need the next line for the `no_super_init` tests + dtype = config.torch_dtype if hasattr(config, "torch_dtype") else torch.get_default_dtype() + config = self._autoset_attn_implementation(config, torch_dtype=dtype, check_device_map=False) self.config = config # for initialization of the loss @@ -1411,7 +1410,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # when we init a model from within another model (e.g. VLMs) and dispatch on FA2 # a warning is raised that dtype should be fp16. Since we never pass dtype from within # modeling code, we can try to infer it here same way as done in `from_pretrained` - torch_dtype = kwargs.pop("torch_dtype", torch.get_default_dtype()) + torch_dtype = kwargs.pop("torch_dtype", config.torch_dtype) + if isinstance(torch_dtype, str): + torch_dtype = getattr(torch, torch_dtype) + use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) # override default dtype if needed @@ -4020,11 +4022,37 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) elif hasattr(torch, torch_dtype): torch_dtype = getattr(torch, torch_dtype) - else: - raise ValueError( - f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}' - ) + for sub_config_key in config.sub_configs.keys(): + sub_config = getattr(config, sub_config_key) + sub_config.torch_dtype = torch_dtype + elif isinstance(torch_dtype, torch.dtype): + pass + elif isinstance(torch_dtype, dict): + for key, curr_dtype in torch_dtype.items(): + if hasattr(config, key): + value = getattr(config, key) + value.torch_dtype = curr_dtype + # main torch dtype for modules that aren't part of any sub-config + torch_dtype = torch_dtype.get("") + config.torch_dtype = torch_dtype + if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype): + torch_dtype = getattr(torch, torch_dtype) + elif torch_dtype is None: + torch_dtype = torch.float32 + else: + raise ValueError( + f"`torch_dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `torch_dtype` " + f"for each sub-config in composite configs, but received {torch_dtype}" + ) + dtype_orig = cls._set_default_torch_dtype(torch_dtype) + else: + # set fp32 as the default dtype for BC + default_dtype = str(torch.get_default_dtype()).split(".")[-1] + config.torch_dtype = default_dtype + for key in config.sub_configs.keys(): + value = getattr(config, key) + value.torch_dtype = default_dtype # Check if `_keep_in_fp32_modules` is not None use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 90a02dd5b..edbac91bb 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -967,62 +967,6 @@ class ChameleonVQVAEEncoder(nn.Module): return last_hidden_state -CHAMELEON_VQ_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`ChameleonVQVAEConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - """The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens. - This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from - [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131). - """, - CHAMELEON_VQ_START_DOCSTRING, -) -class ChameleonVQVAE(PreTrainedModel): - config_class = ChameleonVQVAEConfig - _no_split_modules = ["ChameleonVQVAEVectorQuantizer"] - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - elif isinstance(module, nn.GroupNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - - def __init__(self, config: ChameleonVQVAEConfig): - super().__init__(config) - - self.encoder = ChameleonVQVAEEncoder(config) - self.quantize = ChameleonVQVAEVectorQuantizer(config) - self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1) - self.eval() # Chameleon's VQ model is frozen - - def encode(self, pixel_values: torch.LongTensor): - hidden_states = self.encoder(pixel_values) - hidden_states = self.quant_conv(hidden_states) - quant, emb_loss, indices = self.quantize(hidden_states) - return quant, emb_loss, indices - - class ChameleonImageVocabularyMapping: """ A class for mapping discrete image tokens from VQGAN to BPE tokens. @@ -1118,6 +1062,62 @@ class ChameleonPreTrainedModel(PreTrainedModel): module.weight.data[module.padding_idx].zero_() +CHAMELEON_VQ_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ChameleonVQVAEConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + """The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens. + This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from + [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131). + """, + CHAMELEON_VQ_START_DOCSTRING, +) +class ChameleonVQVAE(ChameleonPreTrainedModel): + config_class = ChameleonVQVAEConfig + _no_split_modules = ["ChameleonVQVAEVectorQuantizer"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + elif isinstance(module, nn.GroupNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + + def __init__(self, config: ChameleonVQVAEConfig): + super().__init__(config) + + self.encoder = ChameleonVQVAEEncoder(config) + self.quantize = ChameleonVQVAEVectorQuantizer(config) + self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1) + self.eval() # Chameleon's VQ model is frozen + + def encode(self, pixel_values: torch.LongTensor): + hidden_states = self.encoder(pixel_values) + hidden_states = self.quant_conv(hidden_states) + quant, emb_loss, indices = self.quantize(hidden_states) + return quant, emb_loss, indices + + CHAMELEON_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -1211,7 +1211,7 @@ class ChameleonModel(ChameleonPreTrainedModel): [decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.vqmodel = ChameleonVQVAE(config.vq_config) + self.vqmodel = ChameleonVQVAE._from_config(config.vq_config) self.gradient_checkpointing = False # Initialize weights and apply final processing diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index aedd37992..1b2891fe6 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -227,6 +227,7 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas pipeline_model_mapping = {"image-text-to-text": Qwen2VLForConditionalGeneration} test_pruning = False test_head_masking = False + _is_composite = True def setUp(self): self.model_tester = Qwen2VLVisionText2TextModelTester(self) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 383f0cbe6..b8e10ff8a 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -37,6 +37,7 @@ from transformers import ( AutoModel, AutoModelForImageClassification, AutoModelForSequenceClassification, + LlavaForConditionalGeneration, OwlViTForObjectDetection, PretrainedConfig, is_torch_available, @@ -300,6 +301,7 @@ TINY_T5 = "patrickvonplaten/t5-tiny-random" TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification" TINY_MISTRAL = "hf-internal-testing/tiny-random-MistralForCausalLM" TINY_IMAGE_CLASSIF = "hf-internal-testing/tiny-random-SiglipForImageClassification" +TINY_LLAVA = "hf-internal-testing/tiny-random-LlavaForConditionalGeneration" LOG = logging.get_logger(__name__) @@ -460,6 +462,59 @@ class ModelUtilsTest(TestCasePlus): with self.assertRaises(ValueError): model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64") + def test_model_from_config_torch_dtype_composite(self): + """ + Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config + """ + # should be able to set torch_dtype as a simple string and the model loads it correctly + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32") + self.assertEqual(model.language_model.dtype, torch.float32) + self.assertEqual(model.vision_tower.dtype, torch.float32) + + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float16") + self.assertEqual(model.language_model.dtype, torch.float16) + self.assertEqual(model.vision_tower.dtype, torch.float16) + + # should be able to set torch_dtype as a dict for each sub-config + model = LlavaForConditionalGeneration.from_pretrained( + TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "float16", "": "bfloat16"} + ) + self.assertEqual(model.language_model.dtype, torch.float32) + self.assertEqual(model.vision_tower.dtype, torch.float16) + self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16) + + # should be able to set the values as torch.dtype (not str) + model = LlavaForConditionalGeneration.from_pretrained( + TINY_LLAVA, torch_dtype={"text_config": torch.float32, "vision_config": torch.float16, "": torch.bfloat16} + ) + self.assertEqual(model.language_model.dtype, torch.float32) + self.assertEqual(model.vision_tower.dtype, torch.float16) + self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16) + + # should be able to set the values in configs directly and pass it to `from_pretrained` + config = copy.deepcopy(model.config) + config.text_config.torch_dtype = torch.float32 + config.vision_config.torch_dtype = torch.bfloat16 + config.torch_dtype = torch.float16 + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto") + self.assertEqual(model.language_model.dtype, torch.float32) + self.assertEqual(model.vision_tower.dtype, torch.bfloat16) + self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16) + + # but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what + LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"] + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto") + self.assertEqual(model.language_model.dtype, torch.float32) + self.assertEqual(model.vision_tower.dtype, torch.bfloat16) + self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32) + + # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type + with self.assertRaises(ValueError): + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="int64") + model = LlavaForConditionalGeneration.from_pretrained( + TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "int64", "": "float16"} + ) + @require_torch def test_model_from_pretrained_meta_device(self): def is_on_meta(model_id, dtype):