2017-09-28 23:00:15 +00:00
|
|
|
# Copyright (c) 2016-present, Facebook, Inc.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
##############################################################################
|
|
|
|
|
|
2016-11-14 22:58:04 +00:00
|
|
|
from __future__ import absolute_import
|
|
|
|
|
from __future__ import division
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
|
from caffe2.python import core, workspace
|
|
|
|
|
from caffe2.python.core import CreatePythonOperator
|
|
|
|
|
import caffe2.python.hypothesis_test_util as hu
|
|
|
|
|
from hypothesis import given
|
|
|
|
|
import hypothesis.strategies as st
|
|
|
|
|
import numpy as np
|
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
import numba
|
|
|
|
|
HAS_NUMBA = True
|
|
|
|
|
except ImportError:
|
|
|
|
|
HAS_NUMBA = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PythonOpTest(hu.HypothesisTestCase):
|
|
|
|
|
@unittest.skipIf(not HAS_NUMBA, "")
|
|
|
|
|
@given(x=hu.tensor(),
|
|
|
|
|
n=st.integers(min_value=1, max_value=20),
|
|
|
|
|
w=st.integers(min_value=1, max_value=20))
|
|
|
|
|
def test_multithreaded_evaluation_numba_nogil(self, x, n, w):
|
|
|
|
|
@numba.jit(nopython=True, nogil=True)
|
|
|
|
|
def g(input_, output):
|
|
|
|
|
output[...] = input_
|
|
|
|
|
|
|
|
|
|
def f(inputs, outputs):
|
|
|
|
|
outputs[0].reshape(inputs[0].shape)
|
|
|
|
|
g(inputs[0].data, outputs[0].data)
|
|
|
|
|
|
|
|
|
|
ops = [CreatePythonOperator(f, ["x"], [str(i)]) for i in range(n)]
|
|
|
|
|
net = core.Net("net")
|
|
|
|
|
net.Proto().op.extend(ops)
|
|
|
|
|
net.Proto().type = "dag"
|
|
|
|
|
net.Proto().num_workers = w
|
|
|
|
|
iters = 100
|
|
|
|
|
plan = core.Plan("plan")
|
|
|
|
|
plan.AddStep(core.ExecutionStep("test-step", net, iters))
|
|
|
|
|
workspace.FeedBlob("x", x)
|
|
|
|
|
workspace.RunPlan(plan.Proto().SerializeToString())
|
|
|
|
|
for i in range(n):
|
|
|
|
|
y = workspace.FetchBlob(str(i))
|
|
|
|
|
np.testing.assert_almost_equal(x, y)
|
2016-11-23 02:31:47 +00:00
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
import unittest
|
|
|
|
|
unittest.main()
|