Fix a perf regression by providing a better estimate for the cost in LSTM's TryParallelFor call.

This commit is contained in:
Pranav Sharma 2020-04-28 00:59:02 -07:00 committed by Changming Sun
parent 12d7c2f6e4
commit bad90d7a53

View file

@ -849,7 +849,9 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
}
};
double cost = max_sequence_length * fused_hidden_rows; // TODO: approximate cost, needs more tuning.
// Approximate worst case cost of hidden_gemm_and_activations.
double gemm_cost = fused_hidden_rows * hidden_size_x4 * hidden_size_;
double cost = max_sequence_length * (gemm_cost + fused_hidden_rows);
ExecuteLambdaInParallel(hidden_gemm_and_activations, batch_size_, fused_hidden_rows, cost, mlas_tp_);
} else {