diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc index 4a197a59e0..f8a8b7dac8 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc @@ -849,7 +849,9 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, } }; - double cost = max_sequence_length * fused_hidden_rows; // TODO: approximate cost, needs more tuning. + // Approximate worst case cost of hidden_gemm_and_activations. + double gemm_cost = fused_hidden_rows * hidden_size_x4 * hidden_size_; + double cost = max_sequence_length * (gemm_cost + fused_hidden_rows); ExecuteLambdaInParallel(hidden_gemm_and_activations, batch_size_, fused_hidden_rows, cost, mlas_tp_); } else {