mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Back out "Revert D25717510: Clean up some type annotations in benchmarks/fastrnns" (#50556)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50556 Original commit changeset: 2bcc19cd4340 Test Plan: Soft revert hammer Reviewed By: walterddr, seemethere Differential Revision: D25917129 fbshipit-source-id: e5caad77655789d607b84eee820aa7c960e00f51
This commit is contained in:
parent
51157e802f
commit
171f265d80
3 changed files with 27 additions and 36 deletions
|
|
@ -24,8 +24,8 @@ def milstm_cell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias):
|
|||
return hy, cy
|
||||
|
||||
|
||||
def lstm_cell(input, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
# type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
|
||||
def lstm_cell(input: Tensor, hidden: Tuple[Tensor, Tensor], w_ih: Tensor,
|
||||
w_hh: Tensor, b_ih: Tensor, b_hh: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
hx, cx = hidden
|
||||
gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh
|
||||
|
||||
|
|
@ -42,8 +42,8 @@ def lstm_cell(input, hidden, w_ih, w_hh, b_ih, b_hh):
|
|||
return hy, cy
|
||||
|
||||
|
||||
def flat_lstm_cell(input, hx, cx, w_ih, w_hh, b_ih, b_hh):
|
||||
# type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
|
||||
def flat_lstm_cell(input: Tensor, hx: Tensor, cx: Tensor, w_ih: Tensor,
|
||||
w_hh: Tensor, b_ih: Tensor, b_hh: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh
|
||||
|
||||
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
||||
|
|
@ -59,8 +59,8 @@ def flat_lstm_cell(input, hx, cx, w_ih, w_hh, b_ih, b_hh):
|
|||
return hy, cy
|
||||
|
||||
|
||||
def premul_lstm_cell(igates, hidden, w_hh, b_ih, b_hh):
|
||||
# type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
|
||||
def premul_lstm_cell(igates: Tensor, hidden: Tuple[Tensor, Tensor], w_hh: Tensor,
|
||||
b_ih: Tensor, b_hh: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
hx, cx = hidden
|
||||
gates = igates + torch.mm(hx, w_hh.t()) + b_ih + b_hh
|
||||
|
||||
|
|
@ -77,8 +77,7 @@ def premul_lstm_cell(igates, hidden, w_hh, b_ih, b_hh):
|
|||
return hy, cy
|
||||
|
||||
|
||||
def premul_lstm_cell_no_bias(igates, hidden, w_hh, b_hh):
|
||||
# type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor) -> Tuple[Tensor, Tensor]
|
||||
def premul_lstm_cell_no_bias(igates: Tensor, hidden: Tuple[Tensor, Tensor], w_hh: Tensor, b_hh: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
hx, cx = hidden
|
||||
gates = igates + torch.mm(hx, w_hh.t()) + b_hh
|
||||
|
||||
|
|
|
|||
|
|
@ -86,8 +86,7 @@ def script_lnlstm(input_size, hidden_size, num_layers, bias=True,
|
|||
LSTMState = namedtuple('LSTMState', ['hx', 'cx'])
|
||||
|
||||
|
||||
def reverse(lst):
|
||||
# type: (List[Tensor]) -> List[Tensor]
|
||||
def reverse(lst: List[Tensor]) -> List[Tensor]:
|
||||
return lst[::-1]
|
||||
|
||||
|
||||
|
|
@ -102,8 +101,7 @@ class LSTMCell(jit.ScriptModule):
|
|||
self.bias_hh = Parameter(torch.randn(4 * hidden_size))
|
||||
|
||||
@jit.script_method
|
||||
def forward(self, input, state):
|
||||
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
|
||||
def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
hx, cx = state
|
||||
gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
|
||||
torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
|
||||
|
|
@ -165,8 +163,7 @@ class LayerNormLSTMCell(jit.ScriptModule):
|
|||
self.layernorm_c = ln(hidden_size)
|
||||
|
||||
@jit.script_method
|
||||
def forward(self, input, state):
|
||||
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
|
||||
def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
hx, cx = state
|
||||
igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))
|
||||
hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))
|
||||
|
|
@ -190,8 +187,7 @@ class LSTMLayer(jit.ScriptModule):
|
|||
self.cell = cell(*cell_args)
|
||||
|
||||
@jit.script_method
|
||||
def forward(self, input, state):
|
||||
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
|
||||
def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
inputs = input.unbind(0)
|
||||
outputs = torch.jit.annotate(List[Tensor], [])
|
||||
for i in range(len(inputs)):
|
||||
|
|
@ -206,8 +202,7 @@ class ReverseLSTMLayer(jit.ScriptModule):
|
|||
self.cell = cell(*cell_args)
|
||||
|
||||
@jit.script_method
|
||||
def forward(self, input, state):
|
||||
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
|
||||
def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
inputs = reverse(input.unbind(0))
|
||||
outputs = jit.annotate(List[Tensor], [])
|
||||
for i in range(len(inputs)):
|
||||
|
|
@ -227,8 +222,7 @@ class BidirLSTMLayer(jit.ScriptModule):
|
|||
])
|
||||
|
||||
@jit.script_method
|
||||
def forward(self, input, states):
|
||||
# type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
|
||||
def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
|
||||
# List[LSTMState]: [forward LSTMState, backward LSTMState]
|
||||
outputs = jit.annotate(List[Tensor], [])
|
||||
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
|
||||
|
|
@ -258,8 +252,7 @@ class StackedLSTM(jit.ScriptModule):
|
|||
other_layer_args)
|
||||
|
||||
@jit.script_method
|
||||
def forward(self, input, states):
|
||||
# type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
|
||||
def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
|
||||
# List[LSTMState]: One state per layer
|
||||
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
|
||||
output = input
|
||||
|
|
@ -286,8 +279,7 @@ class StackedLSTM2(jit.ScriptModule):
|
|||
other_layer_args)
|
||||
|
||||
@jit.script_method
|
||||
def forward(self, input, states):
|
||||
# type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]
|
||||
def forward(self, input: Tensor, states: List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]:
|
||||
# List[List[LSTMState]]: The outer list is for layers,
|
||||
# inner list is for directions.
|
||||
output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])
|
||||
|
|
@ -322,8 +314,7 @@ class StackedLSTMWithDropout(jit.ScriptModule):
|
|||
self.dropout_layer = nn.Dropout(0.4)
|
||||
|
||||
@jit.script_method
|
||||
def forward(self, input, states):
|
||||
# type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
|
||||
def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
|
||||
# List[LSTMState]: One state per layer
|
||||
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
|
||||
output = input
|
||||
|
|
|
|||
|
|
@ -236,8 +236,10 @@ def varlen_pytorch_lstm_creator(**kwargs):
|
|||
|
||||
|
||||
def varlen_lstm_factory(cell, script):
|
||||
def dynamic_rnn(sequences, hiddens, wih, whh, bih, bhh):
|
||||
# type: (List[Tensor], Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]] # noqa
|
||||
def dynamic_rnn(sequences: List[Tensor], hiddens: Tuple[Tensor, Tensor], wih: Tensor,
|
||||
whh: Tensor, bih: Tensor, bhh: Tensor
|
||||
) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]]:
|
||||
# noqa
|
||||
hx, cx = hiddens
|
||||
hxs = hx.unbind(1)
|
||||
cxs = cx.unbind(1)
|
||||
|
|
@ -361,8 +363,8 @@ def lstm_inputs(seqLength=100, numLayers=1, inputSize=512, hiddenSize=512,
|
|||
|
||||
|
||||
def lstm_factory(cell, script):
|
||||
def dynamic_rnn(input, hidden, wih, whh, bih, bhh):
|
||||
# type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
|
||||
def dynamic_rnn(input: Tensor, hidden: Tuple[Tensor, Tensor], wih: Tensor, whh: Tensor,
|
||||
bih: Tensor, bhh: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
hx, cx = hidden
|
||||
outputs = []
|
||||
inputs = input.unbind(0)
|
||||
|
|
@ -381,8 +383,8 @@ def lstm_factory(cell, script):
|
|||
|
||||
# premul: we're going to premultiply the inputs & weights
|
||||
def lstm_factory_premul(premul_cell, script):
|
||||
def dynamic_rnn(input, hidden, wih, whh, bih, bhh):
|
||||
# type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
|
||||
def dynamic_rnn(input: Tensor, hidden: Tuple[Tensor, Tensor], wih: Tensor, whh: Tensor,
|
||||
bih: Tensor, bhh: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
hx, cx = hidden
|
||||
outputs = []
|
||||
inputs = torch.matmul(input, wih.t()).unbind(0)
|
||||
|
|
@ -401,8 +403,8 @@ def lstm_factory_premul(premul_cell, script):
|
|||
|
||||
# premul: we're going to premultiply the inputs & weights, and add bias
|
||||
def lstm_factory_premul_bias(premul_cell, script):
|
||||
def dynamic_rnn(input, hidden, wih, whh, bih, bhh):
|
||||
# type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
|
||||
def dynamic_rnn(input: Tensor, hidden: Tuple[Tensor, Tensor], wih: Tensor, whh: Tensor,
|
||||
bih: Tensor, bhh: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
hx, cx = hidden
|
||||
outputs = []
|
||||
inpSize = input.size()
|
||||
|
|
@ -444,8 +446,7 @@ def lstm_factory_simple(cell, script):
|
|||
|
||||
|
||||
def lstm_factory_multilayer(cell, script):
|
||||
def dynamic_rnn(input, hidden, params):
|
||||
# type: (Tensor, Tuple[Tensor, Tensor], List[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
|
||||
def dynamic_rnn(input: Tensor, hidden: Tuple[Tensor, Tensor], params: List[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
params_stride = 4 # NB: this assumes that biases are there
|
||||
hx, cx = hidden
|
||||
hy, cy = hidden # for scoping...
|
||||
|
|
|
|||
Loading…
Reference in a new issue