Add support for specifying device_option in Functional (#9619)

Summary:
e.g.
```
Functional.Add(x, y, device_option=DeviceOption(HIP, 0))

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9619

Differential Revision: D8966599

Pulled By: bddppq

fbshipit-source-id: 22235e42f19278e79802642798bf0ee70a1202f6
This commit is contained in:
Junjie Bai 2018-07-24 14:35:12 -07:00 committed by Facebook Github Bot
parent 37fc58f1d3
commit e5fe66d7ea
2 changed files with 26 additions and 16 deletions

View file

@ -4,6 +4,8 @@ from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core, workspace
from caffe2.proto import caffe2_pb2
from caffe2.python.onnx.workspace import Workspace
from collections import namedtuple
from six import string_types
@ -28,7 +30,7 @@ def namedtupledict(typename, field_names, *args, **kwargs):
class _Functional(object):
def __getattribute__(self, op_type):
def op_func(*inputs, **args):
ws = workspace.C.Workspace()
ws = Workspace()
schema = OpSchema.get(op_type)
input_prefix = 'input_'
output_prefix = 'output_'
@ -86,16 +88,18 @@ class _Functional(object):
output_names = get_name_list(
output_prefix, max_output, max_output
)
for i, input_blob in enumerate(inputs):
ws.create_blob(input_names[i]).feed(input_blob)
op = core.CreateOperator(
op_type, input_names, output_names, **args
)
ws._run_operator(op.SerializeToString())
# RunOperator
output_values = [ws.fetch_blob(x) for x in output_names]
return namedtupledict('output', output_names)(*output_values)
device_option = args.get('device_option', core.DeviceOption(caffe2_pb2.CPU))
with core.DeviceScope(device_option):
for i, input_blob in enumerate(inputs):
ws.FeedBlob(input_names[i], input_blob)
# RunOperator
ws.RunOperatorOnce(op)
output_values = [ws.FetchBlob(x) for x in output_names]
return namedtupledict('output', output_names)(*output_values)
return op_func

View file

@ -3,6 +3,8 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import unittest
from caffe2.python import core
from hypothesis import given
import hypothesis.strategies as st
@ -44,11 +46,11 @@ def _tensor_splits(draw, add_axis=False):
class TestFunctional(hu.HypothesisTestCase):
@given(X=hu.tensor(), engine=st.sampled_from(["", "CUDNN"]))
def test_relu(self, X, engine):
@given(X=hu.tensor(), engine=st.sampled_from(["", "CUDNN"]), **hu.gcs)
def test_relu(self, X, engine, gc, dc):
X += 0.02 * np.sign(X)
X[X == 0.0] += 0.02
output = Functional.Relu(X)
output = Functional.Relu(X, device_option=gc)
Y_l = output[0]
Y_d = output["output_0"]
@ -66,11 +68,11 @@ class TestFunctional(hu.HypothesisTestCase):
Y_d, Y_ref, err_msg='Functional Relu result mismatch'
)
@given(tensor_splits=_tensor_splits())
def test_concat(self, tensor_splits):
@given(tensor_splits=_tensor_splits(), **hu.gcs)
def test_concat(self, tensor_splits, gc, dc):
# Input Size: 1 -> inf
axis, _, splits = tensor_splits
concat_result, split_info = Functional.Concat(*splits, axis=axis)
concat_result, split_info = Functional.Concat(*splits, axis=axis, device_option=gc)
concat_result_ref = np.concatenate(splits, axis=axis)
split_info_ref = np.array([a.shape[axis] for a in splits])
@ -87,8 +89,8 @@ class TestFunctional(hu.HypothesisTestCase):
err_msg='Functional Concat split info mismatch'
)
@given(tensor_splits=_tensor_splits(), split_as_arg=st.booleans())
def test_split(self, tensor_splits, split_as_arg):
@given(tensor_splits=_tensor_splits(), split_as_arg=st.booleans(), **hu.gcs)
def test_split(self, tensor_splits, split_as_arg, gc, dc):
# Output Size: 1 - inf
axis, split_info, splits = tensor_splits
@ -100,7 +102,7 @@ class TestFunctional(hu.HypothesisTestCase):
else:
input_tensors = [np.concatenate(splits, axis=axis), split_info]
kwargs = dict(axis=axis, num_output=len(splits))
result = Functional.Split(*input_tensors, **kwargs)
result = Functional.Split(*input_tensors, device_option=gc, **kwargs)
def split_ref(input, split=split_info):
s = np.cumsum([0] + list(split))
@ -114,3 +116,7 @@ class TestFunctional(hu.HypothesisTestCase):
np.testing.assert_array_equal(
result[i], ref, err_msg='Functional Relu result mismatch'
)
if __name__ == '__main__':
unittest.main()