diff --git a/caffe2/python/checkpoint.py b/caffe2/python/checkpoint.py index 9d0f5684c58..8a81a52c48b 100644 --- a/caffe2/python/checkpoint.py +++ b/caffe2/python/checkpoint.py @@ -26,7 +26,7 @@ class Job(object): and data file lists. The `epoch_group` will be run in a loop after init_group. The loop will - exit when any of the stop signals added with `add_stop_signal` is True + exit when any of the stop signals added with `add_stop_condition` is True at the end of an epoch. The download_group will be run only once, after all the executions of @@ -48,7 +48,7 @@ class Job(object): with Job.current().epoch_group: limited_reader = ReaderWithLimit(reader, num_iter=10000) data_queue = pipe(limited_reader, num_threads=8) - Job.current().add_stop_signal(limited_reader.data_finished()) + Job.current().add_stop_condition(limited_reader.data_finished()) return data_queue def build_hogwild_trainer(reader, model): @@ -67,13 +67,13 @@ class Job(object): def __init__(self, init_group=None, epoch_group=None, download_group=None, exit_group=None, - stop_signals=None, nodes_to_checkpoint=None): + stop_conditions=None, nodes_to_checkpoint=None): self.init_group = init_group or TaskGroup( workspace_type=WorkspaceType.GLOBAL) self.epoch_group = epoch_group or TaskGroup() self.download_group = download_group or TaskGroup() self.exit_group = exit_group or TaskGroup() - self.stop_signals = stop_signals or [] + self.stop_conditions = stop_conditions or [] self._nodes_to_checkpoint = nodes_to_checkpoint def nodes_to_checkpoint(self): @@ -96,12 +96,12 @@ class Job(object): def __exit__(self, *args): self.epoch_group.__exit__() - def add_stop_signal(self, output): + def add_stop_condition(self, output): if isinstance(output, core.BlobReference): t = Task(outputs=[output], group=self.epoch_group) output = t.outputs()[0] assert isinstance(output, TaskOutput) - self.stop_signals.append(output) + self.stop_conditions.append(output) def get_ckpt_filename(node_name, epoch): @@ -638,12 +638,12 @@ class JobRunner(object): logger.info('Starting epoch %d' % epoch) session.run(self.job.epoch_group) logger.info('Finished epoch %d' % epoch) - stop_signals = [o.fetch() for o in self.job.stop_signals] + stop_conditions = [o.fetch() for o in self.job.stop_conditions] if self.checkpoint_manager: self.save_checkpoints(epoch, session) - if any(stop_signals): + if any(stop_conditions): logger.info('Stopping') break epoch += 1 @@ -733,4 +733,4 @@ def epoch_limiter(job, num_epochs): epoch_net = core.Net('epoch_countdown') finished = epoch_net.CountDown(counter) output = Task(step=epoch_net, outputs=finished).outputs()[0] - job.add_stop_signal(output) + job.add_stop_condition(output) diff --git a/caffe2/python/checkpoint_test.py b/caffe2/python/checkpoint_test.py index 75d9a77fb50..f66af4d3a61 100644 --- a/caffe2/python/checkpoint_test.py +++ b/caffe2/python/checkpoint_test.py @@ -36,7 +36,7 @@ def build_pipeline(node_id): epoch_reader = ReaderWithLimit(full_reader, num_iter=3) pipe(epoch_reader, processor=inc_total) - Job.current().add_stop_signal(epoch_reader.data_finished()) + Job.current().add_stop_condition(epoch_reader.data_finished()) return [total] diff --git a/caffe2/python/net_printer.py b/caffe2/python/net_printer.py index d5a1e83d063..4b5cddb61d2 100644 --- a/caffe2/python/net_printer.py +++ b/caffe2/python/net_printer.py @@ -389,8 +389,8 @@ def print_task_group(text, tg, header=None): def print_job(text, job): text(job.init_group, 'Job.current().init_group') text(job.epoch_group, 'Job.current().epoch_group') - with text.context('Job.current().stop_signals'): - for out in job.stop_signals: + with text.context('Job.current().stop_conditions'): + for out in job.stop_conditions: text.add(_print_task_output(out)) text(job.download_group, 'Job.current().download_group') text(job.exit_group, 'Job.current().exit_group')