mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: GRU is different than LSTM that it only has hidden states but no cell states. So in this case, reusing the code of _LSTM is problematic, as we need to delete the part of creating cell state, and change many other places that use hard-coded 4 (hidden_all, hidden, cell_all, cell) into 2 (hidden_all, hidden). Otherwise GRU will break during the backward pass, when the optimizer tries to apply gradient to each of the parameters, because cell state is never used, so it does not have gradients for the corresponding parameters (i.e., cell_state_w, cell_state_b). Differential Revision: D5589309 fbshipit-source-id: f5af67dfe0842acd68223f6da3e96a81639e8049
150 lines
4.3 KiB
Python
150 lines
4.3 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import functools
|
|
from caffe2.python import brew, rnn_cell
|
|
|
|
|
|
class GRUCell(rnn_cell.RNNCell):
|
|
|
|
def __init__(
|
|
self,
|
|
input_size,
|
|
hidden_size,
|
|
forget_bias, # Currently unused! Values here will be ignored.
|
|
memory_optimization,
|
|
drop_states=False,
|
|
**kwargs
|
|
):
|
|
super(GRUCell, self).__init__(**kwargs)
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.forget_bias = float(forget_bias)
|
|
self.memory_optimization = memory_optimization
|
|
self.drop_states = drop_states
|
|
|
|
# Unlike LSTMCell, GRUCell needs the output of one gate to feed into another.
|
|
# (reset gate -> output_gate)
|
|
# So, much of the logic to calculate the reset gate output and modified
|
|
# output gate input is set here, in the graph definition.
|
|
# The remaining logic lives in in gru_unit_op.{h,cc}.
|
|
def _apply(
|
|
self,
|
|
model,
|
|
input_t,
|
|
seq_lengths,
|
|
states,
|
|
timestep,
|
|
extra_inputs=None,
|
|
):
|
|
hidden_t_prev = states[0]
|
|
|
|
# Split input tensors to get inputs for each gate.
|
|
input_t_reset, input_t_update, input_t_output = model.net.Split(
|
|
[
|
|
input_t,
|
|
],
|
|
[
|
|
self.scope('input_t_reset'),
|
|
self.scope('input_t_update'),
|
|
self.scope('input_t_output'),
|
|
],
|
|
axis=2,
|
|
)
|
|
|
|
# Fully connected layers for reset and update gates.
|
|
reset_gate_t = brew.fc(
|
|
model,
|
|
hidden_t_prev,
|
|
self.scope('reset_gate_t'),
|
|
dim_in=self.hidden_size,
|
|
dim_out=self.hidden_size,
|
|
axis=2,
|
|
)
|
|
update_gate_t = brew.fc(
|
|
model,
|
|
hidden_t_prev,
|
|
self.scope('update_gate_t'),
|
|
dim_in=self.hidden_size,
|
|
dim_out=self.hidden_size,
|
|
axis=2,
|
|
)
|
|
|
|
# Calculating the modified hidden state going into output gate.
|
|
reset_gate_t = model.net.Sum(
|
|
[reset_gate_t, input_t_reset],
|
|
self.scope('reset_gate_t')
|
|
)
|
|
reset_gate_t_sigmoid = model.net.Sigmoid(
|
|
reset_gate_t,
|
|
self.scope('reset_gate_t_sigmoid')
|
|
)
|
|
modified_hidden_t_prev = model.net.Mul(
|
|
[reset_gate_t_sigmoid, hidden_t_prev],
|
|
self.scope('modified_hidden_t_prev')
|
|
)
|
|
output_gate_t = brew.fc(
|
|
model,
|
|
modified_hidden_t_prev,
|
|
self.scope('output_gate_t'),
|
|
dim_in=self.hidden_size,
|
|
dim_out=self.hidden_size,
|
|
axis=2,
|
|
)
|
|
|
|
# Add input contributions to update and output gate.
|
|
# We already (in-place) added input contributions to the reset gate.
|
|
update_gate_t = model.net.Sum(
|
|
[update_gate_t, input_t_update],
|
|
self.scope('update_gate_t'),
|
|
)
|
|
output_gate_t = model.net.Sum(
|
|
[output_gate_t, input_t_output],
|
|
self.scope('output_gate_t'),
|
|
)
|
|
|
|
# Join gate outputs and add input contributions
|
|
gates_t, _gates_t_concat_dims = model.net.Concat(
|
|
[
|
|
reset_gate_t,
|
|
update_gate_t,
|
|
output_gate_t,
|
|
],
|
|
[
|
|
self.scope('gates_t'),
|
|
self.scope('_gates_t_concat_dims'),
|
|
],
|
|
axis=2,
|
|
)
|
|
|
|
hidden_t = model.net.GRUUnit(
|
|
[
|
|
hidden_t_prev,
|
|
gates_t,
|
|
seq_lengths,
|
|
timestep,
|
|
],
|
|
list(self.get_state_names()),
|
|
forget_bias=self.forget_bias,
|
|
drop_states=self.drop_states,
|
|
)
|
|
model.net.AddExternalOutputs(hidden_t)
|
|
return (hidden_t,)
|
|
|
|
def prepare_input(self, model, input_blob):
|
|
return brew.fc(
|
|
model,
|
|
input_blob,
|
|
self.scope('i2h'),
|
|
dim_in=self.input_size,
|
|
dim_out=3 * self.hidden_size,
|
|
axis=2,
|
|
)
|
|
|
|
def get_state_names(self):
|
|
return (self.scope('hidden_t'),)
|
|
|
|
|
|
GRU = functools.partial(rnn_cell._RNN, GRUCell, no_cell_state=True)
|