mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
parent
bf5d7c3fa3
commit
b1d5d6dd65
1 changed files with 6 additions and 5 deletions
|
|
@ -2251,7 +2251,7 @@ class Trainer:
|
|||
else:
|
||||
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
||||
|
||||
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
|
||||
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
|
||||
|
||||
# We need to reset the scheduler, as its parameters may be different on subsequent calls
|
||||
if self._created_lr_scheduler:
|
||||
|
|
@ -2304,12 +2304,13 @@ class Trainer:
|
|||
# In case of auto_find_batch_size=True
|
||||
# Remove FSDP wrapping from sub-models.
|
||||
self.model = unwrap_model(self.model, recursive=True)
|
||||
# configure fsdp plugin for qlora if any
|
||||
self._fsdp_qlora_plugin_updates()
|
||||
|
||||
if delay_optimizer_creation:
|
||||
if use_accelerator_prepare:
|
||||
self.model = self.accelerator.prepare(self.model)
|
||||
# configure fsdp plugin for qlora if any
|
||||
self._fsdp_qlora_plugin_updates()
|
||||
if self.accelerator.mixed_precision != "fp8":
|
||||
self.model = self.accelerator.prepare(self.model)
|
||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||
|
||||
# prepare using `accelerator` prepare
|
||||
|
|
@ -4172,7 +4173,7 @@ class Trainer:
|
|||
start_time = time.time()
|
||||
model = (
|
||||
self.accelerator.prepare(model)
|
||||
if self.is_deepspeed_enabled or self.is_fsdp_enabled
|
||||
if self.is_deepspeed_enabled or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8")
|
||||
else self.accelerator.prepare_model(model, evaluation_mode=True)
|
||||
)
|
||||
self.model_preparation_time = round(time.time() - start_time, 4)
|
||||
|
|
|
|||
Loading…
Reference in a new issue