Use shared threadpool in LSTM (#1167)

This commit is contained in:
Scott McKay 2019-06-06 07:16:31 +10:00 committed by GitHub
parent 7cd2d9f3c4
commit c6abb17b8d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 10 deletions

View file

@ -14,6 +14,7 @@
#include "core/common/common.h"
#include "core/common/logging/logging.h"
#include "core/framework/allocator.h"
#include "core/framework/op_kernel_context_internal.h"
#ifdef _MSC_VER
#pragma warning(pop)
@ -352,6 +353,11 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
}
}
auto& ctx_internal = static_cast<OpKernelContextInternal&>(context);
// the session always has a threadpool so dereferencing is safe
// TODO: Fix having to use a const_cast to run tasks using the threadpool
auto& thread_pool = const_cast<concurrency::ThreadPool&>(*ctx_internal.GetOperatorThreadPool());
AllocatorPtr alloc;
status = context.GetTempSpaceAllocator(&alloc);
ORT_RETURN_IF_ERROR(status);
@ -456,7 +462,7 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
activation_funcs_.Entries()[2],
clip_, ttp_);
clip_, thread_pool);
bw = std::make_unique<detail::UniDirectionalLstm<T>>(alloc, logger,
seq_length, batch_size, input_size,
@ -465,7 +471,7 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
activation_funcs_.Entries()[3],
activation_funcs_.Entries()[4],
activation_funcs_.Entries()[5],
clip_, ttp_);
clip_, thread_pool);
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2, last_cell_2);
@ -477,7 +483,7 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
activation_funcs_.Entries()[2],
clip_, ttp_);
clip_, thread_pool);
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
}
@ -1087,6 +1093,7 @@ void UniDirectionalLstm<T>::GateComputations(span_T_iter& out, span_T_iter& out_
// DumpMatrix("H" + row_str, pH, 1, hidden_size_);
}
#ifdef DUMP_MATRIXES
auto num_rows = local_fused_hidden_rows - row;
std::string rows_str = " rows[" + std::to_string(row) + ".." + std::to_string(num_rows) + "]";
@ -1096,6 +1103,7 @@ void UniDirectionalLstm<T>::GateComputations(span_T_iter& out, span_T_iter& out_
DumpMatrix("c" + rows_str, &*out, num_rows, hidden_size_, 3 * hidden_size_, hidden_size_x4);
DumpMatrix("C" + rows_str, &*C_prev, num_rows, hidden_size_); // Ct overwrites the input C_prev value
DumpMatrix("H" + rows_str, &*batched_output, num_rows, hidden_size_);
#endif
}
template <typename T>

View file

@ -77,13 +77,6 @@ class DeepCpuLstmOp final : public OpKernel {
bool input_forget_ = false;
rnn::detail::ActivationFuncs activation_funcs_;
// Threadpool for operator. If concurrent Compute calls are possible, it will be shared
// across them. mutable due to this.
// The alternative would be to create a threadpool in each call to Compute but that would incur thread creation
// cost on every call.
mutable onnxruntime::concurrency::ThreadPool ttp_{"DEEPCPU_LSTM",
static_cast<int>(std::thread::hardware_concurrency())};
};
} // namespace onnxruntime