mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
cleanup more useless unique_ptr (#1427)
This commit is contained in:
parent
1ff957f96e
commit
02ded802ab
1 changed files with 75 additions and 30 deletions
|
|
@ -228,77 +228,122 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
|
|||
gsl::span<T> last_cell_2 = last_cell.subspan(last_cell_size_per_direction,
|
||||
last_cell_size_per_direction);
|
||||
|
||||
auto fam = std::make_unique<BahdanauAttention<T>>(
|
||||
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false);
|
||||
fam->SetWeights(
|
||||
BahdanauAttention<T> fam(
|
||||
alloc,
|
||||
logger,
|
||||
batch_size,
|
||||
max_memory_step,
|
||||
memory_depth,
|
||||
query_depth,
|
||||
am_attn_size,
|
||||
false);
|
||||
|
||||
fam.SetWeights(
|
||||
FirstHalfSpan(am_v_weights.DataAsSpan<T>()),
|
||||
FirstHalfSpan(am_query_layer_weights.DataAsSpan<T>()),
|
||||
FirstHalfSpan(am_memory_layer_weights.DataAsSpan<T>()));
|
||||
fam->PrepareMemory(attn_memory.DataAsSpan<T>(), memory_seq_lens_span);
|
||||
fam.PrepareMemory(attn_memory.DataAsSpan<T>(), memory_seq_lens_span);
|
||||
|
||||
auto faw = std::make_unique<AttentionWrapper<T>>(
|
||||
alloc, logger, batch_size, memory_depth, attn_layer_depth, hidden_size_, has_attention_layer, *fam);
|
||||
faw->SetWeights(FirstHalfSpan(attn_layer_weights_span));
|
||||
AttentionWrapper<T> faw(
|
||||
alloc,
|
||||
logger,
|
||||
batch_size,
|
||||
memory_depth,
|
||||
attn_layer_depth,
|
||||
hidden_size_,
|
||||
has_attention_layer,
|
||||
fam);
|
||||
faw.SetWeights(FirstHalfSpan(attn_layer_weights_span));
|
||||
|
||||
auto fw = std::make_unique<UniDirectionalAttnLstm<T>>(
|
||||
UniDirectionalAttnLstm<T> fw(
|
||||
alloc, logger,
|
||||
seq_length, batch_size, input_size,
|
||||
hidden_size_, Direction::kForward, input_forget_, *faw,
|
||||
hidden_size_, Direction::kForward, input_forget_, faw,
|
||||
bias_1, peephole_weights_1, initial_hidden_1, initial_cell_1,
|
||||
activation_funcs_.Entries()[0],
|
||||
activation_funcs_.Entries()[1],
|
||||
activation_funcs_.Entries()[2],
|
||||
clip_, ttp_);
|
||||
|
||||
auto bam = std::make_unique<BahdanauAttention<T>>(
|
||||
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false);
|
||||
bam->SetWeights(
|
||||
BahdanauAttention<T> bam(
|
||||
alloc,
|
||||
logger,
|
||||
batch_size,
|
||||
max_memory_step,
|
||||
memory_depth,
|
||||
query_depth,
|
||||
am_attn_size,
|
||||
false);
|
||||
bam.SetWeights(
|
||||
SecondHalfSpan(am_v_weights.DataAsSpan<T>()),
|
||||
SecondHalfSpan(am_query_layer_weights.DataAsSpan<T>()),
|
||||
SecondHalfSpan(am_memory_layer_weights.DataAsSpan<T>()));
|
||||
bam->PrepareMemory(attn_memory.DataAsSpan<T>(), memory_seq_lens_span);
|
||||
bam.PrepareMemory(attn_memory.DataAsSpan<T>(), memory_seq_lens_span);
|
||||
|
||||
auto baw = std::make_unique<AttentionWrapper<T>>(
|
||||
alloc, logger, batch_size, memory_depth, attn_layer_depth, hidden_size_, has_attention_layer, *bam);
|
||||
baw->SetWeights(SecondHalfSpan(attn_layer_weights_span));
|
||||
AttentionWrapper<T> baw(
|
||||
alloc,
|
||||
logger,
|
||||
batch_size,
|
||||
memory_depth,
|
||||
attn_layer_depth,
|
||||
hidden_size_,
|
||||
has_attention_layer,
|
||||
bam);
|
||||
baw.SetWeights(SecondHalfSpan(attn_layer_weights_span));
|
||||
|
||||
auto bw = std::make_unique<UniDirectionalAttnLstm<T>>(
|
||||
UniDirectionalAttnLstm<T> bw(
|
||||
alloc, logger,
|
||||
seq_length, batch_size, input_size,
|
||||
hidden_size_, Direction::kReverse, input_forget_, *baw,
|
||||
hidden_size_, Direction::kReverse, input_forget_, baw,
|
||||
bias_2, peephole_weights_2, initial_hidden_2, initial_cell_2,
|
||||
activation_funcs_.Entries()[3],
|
||||
activation_funcs_.Entries()[4],
|
||||
activation_funcs_.Entries()[5],
|
||||
clip_, ttp_);
|
||||
|
||||
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
|
||||
bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2, last_cell_2);
|
||||
fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
|
||||
bw.Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2, last_cell_2);
|
||||
|
||||
} else {
|
||||
auto fam = std::make_unique<BahdanauAttention<T>>(
|
||||
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false);
|
||||
fam->SetWeights(
|
||||
BahdanauAttention<T> fam(
|
||||
alloc,
|
||||
logger,
|
||||
batch_size,
|
||||
max_memory_step,
|
||||
memory_depth,
|
||||
query_depth,
|
||||
am_attn_size,
|
||||
false);
|
||||
|
||||
fam.SetWeights(
|
||||
am_v_weights.DataAsSpan<T>(),
|
||||
am_query_layer_weights.DataAsSpan<T>(),
|
||||
am_memory_layer_weights.DataAsSpan<T>());
|
||||
fam->PrepareMemory(attn_memory.DataAsSpan<T>(), memory_seq_lens_span);
|
||||
fam.PrepareMemory(attn_memory.DataAsSpan<T>(), memory_seq_lens_span);
|
||||
|
||||
auto faw = std::make_unique<AttentionWrapper<T>>(
|
||||
alloc, logger, batch_size, memory_depth, attn_layer_depth, hidden_size_, has_attention_layer, *fam);
|
||||
faw->SetWeights(attn_layer_weights_span);
|
||||
AttentionWrapper<T> faw(
|
||||
alloc,
|
||||
logger,
|
||||
batch_size,
|
||||
memory_depth,
|
||||
attn_layer_depth,
|
||||
hidden_size_,
|
||||
has_attention_layer,
|
||||
fam);
|
||||
|
||||
auto fw = std::make_unique<UniDirectionalAttnLstm<T>>(
|
||||
faw.SetWeights(attn_layer_weights_span);
|
||||
|
||||
UniDirectionalAttnLstm<T> fw(
|
||||
alloc, logger,
|
||||
seq_length, batch_size, input_size,
|
||||
hidden_size_, direction_, input_forget_, *faw,
|
||||
hidden_size_, direction_, input_forget_, faw,
|
||||
bias_1, peephole_weights_1, initial_hidden_1, initial_cell_1,
|
||||
activation_funcs_.Entries()[0],
|
||||
activation_funcs_.Entries()[1],
|
||||
activation_funcs_.Entries()[2],
|
||||
clip_, ttp_);
|
||||
|
||||
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
|
||||
fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
|
||||
}
|
||||
|
||||
if (!output.empty()) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue