cleanup more useless unique_ptr (#1427)

This commit is contained in:
Yufeng Li 2019-07-18 09:50:48 -07:00 committed by GitHub
parent 1ff957f96e
commit 02ded802ab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -228,77 +228,122 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
gsl::span<T> last_cell_2 = last_cell.subspan(last_cell_size_per_direction,
last_cell_size_per_direction);
auto fam = std::make_unique<BahdanauAttention<T>>(
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false);
fam->SetWeights(
BahdanauAttention<T> fam(
alloc,
logger,
batch_size,
max_memory_step,
memory_depth,
query_depth,
am_attn_size,
false);
fam.SetWeights(
FirstHalfSpan(am_v_weights.DataAsSpan<T>()),
FirstHalfSpan(am_query_layer_weights.DataAsSpan<T>()),
FirstHalfSpan(am_memory_layer_weights.DataAsSpan<T>()));
fam->PrepareMemory(attn_memory.DataAsSpan<T>(), memory_seq_lens_span);
fam.PrepareMemory(attn_memory.DataAsSpan<T>(), memory_seq_lens_span);
auto faw = std::make_unique<AttentionWrapper<T>>(
alloc, logger, batch_size, memory_depth, attn_layer_depth, hidden_size_, has_attention_layer, *fam);
faw->SetWeights(FirstHalfSpan(attn_layer_weights_span));
AttentionWrapper<T> faw(
alloc,
logger,
batch_size,
memory_depth,
attn_layer_depth,
hidden_size_,
has_attention_layer,
fam);
faw.SetWeights(FirstHalfSpan(attn_layer_weights_span));
auto fw = std::make_unique<UniDirectionalAttnLstm<T>>(
UniDirectionalAttnLstm<T> fw(
alloc, logger,
seq_length, batch_size, input_size,
hidden_size_, Direction::kForward, input_forget_, *faw,
hidden_size_, Direction::kForward, input_forget_, faw,
bias_1, peephole_weights_1, initial_hidden_1, initial_cell_1,
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
activation_funcs_.Entries()[2],
clip_, ttp_);
auto bam = std::make_unique<BahdanauAttention<T>>(
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false);
bam->SetWeights(
BahdanauAttention<T> bam(
alloc,
logger,
batch_size,
max_memory_step,
memory_depth,
query_depth,
am_attn_size,
false);
bam.SetWeights(
SecondHalfSpan(am_v_weights.DataAsSpan<T>()),
SecondHalfSpan(am_query_layer_weights.DataAsSpan<T>()),
SecondHalfSpan(am_memory_layer_weights.DataAsSpan<T>()));
bam->PrepareMemory(attn_memory.DataAsSpan<T>(), memory_seq_lens_span);
bam.PrepareMemory(attn_memory.DataAsSpan<T>(), memory_seq_lens_span);
auto baw = std::make_unique<AttentionWrapper<T>>(
alloc, logger, batch_size, memory_depth, attn_layer_depth, hidden_size_, has_attention_layer, *bam);
baw->SetWeights(SecondHalfSpan(attn_layer_weights_span));
AttentionWrapper<T> baw(
alloc,
logger,
batch_size,
memory_depth,
attn_layer_depth,
hidden_size_,
has_attention_layer,
bam);
baw.SetWeights(SecondHalfSpan(attn_layer_weights_span));
auto bw = std::make_unique<UniDirectionalAttnLstm<T>>(
UniDirectionalAttnLstm<T> bw(
alloc, logger,
seq_length, batch_size, input_size,
hidden_size_, Direction::kReverse, input_forget_, *baw,
hidden_size_, Direction::kReverse, input_forget_, baw,
bias_2, peephole_weights_2, initial_hidden_2, initial_cell_2,
activation_funcs_.Entries()[3],
activation_funcs_.Entries()[4],
activation_funcs_.Entries()[5],
clip_, ttp_);
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2, last_cell_2);
fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
bw.Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2, last_cell_2);
} else {
auto fam = std::make_unique<BahdanauAttention<T>>(
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false);
fam->SetWeights(
BahdanauAttention<T> fam(
alloc,
logger,
batch_size,
max_memory_step,
memory_depth,
query_depth,
am_attn_size,
false);
fam.SetWeights(
am_v_weights.DataAsSpan<T>(),
am_query_layer_weights.DataAsSpan<T>(),
am_memory_layer_weights.DataAsSpan<T>());
fam->PrepareMemory(attn_memory.DataAsSpan<T>(), memory_seq_lens_span);
fam.PrepareMemory(attn_memory.DataAsSpan<T>(), memory_seq_lens_span);
auto faw = std::make_unique<AttentionWrapper<T>>(
alloc, logger, batch_size, memory_depth, attn_layer_depth, hidden_size_, has_attention_layer, *fam);
faw->SetWeights(attn_layer_weights_span);
AttentionWrapper<T> faw(
alloc,
logger,
batch_size,
memory_depth,
attn_layer_depth,
hidden_size_,
has_attention_layer,
fam);
auto fw = std::make_unique<UniDirectionalAttnLstm<T>>(
faw.SetWeights(attn_layer_weights_span);
UniDirectionalAttnLstm<T> fw(
alloc, logger,
seq_length, batch_size, input_size,
hidden_size_, direction_, input_forget_, *faw,
hidden_size_, direction_, input_forget_, faw,
bias_1, peephole_weights_1, initial_hidden_1, initial_cell_1,
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
activation_funcs_.Entries()[2],
clip_, ttp_);
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
}
if (!output.empty()) {