diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index 69e674a21..791528a62 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import importlib import inspect import warnings @@ -525,3 +526,64 @@ class PeftAdapterMixin: offload_dir=offload_folder, **dispatch_model_kwargs, ) + + def delete_adapter(self, adapter_names: Union[List[str], str]) -> None: + """ + Delete an adapter's LoRA layers from the underlying model. + + Args: + adapter_names (`Union[List[str], str]`): + The name(s) of the adapter(s) to delete. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic" + ) + pipeline.delete_adapters("cinematic") + ``` + """ + + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + if isinstance(adapter_names, str): + adapter_names = [adapter_names] + + # Check that all adapter names are present in the config + missing_adapters = [name for name in adapter_names if name not in self.peft_config] + if missing_adapters: + raise ValueError( + f"The following adapter(s) are not present and cannot be deleted: {', '.join(missing_adapters)}" + ) + + for adapter_name in adapter_names: + for module in self.modules(): + if isinstance(module, BaseTunerLayer): + if hasattr(module, "delete_adapter"): + module.delete_adapter(adapter_name) + else: + raise ValueError( + "The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1" + ) + + # For transformers integration - we need to pop the adapter from the config + if getattr(self, "_hf_peft_config_loaded", False) and hasattr(self, "peft_config"): + self.peft_config.pop(adapter_name, None) + + # In case all adapters are deleted, we need to delete the config + # and make sure to set the flag to False + if len(self.peft_config) == 0: + del self.peft_config + self._hf_peft_config_loaded = False diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index bdbccee5a..6d6330d3d 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -350,7 +350,6 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): self.assertFalse( torch.allclose(logits_adapter_1.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6) ) - self.assertFalse( torch.allclose(logits_adapter_2.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6) ) @@ -359,6 +358,70 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) + def test_delete_adapter(self): + """ + Enhanced test for `delete_adapter` to handle multiple adapters, + edge cases, and proper error handling. + """ + from peft import LoraConfig + + for model_id in self.transformers_test_model_ids: + for transformers_class in self.transformers_test_model_classes: + model = transformers_class.from_pretrained(model_id).to(torch_device) + + # Add multiple adapters + peft_config_1 = LoraConfig(init_lora_weights=False) + peft_config_2 = LoraConfig(init_lora_weights=False) + model.add_adapter(peft_config_1, adapter_name="adapter_1") + model.add_adapter(peft_config_2, adapter_name="adapter_2") + + # Ensure adapters were added + self.assertIn("adapter_1", model.peft_config) + self.assertIn("adapter_2", model.peft_config) + + # Delete a single adapter + model.delete_adapter("adapter_1") + self.assertNotIn("adapter_1", model.peft_config) + self.assertIn("adapter_2", model.peft_config) + + # Delete remaining adapter + model.delete_adapter("adapter_2") + self.assertNotIn("adapter_2", model.peft_config) + self.assertFalse(model._hf_peft_config_loaded) + + # Re-add adapters for edge case tests + model.add_adapter(peft_config_1, adapter_name="adapter_1") + model.add_adapter(peft_config_2, adapter_name="adapter_2") + + # Attempt to delete multiple adapters at once + model.delete_adapter(["adapter_1", "adapter_2"]) + self.assertNotIn("adapter_1", model.peft_config) + self.assertNotIn("adapter_2", model.peft_config) + self.assertFalse(model._hf_peft_config_loaded) + + # Test edge cases + with self.assertRaisesRegex(ValueError, "The following adapter\\(s\\) are not present"): + model.delete_adapter("nonexistent_adapter") + + with self.assertRaisesRegex(ValueError, "The following adapter\\(s\\) are not present"): + model.delete_adapter(["adapter_1", "nonexistent_adapter"]) + + # Deleting with an empty list or None should not raise errors + model.add_adapter(peft_config_1, adapter_name="adapter_1") + model.add_adapter(peft_config_2, adapter_name="adapter_2") + model.delete_adapter([]) # No-op + self.assertIn("adapter_1", model.peft_config) + self.assertIn("adapter_2", model.peft_config) + + model.delete_adapter(None) # No-op + self.assertIn("adapter_1", model.peft_config) + self.assertIn("adapter_2", model.peft_config) + + # Deleting duplicate adapter names in the list + model.delete_adapter(["adapter_1", "adapter_1"]) + self.assertNotIn("adapter_1", model.peft_config) + self.assertIn("adapter_2", model.peft_config) + @require_torch_gpu @require_bitsandbytes def test_peft_from_pretrained_kwargs(self):