From 1af4bee8965c80e54c0e21aa8aadd035fd1f4189 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 13 Dec 2022 11:59:57 +0100 Subject: [PATCH] Add `keep_in_fp32_modules` support (#20683) * add `keep_in_fp32_modules` support * pass it as class attribute * few modifs - make tests `slow` - fix logic * better logic * fix failing test * `bfloat16` support * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix * simplify tests * simplify tests * fix test * modify message * more checks * fix failing tests * add more conditions - add `is_accelerate_available` - fixes pipleine tests that failed * add suggestions * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix failing `bnb` test * add last safety checker Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/modeling_utils.py | 70 +++++++++++++++++++++-- src/transformers/models/t5/modeling_t5.py | 1 + src/transformers/utils/bitsandbytes.py | 2 +- tests/mixed_int8/test_mixed_int8.py | 7 +++ tests/models/t5/test_modeling_t5.py | 53 ++++++++++++++++- 5 files changed, 127 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index fbd7b2814..48e437fd7 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -562,6 +562,7 @@ def _load_state_dict_into_meta_model( dtype=None, load_in_8bit=False, is_safetensors=False, + keep_in_fp32_modules=None, ): """ This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its @@ -612,7 +613,14 @@ def _load_state_dict_into_meta_model( # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params # in int/uint/bool and not cast them. if dtype is not None and torch.is_floating_point(param): - param = param.to(dtype) + if ( + keep_in_fp32_modules is not None + and any(module_to_keep_in_fp32 in param_name for module_to_keep_in_fp32 in keep_in_fp32_modules) + and dtype == torch.float16 + ): + param = param.to(torch.float32) + else: + param = param.to(dtype) # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model if dtype is None: @@ -974,6 +982,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix main_input_name = "input_ids" _auto_class = None _no_split_modules = None + _keep_in_fp32_modules = None # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings. @@ -2071,6 +2080,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Load model loading_info = None + # Keep in fp32 modules + keep_in_fp32_modules = None + use_keep_in_fp32_modules = False + if pretrained_model_name_or_path is not None: pretrained_model_name_or_path = str(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path) @@ -2269,6 +2282,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype # we also may have config.torch_dtype available, but we won't rely on it till v5 dtype_orig = None + if torch_dtype is not None: if isinstance(torch_dtype, str): if torch_dtype == "auto": @@ -2286,11 +2300,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) dtype_orig = cls._set_default_torch_dtype(torch_dtype) + # Check if `_keep_in_fp32_modules` is not None + use_keep_in_fp32_modules = ( + (cls._keep_in_fp32_modules is not None) and is_accelerate_available() and torch_dtype == torch.float16 + ) + if ( + (cls._keep_in_fp32_modules is not None) + and not is_accelerate_available() + and torch_dtype == torch.float16 + ): + logger.warning( + "For stability purposes, it is recommended to have accelerate installed when using this model in" + " torch.float16, please install it with `pip install accelerate`" + ) + if is_sharded: loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] else: loaded_state_dict_keys = [k for k in state_dict.keys()] - if low_cpu_mem_usage: + if low_cpu_mem_usage or use_keep_in_fp32_modules: state_dict = None config.name_or_path = pretrained_model_name_or_path @@ -2309,6 +2337,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix with ContextManagers(init_contexts): model = cls(config, *model_args, **model_kwargs) + # Check first if we are `from_pt` + if use_keep_in_fp32_modules: + low_cpu_mem_usage = True + keep_in_fp32_modules = model._keep_in_fp32_modules + else: + keep_in_fp32_modules = [] + if load_in_8bit: from .utils.bitsandbytes import get_keys_to_not_convert, replace_8bit_linear @@ -2319,6 +2354,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix modules_to_not_convert = get_keys_to_not_convert(model) else: modules_to_not_convert = load_in_8bit_skip_modules + + if not isinstance(modules_to_not_convert, list): + modules_to_not_convert = [modules_to_not_convert] + + modules_to_not_convert.extend(keep_in_fp32_modules) + model = replace_8bit_linear( model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert ) @@ -2425,6 +2466,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix offload_state_dict=offload_state_dict, dtype=torch_dtype, load_in_8bit=load_in_8bit, + keep_in_fp32_modules=keep_in_fp32_modules, ) model.is_loaded_in_8bit = load_in_8bit @@ -2468,6 +2510,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix offload_state_dict=None, dtype=None, load_in_8bit=False, + keep_in_fp32_modules=None, ): is_safetensors = False if load_in_8bit: @@ -2544,11 +2587,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if key.startswith(prefix): key = ".".join(key.split(".")[1:]) param = model_state_dict[key] + + # upcast in fp32 if any + target_dtype = dtype + if ( + keep_in_fp32_modules is not None + and dtype == torch.float16 + and any(module_to_keep_in_fp32 in key for module_to_keep_in_fp32 in keep_in_fp32_modules) + ): + target_dtype = torch.float32 + if param.device == torch.device("meta"): if not load_in_8bit: - set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype)) + set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype)) else: - set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype)) + set_module_8bit_tensor_to_device( + model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype) + ) # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. if _fast_init: @@ -2558,6 +2613,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix for module in uninitialized_modules: model._init_weights(module) + # Set some modules to fp32 if any + if keep_in_fp32_modules is not None: + for name, param in model.named_parameters(): + if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules): + param = param.to(torch.float32) + # Make sure we are able to load base models as well as derived models (with heads) start_prefix = "" model_to_load = model @@ -2693,6 +2754,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix dtype=dtype, load_in_8bit=load_in_8bit, is_safetensors=is_safetensors, + keep_in_fp32_modules=keep_in_fp32_modules, ) error_msgs += new_error_msgs else: diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index a1510af74..96a3dd3c1 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -757,6 +757,7 @@ class T5PreTrainedModel(PreTrainedModel): is_parallelizable = True supports_gradient_checkpointing = True _no_split_modules = ["T5Block"] + _keep_in_fp32_modules = ["wo"] @property def dummy_inputs(self): diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index b2339efd6..4e14dbaf7 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -150,7 +150,7 @@ def get_keys_to_not_convert(model): # Ignore this for base models (BertModel, GPT2Model, etc.) if (not has_tied_params) and is_base_model: - return "" + return [] # otherwise they have an attached head list_modules = list(model.named_parameters()) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 67af67e1d..56ce10638 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -155,6 +155,13 @@ class MixedInt8Test(BaseMixedInt8Test): # Check this does not throw an error _ = self.model_fp16.float() + def test_fp32_int8_conversion(self): + r""" + Test whether it is possible to mix both `int8` and `fp32` weights when using `keep_in_fp32_modules` correctly. + """ + model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", load_in_8bit=True, device_map="auto") + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) + class MixedInt8ModelClassesTest(BaseMixedInt8Test): def setUp(self): diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index ab6a039c9..fe3ce7597 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -19,7 +19,14 @@ import tempfile import unittest from transformers import T5Config, is_torch_available -from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device +from transformers.testing_utils import ( + require_accelerate, + require_sentencepiece, + require_tokenizers, + require_torch, + slow, + torch_device, +) from transformers.utils import cached_property from ...generation.test_utils import GenerationTesterMixin @@ -820,6 +827,50 @@ def use_task_specific_params(model, task): model.config.update(model.config.task_specific_params[task]) +@require_torch +@require_accelerate +@require_tokenizers +@slow +class T5ModelFp16Tests(unittest.TestCase): + def test_fp16_fp32_conversion(self): + r""" + A test to check whether the argument `keep_in_fp32_modules` correctly does its job + """ + # Load without using `accelerate` + model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) + + # Load without in bf16 + model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) + + # Load using `accelerate` in bf16 + model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16, device_map="auto") + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) + + # Load using `accelerate` in bf16 + model = T5ForConditionalGeneration.from_pretrained( + "t5-small", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True + ) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) + + # Load without using `accelerate` + model = T5ForConditionalGeneration.from_pretrained( + "t5-small", torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) + + # Load using `accelerate` + model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16, device_map="auto") + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) + + @require_torch @require_sentencepiece @require_tokenizers