From bad90d7a532cfe8e478b850c8a02a627da7e7376 Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Tue, 28 Apr 2020 00:59:02 -0700 Subject: [PATCH] Fix a perf regression by providing a better estimate for the cost in LSTM's TryParallelFor call. --- onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc index 4a197a59e0..f8a8b7dac8 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc @@ -849,7 +849,9 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, } }; - double cost = max_sequence_length * fused_hidden_rows; // TODO: approximate cost, needs more tuning. + // Approximate worst case cost of hidden_gemm_and_activations. + double gemm_cost = fused_hidden_rows * hidden_size_x4 * hidden_size_; + double cost = max_sequence_length * (gemm_cost + fused_hidden_rows); ExecuteLambdaInParallel(hidden_gemm_and_activations, batch_size_, fused_hidden_rows, cost, mlas_tp_); } else {