diff --git a/caffe2/python/checkpoint.py b/caffe2/python/checkpoint.py index 199c9cb44b9..fbe87c8e326 100644 --- a/caffe2/python/checkpoint.py +++ b/caffe2/python/checkpoint.py @@ -44,9 +44,12 @@ class Job(object): exit when any of the stop signals added with `add_stop_signal` is True at the end of an epoch. - The `exit_group` will be run only once at the very end of the job, when one - of the stopping criterias for `epoch_group` was met. The role of this group - is save the results of training in the end of the job. + The download_group will be run only once, after all the executions of + epoch_group finish. Its role is to collect the distribute scattered + parameters back after training. + + The `exit_group` will be run only once at the very end of the job, the + role of this group is to save the results of training in the end of the job. Jobs are context-driven, so that Tasks can be added to the active Job without having to explicitly pass the job object around. @@ -78,11 +81,12 @@ class Job(object): """ def __init__(self, init_group=None, epoch_group=None, - exit_group=None, stop_signals=None, - nodes_to_checkpoint=None): + download_group=None, exit_group=None, + stop_signals=None, nodes_to_checkpoint=None): self.init_group = init_group or TaskGroup( workspace_type=WorkspaceType.GLOBAL) self.epoch_group = epoch_group or TaskGroup() + self.download_group = download_group or TaskGroup() self.exit_group = exit_group or TaskGroup() self.stop_signals = stop_signals or [] self._nodes_to_checkpoint = nodes_to_checkpoint @@ -97,6 +101,7 @@ class Job(object): return Job( init_group=session_class.compile(self.init_group), epoch_group=session_class.compile(self.epoch_group), + download_group=session_class.compile(self.download_group), exit_group=session_class.compile(self.exit_group), stop_signals=self.stop_signals, nodes_to_checkpoint=self.nodes_to_checkpoint()) @@ -570,6 +575,12 @@ class JobRunner(object): epoch, self.checkpoint_manager) session.run(upload_task_group) logger.info('Finished uploading the checkpoints') + + # Download the parameters to save + session.run(self.job.download_group) + logger.info('Finished downloading the parameters') + + # Finally run the exit step to save nets session.run(self.job.exit_group) logger.info('Finished running the exit group') return epoch diff --git a/caffe2/python/checkpoint_test.py b/caffe2/python/checkpoint_test.py index ccd108c185d..6149f2cbd48 100644 --- a/caffe2/python/checkpoint_test.py +++ b/caffe2/python/checkpoint_test.py @@ -19,12 +19,12 @@ from __future__ import print_function from __future__ import unicode_literals from caffe2.python.schema import Struct, ConstRecord -from caffe2.python import core, workspace +from caffe2.python import core, workspace, model_helper from caffe2.python.session import LocalSession from caffe2.python.dataset import Dataset from caffe2.python.pipeline import pipe from caffe2.python.checkpoint import ( - CheckpointManager, MultiNodeCheckpointManager, Job, JobRunner, + CheckpointManager, MultiNodeCheckpointManager, Job, JobRunner, epoch_limiter, UploadTaskGroupBuilder, db_name) from caffe2.python.net_builder import ops from caffe2.python.task import Node, Task, TaskGroup, WorkspaceType, Cluster @@ -273,3 +273,45 @@ class TestCheckpoint(TestCase): # make sure all epochs are executed even though saving the checkpoint failed # Saving checkpoint failure should not cause job failure self.assertEquals(num_epochs, len(EXPECTED_TOTALS)) + + def test_download_group_simple(self): + """ + A simple test that ensures we have download task group + executed between epoch_group and exit_group. + """ + model = model_helper.ModelHelper(name="test_model") + download_net = core.Net("download_net") + + for name in ["input1", "input2", "output", "download_result"]: + model.param_init_net.ConstantFill([], + [name], + shape=[8, ], + value=1.0, + run_once=0) + model.net.Add(["input1", "input2"], ["output"]) + download_net.Copy(["output"], ["download_result"]) + + # All blob values are initialized as 1.0, after download_net executed + # we expect to see download result is the same as training result. + with Job() as job: + with Node("trainer:0"): + epoch_limiter(1) + with job.init_group: + Task(step=model.param_init_net) + with job.epoch_group: + with Task(): + with ops.loop(1): + ops.net(model.net) + with job.download_group: + Task(step=download_net) + + ws = workspace.C.Workspace() + session = LocalSession(ws) + job_runner = JobRunner(job) + job_runner(session) + + expected_result = np.full(8, 2.0).astype(np.float32) + self.assertTrue(np.array_equal(expected_result, + ws.fetch_blob("output"))) + self.assertTrue(np.array_equal(expected_result, + ws.fetch_blob("download_result"))) diff --git a/caffe2/python/net_printer.py b/caffe2/python/net_printer.py index ec1cce000f8..8d127f6b566 100644 --- a/caffe2/python/net_printer.py +++ b/caffe2/python/net_printer.py @@ -400,6 +400,7 @@ def print_job(text, job): with text.context('Job.current().stop_signals'): for out in job.stop_signals: text.add(_print_task_output(out)) + text(job.download_group, 'Job.current().download_group') text(job.exit_group, 'Job.current().exit_group')