mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add support for specifying device_option in Functional (#9619)
Summary: e.g. ``` Functional.Add(x, y, device_option=DeviceOption(HIP, 0)) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/9619 Differential Revision: D8966599 Pulled By: bddppq fbshipit-source-id: 22235e42f19278e79802642798bf0ee70a1202f6
This commit is contained in:
parent
37fc58f1d3
commit
e5fe66d7ea
2 changed files with 26 additions and 16 deletions
|
|
@ -4,6 +4,8 @@ from __future__ import print_function
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from caffe2.python import core, workspace
|
||||
from caffe2.proto import caffe2_pb2
|
||||
from caffe2.python.onnx.workspace import Workspace
|
||||
from collections import namedtuple
|
||||
from six import string_types
|
||||
|
||||
|
|
@ -28,7 +30,7 @@ def namedtupledict(typename, field_names, *args, **kwargs):
|
|||
class _Functional(object):
|
||||
def __getattribute__(self, op_type):
|
||||
def op_func(*inputs, **args):
|
||||
ws = workspace.C.Workspace()
|
||||
ws = Workspace()
|
||||
schema = OpSchema.get(op_type)
|
||||
input_prefix = 'input_'
|
||||
output_prefix = 'output_'
|
||||
|
|
@ -86,16 +88,18 @@ class _Functional(object):
|
|||
output_names = get_name_list(
|
||||
output_prefix, max_output, max_output
|
||||
)
|
||||
for i, input_blob in enumerate(inputs):
|
||||
ws.create_blob(input_names[i]).feed(input_blob)
|
||||
|
||||
op = core.CreateOperator(
|
||||
op_type, input_names, output_names, **args
|
||||
)
|
||||
ws._run_operator(op.SerializeToString())
|
||||
# RunOperator
|
||||
output_values = [ws.fetch_blob(x) for x in output_names]
|
||||
return namedtupledict('output', output_names)(*output_values)
|
||||
device_option = args.get('device_option', core.DeviceOption(caffe2_pb2.CPU))
|
||||
with core.DeviceScope(device_option):
|
||||
for i, input_blob in enumerate(inputs):
|
||||
ws.FeedBlob(input_names[i], input_blob)
|
||||
# RunOperator
|
||||
ws.RunOperatorOnce(op)
|
||||
output_values = [ws.FetchBlob(x) for x in output_names]
|
||||
return namedtupledict('output', output_names)(*output_values)
|
||||
|
||||
return op_func
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import unittest
|
||||
|
||||
from caffe2.python import core
|
||||
from hypothesis import given
|
||||
import hypothesis.strategies as st
|
||||
|
|
@ -44,11 +46,11 @@ def _tensor_splits(draw, add_axis=False):
|
|||
|
||||
|
||||
class TestFunctional(hu.HypothesisTestCase):
|
||||
@given(X=hu.tensor(), engine=st.sampled_from(["", "CUDNN"]))
|
||||
def test_relu(self, X, engine):
|
||||
@given(X=hu.tensor(), engine=st.sampled_from(["", "CUDNN"]), **hu.gcs)
|
||||
def test_relu(self, X, engine, gc, dc):
|
||||
X += 0.02 * np.sign(X)
|
||||
X[X == 0.0] += 0.02
|
||||
output = Functional.Relu(X)
|
||||
output = Functional.Relu(X, device_option=gc)
|
||||
Y_l = output[0]
|
||||
Y_d = output["output_0"]
|
||||
|
||||
|
|
@ -66,11 +68,11 @@ class TestFunctional(hu.HypothesisTestCase):
|
|||
Y_d, Y_ref, err_msg='Functional Relu result mismatch'
|
||||
)
|
||||
|
||||
@given(tensor_splits=_tensor_splits())
|
||||
def test_concat(self, tensor_splits):
|
||||
@given(tensor_splits=_tensor_splits(), **hu.gcs)
|
||||
def test_concat(self, tensor_splits, gc, dc):
|
||||
# Input Size: 1 -> inf
|
||||
axis, _, splits = tensor_splits
|
||||
concat_result, split_info = Functional.Concat(*splits, axis=axis)
|
||||
concat_result, split_info = Functional.Concat(*splits, axis=axis, device_option=gc)
|
||||
|
||||
concat_result_ref = np.concatenate(splits, axis=axis)
|
||||
split_info_ref = np.array([a.shape[axis] for a in splits])
|
||||
|
|
@ -87,8 +89,8 @@ class TestFunctional(hu.HypothesisTestCase):
|
|||
err_msg='Functional Concat split info mismatch'
|
||||
)
|
||||
|
||||
@given(tensor_splits=_tensor_splits(), split_as_arg=st.booleans())
|
||||
def test_split(self, tensor_splits, split_as_arg):
|
||||
@given(tensor_splits=_tensor_splits(), split_as_arg=st.booleans(), **hu.gcs)
|
||||
def test_split(self, tensor_splits, split_as_arg, gc, dc):
|
||||
# Output Size: 1 - inf
|
||||
axis, split_info, splits = tensor_splits
|
||||
|
||||
|
|
@ -100,7 +102,7 @@ class TestFunctional(hu.HypothesisTestCase):
|
|||
else:
|
||||
input_tensors = [np.concatenate(splits, axis=axis), split_info]
|
||||
kwargs = dict(axis=axis, num_output=len(splits))
|
||||
result = Functional.Split(*input_tensors, **kwargs)
|
||||
result = Functional.Split(*input_tensors, device_option=gc, **kwargs)
|
||||
|
||||
def split_ref(input, split=split_info):
|
||||
s = np.cumsum([0] + list(split))
|
||||
|
|
@ -114,3 +116,7 @@ class TestFunctional(hu.HypothesisTestCase):
|
|||
np.testing.assert_array_equal(
|
||||
result[i], ref, err_msg='Functional Relu result mismatch'
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue