mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Expose offload_buffers parameter of accelerate to PreTrainedModel.from_pretrained method (#28755)
Expose offload_buffers parameter to from_pretrained method
This commit is contained in:
parent
0ad770c373
commit
5ee0868a4b
1 changed files with 4 additions and 0 deletions
|
|
@ -2745,6 +2745,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
|
||||
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
|
||||
`True` when there is some disk offload.
|
||||
offload_buffers (`bool`, *optional*):
|
||||
Whether or not to offload the buffers with the model parameters.
|
||||
quantization_config (`Union[QuantizationConfigMixin,Dict]`, *optional*):
|
||||
A dictionary of configuration parameters or a QuantizationConfigMixin object for quantization (e.g
|
||||
bitsandbytes, gptq). There may be other quantization-related kwargs, including `load_in_4bit` and
|
||||
|
|
@ -2835,6 +2837,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
max_memory = kwargs.pop("max_memory", None)
|
||||
offload_folder = kwargs.pop("offload_folder", None)
|
||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||
offload_buffers = kwargs.pop("offload_buffers", False)
|
||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
|
|
@ -3554,6 +3557,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
"device_map": device_map,
|
||||
"offload_dir": offload_folder,
|
||||
"offload_index": offload_index,
|
||||
"offload_buffers": offload_buffers,
|
||||
}
|
||||
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
||||
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
||||
|
|
|
|||
Loading…
Reference in a new issue