mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-15 21:01:19 +00:00
[DeepSpeed] simplify init (#10762)
This commit is contained in:
parent
0486ccdd3d
commit
01c7fb04be
1 changed files with 0 additions and 5 deletions
|
|
@ -22,7 +22,6 @@ import os
|
|||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
from .utils import logging
|
||||
from .utils.versions import require_version
|
||||
|
|
@ -430,16 +429,12 @@ def init_deepspeed(trainer, num_training_steps):
|
|||
"enabled": True,
|
||||
}
|
||||
|
||||
# for clarity extract the specific cl args that are being passed to deepspeed
|
||||
ds_args = dict(local_rank=args.local_rank)
|
||||
|
||||
# keep for quick debug:
|
||||
# from pprint import pprint; pprint(config)
|
||||
|
||||
# init that takes part of the config via `args`, and the bulk of it via `config_params`
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
model, optimizer, _, lr_scheduler = deepspeed.initialize(
|
||||
args=SimpleNamespace(**ds_args), # expects an obj
|
||||
model=model,
|
||||
model_parameters=model_parameters,
|
||||
config_params=config,
|
||||
|
|
|
|||
Loading…
Reference in a new issue