mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[PEFT] introducing adapter_kwargs for loading adapters from different Hub location (subfolder, revision) than the base model (#26270)
* make use of adapter_revision * v1 adapter kwargs * fix CI * fix CI * fix CI * fixup * add BC * Update src/transformers/integrations/peft.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup * change it to error * Update src/transformers/modeling_utils.py * Update src/transformers/modeling_utils.py * fixup * change * Update src/transformers/integrations/peft.py --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
52e2c13da3
commit
38e96324ef
6 changed files with 68 additions and 9 deletions
|
|
@ -77,6 +77,7 @@ class PeftAdapterMixin:
|
|||
offload_index: Optional[int] = None,
|
||||
peft_config: Dict[str, Any] = None,
|
||||
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
|
||||
adapter_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we
|
||||
|
|
@ -128,10 +129,15 @@ class PeftAdapterMixin:
|
|||
adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*):
|
||||
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
|
||||
dicts
|
||||
adapter_kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and
|
||||
`find_adapter_config_file` method.
|
||||
"""
|
||||
check_peft_version(min_version=MIN_PEFT_VERSION)
|
||||
|
||||
adapter_name = adapter_name if adapter_name is not None else "default"
|
||||
if adapter_kwargs is None:
|
||||
adapter_kwargs = {}
|
||||
|
||||
from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
|
||||
from peft.utils import set_peft_model_state_dict
|
||||
|
|
@ -144,11 +150,20 @@ class PeftAdapterMixin:
|
|||
"You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
|
||||
)
|
||||
|
||||
# We keep `revision` in the signature for backward compatibility
|
||||
if revision is not None and "revision" not in adapter_kwargs:
|
||||
adapter_kwargs["revision"] = revision
|
||||
elif revision is not None and "revision" in adapter_kwargs and revision != adapter_kwargs["revision"]:
|
||||
logger.error(
|
||||
"You passed a `revision` argument both in `adapter_kwargs` and as a standalone argument. "
|
||||
"The one in `adapter_kwargs` will be used."
|
||||
)
|
||||
|
||||
if peft_config is None:
|
||||
adapter_config_file = find_adapter_config_file(
|
||||
peft_model_id,
|
||||
revision=revision,
|
||||
token=token,
|
||||
**adapter_kwargs,
|
||||
)
|
||||
|
||||
if adapter_config_file is None:
|
||||
|
|
@ -159,8 +174,8 @@ class PeftAdapterMixin:
|
|||
|
||||
peft_config = PeftConfig.from_pretrained(
|
||||
peft_model_id,
|
||||
revision=revision,
|
||||
use_auth_token=token,
|
||||
**adapter_kwargs,
|
||||
)
|
||||
|
||||
# Create and add fresh new adapters into the model.
|
||||
|
|
@ -170,7 +185,7 @@ class PeftAdapterMixin:
|
|||
self._hf_peft_config_loaded = True
|
||||
|
||||
if peft_model_id is not None:
|
||||
adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token)
|
||||
adapter_state_dict = load_peft_weights(peft_model_id, use_auth_token=token, **adapter_kwargs)
|
||||
|
||||
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
|
||||
processed_adapter_state_dict = {}
|
||||
|
|
|
|||
|
|
@ -623,6 +623,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||
subfolder = kwargs.pop("subfolder", "")
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
|
||||
# Not relevant for Flax Models
|
||||
_ = kwargs.pop("adapter_kwargs", None)
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
||||
|
|
|
|||
|
|
@ -2645,6 +2645,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None)
|
||||
|
||||
# Not relevant for TF models
|
||||
_ = kwargs.pop("adapter_kwargs", None)
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
||||
|
|
|
|||
|
|
@ -2463,7 +2463,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
subfolder = kwargs.pop("subfolder", "")
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
variant = kwargs.pop("variant", None)
|
||||
_adapter_model_path = kwargs.pop("_adapter_model_path", None)
|
||||
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
|
||||
adapter_name = kwargs.pop("adapter_name", "default")
|
||||
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
|
||||
|
||||
|
|
@ -2516,6 +2516,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
commit_hash = getattr(config, "_commit_hash", None)
|
||||
|
||||
if is_peft_available():
|
||||
_adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)
|
||||
|
||||
if _adapter_model_path is None:
|
||||
_adapter_model_path = find_adapter_config_file(
|
||||
pretrained_model_name_or_path,
|
||||
|
|
@ -2525,14 +2527,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_commit_hash=commit_hash,
|
||||
**adapter_kwargs,
|
||||
)
|
||||
if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
|
||||
with open(_adapter_model_path, "r", encoding="utf-8") as f:
|
||||
_adapter_model_path = pretrained_model_name_or_path
|
||||
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
|
||||
else:
|
||||
_adapter_model_path = None
|
||||
|
||||
# change device_map into a map if we passed an int, a str or a torch.device
|
||||
if isinstance(device_map, torch.device):
|
||||
|
|
@ -3371,8 +3374,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
model.load_adapter(
|
||||
_adapter_model_path,
|
||||
adapter_name=adapter_name,
|
||||
revision=revision,
|
||||
token=token,
|
||||
adapter_kwargs=adapter_kwargs,
|
||||
)
|
||||
|
||||
if output_loading_info:
|
||||
|
|
|
|||
|
|
@ -469,6 +469,7 @@ class _BaseAutoModelClass:
|
|||
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
||||
code_revision = kwargs.pop("code_revision", None)
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
adapter_kwargs = kwargs.pop("adapter_kwargs", None)
|
||||
|
||||
revision = hub_kwargs.pop("revision", None)
|
||||
hub_kwargs["revision"] = sanitize_code_revision(pretrained_model_name_or_path, revision, trust_remote_code)
|
||||
|
|
@ -503,15 +504,18 @@ class _BaseAutoModelClass:
|
|||
commit_hash = getattr(config, "_commit_hash", None)
|
||||
|
||||
if is_peft_available():
|
||||
if adapter_kwargs is None:
|
||||
adapter_kwargs = {}
|
||||
|
||||
maybe_adapter_path = find_adapter_config_file(
|
||||
pretrained_model_name_or_path, _commit_hash=commit_hash, **hub_kwargs
|
||||
pretrained_model_name_or_path, _commit_hash=commit_hash, **adapter_kwargs
|
||||
)
|
||||
|
||||
if maybe_adapter_path is not None:
|
||||
with open(maybe_adapter_path, "r", encoding="utf-8") as f:
|
||||
adapter_config = json.load(f)
|
||||
|
||||
kwargs["_adapter_model_path"] = pretrained_model_name_or_path
|
||||
adapter_kwargs["_adapter_model_path"] = pretrained_model_name_or_path
|
||||
pretrained_model_name_or_path = adapter_config["base_model_name_or_path"]
|
||||
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
|
|
@ -545,6 +549,10 @@ class _BaseAutoModelClass:
|
|||
trust_remote_code = resolve_trust_remote_code(
|
||||
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
|
||||
)
|
||||
|
||||
# Set the adapter kwargs
|
||||
kwargs["adapter_kwargs"] = adapter_kwargs
|
||||
|
||||
if has_remote_code and trust_remote_code:
|
||||
class_ref = config.auto_map[cls.__name__]
|
||||
model_class = get_class_from_dynamic_module(
|
||||
|
|
|
|||
|
|
@ -351,3 +351,30 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||
|
||||
# dummy generation
|
||||
_ = model.generate(input_ids=dummy_input)
|
||||
|
||||
def test_peft_from_pretrained_hub_kwargs(self):
|
||||
"""
|
||||
Tests different combinations of PEFT model + from_pretrained + hub kwargs
|
||||
"""
|
||||
peft_model_id = "peft-internal-testing/tiny-opt-lora-revision"
|
||||
|
||||
# This should not work
|
||||
with self.assertRaises(OSError):
|
||||
_ = AutoModelForCausalLM.from_pretrained(peft_model_id)
|
||||
|
||||
adapter_kwargs = {"revision": "test"}
|
||||
|
||||
# This should work
|
||||
model = AutoModelForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
|
||||
self.assertTrue(self._check_lora_correctly_converted(model))
|
||||
|
||||
model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
|
||||
self.assertTrue(self._check_lora_correctly_converted(model))
|
||||
|
||||
adapter_kwargs = {"revision": "main", "subfolder": "test_subfolder"}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
|
||||
self.assertTrue(self._check_lora_correctly_converted(model))
|
||||
|
||||
model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
|
||||
self.assertTrue(self._check_lora_correctly_converted(model))
|
||||
|
|
|
|||
Loading…
Reference in a new issue