From e5fe66d7ea2515ccab0b95808a899ce7b037232d Mon Sep 17 00:00:00 2001 From: Junjie Bai Date: Tue, 24 Jul 2018 14:35:12 -0700 Subject: [PATCH] Add support for specifying device_option in Functional (#9619) Summary: e.g. ``` Functional.Add(x, y, device_option=DeviceOption(HIP, 0)) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/9619 Differential Revision: D8966599 Pulled By: bddppq fbshipit-source-id: 22235e42f19278e79802642798bf0ee70a1202f6 --- caffe2/python/functional.py | 18 +++++++++++------- caffe2/python/functional_test.py | 24 +++++++++++++++--------- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/caffe2/python/functional.py b/caffe2/python/functional.py index 6ec25ab4bb2..ec6847cb4e5 100644 --- a/caffe2/python/functional.py +++ b/caffe2/python/functional.py @@ -4,6 +4,8 @@ from __future__ import print_function from __future__ import unicode_literals from caffe2.python import core, workspace +from caffe2.proto import caffe2_pb2 +from caffe2.python.onnx.workspace import Workspace from collections import namedtuple from six import string_types @@ -28,7 +30,7 @@ def namedtupledict(typename, field_names, *args, **kwargs): class _Functional(object): def __getattribute__(self, op_type): def op_func(*inputs, **args): - ws = workspace.C.Workspace() + ws = Workspace() schema = OpSchema.get(op_type) input_prefix = 'input_' output_prefix = 'output_' @@ -86,16 +88,18 @@ class _Functional(object): output_names = get_name_list( output_prefix, max_output, max_output ) - for i, input_blob in enumerate(inputs): - ws.create_blob(input_names[i]).feed(input_blob) op = core.CreateOperator( op_type, input_names, output_names, **args ) - ws._run_operator(op.SerializeToString()) - # RunOperator - output_values = [ws.fetch_blob(x) for x in output_names] - return namedtupledict('output', output_names)(*output_values) + device_option = args.get('device_option', core.DeviceOption(caffe2_pb2.CPU)) + with core.DeviceScope(device_option): + for i, input_blob in enumerate(inputs): + ws.FeedBlob(input_names[i], input_blob) + # RunOperator + ws.RunOperatorOnce(op) + output_values = [ws.FetchBlob(x) for x in output_names] + return namedtupledict('output', output_names)(*output_values) return op_func diff --git a/caffe2/python/functional_test.py b/caffe2/python/functional_test.py index 731c1a7eaa2..e7803e829bb 100644 --- a/caffe2/python/functional_test.py +++ b/caffe2/python/functional_test.py @@ -3,6 +3,8 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals +import unittest + from caffe2.python import core from hypothesis import given import hypothesis.strategies as st @@ -44,11 +46,11 @@ def _tensor_splits(draw, add_axis=False): class TestFunctional(hu.HypothesisTestCase): - @given(X=hu.tensor(), engine=st.sampled_from(["", "CUDNN"])) - def test_relu(self, X, engine): + @given(X=hu.tensor(), engine=st.sampled_from(["", "CUDNN"]), **hu.gcs) + def test_relu(self, X, engine, gc, dc): X += 0.02 * np.sign(X) X[X == 0.0] += 0.02 - output = Functional.Relu(X) + output = Functional.Relu(X, device_option=gc) Y_l = output[0] Y_d = output["output_0"] @@ -66,11 +68,11 @@ class TestFunctional(hu.HypothesisTestCase): Y_d, Y_ref, err_msg='Functional Relu result mismatch' ) - @given(tensor_splits=_tensor_splits()) - def test_concat(self, tensor_splits): + @given(tensor_splits=_tensor_splits(), **hu.gcs) + def test_concat(self, tensor_splits, gc, dc): # Input Size: 1 -> inf axis, _, splits = tensor_splits - concat_result, split_info = Functional.Concat(*splits, axis=axis) + concat_result, split_info = Functional.Concat(*splits, axis=axis, device_option=gc) concat_result_ref = np.concatenate(splits, axis=axis) split_info_ref = np.array([a.shape[axis] for a in splits]) @@ -87,8 +89,8 @@ class TestFunctional(hu.HypothesisTestCase): err_msg='Functional Concat split info mismatch' ) - @given(tensor_splits=_tensor_splits(), split_as_arg=st.booleans()) - def test_split(self, tensor_splits, split_as_arg): + @given(tensor_splits=_tensor_splits(), split_as_arg=st.booleans(), **hu.gcs) + def test_split(self, tensor_splits, split_as_arg, gc, dc): # Output Size: 1 - inf axis, split_info, splits = tensor_splits @@ -100,7 +102,7 @@ class TestFunctional(hu.HypothesisTestCase): else: input_tensors = [np.concatenate(splits, axis=axis), split_info] kwargs = dict(axis=axis, num_output=len(splits)) - result = Functional.Split(*input_tensors, **kwargs) + result = Functional.Split(*input_tensors, device_option=gc, **kwargs) def split_ref(input, split=split_info): s = np.cumsum([0] + list(split)) @@ -114,3 +116,7 @@ class TestFunctional(hu.HypothesisTestCase): np.testing.assert_array_equal( result[i], ref, err_msg='Functional Relu result mismatch' ) + + +if __name__ == '__main__': + unittest.main()