mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
parent
a986b85afd
commit
90586d925f
3 changed files with 12 additions and 12 deletions
|
|
@ -26,7 +26,7 @@ class Job(object):
|
|||
and data file lists.
|
||||
|
||||
The `epoch_group` will be run in a loop after init_group. The loop will
|
||||
exit when any of the stop signals added with `add_stop_signal` is True
|
||||
exit when any of the stop signals added with `add_stop_condition` is True
|
||||
at the end of an epoch.
|
||||
|
||||
The download_group will be run only once, after all the executions of
|
||||
|
|
@ -48,7 +48,7 @@ class Job(object):
|
|||
with Job.current().epoch_group:
|
||||
limited_reader = ReaderWithLimit(reader, num_iter=10000)
|
||||
data_queue = pipe(limited_reader, num_threads=8)
|
||||
Job.current().add_stop_signal(limited_reader.data_finished())
|
||||
Job.current().add_stop_condition(limited_reader.data_finished())
|
||||
return data_queue
|
||||
|
||||
def build_hogwild_trainer(reader, model):
|
||||
|
|
@ -67,13 +67,13 @@ class Job(object):
|
|||
def __init__(self,
|
||||
init_group=None, epoch_group=None,
|
||||
download_group=None, exit_group=None,
|
||||
stop_signals=None, nodes_to_checkpoint=None):
|
||||
stop_conditions=None, nodes_to_checkpoint=None):
|
||||
self.init_group = init_group or TaskGroup(
|
||||
workspace_type=WorkspaceType.GLOBAL)
|
||||
self.epoch_group = epoch_group or TaskGroup()
|
||||
self.download_group = download_group or TaskGroup()
|
||||
self.exit_group = exit_group or TaskGroup()
|
||||
self.stop_signals = stop_signals or []
|
||||
self.stop_conditions = stop_conditions or []
|
||||
self._nodes_to_checkpoint = nodes_to_checkpoint
|
||||
|
||||
def nodes_to_checkpoint(self):
|
||||
|
|
@ -96,12 +96,12 @@ class Job(object):
|
|||
def __exit__(self, *args):
|
||||
self.epoch_group.__exit__()
|
||||
|
||||
def add_stop_signal(self, output):
|
||||
def add_stop_condition(self, output):
|
||||
if isinstance(output, core.BlobReference):
|
||||
t = Task(outputs=[output], group=self.epoch_group)
|
||||
output = t.outputs()[0]
|
||||
assert isinstance(output, TaskOutput)
|
||||
self.stop_signals.append(output)
|
||||
self.stop_conditions.append(output)
|
||||
|
||||
|
||||
def get_ckpt_filename(node_name, epoch):
|
||||
|
|
@ -638,12 +638,12 @@ class JobRunner(object):
|
|||
logger.info('Starting epoch %d' % epoch)
|
||||
session.run(self.job.epoch_group)
|
||||
logger.info('Finished epoch %d' % epoch)
|
||||
stop_signals = [o.fetch() for o in self.job.stop_signals]
|
||||
stop_conditions = [o.fetch() for o in self.job.stop_conditions]
|
||||
|
||||
if self.checkpoint_manager:
|
||||
self.save_checkpoints(epoch, session)
|
||||
|
||||
if any(stop_signals):
|
||||
if any(stop_conditions):
|
||||
logger.info('Stopping')
|
||||
break
|
||||
epoch += 1
|
||||
|
|
@ -733,4 +733,4 @@ def epoch_limiter(job, num_epochs):
|
|||
epoch_net = core.Net('epoch_countdown')
|
||||
finished = epoch_net.CountDown(counter)
|
||||
output = Task(step=epoch_net, outputs=finished).outputs()[0]
|
||||
job.add_stop_signal(output)
|
||||
job.add_stop_condition(output)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def build_pipeline(node_id):
|
|||
|
||||
epoch_reader = ReaderWithLimit(full_reader, num_iter=3)
|
||||
pipe(epoch_reader, processor=inc_total)
|
||||
Job.current().add_stop_signal(epoch_reader.data_finished())
|
||||
Job.current().add_stop_condition(epoch_reader.data_finished())
|
||||
return [total]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -389,8 +389,8 @@ def print_task_group(text, tg, header=None):
|
|||
def print_job(text, job):
|
||||
text(job.init_group, 'Job.current().init_group')
|
||||
text(job.epoch_group, 'Job.current().epoch_group')
|
||||
with text.context('Job.current().stop_signals'):
|
||||
for out in job.stop_signals:
|
||||
with text.context('Job.current().stop_conditions'):
|
||||
for out in job.stop_conditions:
|
||||
text.add(_print_task_output(out))
|
||||
text(job.download_group, 'Job.current().download_group')
|
||||
text(job.exit_group, 'Job.current().exit_group')
|
||||
|
|
|
|||
Loading…
Reference in a new issue