pytorch/caffe2/python/operator_test/copy_ops_test.py
Yury Zemlyanskiy 40534de705 Gradient for Copy operator
Summary:
One can find a reason, why I need gradient for CopyOp in this post - https://fb.facebook.com/groups/1405155842844877/permalink/1639683782725414/

Gradient for CopyOp is trivial in case the device was the same (cpu, or same gpu), but get's a little harder, when the copy was made across two different gpu.
I introduce new operator CopyOnDeviceLike, which has additional second input. The op copies the first input to the same device as the second one. The default implementation is exactly the same as CopyOp, but I specialize it for CUDAContext.

Please, let me know if I'm doing anything wrong here! That's my first caffe2 diff, related to operators definitions.

Reviewed By: Yangqing

Differential Revision: D4557258

fbshipit-source-id: 9494be589cc1e5696bbbfe25b7622aaa4c9efe4a
2017-02-16 06:11:27 -08:00

159 lines
6 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import unittest
from caffe2.proto import caffe2_pb2
from caffe2.python import workspace, core, cnn
class CopyOpsTest(unittest.TestCase):
def run_test_copy_gradient(self, device_opt):
model = cnn.CNNModelHelper(name="copy_test")
with core.DeviceScope(device_opt):
x = model.net.AddExternalInputs("x")
y = model.Copy(x, "y")
loss = model.AveragedLoss(y, "loss")
gradient_map = model.AddGradientOperators([loss])
workspace.FeedBlob(x, np.random.rand(32).astype(np.float32))
workspace.RunNetOnce(model.param_init_net)
workspace.RunNetOnce(model.net)
self.assertTrue(np.array_equal(
workspace.FetchBlob(x),
workspace.FetchBlob(y),
))
self.assertTrue(np.array_equal(
workspace.FetchBlob(gradient_map[x]),
workspace.FetchBlob(gradient_map[y]),
))
def test_copy_gradient_cpu(self):
self.run_test_copy_gradient(core.DeviceOption(caffe2_pb2.CPU, 0))
@unittest.skipIf(workspace.NumCudaDevices() < 1, "Need at least 1 GPU.")
def test_copy_gradient_gpu(self):
self.run_test_copy_gradient(core.DeviceOption(caffe2_pb2.CUDA, 0))
@unittest.skipIf(workspace.NumCudaDevices() < 2, "Need at least 2 GPU.")
def test_copy_gradient_multiple_gpus(self):
model = cnn.CNNModelHelper(name="copy_test")
with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU, 0)):
x_cpu = model.net.AddExternalInputs("x_cpu")
with core.DeviceScope(core.DeviceOption(caffe2_pb2.CUDA, 0)):
x_gpu_1 = model.CopyCPUToGPU(x_cpu, "x_gpu_1")
with core.DeviceScope(core.DeviceOption(caffe2_pb2.CUDA, 1)):
x_gpu_2 = model.Copy(x_gpu_1, "x_gpu_2")
loss = model.AveragedLoss(x_gpu_2, "loss")
gradient_map = model.AddGradientOperators([loss])
workspace.FeedBlob("x_cpu", np.random.rand(32).astype(np.float32))
workspace.RunNetOnce(model.param_init_net)
workspace.RunNetOnce(model.net)
print(model.net.Proto())
self.assertTrue(np.array_equal(
workspace.FetchBlob("x_gpu_1"),
workspace.FetchBlob("x_gpu_2"),
))
self.assertTrue(np.array_equal(
workspace.FetchBlob(gradient_map["x_gpu_1"]),
workspace.FetchBlob(gradient_map["x_gpu_2"]),
))
def get_op_with_output(model, output_blob_name):
for op in model.net.Proto().op:
if len(op.output) == 1 and op.output[0] == output_blob_name:
return op
return None
self.assertEqual(
get_op_with_output(model, "x_gpu_2_grad").device_option,
core.DeviceOption(caffe2_pb2.CUDA, 1),
)
self.assertEqual(
get_op_with_output(model, "x_cpu_grad").device_option,
core.DeviceOption(caffe2_pb2.CUDA, 0),
)
@unittest.skipIf(workspace.NumCudaDevices() < 1, "Need at least 1 GPU.")
def test_cpu2gpu_gpu2cpu_gradients(self):
model = cnn.CNNModelHelper(name="copy_test")
batch = 32
cpu_opt = core.DeviceOption(caffe2_pb2.CPU, 0)
gpu_opt = core.DeviceOption(caffe2_pb2.CUDA, 0)
with core.NameScope("cpu"):
with core.DeviceScope(cpu_opt):
x_cpu = model.FC('data', 'x_cpu', 16, 8)
with core.NameScope("gpu_0"):
with core.DeviceScope(gpu_opt):
x_gpu = model.CopyCPUToGPU(x_cpu, "x_gpu")
pred_gpu = model.FC(x_gpu, "pred_gpu", 8, 4)
pred_cpu = model.CopyGPUToCPU(pred_gpu, "pred_cpu")
with core.DeviceScope(cpu_opt):
with core.NameScope("cpu"):
(softmax, loss) = model.SoftmaxWithLoss(
[pred_cpu, "label"],
["softmax", "loss"],
)
gradient_map = model.AddGradientOperators([loss])
# Add param updates (for cpu and gpu)
init_net = model.param_init_net
with core.DeviceScope(cpu_opt):
with core.NameScope("cpu"):
ONE = init_net.ConstantFill([], "ONE", shape=[1], value=1.)
LR = init_net.ConstantFill([], "LR", shape=[1], value=-2.0)
for param in model.GetParams():
model.WeightedSum(
[param, ONE, gradient_map[param], LR],
param,
)
with core.NameScope("gpu_0"):
with core.DeviceScope(gpu_opt):
ONE = init_net.ConstantFill([], "ONE", shape=[1], value=1.)
LR = init_net.ConstantFill([], "LR", shape=[1], value=-2.0)
for param in model.GetParams():
model.WeightedSum(
[param, ONE, gradient_map[param], LR],
param,
)
with core.DeviceScope(cpu_opt):
workspace.FeedBlob(
'cpu/data',
np.random.rand(batch, 16).astype(np.float32),
)
workspace.FeedBlob(
'cpu/label',
np.random.randint(4, size=batch).astype(np.int32),
)
workspace.RunNetOnce(model.param_init_net)
workspace.CreateNet(model.net)
initial_params = {p: workspace.FetchBlob(p) for p in model.GetParams()}
workspace.RunNet(model.net.Proto().name)
updated_params = {p: workspace.FetchBlob(p) for p in model.GetParams()}
for p in model.GetParams():
g = gradient_map[p]
expected = initial_params[p] - 2.0 * workspace.FetchBlob(g)
actual = updated_params[p]
self.assertTrue(
np.array_equal(expected, updated_params[p]),
"Mismatch: {}: {}, {}".format(p, expected, actual),
)