mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
[Dist Autograd] Functional API for Dist Autograd and Dist Optimizer (#33711)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33711 Fixed #33480 This makes `dist_autograd.backward` and `dist_optimizer.step` functional by making the user explicitly pass in the `context_id` as opposed to relying on the confusing thread_local context_id. This diff incorporates these API changes and all places where these functions are called. More concretely, this code: ``` with dist_autograd.context(): # Forward pass. dist_autograd.backward([loss.sum()]) dist_optim.step() ``` should now be written as follows: ``` with dist_autograd.context() as context_id: # Forward pass. dist_autograd.backward(context_id, [loss.sum()]) dist_optim.step(context_id) ``` Test Plan: Ensuring all existing dist_autograd and dist_optimizer tests pass with the new API. Also added a new test case for input checking. Differential Revision: D20011710 fbshipit-source-id: 216e12207934a2a79c7223332b97c558d89d4d65
This commit is contained in:
parent
4c33222c51
commit
24dd800e6a
12 changed files with 134 additions and 60 deletions
|
|
@ -113,7 +113,7 @@ From the user's perspective the autograd context is setup as follows:
|
|||
import torch.distributed.autograd as dist_autograd
|
||||
with dist_autograd.context() as context_id:
|
||||
loss = model.forward()
|
||||
dist_autograd.backward(loss)
|
||||
dist_autograd.backward(context_id, loss)
|
||||
|
||||
Distributed Backward Pass
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
@ -258,7 +258,7 @@ As an example the complete code with distributed autograd would be as follows:
|
|||
loss = t5.sum()
|
||||
|
||||
# Run the backward pass.
|
||||
dist_autograd.backward([loss])
|
||||
dist_autograd.backward(context_id, [loss])
|
||||
|
||||
# Retrieve the gradients from the context.
|
||||
dist_autograd.get_gradients(context_id)
|
||||
|
|
@ -350,7 +350,7 @@ file called "dist_autograd_simple.py", it can be run with the command
|
|||
loss = rref1.to_here() + rref2.to_here()
|
||||
|
||||
# Backward pass (run distributed autograd).
|
||||
dist_autograd.backward([loss.sum()])
|
||||
dist_autograd.backward(context_id, [loss.sum()])
|
||||
|
||||
# Build DistributedOptimizer.
|
||||
dist_optim = DistributedOptimizer(
|
||||
|
|
@ -360,7 +360,7 @@ file called "dist_autograd_simple.py", it can be run with the command
|
|||
)
|
||||
|
||||
# Run the distributed optimizer step.
|
||||
dist_optim.step()
|
||||
dist_optim.step(context_id)
|
||||
|
||||
def run_process(rank, world_size):
|
||||
dst_rank = (rank + 1) % world_size
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ TEST_F(DistAutogradTest, TestSendFunctionInvalidInputs) {
|
|||
|
||||
TEST_F(DistAutogradTest, TestInitializedContextCleanup) {
|
||||
autogradContainer_->newContext();
|
||||
auto contextId = autogradContainer_->currentContext()->contextId();
|
||||
auto& engine = DistEngine::getInstance();
|
||||
ASSERT_EQ(0, engine.numBackwardPasses());
|
||||
|
||||
|
|
@ -67,7 +68,7 @@ TEST_F(DistAutogradTest, TestInitializedContextCleanup) {
|
|||
ASSERT_NE(nullptr, t.grad_fn());
|
||||
|
||||
// Execute engine.
|
||||
engine.execute({t}, /* retainGraph */ false);
|
||||
engine.execute(contextId, {t}, /* retainGraph */ false);
|
||||
|
||||
// Validate appropriate cleanup.
|
||||
ASSERT_EQ(0, engine.numBackwardPasses());
|
||||
|
|
|
|||
|
|
@ -177,6 +177,14 @@ void DistAutogradContainer::eraseContextIdAndReset(int64_t context_id) {
|
|||
}
|
||||
}
|
||||
|
||||
void DistAutogradContainer::isValidContext(int64_t context_id) {
|
||||
std::lock_guard<std::mutex> guard(autograd_context_lock_);
|
||||
TORCH_CHECK(
|
||||
autograd_context_.find(context_id) != autograd_context_.end(),
|
||||
"Could not find autograd context with id: ",
|
||||
context_id);
|
||||
}
|
||||
|
||||
ContextPtr DistAutogradContainer::retrieveContext(int64_t context_id) {
|
||||
std::lock_guard<std::mutex> guard(autograd_context_lock_);
|
||||
TORCH_CHECK(
|
||||
|
|
|
|||
|
|
@ -46,6 +46,9 @@ class TORCH_API DistAutogradContainer {
|
|||
// context. Does nothing if it is not present.
|
||||
void releaseContextIfPresent(int64_t context_id);
|
||||
|
||||
// Checks if the passed in context_id is valid.
|
||||
void isValidContext(int64_t context_id);
|
||||
|
||||
// Retrieve the autograd context for a given context_id.
|
||||
ContextPtr retrieveContext(int64_t context_id);
|
||||
|
||||
|
|
|
|||
|
|
@ -297,10 +297,14 @@ std::shared_ptr<rpc::FutureMessage> DistEngine::executeSendFunctionAsync(
|
|||
}
|
||||
}
|
||||
|
||||
void DistEngine::execute(const variable_list& roots, bool retainGraph) {
|
||||
// Get the current context, if exists. This will throw if we don't have a
|
||||
// valid context.
|
||||
auto autogradContext = DistAutogradContainer::getInstance().currentContext();
|
||||
void DistEngine::execute(
|
||||
int64_t contextId,
|
||||
const variable_list& roots,
|
||||
bool retainGraph) {
|
||||
// Retrieve the context for the given context_id. This will throw if the
|
||||
// context_id is invalid.
|
||||
auto autogradContext =
|
||||
DistAutogradContainer::getInstance().retrieveContext(contextId);
|
||||
|
||||
// Perform initial pre-processing.
|
||||
edge_list rootEdges;
|
||||
|
|
|
|||
|
|
@ -33,7 +33,10 @@ class TORCH_API DistEngine {
|
|||
// these variables and accumulate all the gradients in the current autograd
|
||||
// context on each node. This method is used to kickoff distributed autograd
|
||||
// on a single node.
|
||||
void execute(const torch::autograd::variable_list& roots, bool retainGraph);
|
||||
void execute(
|
||||
int64_t context_id,
|
||||
const torch::autograd::variable_list& roots,
|
||||
bool retainGraph);
|
||||
|
||||
// Given a send function to execute in the autograd engine, ensures we compute
|
||||
// dependencies once for this node and enqueues the send function for execute
|
||||
|
|
|
|||
|
|
@ -77,6 +77,13 @@ PyObject* dist_autograd_init(PyObject* /* unused */) {
|
|||
return DistAutogradContainer::getInstance().getMaxId();
|
||||
});
|
||||
|
||||
module.def(
|
||||
"_is_valid_context",
|
||||
[](int64_t worker_id) {
|
||||
DistAutogradContainer::getInstance().isValidContext(worker_id);
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
module.def(
|
||||
"_retrieve_context",
|
||||
[](int64_t context_id) -> const ContextPtr {
|
||||
|
|
@ -106,20 +113,22 @@ PyObject* dist_autograd_init(PyObject* /* unused */) {
|
|||
|
||||
module.def(
|
||||
"backward",
|
||||
[](const std::vector<torch::Tensor>& roots, bool retainGraph = false) {
|
||||
[](int64_t contextId,
|
||||
const std::vector<torch::Tensor>& roots,
|
||||
bool retainGraph = false) {
|
||||
torch::autograd::variable_list variables;
|
||||
for (const auto& root : roots) {
|
||||
variables.emplace_back(root);
|
||||
}
|
||||
try {
|
||||
DistEngine::getInstance().execute(variables, retainGraph);
|
||||
} catch (python_error & e) {
|
||||
DistEngine::getInstance().execute(contextId, variables, retainGraph);
|
||||
} catch (python_error& e) {
|
||||
// FIXME: crashes if exception type is not RuntimeError
|
||||
throw std::runtime_error(e.what());
|
||||
}
|
||||
},
|
||||
R"(
|
||||
backward(roots: List[Tensor], retain_graph = False) -> None
|
||||
backward(context_id: int, roots: List[Tensor], retain_graph = False) -> None
|
||||
|
||||
Kicks off the distributed backward pass using the provided roots. This
|
||||
currently implements the :ref:`fast-mode-algorithm` which
|
||||
|
|
@ -138,6 +147,7 @@ autograd context, we throw an error. You can retrieve the accumulated
|
|||
gradients using the :meth:`~torch.distributed.autograd.get_gradients` API.
|
||||
|
||||
Arguments:
|
||||
context_id (int): The autograd context id for which we should retrieve the gradients.
|
||||
roots (list): Tensors which represent the roots of the autograd
|
||||
computation. All the tensors should be scalars.
|
||||
retain_graph(bool, optional): If False, the graph used to compute the grad
|
||||
|
|
@ -152,8 +162,9 @@ Example::
|
|||
>> with dist_autograd.context() as context_id:
|
||||
>> pred = model.forward()
|
||||
>> loss = loss_func(pred, loss)
|
||||
>> dist_autograd.backward(loss)
|
||||
>> dist_autograd.backward(context_id, loss)
|
||||
)",
|
||||
py::arg("contextId"),
|
||||
py::arg("roots"),
|
||||
py::arg("retain_graph") = false,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
|
@ -186,7 +197,7 @@ Example::
|
|||
>> t1 = torch.rand((3, 3), requires_grad=True)
|
||||
>> t2 = torch.rand((3, 3), requires_grad=True)
|
||||
>> loss = t1 + t2
|
||||
>> dist_autograd.backward([loss.sum()])
|
||||
>> dist_autograd.backward(context_id, [loss.sum()])
|
||||
>> grads = dist_autograd.get_gradients(context_id)
|
||||
>> print (grads[t1])
|
||||
>> print (grads[t2])
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class context(object):
|
|||
>> t1 = torch.rand((3, 3), requires_grad=True)
|
||||
>> t2 = torch.rand((3, 3), requires_grad=True)
|
||||
>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum()
|
||||
>> dist_autograd.backward([loss])
|
||||
>> dist_autograd.backward(context_id, [loss])
|
||||
'''
|
||||
def __enter__(self):
|
||||
self.autograd_context = _new_context()
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ class DistributedOptimizer:
|
|||
>> loss = rref1.to_here() + rref2.to_here()
|
||||
>>
|
||||
>> # Backward pass.
|
||||
>> dist_autograd.backward([loss.sum()])
|
||||
>> dist_autograd.backward(context_id, [loss.sum()])
|
||||
>>
|
||||
>> # Optimizer.
|
||||
>> dist_optim = DistributedOptimizer(
|
||||
|
|
@ -103,7 +103,7 @@ class DistributedOptimizer:
|
|||
>> [rref1, rref2],
|
||||
>> lr=0.05,
|
||||
>> )
|
||||
>> dist_optim.step()
|
||||
>> dist_optim.step(context_id)
|
||||
"""
|
||||
def __init__(self, optimizer_class, params_rref, *args, **kwargs):
|
||||
per_worker_params_rref = defaultdict(list)
|
||||
|
|
@ -122,7 +122,7 @@ class DistributedOptimizer:
|
|||
|
||||
self.remote_optimizers = _wait_for_all(remote_optim_futs)
|
||||
|
||||
def step(self):
|
||||
def step(self, context_id):
|
||||
"""
|
||||
Performs a single optimization step.
|
||||
|
||||
|
|
@ -130,13 +130,17 @@ class DistributedOptimizer:
|
|||
containing parameters to be optimized, and will block until all workers
|
||||
return. The current distributed autograd
|
||||
:class:`~torch.distributed.autograd.context` will be used globally.
|
||||
|
||||
Args:
|
||||
context_id: the autograd context id for which we should run the
|
||||
optimizer step.
|
||||
"""
|
||||
autograd_ctx_id = dist_autograd._current_context()._context_id()
|
||||
dist_autograd._is_valid_context(context_id)
|
||||
rpc_futs = []
|
||||
for optim in self.remote_optimizers:
|
||||
rpc_futs.append(rpc.rpc_async(
|
||||
optim.owner(),
|
||||
_local_optimizer_step,
|
||||
args=(optim, autograd_ctx_id),
|
||||
args=(optim, context_id),
|
||||
))
|
||||
_wait_for_all(rpc_futs)
|
||||
|
|
|
|||
|
|
@ -144,7 +144,7 @@ def _all_contexts_cleaned_up(timeout_seconds=10):
|
|||
def _run_trainer(rref_t1, t2, ps, rank_diff):
|
||||
with dist_autograd.context() as context_id:
|
||||
ret = rpc.rpc_sync(ps, my_rref_add, args=(rref_t1, t2))
|
||||
dist_autograd.backward([ret.sum()])
|
||||
dist_autograd.backward(context_id, [ret.sum()])
|
||||
# prevent deleting dist autograd context
|
||||
rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
|
||||
rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
|
||||
|
|
@ -154,7 +154,7 @@ def _run_trainer(rref_t1, t2, ps, rank_diff):
|
|||
def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff):
|
||||
with dist_autograd.context() as context_id:
|
||||
ret = rpc.rpc_sync(ps, my_script_ref_add, args=(rref_t1, t2))
|
||||
dist_autograd.backward([ret.sum()])
|
||||
dist_autograd.backward(context_id, [ret.sum()])
|
||||
# prevent deleting dist autograd context
|
||||
rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
|
||||
rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
|
||||
|
|
@ -645,7 +645,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
else:
|
||||
raise ValueError("Unrecognized ExecMode {}".format(exec_mode))
|
||||
|
||||
dist_autograd.backward([ret.sum()])
|
||||
dist_autograd.backward(context_id, [ret.sum()])
|
||||
|
||||
rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)
|
||||
|
|
@ -829,7 +829,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
self._verify_backwards_remote(tensors, context_id, local_grads, *args)
|
||||
|
||||
def _verify_backwards_remote(self, tensors, context_id, local_grads, *args):
|
||||
dist_autograd.backward(tensors)
|
||||
dist_autograd.backward(context_id, tensors)
|
||||
|
||||
# Verify grads were accumulated appropriately.
|
||||
grads = dist_autograd.get_gradients(context_id)
|
||||
|
|
@ -885,7 +885,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
callee, my_nested_rref_add, args=(rref_owner, rref_t1, t2)
|
||||
)
|
||||
ret = rref.to_here()
|
||||
dist_autograd.backward([ret.sum()])
|
||||
dist_autograd.backward(context_id, [ret.sum()])
|
||||
|
||||
# verify grads on caller
|
||||
grads = dist_autograd.get_gradients(context_id)
|
||||
|
|
@ -1091,7 +1091,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
val = torch.mul(t1, t2)
|
||||
|
||||
# Run backward, this would hang forever.
|
||||
dist_autograd.backward([val.sum()])
|
||||
dist_autograd.backward(context_id, [val.sum()])
|
||||
|
||||
@dist_init
|
||||
def test_backward_unused_send_function(self):
|
||||
|
|
@ -1134,7 +1134,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
RuntimeError, "Simulate error on backward pass"
|
||||
):
|
||||
# Run backwards, and validate we receive an error.
|
||||
dist_autograd.backward([val.sum()])
|
||||
dist_autograd.backward(context_id, [val.sum()])
|
||||
|
||||
@unittest.skipIf(
|
||||
torch.testing._internal.dist_utils.TEST_CONFIG.rpc_backend_name
|
||||
|
|
@ -1168,7 +1168,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
with self.assertRaisesRegex(RuntimeError, get_shutdown_error_regex()):
|
||||
# Run backwards, and validate we receive an error since all
|
||||
# other nodes are dead.
|
||||
dist_autograd.backward([res.sum()])
|
||||
dist_autograd.backward(context_id, [res.sum()])
|
||||
else:
|
||||
# Exit all other nodes.
|
||||
pass
|
||||
|
|
@ -1178,13 +1178,15 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
t1 = torch.rand((3, 3), requires_grad=True)
|
||||
t2 = torch.rand((3, 3), requires_grad=True)
|
||||
|
||||
context_id = 100 # dummy context_id
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Current thread doesn't have a valid autograd context"
|
||||
RuntimeError,
|
||||
"Could not find autograd context with id: {}".format(context_id),
|
||||
):
|
||||
res = rpc.rpc_sync(
|
||||
"worker{}".format(self._next_rank()), torch.add, args=(t1, t2)
|
||||
)
|
||||
dist_autograd.backward([res.sum()])
|
||||
dist_autograd.backward(context_id, [res.sum()])
|
||||
|
||||
@dist_init
|
||||
def test_backward_without_rpc(self):
|
||||
|
|
@ -1194,7 +1196,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
t2 = torch.rand((3, 3), requires_grad=True)
|
||||
t3 = torch.add(t1, t2)
|
||||
|
||||
dist_autograd.backward([t3.sum()])
|
||||
dist_autograd.backward(context_id, [t3.sum()])
|
||||
grads = dist_autograd.get_gradients(context_id)
|
||||
self.assertEqual(2, len(grads))
|
||||
self.assertIn(t1, grads)
|
||||
|
|
@ -1207,28 +1209,31 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
with dist_autograd.context() as context_id:
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
|
||||
dist_autograd.backward(None)
|
||||
dist_autograd.backward(context_id, None)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
|
||||
dist_autograd.backward(None, None)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "No tensors provided for gradient computation"
|
||||
):
|
||||
dist_autograd.backward([])
|
||||
dist_autograd.backward(context_id, [])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "requires_grad not set on"):
|
||||
t = torch.rand(3, 3)
|
||||
dist_autograd.backward([t])
|
||||
dist_autograd.backward(context_id, [t])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "is not a scalar, all roots need to be scalar"
|
||||
):
|
||||
t = torch.rand(3, 3, requires_grad=True)
|
||||
dist_autograd.backward([t])
|
||||
dist_autograd.backward(context_id, [t])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "does not have a valid gradient function"
|
||||
):
|
||||
t = torch.rand(1, requires_grad=True)
|
||||
dist_autograd.backward([t])
|
||||
dist_autograd.backward(context_id, [t])
|
||||
|
||||
@dist_init
|
||||
def test_backward_multiple_roots(self):
|
||||
|
|
@ -1349,7 +1354,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Simulate error on backward pass"
|
||||
):
|
||||
dist_autograd.backward([loss.sum()])
|
||||
dist_autograd.backward(context_id, [loss.sum()])
|
||||
|
||||
_backward_done = False
|
||||
|
||||
|
|
@ -1398,7 +1403,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
# we might see any error given by get_shutdown_error_regex().
|
||||
with self.assertRaisesRegex(RuntimeError, get_shutdown_error_regex()):
|
||||
# Run backwards, and validate we receive an error since rank 2 is dead.
|
||||
dist_autograd.backward([res.sum()])
|
||||
dist_autograd.backward(context_id, [res.sum()])
|
||||
|
||||
# Tell other nodes RPC is done.
|
||||
for i in range(self.world_size):
|
||||
|
|
@ -1437,7 +1442,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
DistAutogradTest._nested_python_udf,
|
||||
args=(t1, t2, self._next_rank()),
|
||||
)
|
||||
dist_autograd.backward([loss.sum()])
|
||||
dist_autograd.backward(context_id, [loss.sum()])
|
||||
|
||||
grads = dist_autograd.get_gradients(context_id)
|
||||
self.assertEqual(t1.grad, grads[t1])
|
||||
|
|
@ -1510,10 +1515,12 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
t1 = DistAutogradTest.MyBackwardFunc.apply(t1)
|
||||
self.assertEqual(100, len(context._send_functions()))
|
||||
|
||||
context_id = 100 # dummy context_id
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Could not find autograd context with id"
|
||||
RuntimeError,
|
||||
"Could not find autograd context with id: {}".format(context_id),
|
||||
):
|
||||
dist_autograd.backward([t1.sum()])
|
||||
dist_autograd.backward(context_id, [t1.sum()])
|
||||
|
||||
# HACK: Killing workers since otherwise the autograd engine gets stuck on
|
||||
# other nodes. The proper fix would be addressing:
|
||||
|
|
@ -1569,8 +1576,8 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
)
|
||||
|
||||
# Run backward twice to test accumulation of sparse gradients.
|
||||
dist_autograd.backward([res.sum()], retain_graph=True)
|
||||
dist_autograd.backward([res.sum()])
|
||||
dist_autograd.backward(context_id, [res.sum()], retain_graph=True)
|
||||
dist_autograd.backward(context_id, [res.sum()])
|
||||
|
||||
remote_grad = rpc.rpc_sync(
|
||||
"worker{}".format(dst),
|
||||
|
|
@ -1597,7 +1604,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
exec_mode, DistAutogradTest._mixed_requires_grad, t1, t2
|
||||
)
|
||||
self.assertEqual(t1 * t2, ret)
|
||||
dist_autograd.backward([ret.sum()])
|
||||
dist_autograd.backward(context_id, [ret.sum()])
|
||||
self.assertTrue(t1.requires_grad)
|
||||
self.assertFalse(t2.requires_grad)
|
||||
grads = dist_autograd.get_gradients(context_id)
|
||||
|
|
@ -1652,7 +1659,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
)
|
||||
i += 1
|
||||
|
||||
dist_autograd.backward([res[i].sum()])
|
||||
dist_autograd.backward(context_id, [res[i].sum()])
|
||||
|
||||
debug_info = dist_autograd._get_debug_info()
|
||||
num_autograd_context = int(debug_info["num_autograd_contexts"])
|
||||
|
|
@ -1690,7 +1697,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
t5 = rpc.rpc_sync("worker0", torch.matmul, args=(t3, t4))
|
||||
t6 = rpc.rpc_sync("worker0", torch.add, args=(t4, t5))
|
||||
|
||||
dist_autograd.backward([t6.sum()])
|
||||
dist_autograd.backward(context_id, [t6.sum()])
|
||||
|
||||
@dist_init
|
||||
def test_async_dist_autograd(self):
|
||||
|
|
@ -1728,8 +1735,8 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
"worker{}".format(self._next_rank()), torch.matmul, args=(t1, t2)
|
||||
)
|
||||
# Run backward twice.
|
||||
dist_autograd.backward([t3.sum()], retain_graph=True)
|
||||
dist_autograd.backward([t3.sum()])
|
||||
dist_autograd.backward(context_id, [t3.sum()], retain_graph=True)
|
||||
dist_autograd.backward(context_id, [t3.sum()])
|
||||
|
||||
# Verify the gradients are same for local and remote execution.
|
||||
grads = dist_autograd.get_gradients(context_id)
|
||||
|
|
@ -1755,8 +1762,8 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
).sum()
|
||||
|
||||
# Run backward twice.
|
||||
dist_autograd.backward([loss], retain_graph=True)
|
||||
dist_autograd.backward([loss])
|
||||
dist_autograd.backward(context_id, [loss], retain_graph=True)
|
||||
dist_autograd.backward(context_id, [loss])
|
||||
|
||||
@dist_init
|
||||
def test_multiple_backward(self):
|
||||
|
|
@ -1771,4 +1778,37 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||
|
||||
# Run backward in a loop multiple times.
|
||||
for i in range(1000):
|
||||
dist_autograd.backward([loss], retain_graph=True)
|
||||
dist_autograd.backward(context_id, [loss], retain_graph=True)
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch._six.PY3,
|
||||
"Pytorch distributed autograd package " "does not support python2",
|
||||
)
|
||||
class DistAutogradJitTest(RpcAgentTestFixture):
|
||||
@dist_init
|
||||
def test_get_gradients(self):
|
||||
dst_rank = self.rank
|
||||
|
||||
@torch.jit.script
|
||||
def dist_get_gradients(context_id):
|
||||
# type: (int) -> (Dict[Tensor, Tensor])
|
||||
return dist_autograd.get_gradients(context_id)
|
||||
|
||||
FileCheck().check("get_gradients").run(str(dist_get_gradients.graph))
|
||||
with dist_autograd.context() as context_id:
|
||||
t1 = torch.rand((3, 3), requires_grad=True)
|
||||
t2 = torch.rand((3, 3), requires_grad=True)
|
||||
t3 = torch.add(t1, t2)
|
||||
|
||||
dist_autograd.backward(context_id, [t3.sum()])
|
||||
grads = dist_get_gradients(context_id)
|
||||
|
||||
self.assertEqual(2, len(grads))
|
||||
self.assertIn(t1, grads)
|
||||
self.assertIn(t2, grads)
|
||||
self.assertEqual(torch.ones(3, 3), grads[t1])
|
||||
self.assertEqual(torch.ones(3, 3), grads[t2])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ class DistOptimizerTest(RpcAgentTestFixture):
|
|||
FailingOptimizer, [remote_param1, remote_param2]
|
||||
)
|
||||
|
||||
with dist_autograd.context():
|
||||
with dist_autograd.context() as context_id:
|
||||
torch.manual_seed(0)
|
||||
t1 = torch.rand((3, 3), requires_grad=True)
|
||||
t2 = torch.rand((3, 3), requires_grad=True)
|
||||
|
|
@ -119,9 +119,9 @@ class DistOptimizerTest(RpcAgentTestFixture):
|
|||
output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
|
||||
loss = torch.add(output2.wait(), t1).sum()
|
||||
|
||||
dist_autograd.backward([loss])
|
||||
dist_autograd.backward(context_id, [loss])
|
||||
with self.assertRaisesRegex(Exception, "Error running optimizer"):
|
||||
dist_optim.step()
|
||||
dist_optim.step(context_id)
|
||||
|
||||
@dist_init()
|
||||
def test_dist_optim_exception_on_constructor(self):
|
||||
|
|
@ -179,7 +179,7 @@ class DistOptimizerTest(RpcAgentTestFixture):
|
|||
optim.SGD, [remote_param1, remote_param2], lr=0.05
|
||||
)
|
||||
|
||||
with dist_autograd.context():
|
||||
with dist_autograd.context() as context_id:
|
||||
torch.manual_seed(0)
|
||||
t1 = torch.rand((3, 3), requires_grad=True)
|
||||
t2 = torch.rand((3, 3), requires_grad=True)
|
||||
|
|
@ -187,8 +187,8 @@ class DistOptimizerTest(RpcAgentTestFixture):
|
|||
output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
|
||||
loss = torch.add(output2.wait(), t1)
|
||||
|
||||
dist_autograd.backward([loss.sum()])
|
||||
dist_optim.step()
|
||||
dist_autograd.backward(context_id, [loss.sum()])
|
||||
dist_optim.step(context_id)
|
||||
|
||||
new_w1 = rpc_async_method(MyModule.get_w, remote_module1).wait()
|
||||
new_w2 = rpc_async_method(MyModule.get_w, remote_module2).wait()
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class JitDistAutogradTest(RpcAgentTestFixture):
|
|||
t2 = torch.rand((3, 3), requires_grad=True)
|
||||
t3 = torch.add(t1, t2)
|
||||
|
||||
dist_autograd.backward([t3.sum()])
|
||||
dist_autograd.backward(context_id, [t3.sum()])
|
||||
grads = dist_get_gradients(context_id)
|
||||
|
||||
self.assertEqual(2, len(grads))
|
||||
|
|
|
|||
Loading…
Reference in a new issue