Address two issues: (#1905)

* Address two issues:
  Thread-safety issue with LTSM/RNN running lambda in parallel
  Propagate lambda exceptions and report them when running in
  parallel.
This commit is contained in:
Dmitri Smirnov 2019-09-25 09:57:11 -07:00 committed by GitHub
parent 30c7c76552
commit 4d26f2ce86
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 17 deletions

View file

@ -212,8 +212,8 @@ class UniDirectionalLstm {
void SetNumThreads();
void GateComputations(span_T_iter& out, span_T_iter& out_end, span_T_iter& C_prev,
span_T_iter& C_prev_end, // Ct-1 value not 'ct'. using 'C' for clarity
span_T_iter& C_prev_clipped, span_T_iter& C_prev_clipped_end, span_T_iter& batched_output,
const span_T_iter& C_prev_end, // Ct-1 value not 'ct'. using 'C' for clarity
span_T_iter& C_prev_clipped, const span_T_iter& C_prev_clipped_end, span_T_iter& batched_output,
span_T_iter& batched_output_end, const gsl::span<const int>& seq_lengths,
int min_sequence_length, int step, int row, int local_fused_hidden_rows, bool output_sequence);
@ -794,9 +794,8 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
// explicitly check just the range for each iteration, however if it's going to run over
// it should also run over on the last iteration, so this should be good enough to catch any
// logic errors causing bounds violations.
span_T_iter C_prev_end = batched_internal_state_prev_one_step.end();
span_T_iter C_prev_clipped_end = batched_internal_state_clipped_one_step.end();
span_T_const_iter previous_state_end = batched_hidden_state_one_step.end();
const span_T_iter C_prev_end = batched_internal_state_prev_one_step.end();
const span_T_iter C_prev_clipped_end = batched_internal_state_clipped_one_step.end();
if (batch_parallel_) {
int fused_hidden_rows = batch_size_ / hidden_num_threads_;
@ -805,6 +804,8 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
// lambda to do all processing on fused_hidden_rows rows
auto hidden_gemm_and_activations = [&](int row) {
span_T_const_iter previous_state_end = batched_hidden_state_one_step.cend();
//handling boundaries
int local_fused_hidden_rows = fused_hidden_rows;
if ((row + fused_hidden_rows) > batch_size_)
@ -887,6 +888,8 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
ExecuteLambdaInParallel("Processing batch", hidden_gemm_and_activations, batch_size_, fused_hidden_rows, lstm_tp_, logger_);
} else {
span_T_const_iter previous_state_end = batched_hidden_state_one_step.cend();
span_T_iter c_prev = batched_internal_state_prev_one_step.begin();
span_T_iter c_prev_clipped = batched_internal_state_clipped_one_step.begin();
@ -995,8 +998,8 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
template <typename T>
void UniDirectionalLstm<T>::GateComputations(span_T_iter& out, span_T_iter& out_end,
span_T_iter& C_prev, span_T_iter& C_prev_end, // Ct-1 value not 'ct'. using 'C' for clarity
span_T_iter& C_prev_clipped, span_T_iter& C_prev_clipped_end,
span_T_iter& C_prev, const span_T_iter& C_prev_end, // Ct-1 value not 'ct'. using 'C' for clarity
span_T_iter& C_prev_clipped, const span_T_iter& C_prev_clipped_end,
span_T_iter& batched_output, span_T_iter& batched_output_end,
const gsl::span<const int>& seq_lengths,
const int min_sequence_length,

View file

@ -47,8 +47,8 @@ inline Direction MakeDirection(const std::string& direction) {
if (direction == "bidirectional") {
return kBidirectional;
}
ORT_THROW("Invalid 'direction' argument of '", direction,
"'. Must be one of 'forward', 'reverse', or 'bidirectional'.");
ORT_THROW("Invalid 'direction' argument of '", direction,
"'. Must be one of 'forward', 'reverse', or 'bidirectional'.");
}
/** Allocate a unique_ptr using allocator_, and return a span to the allocated memory so usage is safe
@ -230,17 +230,60 @@ void ExecuteLambdaInParallel(const std::string& name, TLambda lambda, int max, i
ORT_UNUSED_PARAMETER(name);
ORT_UNUSED_PARAMETER(logger);
std::atomic<int> done(0);
for (int i = 0; i < max; i += step) {
ttp.Schedule([lambda, i, &done]() {
lambda(i);
++done;
// ORT_ENFORCE may and does throw at times from within the tasks that run
// on a thread-pool. Without propagating exceptions the process exits silently
// which will make diagnosing bugs more difficult.
// \! UGLY
// We have a problem here with the current thread-pool is that it takes std::function
// by value and copies it more than once (even though it is movable).
//
// To report status and exceptions properly it's better to use
// futures and promises but they are not copyable, so we can't come up with a functor
// with a promise member and we are downgrading to C++11 where we can't have captures that moved in.
//
// At the same time promises MUST live in the child thread so if we throw from the main thread
// we don't destroy any promises that are on the main thread stack which children threads may still be using.
//
// The only solution with the current Eigen that comes to mind is to have shared_ptr to with std::promise.
//
const int total_tasks = max / (step > 0 ? step : 1) + (max % step > 0 ? 1 : 0);
std::vector<std::future<void> > futures;
futures.reserve(total_tasks);
for (int i = 0, t = 0; i < max; i += step, ++t) {
auto p_ptr = std::make_shared<std::promise<void> >();
futures.push_back(p_ptr->get_future());
ttp.Schedule([p_ptr, lambda, i]() {
try {
lambda(i);
p_ptr->set_value();
} catch (...) {
p_ptr->set_exception(std::current_exception());
}
});
}
int totalTasks = max / (step > 0 ? step : 1) + (max % step > 0 ? 1 : 0);
while (done != totalTasks)
;
// We'd like to wait until all of the tasks have finished
// even though one or more have already thrown. We will store
// the first exception and then will re-throw at the end.
std::exception_ptr pending_exception;
for (auto& fut : futures) {
try {
// get() will re-throw any exceptions
// the running task may throw
fut.get();
} catch (...) {
if (!pending_exception) {
pending_exception = std::current_exception();
}
}
}
if (pending_exception) {
std::rethrow_exception(pending_exception);
}
#endif
}