diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c6ce74281..484467c0c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -73,6 +73,7 @@ from .utils import ( logging, replace_return_docstrings, ) +from .utils.import_utils import importlib_metadata from .utils.versions import require_version_core @@ -2439,6 +2440,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert ) + # training in 8-bit is only available in 0.37.0+ + model._is_int8_training_enabled = version.parse( + importlib_metadata.version("bitsandbytes") + ) >= version.parse("0.37.0") + if isinstance(device_map, str): if model._no_split_modules is None: raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.") diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f6c51e6b8..05cfa103e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -368,10 +368,18 @@ class Trainer: # At this stage the model is already loaded if getattr(model, "is_loaded_in_8bit", False): - raise ValueError( - "The model you want to train is loaded in 8-bit precision. " - "Training an 8-bit model is not supported yet. " - ) + if getattr(model, "_is_int8_training_enabled", False): + logger.info( + "The model is loaded in 8-bit precision. To train this model you need to add additional modules" + " inside the model such as adapters using `peft` library and freeze the model weights. Please" + " check " + " the examples in https://github.com/huggingface/peft for more details." + ) + else: + raise ValueError( + "The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit" + " model, please make sure that you have installed `bitsandbytes>=0.37.0`. " + ) # Setup Sharded DDP training self.sharded_ddp = None @@ -458,7 +466,7 @@ class Trainer: self.eval_dataset = eval_dataset self.tokenizer = tokenizer - if self.place_model_on_device: + if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False): self._move_model_to_device(model, args.device) # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index b1e8ab1a3..abd5199ed 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -16,6 +16,8 @@ import gc import tempfile import unittest +from packaging import version + from transformers import ( AutoModel, AutoModelForCausalLM, @@ -33,10 +35,30 @@ from transformers.testing_utils import ( require_torch_multi_gpu, slow, ) +from transformers.utils.versions import importlib_metadata if is_torch_available(): import torch + import torch.nn as nn + + class LoRALayer(nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only""" + + def __init__(self, module: nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = nn.Sequential( + nn.Linear(module.in_features, rank, bias=False), + nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + nn.init.normal_(self.adapter[0].weight, std=small_std) + nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) @require_bitsandbytes @@ -335,3 +357,44 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test): # Second real batch output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) self.assertEqual(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + +class MixedInt8TestTraining(BaseMixedInt8Test): + def setUp(self): + self.model_name = "facebook/opt-350m" + super().setUp() + + def test_training(self): + if version.parse(importlib_metadata.version("bitsandbytes")) < version.parse("0.37.0"): + return + + # Step 1: freeze all parameters + model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + + for param in model.parameters(): + param.requires_grad = False # freeze the model - train adapters later + if param.ndim == 1: + # cast the small parameters (e.g. layernorm) to fp32 for stability + param.data = param.data.to(torch.float32) + + # Step 2: add adapters + for _, module in model.named_modules(): + if "OPTAttention" in repr(type(module)): + module.q_proj = LoRALayer(module.q_proj, rank=16) + module.k_proj = LoRALayer(module.k_proj, rank=16) + module.v_proj = LoRALayer(module.v_proj, rank=16) + + # Step 3: dummy batch + batch = self.tokenizer("Test batch ", return_tensors="pt").to(0) + + # Step 4: Check if the gradient is not None + with torch.cuda.amp.autocast(): + out = model.forward(**batch) + out.logits.norm().backward() + + for module in model.modules(): + if isinstance(module, LoRALayer): + self.assertTrue(module.adapter[1].weight.grad is not None) + self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) + elif isinstance(module, nn.Embedding): + self.assertTrue(module.weight.grad is None)