data_parallel_model names fix

Summary: Updated usage of deprecated functions in data_parallel_model.py

Reviewed By: akyrola

Differential Revision: D5738512

fbshipit-source-id: a7767e518da777ece058bcad480e5df1d91e9b42
This commit is contained in:
Wojciech Glogowski 2017-08-30 12:37:06 -07:00 committed by Facebook Github Bot
parent ae5101c137
commit a7ec5def7b

View file

@ -779,9 +779,9 @@ def GetLearningRateBlobNames(model):
'''
if model._optimizer is not None:
if model._device_type == caffe2_pb2.CPU:
return [model._optimizer.get_cpu_lr_blob_name()]
return [model._optimizer.get_cpu_blob_name('lr')]
elif model._device_type == caffe2_pb2.CUDA:
return [model._optimizer.get_gpu_lr_blob_name(gpu)
return [model._optimizer.get_gpu_blob_name('lr', gpu)
for gpu in model._devices]
else:
raise Exception(