mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Use shared threadpool in LSTM (#1167)
This commit is contained in:
parent
7cd2d9f3c4
commit
c6abb17b8d
2 changed files with 11 additions and 10 deletions
|
|
@ -14,6 +14,7 @@
|
|||
#include "core/common/common.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(pop)
|
||||
|
|
@ -352,6 +353,11 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
|
|||
}
|
||||
}
|
||||
|
||||
auto& ctx_internal = static_cast<OpKernelContextInternal&>(context);
|
||||
// the session always has a threadpool so dereferencing is safe
|
||||
// TODO: Fix having to use a const_cast to run tasks using the threadpool
|
||||
auto& thread_pool = const_cast<concurrency::ThreadPool&>(*ctx_internal.GetOperatorThreadPool());
|
||||
|
||||
AllocatorPtr alloc;
|
||||
status = context.GetTempSpaceAllocator(&alloc);
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
|
|
@ -456,7 +462,7 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
|
|||
activation_funcs_.Entries()[0],
|
||||
activation_funcs_.Entries()[1],
|
||||
activation_funcs_.Entries()[2],
|
||||
clip_, ttp_);
|
||||
clip_, thread_pool);
|
||||
|
||||
bw = std::make_unique<detail::UniDirectionalLstm<T>>(alloc, logger,
|
||||
seq_length, batch_size, input_size,
|
||||
|
|
@ -465,7 +471,7 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
|
|||
activation_funcs_.Entries()[3],
|
||||
activation_funcs_.Entries()[4],
|
||||
activation_funcs_.Entries()[5],
|
||||
clip_, ttp_);
|
||||
clip_, thread_pool);
|
||||
|
||||
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
|
||||
bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2, last_cell_2);
|
||||
|
|
@ -477,7 +483,7 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
|
|||
activation_funcs_.Entries()[0],
|
||||
activation_funcs_.Entries()[1],
|
||||
activation_funcs_.Entries()[2],
|
||||
clip_, ttp_);
|
||||
clip_, thread_pool);
|
||||
|
||||
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
|
||||
}
|
||||
|
|
@ -1087,6 +1093,7 @@ void UniDirectionalLstm<T>::GateComputations(span_T_iter& out, span_T_iter& out_
|
|||
// DumpMatrix("H" + row_str, pH, 1, hidden_size_);
|
||||
}
|
||||
|
||||
#ifdef DUMP_MATRIXES
|
||||
auto num_rows = local_fused_hidden_rows - row;
|
||||
std::string rows_str = " rows[" + std::to_string(row) + ".." + std::to_string(num_rows) + "]";
|
||||
|
||||
|
|
@ -1096,6 +1103,7 @@ void UniDirectionalLstm<T>::GateComputations(span_T_iter& out, span_T_iter& out_
|
|||
DumpMatrix("c" + rows_str, &*out, num_rows, hidden_size_, 3 * hidden_size_, hidden_size_x4);
|
||||
DumpMatrix("C" + rows_str, &*C_prev, num_rows, hidden_size_); // Ct overwrites the input C_prev value
|
||||
DumpMatrix("H" + rows_str, &*batched_output, num_rows, hidden_size_);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -77,13 +77,6 @@ class DeepCpuLstmOp final : public OpKernel {
|
|||
bool input_forget_ = false;
|
||||
|
||||
rnn::detail::ActivationFuncs activation_funcs_;
|
||||
|
||||
// Threadpool for operator. If concurrent Compute calls are possible, it will be shared
|
||||
// across them. mutable due to this.
|
||||
// The alternative would be to create a threadpool in each call to Compute but that would incur thread creation
|
||||
// cost on every call.
|
||||
mutable onnxruntime::concurrency::ThreadPool ttp_{"DEEPCPU_LSTM",
|
||||
static_cast<int>(std::thread::hardware_concurrency())};
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue