[DT] [38/n] Rename add_stop_signal to add_stop_condition (#6825)

att
This commit is contained in:
Qinqing Zheng 2018-04-23 10:39:37 -07:00 committed by Jerry Zhang
parent a986b85afd
commit 90586d925f
3 changed files with 12 additions and 12 deletions

View file

@ -26,7 +26,7 @@ class Job(object):
and data file lists.
The `epoch_group` will be run in a loop after init_group. The loop will
exit when any of the stop signals added with `add_stop_signal` is True
exit when any of the stop signals added with `add_stop_condition` is True
at the end of an epoch.
The download_group will be run only once, after all the executions of
@ -48,7 +48,7 @@ class Job(object):
with Job.current().epoch_group:
limited_reader = ReaderWithLimit(reader, num_iter=10000)
data_queue = pipe(limited_reader, num_threads=8)
Job.current().add_stop_signal(limited_reader.data_finished())
Job.current().add_stop_condition(limited_reader.data_finished())
return data_queue
def build_hogwild_trainer(reader, model):
@ -67,13 +67,13 @@ class Job(object):
def __init__(self,
init_group=None, epoch_group=None,
download_group=None, exit_group=None,
stop_signals=None, nodes_to_checkpoint=None):
stop_conditions=None, nodes_to_checkpoint=None):
self.init_group = init_group or TaskGroup(
workspace_type=WorkspaceType.GLOBAL)
self.epoch_group = epoch_group or TaskGroup()
self.download_group = download_group or TaskGroup()
self.exit_group = exit_group or TaskGroup()
self.stop_signals = stop_signals or []
self.stop_conditions = stop_conditions or []
self._nodes_to_checkpoint = nodes_to_checkpoint
def nodes_to_checkpoint(self):
@ -96,12 +96,12 @@ class Job(object):
def __exit__(self, *args):
self.epoch_group.__exit__()
def add_stop_signal(self, output):
def add_stop_condition(self, output):
if isinstance(output, core.BlobReference):
t = Task(outputs=[output], group=self.epoch_group)
output = t.outputs()[0]
assert isinstance(output, TaskOutput)
self.stop_signals.append(output)
self.stop_conditions.append(output)
def get_ckpt_filename(node_name, epoch):
@ -638,12 +638,12 @@ class JobRunner(object):
logger.info('Starting epoch %d' % epoch)
session.run(self.job.epoch_group)
logger.info('Finished epoch %d' % epoch)
stop_signals = [o.fetch() for o in self.job.stop_signals]
stop_conditions = [o.fetch() for o in self.job.stop_conditions]
if self.checkpoint_manager:
self.save_checkpoints(epoch, session)
if any(stop_signals):
if any(stop_conditions):
logger.info('Stopping')
break
epoch += 1
@ -733,4 +733,4 @@ def epoch_limiter(job, num_epochs):
epoch_net = core.Net('epoch_countdown')
finished = epoch_net.CountDown(counter)
output = Task(step=epoch_net, outputs=finished).outputs()[0]
job.add_stop_signal(output)
job.add_stop_condition(output)

View file

@ -36,7 +36,7 @@ def build_pipeline(node_id):
epoch_reader = ReaderWithLimit(full_reader, num_iter=3)
pipe(epoch_reader, processor=inc_total)
Job.current().add_stop_signal(epoch_reader.data_finished())
Job.current().add_stop_condition(epoch_reader.data_finished())
return [total]

View file

@ -389,8 +389,8 @@ def print_task_group(text, tg, header=None):
def print_job(text, job):
text(job.init_group, 'Job.current().init_group')
text(job.epoch_group, 'Job.current().epoch_group')
with text.context('Job.current().stop_signals'):
for out in job.stop_signals:
with text.context('Job.current().stop_conditions'):
for out in job.stop_conditions:
text.add(_print_task_output(out))
text(job.download_group, 'Job.current().download_group')
text(job.exit_group, 'Job.current().exit_group')