mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Address two issues: (#1905)
* Address two issues: Thread-safety issue with LTSM/RNN running lambda in parallel Propagate lambda exceptions and report them when running in parallel.
This commit is contained in:
parent
30c7c76552
commit
4d26f2ce86
2 changed files with 63 additions and 17 deletions
|
|
@ -212,8 +212,8 @@ class UniDirectionalLstm {
|
|||
void SetNumThreads();
|
||||
|
||||
void GateComputations(span_T_iter& out, span_T_iter& out_end, span_T_iter& C_prev,
|
||||
span_T_iter& C_prev_end, // Ct-1 value not 'ct'. using 'C' for clarity
|
||||
span_T_iter& C_prev_clipped, span_T_iter& C_prev_clipped_end, span_T_iter& batched_output,
|
||||
const span_T_iter& C_prev_end, // Ct-1 value not 'ct'. using 'C' for clarity
|
||||
span_T_iter& C_prev_clipped, const span_T_iter& C_prev_clipped_end, span_T_iter& batched_output,
|
||||
span_T_iter& batched_output_end, const gsl::span<const int>& seq_lengths,
|
||||
int min_sequence_length, int step, int row, int local_fused_hidden_rows, bool output_sequence);
|
||||
|
||||
|
|
@ -794,9 +794,8 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
// explicitly check just the range for each iteration, however if it's going to run over
|
||||
// it should also run over on the last iteration, so this should be good enough to catch any
|
||||
// logic errors causing bounds violations.
|
||||
span_T_iter C_prev_end = batched_internal_state_prev_one_step.end();
|
||||
span_T_iter C_prev_clipped_end = batched_internal_state_clipped_one_step.end();
|
||||
span_T_const_iter previous_state_end = batched_hidden_state_one_step.end();
|
||||
const span_T_iter C_prev_end = batched_internal_state_prev_one_step.end();
|
||||
const span_T_iter C_prev_clipped_end = batched_internal_state_clipped_one_step.end();
|
||||
|
||||
if (batch_parallel_) {
|
||||
int fused_hidden_rows = batch_size_ / hidden_num_threads_;
|
||||
|
|
@ -805,6 +804,8 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
|
||||
// lambda to do all processing on fused_hidden_rows rows
|
||||
auto hidden_gemm_and_activations = [&](int row) {
|
||||
span_T_const_iter previous_state_end = batched_hidden_state_one_step.cend();
|
||||
|
||||
//handling boundaries
|
||||
int local_fused_hidden_rows = fused_hidden_rows;
|
||||
if ((row + fused_hidden_rows) > batch_size_)
|
||||
|
|
@ -887,6 +888,8 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
ExecuteLambdaInParallel("Processing batch", hidden_gemm_and_activations, batch_size_, fused_hidden_rows, lstm_tp_, logger_);
|
||||
|
||||
} else {
|
||||
span_T_const_iter previous_state_end = batched_hidden_state_one_step.cend();
|
||||
|
||||
span_T_iter c_prev = batched_internal_state_prev_one_step.begin();
|
||||
span_T_iter c_prev_clipped = batched_internal_state_clipped_one_step.begin();
|
||||
|
||||
|
|
@ -995,8 +998,8 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
|
||||
template <typename T>
|
||||
void UniDirectionalLstm<T>::GateComputations(span_T_iter& out, span_T_iter& out_end,
|
||||
span_T_iter& C_prev, span_T_iter& C_prev_end, // Ct-1 value not 'ct'. using 'C' for clarity
|
||||
span_T_iter& C_prev_clipped, span_T_iter& C_prev_clipped_end,
|
||||
span_T_iter& C_prev, const span_T_iter& C_prev_end, // Ct-1 value not 'ct'. using 'C' for clarity
|
||||
span_T_iter& C_prev_clipped, const span_T_iter& C_prev_clipped_end,
|
||||
span_T_iter& batched_output, span_T_iter& batched_output_end,
|
||||
const gsl::span<const int>& seq_lengths,
|
||||
const int min_sequence_length,
|
||||
|
|
|
|||
|
|
@ -47,8 +47,8 @@ inline Direction MakeDirection(const std::string& direction) {
|
|||
if (direction == "bidirectional") {
|
||||
return kBidirectional;
|
||||
}
|
||||
ORT_THROW("Invalid 'direction' argument of '", direction,
|
||||
"'. Must be one of 'forward', 'reverse', or 'bidirectional'.");
|
||||
ORT_THROW("Invalid 'direction' argument of '", direction,
|
||||
"'. Must be one of 'forward', 'reverse', or 'bidirectional'.");
|
||||
}
|
||||
|
||||
/** Allocate a unique_ptr using allocator_, and return a span to the allocated memory so usage is safe
|
||||
|
|
@ -230,17 +230,60 @@ void ExecuteLambdaInParallel(const std::string& name, TLambda lambda, int max, i
|
|||
ORT_UNUSED_PARAMETER(name);
|
||||
ORT_UNUSED_PARAMETER(logger);
|
||||
|
||||
std::atomic<int> done(0);
|
||||
for (int i = 0; i < max; i += step) {
|
||||
ttp.Schedule([lambda, i, &done]() {
|
||||
lambda(i);
|
||||
++done;
|
||||
// ORT_ENFORCE may and does throw at times from within the tasks that run
|
||||
// on a thread-pool. Without propagating exceptions the process exits silently
|
||||
// which will make diagnosing bugs more difficult.
|
||||
|
||||
// \! UGLY
|
||||
// We have a problem here with the current thread-pool is that it takes std::function
|
||||
// by value and copies it more than once (even though it is movable).
|
||||
//
|
||||
// To report status and exceptions properly it's better to use
|
||||
// futures and promises but they are not copyable, so we can't come up with a functor
|
||||
// with a promise member and we are downgrading to C++11 where we can't have captures that moved in.
|
||||
//
|
||||
// At the same time promises MUST live in the child thread so if we throw from the main thread
|
||||
// we don't destroy any promises that are on the main thread stack which children threads may still be using.
|
||||
//
|
||||
// The only solution with the current Eigen that comes to mind is to have shared_ptr to with std::promise.
|
||||
//
|
||||
const int total_tasks = max / (step > 0 ? step : 1) + (max % step > 0 ? 1 : 0);
|
||||
std::vector<std::future<void> > futures;
|
||||
futures.reserve(total_tasks);
|
||||
|
||||
for (int i = 0, t = 0; i < max; i += step, ++t) {
|
||||
auto p_ptr = std::make_shared<std::promise<void> >();
|
||||
futures.push_back(p_ptr->get_future());
|
||||
ttp.Schedule([p_ptr, lambda, i]() {
|
||||
try {
|
||||
lambda(i);
|
||||
p_ptr->set_value();
|
||||
} catch (...) {
|
||||
p_ptr->set_exception(std::current_exception());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
int totalTasks = max / (step > 0 ? step : 1) + (max % step > 0 ? 1 : 0);
|
||||
while (done != totalTasks)
|
||||
;
|
||||
// We'd like to wait until all of the tasks have finished
|
||||
// even though one or more have already thrown. We will store
|
||||
// the first exception and then will re-throw at the end.
|
||||
std::exception_ptr pending_exception;
|
||||
for (auto& fut : futures) {
|
||||
try {
|
||||
// get() will re-throw any exceptions
|
||||
// the running task may throw
|
||||
fut.get();
|
||||
} catch (...) {
|
||||
if (!pending_exception) {
|
||||
pending_exception = std::current_exception();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (pending_exception) {
|
||||
std::rethrow_exception(pending_exception);
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue