pytorch/caffe2/python/operator_test/sparse_ops_test.py
Yangqing Jia 5eb836880d Add unittest.main() lines to test scripts under python/operator_test
Summary:
Needed by oss.

This is done by running the following line:

  find . -name "*_test.py" -exec sed -i '$ a \\nif __name__ == "__main__":\n    import unittest\n    unittest.main()' {} \;

Reviewed By: ajtulloch

Differential Revision: D4223848

fbshipit-source-id: ef4696e9701d45962134841165c53e76a2e19233
2016-11-29 15:18:37 -08:00

82 lines
3.2 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
from caffe2.python import core, workspace
from caffe2.python.test_util import TestCase, rand_array
class TestScatterOps(TestCase):
def test_configs(self):
return [
# first_dim, index_num, data_dims
(1, 2, []),
(5, 5, []),
(2, 5, []),
(1, 1, []),
(13, 7, []),
(13, 7, [2]),
(1, 5, [3, 4]),
(17, 8, [2, 2, 2]),
]
# TODO(dzhulgakov): add test cases for failure scenarios
def testScatterWeightedSum(self):
for num_args in [1, 2]:
ins = ['data', 'w0', 'indices']
for i in range(1, num_args + 1):
ins.extend(['x' + str(i), 'w' + str(i)])
op = core.CreateOperator('ScatterWeightedSum', ins, ['data'])
for first_dim, index_dim, extra_dims in self.test_configs():
for dtype in [np.int32, np.int64]:
d = rand_array(first_dim, *extra_dims)
ind = np.random.randint(0, first_dim,
index_dim).astype(dtype)
w0 = rand_array()
r = d.copy()
for i in ind:
r[i] *= w0
# forward
workspace.FeedBlob('data', d)
workspace.FeedBlob('w0', w0)
workspace.FeedBlob('indices', ind)
for inp in range(1, num_args + 1):
w = rand_array()
x = rand_array(index_dim, *extra_dims)
workspace.FeedBlob('x' + str(inp), x)
workspace.FeedBlob('w' + str(inp), w)
for i, j in enumerate(ind):
r[j] += w * x[i]
workspace.RunOperatorOnce(op)
out = workspace.FetchBlob('data')
np.testing.assert_allclose(out, r, rtol=1e-3)
def testScatterAssign(self):
op = core.CreateOperator('ScatterAssign',
['data', 'indices', 'slices'], ['data'])
for first_dim, index_dim, extra_dims in self.test_configs():
# let's have indices unique
if first_dim < index_dim:
first_dim, index_dim = index_dim, first_dim
for dtype in [np.int32, np.int64]:
d = rand_array(first_dim, *extra_dims)
ind = np.random.choice(first_dim, index_dim,
replace=False).astype(dtype)
x = rand_array(index_dim, *extra_dims)
r = d.copy()
r[ind] = x
# forward
workspace.FeedBlob('data', d)
workspace.FeedBlob('indices', ind)
workspace.FeedBlob('slices', x)
workspace.RunOperatorOnce(op)
out = workspace.FetchBlob('data')
np.testing.assert_allclose(out, r, rtol=1e-3)
if __name__ == "__main__":
import unittest
unittest.main()