From 02ded802ab0abe3fe4e1b0ac4c59e1d3eeacaffa Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Thu, 18 Jul 2019 09:50:48 -0700 Subject: [PATCH] cleanup more useless unique_ptr (#1427) --- .../cpu/attnlstm/deep_cpu_attn_lstm.cc | 105 +++++++++++++----- 1 file changed, 75 insertions(+), 30 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc b/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc index 23eb0cc8e1..7f7102475c 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc +++ b/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc @@ -228,77 +228,122 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const { gsl::span last_cell_2 = last_cell.subspan(last_cell_size_per_direction, last_cell_size_per_direction); - auto fam = std::make_unique>( - alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false); - fam->SetWeights( + BahdanauAttention fam( + alloc, + logger, + batch_size, + max_memory_step, + memory_depth, + query_depth, + am_attn_size, + false); + + fam.SetWeights( FirstHalfSpan(am_v_weights.DataAsSpan()), FirstHalfSpan(am_query_layer_weights.DataAsSpan()), FirstHalfSpan(am_memory_layer_weights.DataAsSpan())); - fam->PrepareMemory(attn_memory.DataAsSpan(), memory_seq_lens_span); + fam.PrepareMemory(attn_memory.DataAsSpan(), memory_seq_lens_span); - auto faw = std::make_unique>( - alloc, logger, batch_size, memory_depth, attn_layer_depth, hidden_size_, has_attention_layer, *fam); - faw->SetWeights(FirstHalfSpan(attn_layer_weights_span)); + AttentionWrapper faw( + alloc, + logger, + batch_size, + memory_depth, + attn_layer_depth, + hidden_size_, + has_attention_layer, + fam); + faw.SetWeights(FirstHalfSpan(attn_layer_weights_span)); - auto fw = std::make_unique>( + UniDirectionalAttnLstm fw( alloc, logger, seq_length, batch_size, input_size, - hidden_size_, Direction::kForward, input_forget_, *faw, + hidden_size_, Direction::kForward, input_forget_, faw, bias_1, peephole_weights_1, initial_hidden_1, initial_cell_1, activation_funcs_.Entries()[0], activation_funcs_.Entries()[1], activation_funcs_.Entries()[2], clip_, ttp_); - auto bam = std::make_unique>( - alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false); - bam->SetWeights( + BahdanauAttention bam( + alloc, + logger, + batch_size, + max_memory_step, + memory_depth, + query_depth, + am_attn_size, + false); + bam.SetWeights( SecondHalfSpan(am_v_weights.DataAsSpan()), SecondHalfSpan(am_query_layer_weights.DataAsSpan()), SecondHalfSpan(am_memory_layer_weights.DataAsSpan())); - bam->PrepareMemory(attn_memory.DataAsSpan(), memory_seq_lens_span); + bam.PrepareMemory(attn_memory.DataAsSpan(), memory_seq_lens_span); - auto baw = std::make_unique>( - alloc, logger, batch_size, memory_depth, attn_layer_depth, hidden_size_, has_attention_layer, *bam); - baw->SetWeights(SecondHalfSpan(attn_layer_weights_span)); + AttentionWrapper baw( + alloc, + logger, + batch_size, + memory_depth, + attn_layer_depth, + hidden_size_, + has_attention_layer, + bam); + baw.SetWeights(SecondHalfSpan(attn_layer_weights_span)); - auto bw = std::make_unique>( + UniDirectionalAttnLstm bw( alloc, logger, seq_length, batch_size, input_size, - hidden_size_, Direction::kReverse, input_forget_, *baw, + hidden_size_, Direction::kReverse, input_forget_, baw, bias_2, peephole_weights_2, initial_hidden_2, initial_cell_2, activation_funcs_.Entries()[3], activation_funcs_.Entries()[4], activation_funcs_.Entries()[5], clip_, ttp_); - fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1); - bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2, last_cell_2); + fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1); + bw.Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2, last_cell_2); } else { - auto fam = std::make_unique>( - alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false); - fam->SetWeights( + BahdanauAttention fam( + alloc, + logger, + batch_size, + max_memory_step, + memory_depth, + query_depth, + am_attn_size, + false); + + fam.SetWeights( am_v_weights.DataAsSpan(), am_query_layer_weights.DataAsSpan(), am_memory_layer_weights.DataAsSpan()); - fam->PrepareMemory(attn_memory.DataAsSpan(), memory_seq_lens_span); + fam.PrepareMemory(attn_memory.DataAsSpan(), memory_seq_lens_span); - auto faw = std::make_unique>( - alloc, logger, batch_size, memory_depth, attn_layer_depth, hidden_size_, has_attention_layer, *fam); - faw->SetWeights(attn_layer_weights_span); + AttentionWrapper faw( + alloc, + logger, + batch_size, + memory_depth, + attn_layer_depth, + hidden_size_, + has_attention_layer, + fam); - auto fw = std::make_unique>( + faw.SetWeights(attn_layer_weights_span); + + UniDirectionalAttnLstm fw( alloc, logger, seq_length, batch_size, input_size, - hidden_size_, direction_, input_forget_, *faw, + hidden_size_, direction_, input_forget_, faw, bias_1, peephole_weights_1, initial_hidden_1, initial_cell_1, activation_funcs_.Entries()[0], activation_funcs_.Entries()[1], activation_funcs_.Entries()[2], clip_, ttp_); - fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1); + fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1); } if (!output.empty()) {