mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
data_parallel_model names fix
Summary: Updated usage of deprecated functions in data_parallel_model.py Reviewed By: akyrola Differential Revision: D5738512 fbshipit-source-id: a7767e518da777ece058bcad480e5df1d91e9b42
This commit is contained in:
parent
ae5101c137
commit
a7ec5def7b
1 changed files with 2 additions and 2 deletions
|
|
@ -779,9 +779,9 @@ def GetLearningRateBlobNames(model):
|
|||
'''
|
||||
if model._optimizer is not None:
|
||||
if model._device_type == caffe2_pb2.CPU:
|
||||
return [model._optimizer.get_cpu_lr_blob_name()]
|
||||
return [model._optimizer.get_cpu_blob_name('lr')]
|
||||
elif model._device_type == caffe2_pb2.CUDA:
|
||||
return [model._optimizer.get_gpu_lr_blob_name(gpu)
|
||||
return [model._optimizer.get_gpu_blob_name('lr', gpu)
|
||||
for gpu in model._devices]
|
||||
else:
|
||||
raise Exception(
|
||||
|
|
|
|||
Loading…
Reference in a new issue