mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Make xray net_type configurable
Summary: Make xray net_type configub a command line argument Differential Revision: D4262076 fbshipit-source-id: e2ecb9cd5bee5d6aaebe0ea8d2d4d9b378058cba
This commit is contained in:
parent
6c13dc3dd0
commit
1aba4280d8
1 changed files with 3 additions and 1 deletions
|
|
@ -21,6 +21,7 @@ def Parallelize_GPU(
|
|||
param_update_builder_fun,
|
||||
devices=range(0, workspace.NumCudaDevices()),
|
||||
rendezvous=None,
|
||||
net_type='dag',
|
||||
):
|
||||
'''
|
||||
Function to create a model that can run on many GPUs.
|
||||
|
|
@ -48,12 +49,13 @@ def Parallelize_GPU(
|
|||
rendezvous: used for rendezvous in distributed computation, if None
|
||||
then only one node is used. To create rendezvous,
|
||||
use <TBD>.
|
||||
net_type: Network type
|
||||
|
||||
'''
|
||||
log.info("Parallelizing model for devices: {}".format(devices))
|
||||
extra_workers = 8 if rendezvous is not None else 0 # best-guess
|
||||
model_helper_obj.net.Proto().num_workers = len(devices) * 2 + extra_workers
|
||||
model_helper_obj.net.Proto().type = 'dag'
|
||||
model_helper_obj.net.Proto().type = net_type
|
||||
|
||||
# Store some information in the model -- a bit ugly
|
||||
model_helper_obj._devices = devices
|
||||
|
|
|
|||
Loading…
Reference in a new issue