mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Separate parameter downloading tasks from training tasks and run them in a different group
Summary: At the end of distributed training, trainer needs to download the parameters back from parameter servers for saving the model. Currently, this parameter downloading happens at the end of job's epoch task group, which creates several problems when checkpointing is enabled for distributed training: 1. When checkpointing is enabled, we run multiple training epochs. At the end of each epoch, the model download tasks will run to collect parameters, but we won't save the model until the true end of training, so there is a big waste of resource. 2. After trainer0 downloads the parameters, these parameters take a lot of memory, so trainer0 can easily run out of memory in the next epoch of training. Our solution is to insert a parameter download task group between the job's training epoch_group and the job's exit_group. Reviewed By: azzolini Differential Revision: D6765393 fbshipit-source-id: 5a4f556fc3c1cd7834a7c406a3c0de3fccd50c49
This commit is contained in:
parent
27f4041738
commit
1d4e996b87
3 changed files with 61 additions and 7 deletions
|
|
@ -44,9 +44,12 @@ class Job(object):
|
|||
exit when any of the stop signals added with `add_stop_signal` is True
|
||||
at the end of an epoch.
|
||||
|
||||
The `exit_group` will be run only once at the very end of the job, when one
|
||||
of the stopping criterias for `epoch_group` was met. The role of this group
|
||||
is save the results of training in the end of the job.
|
||||
The download_group will be run only once, after all the executions of
|
||||
epoch_group finish. Its role is to collect the distribute scattered
|
||||
parameters back after training.
|
||||
|
||||
The `exit_group` will be run only once at the very end of the job, the
|
||||
role of this group is to save the results of training in the end of the job.
|
||||
|
||||
Jobs are context-driven, so that Tasks can be added to the active Job
|
||||
without having to explicitly pass the job object around.
|
||||
|
|
@ -78,11 +81,12 @@ class Job(object):
|
|||
"""
|
||||
def __init__(self,
|
||||
init_group=None, epoch_group=None,
|
||||
exit_group=None, stop_signals=None,
|
||||
nodes_to_checkpoint=None):
|
||||
download_group=None, exit_group=None,
|
||||
stop_signals=None, nodes_to_checkpoint=None):
|
||||
self.init_group = init_group or TaskGroup(
|
||||
workspace_type=WorkspaceType.GLOBAL)
|
||||
self.epoch_group = epoch_group or TaskGroup()
|
||||
self.download_group = download_group or TaskGroup()
|
||||
self.exit_group = exit_group or TaskGroup()
|
||||
self.stop_signals = stop_signals or []
|
||||
self._nodes_to_checkpoint = nodes_to_checkpoint
|
||||
|
|
@ -97,6 +101,7 @@ class Job(object):
|
|||
return Job(
|
||||
init_group=session_class.compile(self.init_group),
|
||||
epoch_group=session_class.compile(self.epoch_group),
|
||||
download_group=session_class.compile(self.download_group),
|
||||
exit_group=session_class.compile(self.exit_group),
|
||||
stop_signals=self.stop_signals,
|
||||
nodes_to_checkpoint=self.nodes_to_checkpoint())
|
||||
|
|
@ -570,6 +575,12 @@ class JobRunner(object):
|
|||
epoch, self.checkpoint_manager)
|
||||
session.run(upload_task_group)
|
||||
logger.info('Finished uploading the checkpoints')
|
||||
|
||||
# Download the parameters to save
|
||||
session.run(self.job.download_group)
|
||||
logger.info('Finished downloading the parameters')
|
||||
|
||||
# Finally run the exit step to save nets
|
||||
session.run(self.job.exit_group)
|
||||
logger.info('Finished running the exit group')
|
||||
return epoch
|
||||
|
|
|
|||
|
|
@ -19,12 +19,12 @@ from __future__ import print_function
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from caffe2.python.schema import Struct, ConstRecord
|
||||
from caffe2.python import core, workspace
|
||||
from caffe2.python import core, workspace, model_helper
|
||||
from caffe2.python.session import LocalSession
|
||||
from caffe2.python.dataset import Dataset
|
||||
from caffe2.python.pipeline import pipe
|
||||
from caffe2.python.checkpoint import (
|
||||
CheckpointManager, MultiNodeCheckpointManager, Job, JobRunner,
|
||||
CheckpointManager, MultiNodeCheckpointManager, Job, JobRunner, epoch_limiter,
|
||||
UploadTaskGroupBuilder, db_name)
|
||||
from caffe2.python.net_builder import ops
|
||||
from caffe2.python.task import Node, Task, TaskGroup, WorkspaceType, Cluster
|
||||
|
|
@ -273,3 +273,45 @@ class TestCheckpoint(TestCase):
|
|||
# make sure all epochs are executed even though saving the checkpoint failed
|
||||
# Saving checkpoint failure should not cause job failure
|
||||
self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
|
||||
|
||||
def test_download_group_simple(self):
|
||||
"""
|
||||
A simple test that ensures we have download task group
|
||||
executed between epoch_group and exit_group.
|
||||
"""
|
||||
model = model_helper.ModelHelper(name="test_model")
|
||||
download_net = core.Net("download_net")
|
||||
|
||||
for name in ["input1", "input2", "output", "download_result"]:
|
||||
model.param_init_net.ConstantFill([],
|
||||
[name],
|
||||
shape=[8, ],
|
||||
value=1.0,
|
||||
run_once=0)
|
||||
model.net.Add(["input1", "input2"], ["output"])
|
||||
download_net.Copy(["output"], ["download_result"])
|
||||
|
||||
# All blob values are initialized as 1.0, after download_net executed
|
||||
# we expect to see download result is the same as training result.
|
||||
with Job() as job:
|
||||
with Node("trainer:0"):
|
||||
epoch_limiter(1)
|
||||
with job.init_group:
|
||||
Task(step=model.param_init_net)
|
||||
with job.epoch_group:
|
||||
with Task():
|
||||
with ops.loop(1):
|
||||
ops.net(model.net)
|
||||
with job.download_group:
|
||||
Task(step=download_net)
|
||||
|
||||
ws = workspace.C.Workspace()
|
||||
session = LocalSession(ws)
|
||||
job_runner = JobRunner(job)
|
||||
job_runner(session)
|
||||
|
||||
expected_result = np.full(8, 2.0).astype(np.float32)
|
||||
self.assertTrue(np.array_equal(expected_result,
|
||||
ws.fetch_blob("output")))
|
||||
self.assertTrue(np.array_equal(expected_result,
|
||||
ws.fetch_blob("download_result")))
|
||||
|
|
|
|||
|
|
@ -400,6 +400,7 @@ def print_job(text, job):
|
|||
with text.context('Job.current().stop_signals'):
|
||||
for out in job.stop_signals:
|
||||
text.add(_print_task_output(out))
|
||||
text(job.download_group, 'Job.current().download_group')
|
||||
text(job.exit_group, 'Job.current().exit_group')
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue