diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc index 8bddcdcb08..fbff94cfd6 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc @@ -14,6 +14,7 @@ #include "core/common/common.h" #include "core/common/logging/logging.h" #include "core/framework/allocator.h" +#include "core/framework/op_kernel_context_internal.h" #ifdef _MSC_VER #pragma warning(pop) @@ -352,6 +353,11 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const { } } + auto& ctx_internal = static_cast(context); + // the session always has a threadpool so dereferencing is safe + // TODO: Fix having to use a const_cast to run tasks using the threadpool + auto& thread_pool = const_cast(*ctx_internal.GetOperatorThreadPool()); + AllocatorPtr alloc; status = context.GetTempSpaceAllocator(&alloc); ORT_RETURN_IF_ERROR(status); @@ -456,7 +462,7 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[0], activation_funcs_.Entries()[1], activation_funcs_.Entries()[2], - clip_, ttp_); + clip_, thread_pool); bw = std::make_unique>(alloc, logger, seq_length, batch_size, input_size, @@ -465,7 +471,7 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[3], activation_funcs_.Entries()[4], activation_funcs_.Entries()[5], - clip_, ttp_); + clip_, thread_pool); fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1); bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2, last_cell_2); @@ -477,7 +483,7 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[0], activation_funcs_.Entries()[1], activation_funcs_.Entries()[2], - clip_, ttp_); + clip_, thread_pool); fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1); } @@ -1087,6 +1093,7 @@ void UniDirectionalLstm::GateComputations(span_T_iter& out, span_T_iter& out_ // DumpMatrix("H" + row_str, pH, 1, hidden_size_); } +#ifdef DUMP_MATRIXES auto num_rows = local_fused_hidden_rows - row; std::string rows_str = " rows[" + std::to_string(row) + ".." + std::to_string(num_rows) + "]"; @@ -1096,6 +1103,7 @@ void UniDirectionalLstm::GateComputations(span_T_iter& out, span_T_iter& out_ DumpMatrix("c" + rows_str, &*out, num_rows, hidden_size_, 3 * hidden_size_, hidden_size_x4); DumpMatrix("C" + rows_str, &*C_prev, num_rows, hidden_size_); // Ct overwrites the input C_prev value DumpMatrix("H" + rows_str, &*batched_output, num_rows, hidden_size_); +#endif } template diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h index 606dfbf5b1..0ebab72e2b 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h @@ -77,13 +77,6 @@ class DeepCpuLstmOp final : public OpKernel { bool input_forget_ = false; rnn::detail::ActivationFuncs activation_funcs_; - - // Threadpool for operator. If concurrent Compute calls are possible, it will be shared - // across them. mutable due to this. - // The alternative would be to create a threadpool in each call to Compute but that would incur thread creation - // cost on every call. - mutable onnxruntime::concurrency::ThreadPool ttp_{"DEEPCPU_LSTM", - static_cast(std::thread::hardware_concurrency())}; }; } // namespace onnxruntime