From 171f265d80b6245cc5a71d26ddcf39d1a12e8f07 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 14 Jan 2021 15:13:34 -0800 Subject: [PATCH] Back out "Revert D25717510: Clean up some type annotations in benchmarks/fastrnns" (#50556) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50556 Original commit changeset: 2bcc19cd4340 Test Plan: Soft revert hammer Reviewed By: walterddr, seemethere Differential Revision: D25917129 fbshipit-source-id: e5caad77655789d607b84eee820aa7c960e00f51 --- benchmarks/fastrnns/cells.py | 15 +++++++-------- benchmarks/fastrnns/custom_lstms.py | 27 +++++++++------------------ benchmarks/fastrnns/factory.py | 21 +++++++++++---------- 3 files changed, 27 insertions(+), 36 deletions(-) diff --git a/benchmarks/fastrnns/cells.py b/benchmarks/fastrnns/cells.py index fe9e67a0df2..6e797b9e2d1 100644 --- a/benchmarks/fastrnns/cells.py +++ b/benchmarks/fastrnns/cells.py @@ -24,8 +24,8 @@ def milstm_cell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias): return hy, cy -def lstm_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): - # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] +def lstm_cell(input: Tensor, hidden: Tuple[Tensor, Tensor], w_ih: Tensor, + w_hh: Tensor, b_ih: Tensor, b_hh: Tensor) -> Tuple[Tensor, Tensor]: hx, cx = hidden gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh @@ -42,8 +42,8 @@ def lstm_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): return hy, cy -def flat_lstm_cell(input, hx, cx, w_ih, w_hh, b_ih, b_hh): - # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] +def flat_lstm_cell(input: Tensor, hx: Tensor, cx: Tensor, w_ih: Tensor, + w_hh: Tensor, b_ih: Tensor, b_hh: Tensor) -> Tuple[Tensor, Tensor]: gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) @@ -59,8 +59,8 @@ def flat_lstm_cell(input, hx, cx, w_ih, w_hh, b_ih, b_hh): return hy, cy -def premul_lstm_cell(igates, hidden, w_hh, b_ih, b_hh): - # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] +def premul_lstm_cell(igates: Tensor, hidden: Tuple[Tensor, Tensor], w_hh: Tensor, + b_ih: Tensor, b_hh: Tensor) -> Tuple[Tensor, Tensor]: hx, cx = hidden gates = igates + torch.mm(hx, w_hh.t()) + b_ih + b_hh @@ -77,8 +77,7 @@ def premul_lstm_cell(igates, hidden, w_hh, b_ih, b_hh): return hy, cy -def premul_lstm_cell_no_bias(igates, hidden, w_hh, b_hh): - # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor) -> Tuple[Tensor, Tensor] +def premul_lstm_cell_no_bias(igates: Tensor, hidden: Tuple[Tensor, Tensor], w_hh: Tensor, b_hh: Tensor) -> Tuple[Tensor, Tensor]: hx, cx = hidden gates = igates + torch.mm(hx, w_hh.t()) + b_hh diff --git a/benchmarks/fastrnns/custom_lstms.py b/benchmarks/fastrnns/custom_lstms.py index d835b3e533f..60abb1ac574 100644 --- a/benchmarks/fastrnns/custom_lstms.py +++ b/benchmarks/fastrnns/custom_lstms.py @@ -86,8 +86,7 @@ def script_lnlstm(input_size, hidden_size, num_layers, bias=True, LSTMState = namedtuple('LSTMState', ['hx', 'cx']) -def reverse(lst): - # type: (List[Tensor]) -> List[Tensor] +def reverse(lst: List[Tensor]) -> List[Tensor]: return lst[::-1] @@ -102,8 +101,7 @@ class LSTMCell(jit.ScriptModule): self.bias_hh = Parameter(torch.randn(4 * hidden_size)) @jit.script_method - def forward(self, input, state): - # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: hx, cx = state gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih + torch.mm(hx, self.weight_hh.t()) + self.bias_hh) @@ -165,8 +163,7 @@ class LayerNormLSTMCell(jit.ScriptModule): self.layernorm_c = ln(hidden_size) @jit.script_method - def forward(self, input, state): - # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: hx, cx = state igates = self.layernorm_i(torch.mm(input, self.weight_ih.t())) hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t())) @@ -190,8 +187,7 @@ class LSTMLayer(jit.ScriptModule): self.cell = cell(*cell_args) @jit.script_method - def forward(self, input, state): - # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: inputs = input.unbind(0) outputs = torch.jit.annotate(List[Tensor], []) for i in range(len(inputs)): @@ -206,8 +202,7 @@ class ReverseLSTMLayer(jit.ScriptModule): self.cell = cell(*cell_args) @jit.script_method - def forward(self, input, state): - # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: inputs = reverse(input.unbind(0)) outputs = jit.annotate(List[Tensor], []) for i in range(len(inputs)): @@ -227,8 +222,7 @@ class BidirLSTMLayer(jit.ScriptModule): ]) @jit.script_method - def forward(self, input, states): - # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] + def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: # List[LSTMState]: [forward LSTMState, backward LSTMState] outputs = jit.annotate(List[Tensor], []) output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) @@ -258,8 +252,7 @@ class StackedLSTM(jit.ScriptModule): other_layer_args) @jit.script_method - def forward(self, input, states): - # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] + def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: # List[LSTMState]: One state per layer output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) output = input @@ -286,8 +279,7 @@ class StackedLSTM2(jit.ScriptModule): other_layer_args) @jit.script_method - def forward(self, input, states): - # type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]] + def forward(self, input: Tensor, states: List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]: # List[List[LSTMState]]: The outer list is for layers, # inner list is for directions. output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], []) @@ -322,8 +314,7 @@ class StackedLSTMWithDropout(jit.ScriptModule): self.dropout_layer = nn.Dropout(0.4) @jit.script_method - def forward(self, input, states): - # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] + def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: # List[LSTMState]: One state per layer output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) output = input diff --git a/benchmarks/fastrnns/factory.py b/benchmarks/fastrnns/factory.py index bb59a172dc1..91ac39b06a8 100644 --- a/benchmarks/fastrnns/factory.py +++ b/benchmarks/fastrnns/factory.py @@ -236,8 +236,10 @@ def varlen_pytorch_lstm_creator(**kwargs): def varlen_lstm_factory(cell, script): - def dynamic_rnn(sequences, hiddens, wih, whh, bih, bhh): - # type: (List[Tensor], Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]] # noqa + def dynamic_rnn(sequences: List[Tensor], hiddens: Tuple[Tensor, Tensor], wih: Tensor, + whh: Tensor, bih: Tensor, bhh: Tensor + ) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]]: + # noqa hx, cx = hiddens hxs = hx.unbind(1) cxs = cx.unbind(1) @@ -361,8 +363,8 @@ def lstm_inputs(seqLength=100, numLayers=1, inputSize=512, hiddenSize=512, def lstm_factory(cell, script): - def dynamic_rnn(input, hidden, wih, whh, bih, bhh): - # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def dynamic_rnn(input: Tensor, hidden: Tuple[Tensor, Tensor], wih: Tensor, whh: Tensor, + bih: Tensor, bhh: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: hx, cx = hidden outputs = [] inputs = input.unbind(0) @@ -381,8 +383,8 @@ def lstm_factory(cell, script): # premul: we're going to premultiply the inputs & weights def lstm_factory_premul(premul_cell, script): - def dynamic_rnn(input, hidden, wih, whh, bih, bhh): - # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def dynamic_rnn(input: Tensor, hidden: Tuple[Tensor, Tensor], wih: Tensor, whh: Tensor, + bih: Tensor, bhh: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: hx, cx = hidden outputs = [] inputs = torch.matmul(input, wih.t()).unbind(0) @@ -401,8 +403,8 @@ def lstm_factory_premul(premul_cell, script): # premul: we're going to premultiply the inputs & weights, and add bias def lstm_factory_premul_bias(premul_cell, script): - def dynamic_rnn(input, hidden, wih, whh, bih, bhh): - # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def dynamic_rnn(input: Tensor, hidden: Tuple[Tensor, Tensor], wih: Tensor, whh: Tensor, + bih: Tensor, bhh: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: hx, cx = hidden outputs = [] inpSize = input.size() @@ -444,8 +446,7 @@ def lstm_factory_simple(cell, script): def lstm_factory_multilayer(cell, script): - def dynamic_rnn(input, hidden, params): - # type: (Tensor, Tuple[Tensor, Tensor], List[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def dynamic_rnn(input: Tensor, hidden: Tuple[Tensor, Tensor], params: List[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: params_stride = 4 # NB: this assumes that biases are there hx, cx = hidden hy, cy = hidden # for scoping...