mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Fix a perf regression by providing a better estimate for the cost in LSTM's TryParallelFor call.
This commit is contained in:
parent
12d7c2f6e4
commit
bad90d7a53
1 changed files with 3 additions and 1 deletions
|
|
@ -849,7 +849,9 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
}
|
||||
};
|
||||
|
||||
double cost = max_sequence_length * fused_hidden_rows; // TODO: approximate cost, needs more tuning.
|
||||
// Approximate worst case cost of hidden_gemm_and_activations.
|
||||
double gemm_cost = fused_hidden_rows * hidden_size_x4 * hidden_size_;
|
||||
double cost = max_sequence_length * (gemm_cost + fused_hidden_rows);
|
||||
ExecuteLambdaInParallel(hidden_gemm_and_activations, batch_size_, fused_hidden_rows, cost, mlas_tp_);
|
||||
|
||||
} else {
|
||||
|
|
|
|||
Loading…
Reference in a new issue