mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[Peft] modules_to_save support for peft integration (#27466)
* `modules_to_save` support for peft integration * Update docs/source/en/peft.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * slightly elaborate test --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
721d1c8ca6
commit
d71fa9f618
3 changed files with 63 additions and 3 deletions
|
|
@ -98,7 +98,7 @@ You can use [`~peft.PeftModel.add_adapter`] to add a new adapter to a model with
|
|||
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
|
||||
from peft import PeftConfig
|
||||
from peft import LoraConfig
|
||||
|
||||
model_id = "facebook/opt-350m"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
|
@ -208,6 +208,26 @@ model.save_pretrained(save_dir)
|
|||
model = AutoModelForCausalLM.from_pretrained(save_dir)
|
||||
```
|
||||
|
||||
## Add additional trainable layers to a PEFT adapter
|
||||
|
||||
You can also fine-tune additional trainable adapters on top of a model that has adapters attached by passing `modules_to_save` in your PEFT config. For example, if you want to also fine-tune the lm_head on top of a model with a LoRA adapter:
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
|
||||
model_id = "facebook/opt-350m"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
||||
lora_config = LoraConfig(
|
||||
target_modules=["q_proj", "k_proj"],
|
||||
modules_to_save=["lm_head"],
|
||||
)
|
||||
|
||||
model.add_adapter(lora_config)
|
||||
```
|
||||
|
||||
|
||||
<!--
|
||||
TODO: (@younesbelkada @stevhliu)
|
||||
- Link to PEFT docs for further details
|
||||
|
|
|
|||
|
|
@ -292,11 +292,12 @@ class PeftAdapterMixin:
|
|||
)
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
from peft.utils import ModulesToSaveWrapper
|
||||
|
||||
_adapters_has_been_set = False
|
||||
|
||||
for _, module in self.named_modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
|
||||
# For backward compatbility with previous PEFT versions
|
||||
if hasattr(module, "set_adapter"):
|
||||
module.set_adapter(adapter_name)
|
||||
|
|
@ -322,9 +323,10 @@ class PeftAdapterMixin:
|
|||
raise ValueError("No adapter loaded. Please load an adapter first.")
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
from peft.utils import ModulesToSaveWrapper
|
||||
|
||||
for _, module in self.named_modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
|
||||
# The recent version of PEFT need to call `enable_adapters` instead
|
||||
if hasattr(module, "enable_adapters"):
|
||||
module.enable_adapters(enabled=False)
|
||||
|
|
|
|||
|
|
@ -182,6 +182,44 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||
model_from_pretrained = transformers_class.from_pretrained(tmpdirname).to(torch_device)
|
||||
self.assertTrue(self._check_lora_correctly_converted(model_from_pretrained))
|
||||
|
||||
def test_peft_add_adapter_modules_to_save(self):
|
||||
"""
|
||||
Simple test that tests if `add_adapter` works as expected when training with
|
||||
modules to save.
|
||||
"""
|
||||
from peft import LoraConfig
|
||||
from peft.utils import ModulesToSaveWrapper
|
||||
|
||||
for model_id in self.transformers_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
|
||||
|
||||
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
peft_config = LoraConfig(init_lora_weights=False, modules_to_save=["lm_head"])
|
||||
model.add_adapter(peft_config)
|
||||
self._check_lora_correctly_converted(model)
|
||||
|
||||
_has_modules_to_save_wrapper = False
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, ModulesToSaveWrapper):
|
||||
_has_modules_to_save_wrapper = True
|
||||
self.assertTrue(module.modules_to_save.default.weight.requires_grad)
|
||||
self.assertTrue("lm_head" in name)
|
||||
break
|
||||
|
||||
self.assertTrue(_has_modules_to_save_wrapper)
|
||||
state_dict = model.get_adapter_state_dict()
|
||||
|
||||
self.assertTrue("lm_head.weight" in state_dict.keys())
|
||||
|
||||
logits = model(dummy_input).logits
|
||||
loss = logits.mean()
|
||||
loss.backward()
|
||||
|
||||
for _, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.assertTrue(param.grad is not None)
|
||||
|
||||
def test_peft_add_adapter_training_gradient_checkpointing(self):
|
||||
"""
|
||||
Simple test that tests if `add_adapter` works as expected when training with
|
||||
|
|
|
|||
Loading…
Reference in a new issue