Fix set_device_map docs (#53508)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53508

closes #53501

Differential Revision: D26885263

Test Plan: Imported from OSS

Reviewed By: H-Huang

Pulled By: mrshenli

fbshipit-source-id: dd0493e6f179d93b518af8f082399cacb1c7cba6
This commit is contained in:
Shen Li 2021-03-08 10:52:52 -08:00 committed by Facebook GitHub Bot
parent 93f1b10f72
commit 115df4fa28

View file

@ -73,16 +73,16 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
>>> # on worker 0
>>> options = TensorPipeRpcBackendOptions(
>>> num_worker_threads=8,
>>> device_maps={"worker1": {0, 1}}
>>> device_maps={"worker1": {0: 1}}
>>> # maps worker0's cuda:0 to worker1's cuda:1
>>> )
>>> options.set_device_map("worker1", {1, 2})
>>> options.set_device_map("worker1", {1: 2})
>>> # maps worker0's cuda:1 to worker1's cuda:2
>>>
>>> rpc.init_rpc(
>>> "worker0",
>>> rank=0,
>>> world_size=2
>>> world_size=2,
>>> backend=rpc.BackendType.TENSORPIPE,
>>> rpc_backend_options=options
>>> )
@ -94,7 +94,7 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
>>> # the device map, and hence will be moved back to cuda:0 and
>>> # cuda:1 on worker0
>>> print(rets[0]) # tensor([2., 2.], device='cuda:0')
>>> print(rets[0]) # tensor([2., 2.], device='cuda:1')
>>> print(rets[1]) # tensor([2., 2.], device='cuda:1')
"""
device_index_map = {}
curr_device_maps = super().device_maps