From 6ff05fd49dcafbc3bc36bca537da269e20595b6f Mon Sep 17 00:00:00 2001 From: Alisson Gusatti Azzolini Date: Tue, 21 Feb 2017 20:42:35 -0800 Subject: [PATCH] Fix issues pickling jobs Summary: We were running into a problem where a Job could not be pickled. It needs to be pickled in order for the master flow operator to execute it using the session. This creates a concept of "compiled" Job, that pretty much only stores protobufs with the Jobs to be executed, avoiding any issue with pickling. Reviewed By: dzhulgakov Differential Revision: D4554799 fbshipit-source-id: 2ee9877ca49a796d51925e5ec917436e3d930984 --- caffe2/python/checkpoint.py | 31 +++++++++--- caffe2/python/checkpoint_test.py | 7 ++- caffe2/python/session.py | 84 ++++++++++++++++++++++---------- caffe2/python/task.py | 52 ++++++++++---------- 4 files changed, 114 insertions(+), 60 deletions(-) diff --git a/caffe2/python/checkpoint.py b/caffe2/python/checkpoint.py index 4bd8c672a22..21e3ab673a1 100644 --- a/caffe2/python/checkpoint.py +++ b/caffe2/python/checkpoint.py @@ -58,11 +58,30 @@ class Job(object): model = build_model(params) build_hogwild_trainer(reader, model) """ - def __init__(self): - self.init_group = TaskGroup(workspace_type=WorkspaceType.GLOBAL) - self.epoch_group = TaskGroup() - self.exit_group = TaskGroup() - self.stop_signals = [] + def __init__(self, + init_group=None, epoch_group=None, + exit_group=None, stop_signals=None, + nodes_to_checkpoint=None): + self.init_group = init_group or TaskGroup( + workspace_type=WorkspaceType.GLOBAL) + self.epoch_group = epoch_group or TaskGroup() + self.exit_group = exit_group or TaskGroup() + self.stop_signals = stop_signals or [] + self._nodes_to_checkpoint = nodes_to_checkpoint + + def nodes_to_checkpoint(self): + if self._nodes_to_checkpoint: + return self._nodes_to_checkpoint + else: + return self.init_group.used_nodes() + + def compile(self, session_class): + return Job( + init_group=session_class.compile(self.init_group), + epoch_group=session_class.compile(self.epoch_group), + exit_group=session_class.compile(self.exit_group), + stop_signals=self.stop_signals, + nodes_to_checkpoint=self.nodes_to_checkpoint()) def __enter__(self): self.epoch_group.__enter__() @@ -225,7 +244,7 @@ class JobRunner(object): if self.checkpoint: logger.info('Preparing checkpoint ...') client.run(self.checkpoint.init( - self.job.init_group.used_nodes(), + self.job.nodes_to_checkpoint(), retrieve_from_epoch=self.resume_from_epoch)) if from_scratch: logger.info('Saving first checkpoint ...') diff --git a/caffe2/python/checkpoint_test.py b/caffe2/python/checkpoint_test.py index 01d2b1f2fb6..9c8dc4d9f5b 100644 --- a/caffe2/python/checkpoint_test.py +++ b/caffe2/python/checkpoint_test.py @@ -55,13 +55,16 @@ class TestCheckpoint(TestCase): return output_fetcher.outputs()[0].fetch() session, checkpoint = builder() - num_epochs = JobRunner(job, checkpoint)(session) + compiled_job = job.compile(LocalSession) + num_epochs = JobRunner(compiled_job, checkpoint)(session) self.assertEquals(num_epochs, len(EXPECTED_TOTALS)) self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1]) for initial_epoch in range(1, num_epochs + 1): session, checkpoint = builder() - JobRunner(job, checkpoint, resume_from_epoch=initial_epoch)(session) + JobRunner( + compiled_job, + checkpoint, resume_from_epoch=initial_epoch)(session) self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1]) for epoch in range(1, num_epochs + 1): diff --git a/caffe2/python/session.py b/caffe2/python/session.py index 08361e6c548..5f0b979c494 100644 --- a/caffe2/python/session.py +++ b/caffe2/python/session.py @@ -5,7 +5,14 @@ from __future__ import unicode_literals from caffe2.python import core, workspace -from caffe2.python.task import Task, TaskGroup, WorkspaceType +from caffe2.python.task import Cluster, Task, TaskGroup, WorkspaceType + + +class CompiledRunnable(object): + """ Wrapper for compiled runnable returned from session.compile() """ + def __init__(self, obj, session_class): + self.obj = obj + self.session_class = session_class class Session(object): @@ -62,29 +69,46 @@ class Session(object): access each other's blobs. On the other hand, tasks running on the same node are guaranteed to run on the same workspace within a run. """ + + _compiled_cache = {} + def __init__(self): self._open = True - self._runnable_cache = {} def is_open(self): return self._open + @classmethod + def compile(cls, runnable): + if isinstance(runnable, CompiledRunnable): + assert cls == runnable.session_class, ( + 'Runnable was compiled for different session type. ' + + 'Need: %s, got: %s' % ( + cls.__name__, runnable.session_class.__name__)) + return runnable + + if runnable in cls._compiled_cache: + return cls._compiled_cache[runnable] + + if isinstance(runnable, TaskGroup): + tg = runnable + else: + tg = TaskGroup(workspace_type=WorkspaceType.GLOBAL) + if isinstance(runnable, Task): + tg.add(runnable) + elif isinstance(runnable, core.ExecutionStep): + tg.add(Task(step=runnable)) + else: + step = core.execution_step('runnable', runnable) + tg.add(Task(step=step)) + compiled = CompiledRunnable( + cls._compile_task_group(tg), session_class=cls) + cls._compiled_cache[runnable] = compiled + return compiled + def run(self, runnable): assert self.is_open(), 'Session is closed.' - if runnable not in self._runnable_cache: - if isinstance(runnable, TaskGroup): - tg = runnable - else: - tg = TaskGroup(workspace_type=WorkspaceType.GLOBAL) - if isinstance(runnable, Task): - tg.add(runnable) - elif isinstance(runnable, core.ExecutionStep): - tg.add(Task(step=runnable)) - else: - step = core.execution_step('runnable', runnable) - tg.add(Task(step=step)) - self._runnable_cache[runnable] = tg - self._run_task_group(self._runnable_cache[runnable]) + self._run_compiled(self.compile(runnable).obj) def close(self): if self.is_open(): @@ -94,9 +118,13 @@ class Session(object): def fetch_output(self, output): raise NotImplementedError() - def _run_task_group(self, task_group): + def _run_compiled(self, task_group): raise NotImplementedError() + @classmethod + def _compile_task_group(cls, task_group): + return task_group + def _do_close(self): pass @@ -121,25 +149,27 @@ class LocalSession(Session): def __init__(self, ws=None): Session.__init__(self) self._ws = ws or workspace.C.Workspace() - self._plan_caches = {} - def _run_task_group(self, task_group): - if task_group not in self._plan_caches: + @classmethod + def _compile_task_group(cls, task_group): + with Cluster(): task = task_group.to_task() - plan = core.Plan('task_group_plan') - plan.AddStep(task.get_step()) - self._plan_caches[task_group] = (plan, task) - plan, task = self._plan_caches[task_group] + plan = core.Plan('task_group_plan') + plan.AddStep(task.get_step()) + return (plan, task.output_list(), task.workspace_type) + + def _run_compiled(self, compiled): + plan, output_list, workspace_type = compiled # make sure the output blobs belong to the parent workspace outputs = [] - for name in task.output_names(): + for name in output_list.names(): self._ws.create_blob(str(name)) outputs.append(core.BlobReference(str(name))) - task.set_outputs(outputs, _fetch_func=self._fetch_output) + output_list.set_values(outputs, _fetch_func=self._fetch_output) task_ws = ( workspace.C.Workspace(self._ws) - if task.workspace_type == WorkspaceType.PRIVATE else self._ws) + if workspace_type == WorkspaceType.PRIVATE else self._ws) with workspace.WorkspaceGuard(task_ws): task_ws.run(plan) diff --git a/caffe2/python/task.py b/caffe2/python/task.py index 1cedbe70680..869b6e2d3aa 100644 --- a/caffe2/python/task.py +++ b/caffe2/python/task.py @@ -378,6 +378,30 @@ def final_output(blob_or_record): return cur_task.add_output(blob_or_record) +class TaskOutputList(object): + """ Keeps a list of outputs for a task """ + def __init__(self, outputs=None): + self.outputs = outputs or [] + + def names(self): + """ + Retrive the output names. + TODO(azzolini): make this schema-based. + """ + names = [] + for o in self.outputs: + names += o.names + return names + + def set_values(self, values, _fetch_func=None): + offset = 0 + for o in self.outputs: + num = len(o.names) + o.set(values[offset:offset + num], _fetch_func) + offset += num + assert offset == len(values), 'Wrong number of output values.' + + @context.define_context() class Task(object): """ @@ -515,34 +539,12 @@ class Task(object): self._step_with_setup = core.execution_step(self.name, []) return self._step_with_setup + def output_list(self): + return TaskOutputList(self._outputs) + def outputs(self): return self._outputs - def output_names(self): - """ - Retrive the output names. - TODO(azzolini): make this schema-based. - """ - names = [] - for o in self._outputs: - names += o.names - return names - - def set_outputs(self, values, _fetch_func): - """ - Set output values. - TODO(azzolini): make this schema-based. - """ - offset = 0 - for o in self._outputs: - num = len(o.names) - o.set(values[offset:offset + num], _fetch_func) - offset += num - assert offset == len(values), 'Wrong number of output values.' - - def resolved_outputs(self): - return [output.get() for output in self._outputs] - def _notify_used(self): self.get_step() self._already_used = True