[Dist Autograd] Functional API for Dist Autograd and Dist Optimizer (#33711)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33711

Fixed #33480

This makes `dist_autograd.backward` and `dist_optimizer.step` functional by making the user explicitly pass in the `context_id` as opposed to relying on the confusing thread_local context_id.

This diff incorporates these API changes and all places where these functions are called.

More concretely, this code:

```
with dist_autograd.context():
    # Forward pass.
    dist_autograd.backward([loss.sum()])
    dist_optim.step()
```

should now be written as follows:

```
with dist_autograd.context() as context_id:
    # Forward pass.
    dist_autograd.backward(context_id, [loss.sum()])
    dist_optim.step(context_id)
```

Test Plan: Ensuring all existing dist_autograd and dist_optimizer tests pass with the new API. Also added a new test case for input checking.

Differential Revision: D20011710

fbshipit-source-id: 216e12207934a2a79c7223332b97c558d89d4d65
This commit is contained in:
Omkar Salpekar 2020-02-26 18:57:25 -08:00 committed by Facebook Github Bot
parent 4c33222c51
commit 24dd800e6a
12 changed files with 134 additions and 60 deletions

View file

@ -113,7 +113,7 @@ From the user's perspective the autograd context is setup as follows:
import torch.distributed.autograd as dist_autograd
with dist_autograd.context() as context_id:
loss = model.forward()
dist_autograd.backward(loss)
dist_autograd.backward(context_id, loss)
Distributed Backward Pass
^^^^^^^^^^^^^^^^^^^^^^^^^
@ -258,7 +258,7 @@ As an example the complete code with distributed autograd would be as follows:
loss = t5.sum()
# Run the backward pass.
dist_autograd.backward([loss])
dist_autograd.backward(context_id, [loss])
# Retrieve the gradients from the context.
dist_autograd.get_gradients(context_id)
@ -350,7 +350,7 @@ file called "dist_autograd_simple.py", it can be run with the command
loss = rref1.to_here() + rref2.to_here()
# Backward pass (run distributed autograd).
dist_autograd.backward([loss.sum()])
dist_autograd.backward(context_id, [loss.sum()])
# Build DistributedOptimizer.
dist_optim = DistributedOptimizer(
@ -360,7 +360,7 @@ file called "dist_autograd_simple.py", it can be run with the command
)
# Run the distributed optimizer step.
dist_optim.step()
dist_optim.step(context_id)
def run_process(rank, world_size):
dst_rank = (rank + 1) % world_size

View file

@ -56,6 +56,7 @@ TEST_F(DistAutogradTest, TestSendFunctionInvalidInputs) {
TEST_F(DistAutogradTest, TestInitializedContextCleanup) {
autogradContainer_->newContext();
auto contextId = autogradContainer_->currentContext()->contextId();
auto& engine = DistEngine::getInstance();
ASSERT_EQ(0, engine.numBackwardPasses());
@ -67,7 +68,7 @@ TEST_F(DistAutogradTest, TestInitializedContextCleanup) {
ASSERT_NE(nullptr, t.grad_fn());
// Execute engine.
engine.execute({t}, /* retainGraph */ false);
engine.execute(contextId, {t}, /* retainGraph */ false);
// Validate appropriate cleanup.
ASSERT_EQ(0, engine.numBackwardPasses());

View file

@ -177,6 +177,14 @@ void DistAutogradContainer::eraseContextIdAndReset(int64_t context_id) {
}
}
void DistAutogradContainer::isValidContext(int64_t context_id) {
std::lock_guard<std::mutex> guard(autograd_context_lock_);
TORCH_CHECK(
autograd_context_.find(context_id) != autograd_context_.end(),
"Could not find autograd context with id: ",
context_id);
}
ContextPtr DistAutogradContainer::retrieveContext(int64_t context_id) {
std::lock_guard<std::mutex> guard(autograd_context_lock_);
TORCH_CHECK(

View file

@ -46,6 +46,9 @@ class TORCH_API DistAutogradContainer {
// context. Does nothing if it is not present.
void releaseContextIfPresent(int64_t context_id);
// Checks if the passed in context_id is valid.
void isValidContext(int64_t context_id);
// Retrieve the autograd context for a given context_id.
ContextPtr retrieveContext(int64_t context_id);

View file

@ -297,10 +297,14 @@ std::shared_ptr<rpc::FutureMessage> DistEngine::executeSendFunctionAsync(
}
}
void DistEngine::execute(const variable_list& roots, bool retainGraph) {
// Get the current context, if exists. This will throw if we don't have a
// valid context.
auto autogradContext = DistAutogradContainer::getInstance().currentContext();
void DistEngine::execute(
int64_t contextId,
const variable_list& roots,
bool retainGraph) {
// Retrieve the context for the given context_id. This will throw if the
// context_id is invalid.
auto autogradContext =
DistAutogradContainer::getInstance().retrieveContext(contextId);
// Perform initial pre-processing.
edge_list rootEdges;

View file

@ -33,7 +33,10 @@ class TORCH_API DistEngine {
// these variables and accumulate all the gradients in the current autograd
// context on each node. This method is used to kickoff distributed autograd
// on a single node.
void execute(const torch::autograd::variable_list& roots, bool retainGraph);
void execute(
int64_t context_id,
const torch::autograd::variable_list& roots,
bool retainGraph);
// Given a send function to execute in the autograd engine, ensures we compute
// dependencies once for this node and enqueues the send function for execute

View file

@ -77,6 +77,13 @@ PyObject* dist_autograd_init(PyObject* /* unused */) {
return DistAutogradContainer::getInstance().getMaxId();
});
module.def(
"_is_valid_context",
[](int64_t worker_id) {
DistAutogradContainer::getInstance().isValidContext(worker_id);
},
py::call_guard<py::gil_scoped_release>());
module.def(
"_retrieve_context",
[](int64_t context_id) -> const ContextPtr {
@ -106,20 +113,22 @@ PyObject* dist_autograd_init(PyObject* /* unused */) {
module.def(
"backward",
[](const std::vector<torch::Tensor>& roots, bool retainGraph = false) {
[](int64_t contextId,
const std::vector<torch::Tensor>& roots,
bool retainGraph = false) {
torch::autograd::variable_list variables;
for (const auto& root : roots) {
variables.emplace_back(root);
}
try {
DistEngine::getInstance().execute(variables, retainGraph);
} catch (python_error & e) {
DistEngine::getInstance().execute(contextId, variables, retainGraph);
} catch (python_error& e) {
// FIXME: crashes if exception type is not RuntimeError
throw std::runtime_error(e.what());
}
},
R"(
backward(roots: List[Tensor], retain_graph = False) -> None
backward(context_id: int, roots: List[Tensor], retain_graph = False) -> None
Kicks off the distributed backward pass using the provided roots. This
currently implements the :ref:`fast-mode-algorithm` which
@ -138,6 +147,7 @@ autograd context, we throw an error. You can retrieve the accumulated
gradients using the :meth:`~torch.distributed.autograd.get_gradients` API.
Arguments:
context_id (int): The autograd context id for which we should retrieve the gradients.
roots (list): Tensors which represent the roots of the autograd
computation. All the tensors should be scalars.
retain_graph(bool, optional): If False, the graph used to compute the grad
@ -152,8 +162,9 @@ Example::
>> with dist_autograd.context() as context_id:
>> pred = model.forward()
>> loss = loss_func(pred, loss)
>> dist_autograd.backward(loss)
>> dist_autograd.backward(context_id, loss)
)",
py::arg("contextId"),
py::arg("roots"),
py::arg("retain_graph") = false,
py::call_guard<py::gil_scoped_release>());
@ -186,7 +197,7 @@ Example::
>> t1 = torch.rand((3, 3), requires_grad=True)
>> t2 = torch.rand((3, 3), requires_grad=True)
>> loss = t1 + t2
>> dist_autograd.backward([loss.sum()])
>> dist_autograd.backward(context_id, [loss.sum()])
>> grads = dist_autograd.get_gradients(context_id)
>> print (grads[t1])
>> print (grads[t2])

View file

@ -28,7 +28,7 @@ class context(object):
>> t1 = torch.rand((3, 3), requires_grad=True)
>> t2 = torch.rand((3, 3), requires_grad=True)
>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum()
>> dist_autograd.backward([loss])
>> dist_autograd.backward(context_id, [loss])
'''
def __enter__(self):
self.autograd_context = _new_context()

View file

@ -95,7 +95,7 @@ class DistributedOptimizer:
>> loss = rref1.to_here() + rref2.to_here()
>>
>> # Backward pass.
>> dist_autograd.backward([loss.sum()])
>> dist_autograd.backward(context_id, [loss.sum()])
>>
>> # Optimizer.
>> dist_optim = DistributedOptimizer(
@ -103,7 +103,7 @@ class DistributedOptimizer:
>> [rref1, rref2],
>> lr=0.05,
>> )
>> dist_optim.step()
>> dist_optim.step(context_id)
"""
def __init__(self, optimizer_class, params_rref, *args, **kwargs):
per_worker_params_rref = defaultdict(list)
@ -122,7 +122,7 @@ class DistributedOptimizer:
self.remote_optimizers = _wait_for_all(remote_optim_futs)
def step(self):
def step(self, context_id):
"""
Performs a single optimization step.
@ -130,13 +130,17 @@ class DistributedOptimizer:
containing parameters to be optimized, and will block until all workers
return. The current distributed autograd
:class:`~torch.distributed.autograd.context` will be used globally.
Args:
context_id: the autograd context id for which we should run the
optimizer step.
"""
autograd_ctx_id = dist_autograd._current_context()._context_id()
dist_autograd._is_valid_context(context_id)
rpc_futs = []
for optim in self.remote_optimizers:
rpc_futs.append(rpc.rpc_async(
optim.owner(),
_local_optimizer_step,
args=(optim, autograd_ctx_id),
args=(optim, context_id),
))
_wait_for_all(rpc_futs)

View file

@ -144,7 +144,7 @@ def _all_contexts_cleaned_up(timeout_seconds=10):
def _run_trainer(rref_t1, t2, ps, rank_diff):
with dist_autograd.context() as context_id:
ret = rpc.rpc_sync(ps, my_rref_add, args=(rref_t1, t2))
dist_autograd.backward([ret.sum()])
dist_autograd.backward(context_id, [ret.sum()])
# prevent deleting dist autograd context
rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
@ -154,7 +154,7 @@ def _run_trainer(rref_t1, t2, ps, rank_diff):
def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff):
with dist_autograd.context() as context_id:
ret = rpc.rpc_sync(ps, my_script_ref_add, args=(rref_t1, t2))
dist_autograd.backward([ret.sum()])
dist_autograd.backward(context_id, [ret.sum()])
# prevent deleting dist autograd context
rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
@ -645,7 +645,7 @@ class DistAutogradTest(RpcAgentTestFixture):
else:
raise ValueError("Unrecognized ExecMode {}".format(exec_mode))
dist_autograd.backward([ret.sum()])
dist_autograd.backward(context_id, [ret.sum()])
rpc.rpc_sync(
"worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)
@ -829,7 +829,7 @@ class DistAutogradTest(RpcAgentTestFixture):
self._verify_backwards_remote(tensors, context_id, local_grads, *args)
def _verify_backwards_remote(self, tensors, context_id, local_grads, *args):
dist_autograd.backward(tensors)
dist_autograd.backward(context_id, tensors)
# Verify grads were accumulated appropriately.
grads = dist_autograd.get_gradients(context_id)
@ -885,7 +885,7 @@ class DistAutogradTest(RpcAgentTestFixture):
callee, my_nested_rref_add, args=(rref_owner, rref_t1, t2)
)
ret = rref.to_here()
dist_autograd.backward([ret.sum()])
dist_autograd.backward(context_id, [ret.sum()])
# verify grads on caller
grads = dist_autograd.get_gradients(context_id)
@ -1091,7 +1091,7 @@ class DistAutogradTest(RpcAgentTestFixture):
val = torch.mul(t1, t2)
# Run backward, this would hang forever.
dist_autograd.backward([val.sum()])
dist_autograd.backward(context_id, [val.sum()])
@dist_init
def test_backward_unused_send_function(self):
@ -1134,7 +1134,7 @@ class DistAutogradTest(RpcAgentTestFixture):
RuntimeError, "Simulate error on backward pass"
):
# Run backwards, and validate we receive an error.
dist_autograd.backward([val.sum()])
dist_autograd.backward(context_id, [val.sum()])
@unittest.skipIf(
torch.testing._internal.dist_utils.TEST_CONFIG.rpc_backend_name
@ -1168,7 +1168,7 @@ class DistAutogradTest(RpcAgentTestFixture):
with self.assertRaisesRegex(RuntimeError, get_shutdown_error_regex()):
# Run backwards, and validate we receive an error since all
# other nodes are dead.
dist_autograd.backward([res.sum()])
dist_autograd.backward(context_id, [res.sum()])
else:
# Exit all other nodes.
pass
@ -1178,13 +1178,15 @@ class DistAutogradTest(RpcAgentTestFixture):
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
context_id = 100 # dummy context_id
with self.assertRaisesRegex(
RuntimeError, "Current thread doesn't have a valid autograd context"
RuntimeError,
"Could not find autograd context with id: {}".format(context_id),
):
res = rpc.rpc_sync(
"worker{}".format(self._next_rank()), torch.add, args=(t1, t2)
)
dist_autograd.backward([res.sum()])
dist_autograd.backward(context_id, [res.sum()])
@dist_init
def test_backward_without_rpc(self):
@ -1194,7 +1196,7 @@ class DistAutogradTest(RpcAgentTestFixture):
t2 = torch.rand((3, 3), requires_grad=True)
t3 = torch.add(t1, t2)
dist_autograd.backward([t3.sum()])
dist_autograd.backward(context_id, [t3.sum()])
grads = dist_autograd.get_gradients(context_id)
self.assertEqual(2, len(grads))
self.assertIn(t1, grads)
@ -1207,28 +1209,31 @@ class DistAutogradTest(RpcAgentTestFixture):
with dist_autograd.context() as context_id:
with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
dist_autograd.backward(None)
dist_autograd.backward(context_id, None)
with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
dist_autograd.backward(None, None)
with self.assertRaisesRegex(
RuntimeError, "No tensors provided for gradient computation"
):
dist_autograd.backward([])
dist_autograd.backward(context_id, [])
with self.assertRaisesRegex(RuntimeError, "requires_grad not set on"):
t = torch.rand(3, 3)
dist_autograd.backward([t])
dist_autograd.backward(context_id, [t])
with self.assertRaisesRegex(
RuntimeError, "is not a scalar, all roots need to be scalar"
):
t = torch.rand(3, 3, requires_grad=True)
dist_autograd.backward([t])
dist_autograd.backward(context_id, [t])
with self.assertRaisesRegex(
RuntimeError, "does not have a valid gradient function"
):
t = torch.rand(1, requires_grad=True)
dist_autograd.backward([t])
dist_autograd.backward(context_id, [t])
@dist_init
def test_backward_multiple_roots(self):
@ -1349,7 +1354,7 @@ class DistAutogradTest(RpcAgentTestFixture):
with self.assertRaisesRegex(
RuntimeError, "Simulate error on backward pass"
):
dist_autograd.backward([loss.sum()])
dist_autograd.backward(context_id, [loss.sum()])
_backward_done = False
@ -1398,7 +1403,7 @@ class DistAutogradTest(RpcAgentTestFixture):
# we might see any error given by get_shutdown_error_regex().
with self.assertRaisesRegex(RuntimeError, get_shutdown_error_regex()):
# Run backwards, and validate we receive an error since rank 2 is dead.
dist_autograd.backward([res.sum()])
dist_autograd.backward(context_id, [res.sum()])
# Tell other nodes RPC is done.
for i in range(self.world_size):
@ -1437,7 +1442,7 @@ class DistAutogradTest(RpcAgentTestFixture):
DistAutogradTest._nested_python_udf,
args=(t1, t2, self._next_rank()),
)
dist_autograd.backward([loss.sum()])
dist_autograd.backward(context_id, [loss.sum()])
grads = dist_autograd.get_gradients(context_id)
self.assertEqual(t1.grad, grads[t1])
@ -1510,10 +1515,12 @@ class DistAutogradTest(RpcAgentTestFixture):
t1 = DistAutogradTest.MyBackwardFunc.apply(t1)
self.assertEqual(100, len(context._send_functions()))
context_id = 100 # dummy context_id
with self.assertRaisesRegex(
RuntimeError, "Could not find autograd context with id"
RuntimeError,
"Could not find autograd context with id: {}".format(context_id),
):
dist_autograd.backward([t1.sum()])
dist_autograd.backward(context_id, [t1.sum()])
# HACK: Killing workers since otherwise the autograd engine gets stuck on
# other nodes. The proper fix would be addressing:
@ -1569,8 +1576,8 @@ class DistAutogradTest(RpcAgentTestFixture):
)
# Run backward twice to test accumulation of sparse gradients.
dist_autograd.backward([res.sum()], retain_graph=True)
dist_autograd.backward([res.sum()])
dist_autograd.backward(context_id, [res.sum()], retain_graph=True)
dist_autograd.backward(context_id, [res.sum()])
remote_grad = rpc.rpc_sync(
"worker{}".format(dst),
@ -1597,7 +1604,7 @@ class DistAutogradTest(RpcAgentTestFixture):
exec_mode, DistAutogradTest._mixed_requires_grad, t1, t2
)
self.assertEqual(t1 * t2, ret)
dist_autograd.backward([ret.sum()])
dist_autograd.backward(context_id, [ret.sum()])
self.assertTrue(t1.requires_grad)
self.assertFalse(t2.requires_grad)
grads = dist_autograd.get_gradients(context_id)
@ -1652,7 +1659,7 @@ class DistAutogradTest(RpcAgentTestFixture):
)
i += 1
dist_autograd.backward([res[i].sum()])
dist_autograd.backward(context_id, [res[i].sum()])
debug_info = dist_autograd._get_debug_info()
num_autograd_context = int(debug_info["num_autograd_contexts"])
@ -1690,7 +1697,7 @@ class DistAutogradTest(RpcAgentTestFixture):
t5 = rpc.rpc_sync("worker0", torch.matmul, args=(t3, t4))
t6 = rpc.rpc_sync("worker0", torch.add, args=(t4, t5))
dist_autograd.backward([t6.sum()])
dist_autograd.backward(context_id, [t6.sum()])
@dist_init
def test_async_dist_autograd(self):
@ -1728,8 +1735,8 @@ class DistAutogradTest(RpcAgentTestFixture):
"worker{}".format(self._next_rank()), torch.matmul, args=(t1, t2)
)
# Run backward twice.
dist_autograd.backward([t3.sum()], retain_graph=True)
dist_autograd.backward([t3.sum()])
dist_autograd.backward(context_id, [t3.sum()], retain_graph=True)
dist_autograd.backward(context_id, [t3.sum()])
# Verify the gradients are same for local and remote execution.
grads = dist_autograd.get_gradients(context_id)
@ -1755,8 +1762,8 @@ class DistAutogradTest(RpcAgentTestFixture):
).sum()
# Run backward twice.
dist_autograd.backward([loss], retain_graph=True)
dist_autograd.backward([loss])
dist_autograd.backward(context_id, [loss], retain_graph=True)
dist_autograd.backward(context_id, [loss])
@dist_init
def test_multiple_backward(self):
@ -1771,4 +1778,37 @@ class DistAutogradTest(RpcAgentTestFixture):
# Run backward in a loop multiple times.
for i in range(1000):
dist_autograd.backward([loss], retain_graph=True)
dist_autograd.backward(context_id, [loss], retain_graph=True)
@unittest.skipIf(
not torch._six.PY3,
"Pytorch distributed autograd package " "does not support python2",
)
class DistAutogradJitTest(RpcAgentTestFixture):
@dist_init
def test_get_gradients(self):
dst_rank = self.rank
@torch.jit.script
def dist_get_gradients(context_id):
# type: (int) -> (Dict[Tensor, Tensor])
return dist_autograd.get_gradients(context_id)
FileCheck().check("get_gradients").run(str(dist_get_gradients.graph))
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
t3 = torch.add(t1, t2)
dist_autograd.backward(context_id, [t3.sum()])
grads = dist_get_gradients(context_id)
self.assertEqual(2, len(grads))
self.assertIn(t1, grads)
self.assertIn(t2, grads)
self.assertEqual(torch.ones(3, 3), grads[t1])
self.assertEqual(torch.ones(3, 3), grads[t2])
if __name__ == "__main__":
unittest.main()

View file

@ -111,7 +111,7 @@ class DistOptimizerTest(RpcAgentTestFixture):
FailingOptimizer, [remote_param1, remote_param2]
)
with dist_autograd.context():
with dist_autograd.context() as context_id:
torch.manual_seed(0)
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
@ -119,9 +119,9 @@ class DistOptimizerTest(RpcAgentTestFixture):
output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
loss = torch.add(output2.wait(), t1).sum()
dist_autograd.backward([loss])
dist_autograd.backward(context_id, [loss])
with self.assertRaisesRegex(Exception, "Error running optimizer"):
dist_optim.step()
dist_optim.step(context_id)
@dist_init()
def test_dist_optim_exception_on_constructor(self):
@ -179,7 +179,7 @@ class DistOptimizerTest(RpcAgentTestFixture):
optim.SGD, [remote_param1, remote_param2], lr=0.05
)
with dist_autograd.context():
with dist_autograd.context() as context_id:
torch.manual_seed(0)
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
@ -187,8 +187,8 @@ class DistOptimizerTest(RpcAgentTestFixture):
output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
loss = torch.add(output2.wait(), t1)
dist_autograd.backward([loss.sum()])
dist_optim.step()
dist_autograd.backward(context_id, [loss.sum()])
dist_optim.step(context_id)
new_w1 = rpc_async_method(MyModule.get_w, remote_module1).wait()
new_w2 = rpc_async_method(MyModule.get_w, remote_module2).wait()

View file

@ -28,7 +28,7 @@ class JitDistAutogradTest(RpcAgentTestFixture):
t2 = torch.rand((3, 3), requires_grad=True)
t3 = torch.add(t1, t2)
dist_autograd.backward([t3.sum()])
dist_autograd.backward(context_id, [t3.sum()])
grads = dist_get_gradients(context_id)
self.assertEqual(2, len(grads))