mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
new first tp loading version
This commit is contained in:
parent
c3e818561e
commit
ff1078387e
1 changed files with 86 additions and 20 deletions
|
|
@ -766,6 +766,8 @@ def _load_state_dict_into_meta_model(
|
|||
is_safetensors: bool = False,
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
unexpected_keys: Optional[Dict] = None, # passing `unexpected` for cleanup from quantization items
|
||||
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
tp_key_registry: Optional[Dict] = None,
|
||||
) -> Tuple[List[str], Optional[Dict], Optional[Dict]]:
|
||||
"""
|
||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||
|
|
@ -775,6 +777,7 @@ def _load_state_dict_into_meta_model(
|
|||
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
|
||||
`bert.pooler.dense.weight`
|
||||
|
||||
It also initialize tensor parallelism according to `tp_key_registry` if needed.
|
||||
"""
|
||||
|
||||
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
|
||||
|
|
@ -891,6 +894,31 @@ def _load_state_dict_into_meta_model(
|
|||
setattr(module, tensor_name, value)
|
||||
# TODO: consider removing used param_parts from state_dict before return
|
||||
|
||||
|
||||
# In this case, let's parallelize the modules!
|
||||
if tp_key_registry is not None:
|
||||
plan = None
|
||||
prefix = None
|
||||
for module_prefix in tp_key_registry.keys():
|
||||
if f"{module_prefix}." in param_name:
|
||||
tp_key_registry[module_prefix]["children"].remove(param_name)
|
||||
if len(tp_key_registry[module_prefix]["children"]) == 0:
|
||||
plan = tp_key_registry[module_prefix]["plan"]
|
||||
prefix = module_prefix
|
||||
break
|
||||
|
||||
if plan is not None:
|
||||
del tp_key_registry[prefix]
|
||||
parent_module = model
|
||||
for name in prefix.split("."):
|
||||
parent_module = getattr(parent_module, name)
|
||||
|
||||
torch.distributed.tensor.parallel.parallelize_module(
|
||||
parent_module,
|
||||
device_mesh=device_mesh,
|
||||
parallelize_plan=plan,
|
||||
)
|
||||
|
||||
return error_msgs, disk_offload_index, cpu_offload_index
|
||||
|
||||
|
||||
|
|
@ -4101,12 +4129,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
)
|
||||
|
||||
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
|
||||
# `device_map` pointing to the correct device. If we don't, torch will use the default device (index 0) for all
|
||||
# childs processes at parallelization time, resulting in excessive memory usage on device 0 and OOMs.
|
||||
# And temporarily setting the default device to current process rank result in the following error
|
||||
# `torch.distributed.DistBackendError: Attempt to perform collective on tensor not on device passed to init_process_group`
|
||||
tp_device = None
|
||||
# `device_map` pointing to the correct device
|
||||
device_mesh = None
|
||||
if tp_plan is not None:
|
||||
if not is_torch_greater_or_equal("2.5"):
|
||||
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
|
||||
if not torch.distributed.is_initialized():
|
||||
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")
|
||||
|
||||
|
|
@ -4118,6 +4145,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
# This is the easiest way to dispatch to the current process device
|
||||
device_map = tp_device
|
||||
|
||||
# Assuming sharding the model onto the world
|
||||
world_size = torch.distributed.get_world_size()
|
||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
|
||||
|
||||
if is_fsdp_enabled():
|
||||
low_cpu_mem_usage = True
|
||||
|
||||
|
|
@ -4210,7 +4241,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
elif not low_cpu_mem_usage:
|
||||
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
|
||||
raise ValueError("Passing along a `device_map` or a `tp_plan` requires `low_cpu_mem_usage=True`")
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
|
|
@ -4219,7 +4250,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
)
|
||||
elif not is_accelerate_available():
|
||||
raise ImportError(
|
||||
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
||||
f"Using `low_cpu_mem_usage=True`, a `device_map` or a `tp_plan` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
||||
)
|
||||
|
||||
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
|
||||
|
|
@ -4432,6 +4463,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
# Let's make sure we don't run the init function of buffer modules
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
# Last check for tp
|
||||
if device_mesh is not None and not model.supports_tp_plan:
|
||||
raise NotImplementedError("This model does not have a tensor parallel plan.")
|
||||
|
||||
# make sure we use the model's config since the __init__ call might have copied it
|
||||
config = model.config
|
||||
|
||||
|
|
@ -4491,6 +4526,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
gguf_file=gguf_file,
|
||||
device_mesh=device_mesh,
|
||||
weights_only=weights_only,
|
||||
_fast_init=_fast_init,
|
||||
)
|
||||
|
|
@ -4526,8 +4562,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
)
|
||||
pass
|
||||
|
||||
# Dispatch model with hooks on all devices if necessary
|
||||
if device_map is not None:
|
||||
# Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
|
||||
# harm performances)
|
||||
if device_map is not None and device_mesh is None:
|
||||
device_map_kwargs = {
|
||||
"device_map": device_map,
|
||||
"offload_dir": offload_folder,
|
||||
|
|
@ -4554,6 +4591,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
||||
dispatch_model(model, **device_map_kwargs)
|
||||
|
||||
# This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
|
||||
# not part of the state_dict (persistent=False)
|
||||
if device_mesh is not None:
|
||||
for buffer in model.buffers():
|
||||
if buffer.device != tp_device:
|
||||
buffer.data = buffer.to(tp_device)
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.postprocess_model(model, config=config)
|
||||
model.hf_quantizer = hf_quantizer
|
||||
|
|
@ -4566,16 +4610,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
adapter_kwargs=adapter_kwargs,
|
||||
)
|
||||
|
||||
if tp_plan is not None:
|
||||
assert tp_device is not None, "tp_device not set!"
|
||||
if not model.supports_tp_plan:
|
||||
raise NotImplementedError("This model does not have a tensor parallel plan.")
|
||||
# Assuming sharding the model onto the world
|
||||
world_size = torch.distributed.get_world_size()
|
||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
|
||||
# Apply Tensor Parallelism
|
||||
model.tensor_parallel(device_mesh)
|
||||
|
||||
if output_loading_info:
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
|
|
@ -4675,6 +4709,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
hf_quantizer: Optional[HfQuantizer] = None,
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
gguf_file: Optional[str] = None,
|
||||
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
weights_only: bool = True,
|
||||
_fast_init: bool = True,
|
||||
):
|
||||
|
|
@ -4807,6 +4842,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
else:
|
||||
disk_offload_index = {}
|
||||
|
||||
# Prepare inputs if using tensor parallel
|
||||
tp_key_registry = None
|
||||
if device_mesh is not None:
|
||||
tp_key_registry = {}
|
||||
layer_tp_plan = {".".join(k.split(".")[2:]): v for k, v in model.config.base_model_tp_plan.items()}
|
||||
layer_tp_plan = torch.utils._pytree.tree_map(
|
||||
translate_to_torch_parallel_style,
|
||||
layer_tp_plan,
|
||||
)
|
||||
for key in model_to_load.state_dict().keys():
|
||||
pattern = r"^(.*layers\.[0-9]+)"
|
||||
match = re.match(pattern, key)
|
||||
if match is not None:
|
||||
layer = match.group(1)
|
||||
if layer not in tp_key_registry:
|
||||
tp_key_registry[layer] = {"children": {key}, "plan": layer_tp_plan}
|
||||
else:
|
||||
tp_key_registry[layer]["children"].add(key)
|
||||
elif "lm_head." in key:
|
||||
if "lm_head" not in tp_key_registry:
|
||||
plan = translate_to_torch_parallel_style(model._tp_plan["lm_head"])
|
||||
tp_key_registry["lm_head"] = {"children": {key}, "plan": plan}
|
||||
else:
|
||||
tp_key_registry["lm_head"]["children"].add(key)
|
||||
|
||||
# This offload index if for params that are supposed to be on the "cpu", either with or without a device_map
|
||||
# It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size,
|
||||
# i.e. 1x to load it, and 1x to copy it to model
|
||||
|
|
@ -4878,6 +4938,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
is_safetensors=is_offloaded_safetensors,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
device_mesh=device_mesh,
|
||||
tp_key_registry=tp_key_registry,
|
||||
)
|
||||
error_msgs += new_error_msgs
|
||||
else:
|
||||
|
|
@ -5192,7 +5254,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
|
||||
def tensor_parallel(self, device_mesh):
|
||||
"""
|
||||
Tensor parallelize the model across the given device mesh.
|
||||
Tensor parallelize the model across the given device mesh. This function is a helper to be called after the model
|
||||
was already loaded in memory, note however that this means that each process will first initialize the whole model,
|
||||
then parallelize it accross devices. Thus there is a huge waste of GPU memory, and this can lead to OOM at loading time.
|
||||
Calling `from_pretrained(..., tp_plan="auto")` is prefered, and will parallelize module-by-module during initialization,
|
||||
so that the expected per-device memory spike at loading time is not larger than the final model size on each device.
|
||||
|
||||
Args:
|
||||
device_mesh (`torch.distributed.DeviceMesh`):
|
||||
|
|
|
|||
Loading…
Reference in a new issue