pytorch/caffe2/python/gru_cell.py
Tao Wu 7b86a34610 modify _LSTM into _RNN to adapt GRU
Summary: GRU is different than LSTM that it only has hidden states but no cell states. So in this case, reusing the code of _LSTM is problematic, as we need to delete the part of creating cell state, and change many other places that use hard-coded 4 (hidden_all, hidden, cell_all, cell) into 2 (hidden_all, hidden). Otherwise GRU will break during the backward pass, when the optimizer tries to apply gradient to each of the parameters, because cell state is never used, so it does not have gradients for the corresponding parameters (i.e., cell_state_w, cell_state_b).

Differential Revision: D5589309

fbshipit-source-id: f5af67dfe0842acd68223f6da3e96a81639e8049
2017-08-09 13:24:45 -07:00

150 lines
4.3 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import functools
from caffe2.python import brew, rnn_cell
class GRUCell(rnn_cell.RNNCell):
def __init__(
self,
input_size,
hidden_size,
forget_bias, # Currently unused! Values here will be ignored.
memory_optimization,
drop_states=False,
**kwargs
):
super(GRUCell, self).__init__(**kwargs)
self.input_size = input_size
self.hidden_size = hidden_size
self.forget_bias = float(forget_bias)
self.memory_optimization = memory_optimization
self.drop_states = drop_states
# Unlike LSTMCell, GRUCell needs the output of one gate to feed into another.
# (reset gate -> output_gate)
# So, much of the logic to calculate the reset gate output and modified
# output gate input is set here, in the graph definition.
# The remaining logic lives in in gru_unit_op.{h,cc}.
def _apply(
self,
model,
input_t,
seq_lengths,
states,
timestep,
extra_inputs=None,
):
hidden_t_prev = states[0]
# Split input tensors to get inputs for each gate.
input_t_reset, input_t_update, input_t_output = model.net.Split(
[
input_t,
],
[
self.scope('input_t_reset'),
self.scope('input_t_update'),
self.scope('input_t_output'),
],
axis=2,
)
# Fully connected layers for reset and update gates.
reset_gate_t = brew.fc(
model,
hidden_t_prev,
self.scope('reset_gate_t'),
dim_in=self.hidden_size,
dim_out=self.hidden_size,
axis=2,
)
update_gate_t = brew.fc(
model,
hidden_t_prev,
self.scope('update_gate_t'),
dim_in=self.hidden_size,
dim_out=self.hidden_size,
axis=2,
)
# Calculating the modified hidden state going into output gate.
reset_gate_t = model.net.Sum(
[reset_gate_t, input_t_reset],
self.scope('reset_gate_t')
)
reset_gate_t_sigmoid = model.net.Sigmoid(
reset_gate_t,
self.scope('reset_gate_t_sigmoid')
)
modified_hidden_t_prev = model.net.Mul(
[reset_gate_t_sigmoid, hidden_t_prev],
self.scope('modified_hidden_t_prev')
)
output_gate_t = brew.fc(
model,
modified_hidden_t_prev,
self.scope('output_gate_t'),
dim_in=self.hidden_size,
dim_out=self.hidden_size,
axis=2,
)
# Add input contributions to update and output gate.
# We already (in-place) added input contributions to the reset gate.
update_gate_t = model.net.Sum(
[update_gate_t, input_t_update],
self.scope('update_gate_t'),
)
output_gate_t = model.net.Sum(
[output_gate_t, input_t_output],
self.scope('output_gate_t'),
)
# Join gate outputs and add input contributions
gates_t, _gates_t_concat_dims = model.net.Concat(
[
reset_gate_t,
update_gate_t,
output_gate_t,
],
[
self.scope('gates_t'),
self.scope('_gates_t_concat_dims'),
],
axis=2,
)
hidden_t = model.net.GRUUnit(
[
hidden_t_prev,
gates_t,
seq_lengths,
timestep,
],
list(self.get_state_names()),
forget_bias=self.forget_bias,
drop_states=self.drop_states,
)
model.net.AddExternalOutputs(hidden_t)
return (hidden_t,)
def prepare_input(self, model, input_blob):
return brew.fc(
model,
input_blob,
self.scope('i2h'),
dim_in=self.input_size,
dim_out=3 * self.hidden_size,
axis=2,
)
def get_state_names(self):
return (self.scope('hidden_t'),)
GRU = functools.partial(rnn_cell._RNN, GRUCell, no_cell_state=True)