mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Upgrades python/hypothesis_test.py to use brew instead of CNNHelperModel
Summary: Upgrades this file to use brew instead of CNNHelperModel Reviewed By: harouwu Differential Revision: D5252089 fbshipit-source-id: 6df4350717c1d42bc4bcc63d255cd422f085ee05
This commit is contained in:
parent
e9cba7e69f
commit
5ce9cbae70
1 changed files with 15 additions and 11 deletions
|
|
@ -1422,8 +1422,9 @@ class TestOperators(hu.HypothesisTestCase):
|
|||
(["async_dag"] if workspace.has_gpu_support else [])),
|
||||
do=st.sampled_from(hu.device_options))
|
||||
def test_dag_net_forking(self, net_type, num_workers, do):
|
||||
from caffe2.python.cnn import CNNModelHelper
|
||||
m = CNNModelHelper()
|
||||
from caffe2.python.model_helper import ModelHelper
|
||||
from caffe2.python import brew
|
||||
m = ModelHelper(name="test_model")
|
||||
n = 10
|
||||
d = 2
|
||||
depth = 2
|
||||
|
|
@ -1437,16 +1438,18 @@ class TestOperators(hu.HypothesisTestCase):
|
|||
mid_1 = "{}_{}_m".format(i + 1, 2 * j)
|
||||
mid_2 = "{}_{}_m".format(i + 1, 2 * j + 1)
|
||||
top = "{}_{}".format(i, j)
|
||||
m.FC(
|
||||
brew.fc(
|
||||
m,
|
||||
bottom_1, mid_1,
|
||||
dim_in=d, dim_out=d,
|
||||
weight_init=m.ConstantInit(np.random.randn()),
|
||||
bias_init=m.ConstantInit(np.random.randn()))
|
||||
m.FC(
|
||||
weight_init=('ConstantFill', dict(value=np.random.randn())),
|
||||
bias_init=('ConstantFill', dict(value=np.random.randn())))
|
||||
brew.fc(
|
||||
m,
|
||||
bottom_2, mid_2,
|
||||
dim_in=d, dim_out=d,
|
||||
weight_init=m.ConstantInit(np.random.randn()),
|
||||
bias_init=m.ConstantInit(np.random.randn()))
|
||||
weight_init=('ConstantFill', dict(value=np.random.randn())),
|
||||
bias_init=('ConstantFill', dict(value=np.random.randn())))
|
||||
m.net.Sum([mid_1, mid_2], top)
|
||||
m.net.SquaredL2Distance(["0_0", "label"], "xent")
|
||||
m.net.AveragedLoss("xent", "loss")
|
||||
|
|
@ -1769,16 +1772,17 @@ class TestOperators(hu.HypothesisTestCase):
|
|||
n=st.integers(1, 5),
|
||||
d=st.integers(1, 5))
|
||||
def test_elman_recurrent_network(self, t, n, d):
|
||||
from caffe2.python import cnn
|
||||
from caffe2.python import model_helper, brew
|
||||
np.random.seed(1701)
|
||||
step_net = cnn.CNNModelHelper(name="Elman")
|
||||
step_net = model_helper.ModelHelper(name="Elman")
|
||||
# TODO: name scope external inputs and outputs
|
||||
step_net.Proto().external_input.extend(
|
||||
["input_t", "seq_lengths", "timestep",
|
||||
"hidden_t_prev", "gates_t_w", "gates_t_b"])
|
||||
step_net.Proto().type = "simple"
|
||||
step_net.Proto().external_output.extend(["hidden_t", "gates_t"])
|
||||
step_net.FC("hidden_t_prev", "gates_t", dim_in=d, dim_out=d, axis=2)
|
||||
brew.fc(step_net,
|
||||
"hidden_t_prev", "gates_t", dim_in=d, dim_out=d, axis=2)
|
||||
step_net.net.Sum(["gates_t", "input_t"], ["gates_t"])
|
||||
step_net.net.Sigmoid(["gates_t"], ["hidden_t"])
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue