mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Per title Test Plan: Fixes existing tests Reviewed By: robieta Differential Revision: D28690296 fbshipit-source-id: d7b5b5065517373b75d501872814c89b24ec8cfc
43 lines
1.3 KiB
Python
43 lines
1.3 KiB
Python
|
|
|
|
|
|
|
|
from caffe2.python import core, workspace
|
|
from caffe2.python.core import CreatePythonOperator
|
|
import caffe2.python.hypothesis_test_util as hu
|
|
from hypothesis import given, settings
|
|
import hypothesis.strategies as st
|
|
import numpy as np
|
|
import unittest
|
|
|
|
class PythonOpTest(hu.HypothesisTestCase):
|
|
@given(x=hu.tensor(),
|
|
n=st.integers(min_value=1, max_value=20),
|
|
w=st.integers(min_value=1, max_value=20))
|
|
@settings(deadline=10000)
|
|
def test_simple_python_op(self, x, n, w):
|
|
def g(input_, output):
|
|
output[...] = input_
|
|
|
|
def f(inputs, outputs):
|
|
outputs[0].reshape(inputs[0].shape)
|
|
g(inputs[0].data, outputs[0].data)
|
|
|
|
ops = [CreatePythonOperator(f, ["x"], [str(i)]) for i in range(n)]
|
|
net = core.Net("net")
|
|
net.Proto().op.extend(ops)
|
|
net.Proto().type = "dag"
|
|
net.Proto().num_workers = w
|
|
iters = 100
|
|
plan = core.Plan("plan")
|
|
plan.AddStep(core.ExecutionStep("test-step", net, iters))
|
|
workspace.FeedBlob("x", x)
|
|
workspace.RunPlan(plan.Proto().SerializeToString())
|
|
for i in range(n):
|
|
y = workspace.FetchBlob(str(i))
|
|
np.testing.assert_almost_equal(x, y)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import unittest
|
|
unittest.main()
|