mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51769 Remove some Python 2 compatibility code that otherwise causes errors to be reported from static type checkers. Static type checkers complain that the old Python 2 modules and functions referenced by this code do not exist. Given that Python 2 support is entirely deprecated now we can simply remove the compatibility code. ghstack-source-id: 121313191 Test Plan: Was able to get Pyre to successfully type check the `caffe2/python` directory with this and some other changes. Reviewed By: Tianshu-Bao Differential Revision: D26271723 Pulled By: simpkins fbshipit-source-id: fec8a09466be6867388832380480aafd36616aa1
655 lines
24 KiB
Python
655 lines
24 KiB
Python
## @package translate
|
|
# Module caffe2.python.models.seq2seq.translate
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABCMeta, abstractmethod
|
|
import argparse
|
|
from future.utils import viewitems
|
|
import logging
|
|
import numpy as np
|
|
import sys
|
|
|
|
from caffe2.python import core, rnn_cell, workspace
|
|
from caffe2.python.models.seq2seq.beam_search import BeamSearchForwardOnly
|
|
from caffe2.python.models.seq2seq.seq2seq_model_helper import Seq2SeqModelHelper
|
|
import caffe2.python.models.seq2seq.seq2seq_util as seq2seq_util
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
logger.addHandler(logging.StreamHandler(sys.stderr))
|
|
|
|
|
|
def _weighted_sum(model, values, weight, output_name):
|
|
values_weights = zip(values, [weight] * len(values))
|
|
values_weights_flattened = [x for v_w in values_weights for x in v_w]
|
|
return model.net.WeightedSum(
|
|
values_weights_flattened,
|
|
output_name,
|
|
)
|
|
|
|
|
|
class Seq2SeqModelCaffe2EnsembleDecoderBase(metaclass=ABCMeta):
|
|
|
|
@abstractmethod
|
|
def get_model_file(self, model):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_db_type(self):
|
|
pass
|
|
|
|
def build_word_rewards(self, vocab_size, word_reward, unk_reward):
|
|
word_rewards = np.full([vocab_size], word_reward, dtype=np.float32)
|
|
word_rewards[seq2seq_util.PAD_ID] = 0
|
|
word_rewards[seq2seq_util.GO_ID] = 0
|
|
word_rewards[seq2seq_util.EOS_ID] = 0
|
|
word_rewards[seq2seq_util.UNK_ID] = word_reward + unk_reward
|
|
return word_rewards
|
|
|
|
def load_models(self):
|
|
db_reader = 'reader'
|
|
for model, scope_name in zip(
|
|
self.models,
|
|
self.decoder_scope_names,
|
|
):
|
|
params_for_current_model = [
|
|
param
|
|
for param in self.model.GetAllParams()
|
|
if str(param).startswith(scope_name)
|
|
]
|
|
assert workspace.RunOperatorOnce(core.CreateOperator(
|
|
'CreateDB',
|
|
[], [db_reader],
|
|
db=self.get_model_file(model),
|
|
db_type=self.get_db_type())
|
|
), 'Failed to create db {}'.format(self.get_model_file(model))
|
|
assert workspace.RunOperatorOnce(core.CreateOperator(
|
|
'Load',
|
|
[db_reader],
|
|
params_for_current_model,
|
|
load_all=1,
|
|
add_prefix=scope_name + '/',
|
|
strip_prefix='gpu_0/',
|
|
))
|
|
logger.info('Model {} is loaded from a checkpoint {}'.format(
|
|
scope_name, self.get_model_file(model)))
|
|
|
|
|
|
class Seq2SeqModelCaffe2EnsembleDecoder(Seq2SeqModelCaffe2EnsembleDecoderBase):
|
|
|
|
def get_model_file(self, model):
|
|
return model['model_file']
|
|
|
|
def get_db_type(self):
|
|
return 'minidb'
|
|
|
|
def scope(self, scope_name, blob_name):
|
|
return (
|
|
scope_name + '/' + blob_name
|
|
if scope_name is not None
|
|
else blob_name
|
|
)
|
|
|
|
def _build_decoder(
|
|
self,
|
|
model,
|
|
step_model,
|
|
model_params,
|
|
scope,
|
|
previous_tokens,
|
|
timestep,
|
|
fake_seq_lengths,
|
|
):
|
|
attention_type = model_params['attention']
|
|
assert attention_type in ['none', 'regular']
|
|
use_attention = (attention_type != 'none')
|
|
|
|
with core.NameScope(scope):
|
|
encoder_embeddings = seq2seq_util.build_embeddings(
|
|
model=model,
|
|
vocab_size=self.source_vocab_size,
|
|
embedding_size=model_params['encoder_embedding_size'],
|
|
name='encoder_embeddings',
|
|
freeze_embeddings=False,
|
|
)
|
|
|
|
(
|
|
encoder_outputs,
|
|
weighted_encoder_outputs,
|
|
final_encoder_hidden_states,
|
|
final_encoder_cell_states,
|
|
encoder_units_per_layer,
|
|
) = seq2seq_util.build_embedding_encoder(
|
|
model=model,
|
|
encoder_params=model_params['encoder_type'],
|
|
num_decoder_layers=len(model_params['decoder_layer_configs']),
|
|
inputs=self.encoder_inputs,
|
|
input_lengths=self.encoder_lengths,
|
|
vocab_size=self.source_vocab_size,
|
|
embeddings=encoder_embeddings,
|
|
embedding_size=model_params['encoder_embedding_size'],
|
|
use_attention=use_attention,
|
|
num_gpus=0,
|
|
forward_only=True,
|
|
scope=scope,
|
|
)
|
|
with core.NameScope(scope):
|
|
if use_attention:
|
|
# [max_source_length, beam_size, encoder_output_dim]
|
|
encoder_outputs = model.net.Tile(
|
|
encoder_outputs,
|
|
'encoder_outputs_tiled',
|
|
tiles=self.beam_size,
|
|
axis=1,
|
|
)
|
|
|
|
if weighted_encoder_outputs is not None:
|
|
weighted_encoder_outputs = model.net.Tile(
|
|
weighted_encoder_outputs,
|
|
'weighted_encoder_outputs_tiled',
|
|
tiles=self.beam_size,
|
|
axis=1,
|
|
)
|
|
|
|
decoder_embeddings = seq2seq_util.build_embeddings(
|
|
model=model,
|
|
vocab_size=self.target_vocab_size,
|
|
embedding_size=model_params['decoder_embedding_size'],
|
|
name='decoder_embeddings',
|
|
freeze_embeddings=False,
|
|
)
|
|
embedded_tokens_t_prev = step_model.net.Gather(
|
|
[decoder_embeddings, previous_tokens],
|
|
'embedded_tokens_t_prev',
|
|
)
|
|
|
|
decoder_cells = []
|
|
decoder_units_per_layer = []
|
|
for i, layer_config in enumerate(model_params['decoder_layer_configs']):
|
|
num_units = layer_config['num_units']
|
|
decoder_units_per_layer.append(num_units)
|
|
if i == 0:
|
|
input_size = model_params['decoder_embedding_size']
|
|
else:
|
|
input_size = (
|
|
model_params['decoder_layer_configs'][i - 1]['num_units']
|
|
)
|
|
|
|
cell = rnn_cell.LSTMCell(
|
|
forward_only=True,
|
|
input_size=input_size,
|
|
hidden_size=num_units,
|
|
forget_bias=0.0,
|
|
memory_optimization=False,
|
|
)
|
|
decoder_cells.append(cell)
|
|
|
|
with core.NameScope(scope):
|
|
if final_encoder_hidden_states is not None:
|
|
for i in range(len(final_encoder_hidden_states)):
|
|
if final_encoder_hidden_states[i] is not None:
|
|
final_encoder_hidden_states[i] = model.net.Tile(
|
|
final_encoder_hidden_states[i],
|
|
'final_encoder_hidden_tiled_{}'.format(i),
|
|
tiles=self.beam_size,
|
|
axis=1,
|
|
)
|
|
if final_encoder_cell_states is not None:
|
|
for i in range(len(final_encoder_cell_states)):
|
|
if final_encoder_cell_states[i] is not None:
|
|
final_encoder_cell_states[i] = model.net.Tile(
|
|
final_encoder_cell_states[i],
|
|
'final_encoder_cell_tiled_{}'.format(i),
|
|
tiles=self.beam_size,
|
|
axis=1,
|
|
)
|
|
initial_states = \
|
|
seq2seq_util.build_initial_rnn_decoder_states(
|
|
model=model,
|
|
encoder_units_per_layer=encoder_units_per_layer,
|
|
decoder_units_per_layer=decoder_units_per_layer,
|
|
final_encoder_hidden_states=final_encoder_hidden_states,
|
|
final_encoder_cell_states=final_encoder_cell_states,
|
|
use_attention=use_attention,
|
|
)
|
|
|
|
attention_decoder = seq2seq_util.LSTMWithAttentionDecoder(
|
|
encoder_outputs=encoder_outputs,
|
|
encoder_output_dim=encoder_units_per_layer[-1],
|
|
encoder_lengths=None,
|
|
vocab_size=self.target_vocab_size,
|
|
attention_type=attention_type,
|
|
embedding_size=model_params['decoder_embedding_size'],
|
|
decoder_num_units=decoder_units_per_layer[-1],
|
|
decoder_cells=decoder_cells,
|
|
weighted_encoder_outputs=weighted_encoder_outputs,
|
|
name=scope,
|
|
)
|
|
states_prev = step_model.net.AddExternalInputs(*[
|
|
'{}/{}_prev'.format(scope, s)
|
|
for s in attention_decoder.get_state_names()
|
|
])
|
|
decoder_outputs, states = attention_decoder.apply(
|
|
model=step_model,
|
|
input_t=embedded_tokens_t_prev,
|
|
seq_lengths=fake_seq_lengths,
|
|
states=states_prev,
|
|
timestep=timestep,
|
|
)
|
|
|
|
state_configs = [
|
|
BeamSearchForwardOnly.StateConfig(
|
|
initial_value=initial_state,
|
|
state_prev_link=BeamSearchForwardOnly.LinkConfig(
|
|
blob=state_prev,
|
|
offset=0,
|
|
window=1,
|
|
),
|
|
state_link=BeamSearchForwardOnly.LinkConfig(
|
|
blob=state,
|
|
offset=1,
|
|
window=1,
|
|
),
|
|
)
|
|
for initial_state, state_prev, state in zip(
|
|
initial_states,
|
|
states_prev,
|
|
states,
|
|
)
|
|
]
|
|
|
|
with core.NameScope(scope):
|
|
decoder_outputs_flattened, _ = step_model.net.Reshape(
|
|
[decoder_outputs],
|
|
[
|
|
'decoder_outputs_flattened',
|
|
'decoder_outputs_and_contexts_combination_old_shape',
|
|
],
|
|
shape=[-1, attention_decoder.get_output_dim()],
|
|
)
|
|
output_logits = seq2seq_util.output_projection(
|
|
model=step_model,
|
|
decoder_outputs=decoder_outputs_flattened,
|
|
decoder_output_size=attention_decoder.get_output_dim(),
|
|
target_vocab_size=self.target_vocab_size,
|
|
decoder_softmax_size=model_params['decoder_softmax_size'],
|
|
)
|
|
# [1, beam_size, target_vocab_size]
|
|
output_probs = step_model.net.Softmax(
|
|
output_logits,
|
|
'output_probs',
|
|
)
|
|
output_log_probs = step_model.net.Log(
|
|
output_probs,
|
|
'output_log_probs',
|
|
)
|
|
if use_attention:
|
|
attention_weights = attention_decoder.get_attention_weights()
|
|
else:
|
|
attention_weights = step_model.net.ConstantFill(
|
|
[self.encoder_inputs],
|
|
'zero_attention_weights_tmp_1',
|
|
value=0.0,
|
|
)
|
|
attention_weights = step_model.net.Transpose(
|
|
attention_weights,
|
|
'zero_attention_weights_tmp_2',
|
|
)
|
|
attention_weights = step_model.net.Tile(
|
|
attention_weights,
|
|
'zero_attention_weights_tmp',
|
|
tiles=self.beam_size,
|
|
axis=0,
|
|
)
|
|
|
|
return (
|
|
state_configs,
|
|
output_log_probs,
|
|
attention_weights,
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
translate_params,
|
|
):
|
|
self.models = translate_params['ensemble_models']
|
|
decoding_params = translate_params['decoding_params']
|
|
self.beam_size = decoding_params['beam_size']
|
|
|
|
assert len(self.models) > 0
|
|
source_vocab = self.models[0]['source_vocab']
|
|
target_vocab = self.models[0]['target_vocab']
|
|
for model in self.models:
|
|
assert model['source_vocab'] == source_vocab
|
|
assert model['target_vocab'] == target_vocab
|
|
|
|
self.source_vocab_size = len(source_vocab)
|
|
self.target_vocab_size = len(target_vocab)
|
|
|
|
self.decoder_scope_names = [
|
|
'model{}'.format(i) for i in range(len(self.models))
|
|
]
|
|
|
|
self.model = Seq2SeqModelHelper(init_params=True)
|
|
|
|
self.encoder_inputs = self.model.net.AddExternalInput('encoder_inputs')
|
|
self.encoder_lengths = self.model.net.AddExternalInput(
|
|
'encoder_lengths'
|
|
)
|
|
self.max_output_seq_len = self.model.net.AddExternalInput(
|
|
'max_output_seq_len'
|
|
)
|
|
|
|
fake_seq_lengths = self.model.param_init_net.ConstantFill(
|
|
[],
|
|
'fake_seq_lengths',
|
|
shape=[self.beam_size],
|
|
value=100000,
|
|
dtype=core.DataType.INT32,
|
|
)
|
|
|
|
beam_decoder = BeamSearchForwardOnly(
|
|
beam_size=self.beam_size,
|
|
model=self.model,
|
|
go_token_id=seq2seq_util.GO_ID,
|
|
eos_token_id=seq2seq_util.EOS_ID,
|
|
)
|
|
step_model = beam_decoder.get_step_model()
|
|
|
|
state_configs = []
|
|
output_log_probs = []
|
|
attention_weights = []
|
|
for model, scope_name in zip(
|
|
self.models,
|
|
self.decoder_scope_names,
|
|
):
|
|
(
|
|
state_configs_per_decoder,
|
|
output_log_probs_per_decoder,
|
|
attention_weights_per_decoder,
|
|
) = self._build_decoder(
|
|
model=self.model,
|
|
step_model=step_model,
|
|
model_params=model['model_params'],
|
|
scope=scope_name,
|
|
previous_tokens=beam_decoder.get_previous_tokens(),
|
|
timestep=beam_decoder.get_timestep(),
|
|
fake_seq_lengths=fake_seq_lengths,
|
|
)
|
|
state_configs.extend(state_configs_per_decoder)
|
|
output_log_probs.append(output_log_probs_per_decoder)
|
|
if attention_weights_per_decoder is not None:
|
|
attention_weights.append(attention_weights_per_decoder)
|
|
|
|
assert len(attention_weights) > 0
|
|
num_decoders_with_attention_blob = (
|
|
self.model.param_init_net.ConstantFill(
|
|
[],
|
|
'num_decoders_with_attention_blob',
|
|
value=1 / float(len(attention_weights)),
|
|
shape=[1],
|
|
)
|
|
)
|
|
# [beam_size, encoder_length, 1]
|
|
attention_weights_average = _weighted_sum(
|
|
model=step_model,
|
|
values=attention_weights,
|
|
weight=num_decoders_with_attention_blob,
|
|
output_name='attention_weights_average',
|
|
)
|
|
|
|
num_decoders_blob = self.model.param_init_net.ConstantFill(
|
|
[],
|
|
'num_decoders_blob',
|
|
value=1 / float(len(output_log_probs)),
|
|
shape=[1],
|
|
)
|
|
# [beam_size, target_vocab_size]
|
|
output_log_probs_average = _weighted_sum(
|
|
model=step_model,
|
|
values=output_log_probs,
|
|
weight=num_decoders_blob,
|
|
output_name='output_log_probs_average',
|
|
)
|
|
word_rewards = self.model.param_init_net.ConstantFill(
|
|
[],
|
|
'word_rewards',
|
|
shape=[self.target_vocab_size],
|
|
value=0.0,
|
|
dtype=core.DataType.FLOAT,
|
|
)
|
|
(
|
|
self.output_token_beam_list,
|
|
self.output_prev_index_beam_list,
|
|
self.output_score_beam_list,
|
|
self.output_attention_weights_beam_list,
|
|
) = beam_decoder.apply(
|
|
inputs=self.encoder_inputs,
|
|
length=self.max_output_seq_len,
|
|
log_probs=output_log_probs_average,
|
|
attentions=attention_weights_average,
|
|
state_configs=state_configs,
|
|
data_dependencies=[],
|
|
word_rewards=word_rewards,
|
|
)
|
|
|
|
workspace.RunNetOnce(self.model.param_init_net)
|
|
workspace.FeedBlob(
|
|
'word_rewards',
|
|
self.build_word_rewards(
|
|
vocab_size=self.target_vocab_size,
|
|
word_reward=translate_params['decoding_params']['word_reward'],
|
|
unk_reward=translate_params['decoding_params']['unk_reward'],
|
|
)
|
|
)
|
|
|
|
workspace.CreateNet(
|
|
self.model.net,
|
|
input_blobs=[
|
|
str(self.encoder_inputs),
|
|
str(self.encoder_lengths),
|
|
str(self.max_output_seq_len),
|
|
],
|
|
)
|
|
|
|
logger.info('Params created: ')
|
|
for param in self.model.params:
|
|
logger.info(param)
|
|
|
|
def decode(self, numberized_input, max_output_seq_len):
|
|
workspace.FeedBlob(
|
|
self.encoder_inputs,
|
|
np.array([
|
|
[token_id] for token_id in reversed(numberized_input)
|
|
]).astype(dtype=np.int32),
|
|
)
|
|
workspace.FeedBlob(
|
|
self.encoder_lengths,
|
|
np.array([len(numberized_input)]).astype(dtype=np.int32),
|
|
)
|
|
workspace.FeedBlob(
|
|
self.max_output_seq_len,
|
|
np.array([max_output_seq_len]).astype(dtype=np.int64),
|
|
)
|
|
|
|
workspace.RunNet(self.model.net)
|
|
|
|
num_steps = max_output_seq_len
|
|
score_beam_list = workspace.FetchBlob(self.output_score_beam_list)
|
|
token_beam_list = (
|
|
workspace.FetchBlob(self.output_token_beam_list)
|
|
)
|
|
prev_index_beam_list = (
|
|
workspace.FetchBlob(self.output_prev_index_beam_list)
|
|
)
|
|
|
|
attention_weights_beam_list = (
|
|
workspace.FetchBlob(self.output_attention_weights_beam_list)
|
|
)
|
|
best_indices = (num_steps, 0)
|
|
for i in range(num_steps + 1):
|
|
for hyp_index in range(self.beam_size):
|
|
if (
|
|
(
|
|
token_beam_list[i][hyp_index][0] ==
|
|
seq2seq_util.EOS_ID or
|
|
i == num_steps
|
|
) and
|
|
(
|
|
score_beam_list[i][hyp_index][0] >
|
|
score_beam_list[best_indices[0]][best_indices[1]][0]
|
|
)
|
|
):
|
|
best_indices = (i, hyp_index)
|
|
|
|
i, hyp_index = best_indices
|
|
output = []
|
|
attention_weights_per_token = []
|
|
best_score = -score_beam_list[i][hyp_index][0]
|
|
while i > 0:
|
|
output.append(token_beam_list[i][hyp_index][0])
|
|
attention_weights_per_token.append(
|
|
attention_weights_beam_list[i][hyp_index]
|
|
)
|
|
hyp_index = prev_index_beam_list[i][hyp_index][0]
|
|
i -= 1
|
|
|
|
attention_weights_per_token = reversed(attention_weights_per_token)
|
|
# encoder_inputs are reversed, see get_batch func
|
|
attention_weights_per_token = [
|
|
list(reversed(attention_weights))[:len(numberized_input)]
|
|
for attention_weights in attention_weights_per_token
|
|
]
|
|
output = list(reversed(output))
|
|
return output, attention_weights_per_token, best_score
|
|
|
|
|
|
def run_seq2seq_beam_decoder(args, model_params, decoding_params):
|
|
source_vocab = seq2seq_util.gen_vocab(
|
|
args.source_corpus,
|
|
args.unk_threshold,
|
|
)
|
|
logger.info('Source vocab size {}'.format(len(source_vocab)))
|
|
target_vocab = seq2seq_util.gen_vocab(
|
|
args.target_corpus,
|
|
args.unk_threshold,
|
|
)
|
|
inversed_target_vocab = {v: k for (k, v) in viewitems(target_vocab)}
|
|
logger.info('Target vocab size {}'.format(len(target_vocab)))
|
|
|
|
decoder = Seq2SeqModelCaffe2EnsembleDecoder(
|
|
translate_params=dict(
|
|
ensemble_models=[dict(
|
|
source_vocab=source_vocab,
|
|
target_vocab=target_vocab,
|
|
model_params=model_params,
|
|
model_file=args.checkpoint,
|
|
)],
|
|
decoding_params=decoding_params,
|
|
),
|
|
)
|
|
decoder.load_models()
|
|
|
|
for line in sys.stdin:
|
|
numerized_source_sentence = seq2seq_util.get_numberized_sentence(
|
|
line,
|
|
source_vocab,
|
|
)
|
|
translation, alignment, _ = decoder.decode(
|
|
numerized_source_sentence,
|
|
2 * len(numerized_source_sentence) + 5,
|
|
)
|
|
print(' '.join([inversed_target_vocab[tid] for tid in translation]))
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description='Caffe2: Seq2Seq Translation',
|
|
)
|
|
parser.add_argument('--source-corpus', type=str, default=None,
|
|
help='Path to source corpus in a text file format. Each '
|
|
'line in the file should contain a single sentence',
|
|
required=True)
|
|
parser.add_argument('--target-corpus', type=str, default=None,
|
|
help='Path to target corpus in a text file format',
|
|
required=True)
|
|
parser.add_argument('--unk-threshold', type=int, default=50,
|
|
help='Threshold frequency under which token becomes '
|
|
'labeled unknown token')
|
|
|
|
parser.add_argument('--use-bidirectional-encoder', action='store_true',
|
|
help='Set flag to use bidirectional recurrent network '
|
|
'in encoder')
|
|
parser.add_argument('--use-attention', action='store_true',
|
|
help='Set flag to use seq2seq with attention model')
|
|
parser.add_argument('--encoder-cell-num-units', type=int, default=512,
|
|
help='Number of cell units per encoder layer')
|
|
parser.add_argument('--encoder-num-layers', type=int, default=2,
|
|
help='Number encoder layers')
|
|
parser.add_argument('--decoder-cell-num-units', type=int, default=512,
|
|
help='Number of cell units in the decoder layer')
|
|
parser.add_argument('--decoder-num-layers', type=int, default=2,
|
|
help='Number decoder layers')
|
|
parser.add_argument('--encoder-embedding-size', type=int, default=256,
|
|
help='Size of embedding in the encoder layer')
|
|
parser.add_argument('--decoder-embedding-size', type=int, default=512,
|
|
help='Size of embedding in the decoder layer')
|
|
parser.add_argument('--decoder-softmax-size', type=int, default=None,
|
|
help='Size of softmax layer in the decoder')
|
|
|
|
parser.add_argument('--beam-size', type=int, default=6,
|
|
help='Size of beam for the decoder')
|
|
parser.add_argument('--word-reward', type=float, default=0.0,
|
|
help='Reward per each word generated.')
|
|
parser.add_argument('--unk-reward', type=float, default=0.0,
|
|
help='Reward per each UNK token generated. '
|
|
'Typically should be negative.')
|
|
|
|
parser.add_argument('--checkpoint', type=str, default=None,
|
|
help='Path to checkpoint', required=True)
|
|
|
|
args = parser.parse_args()
|
|
|
|
encoder_layer_configs = [
|
|
dict(
|
|
num_units=args.encoder_cell_num_units,
|
|
),
|
|
] * args.encoder_num_layers
|
|
|
|
if args.use_bidirectional_encoder:
|
|
assert args.encoder_cell_num_units % 2 == 0
|
|
encoder_layer_configs[0]['num_units'] /= 2
|
|
|
|
decoder_layer_configs = [
|
|
dict(
|
|
num_units=args.decoder_cell_num_units,
|
|
),
|
|
] * args.decoder_num_layers
|
|
|
|
run_seq2seq_beam_decoder(
|
|
args,
|
|
model_params=dict(
|
|
attention=('regular' if args.use_attention else 'none'),
|
|
decoder_layer_configs=decoder_layer_configs,
|
|
encoder_type=dict(
|
|
encoder_layer_configs=encoder_layer_configs,
|
|
use_bidirectional_encoder=args.use_bidirectional_encoder,
|
|
),
|
|
encoder_embedding_size=args.encoder_embedding_size,
|
|
decoder_embedding_size=args.decoder_embedding_size,
|
|
decoder_softmax_size=args.decoder_softmax_size,
|
|
),
|
|
decoding_params=dict(
|
|
beam_size=args.beam_size,
|
|
word_reward=args.word_reward,
|
|
unk_reward=args.unk_reward,
|
|
),
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|