diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc index 682dabd926..05d6753cd7 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc @@ -212,8 +212,8 @@ class UniDirectionalLstm { void SetNumThreads(); void GateComputations(span_T_iter& out, span_T_iter& out_end, span_T_iter& C_prev, - span_T_iter& C_prev_end, // Ct-1 value not 'ct'. using 'C' for clarity - span_T_iter& C_prev_clipped, span_T_iter& C_prev_clipped_end, span_T_iter& batched_output, + const span_T_iter& C_prev_end, // Ct-1 value not 'ct'. using 'C' for clarity + span_T_iter& C_prev_clipped, const span_T_iter& C_prev_clipped_end, span_T_iter& batched_output, span_T_iter& batched_output_end, const gsl::span& seq_lengths, int min_sequence_length, int step, int row, int local_fused_hidden_rows, bool output_sequence); @@ -794,9 +794,8 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, // explicitly check just the range for each iteration, however if it's going to run over // it should also run over on the last iteration, so this should be good enough to catch any // logic errors causing bounds violations. - span_T_iter C_prev_end = batched_internal_state_prev_one_step.end(); - span_T_iter C_prev_clipped_end = batched_internal_state_clipped_one_step.end(); - span_T_const_iter previous_state_end = batched_hidden_state_one_step.end(); + const span_T_iter C_prev_end = batched_internal_state_prev_one_step.end(); + const span_T_iter C_prev_clipped_end = batched_internal_state_clipped_one_step.end(); if (batch_parallel_) { int fused_hidden_rows = batch_size_ / hidden_num_threads_; @@ -805,6 +804,8 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, // lambda to do all processing on fused_hidden_rows rows auto hidden_gemm_and_activations = [&](int row) { + span_T_const_iter previous_state_end = batched_hidden_state_one_step.cend(); + //handling boundaries int local_fused_hidden_rows = fused_hidden_rows; if ((row + fused_hidden_rows) > batch_size_) @@ -887,6 +888,8 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, ExecuteLambdaInParallel("Processing batch", hidden_gemm_and_activations, batch_size_, fused_hidden_rows, lstm_tp_, logger_); } else { + span_T_const_iter previous_state_end = batched_hidden_state_one_step.cend(); + span_T_iter c_prev = batched_internal_state_prev_one_step.begin(); span_T_iter c_prev_clipped = batched_internal_state_clipped_one_step.begin(); @@ -995,8 +998,8 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, template void UniDirectionalLstm::GateComputations(span_T_iter& out, span_T_iter& out_end, - span_T_iter& C_prev, span_T_iter& C_prev_end, // Ct-1 value not 'ct'. using 'C' for clarity - span_T_iter& C_prev_clipped, span_T_iter& C_prev_clipped_end, + span_T_iter& C_prev, const span_T_iter& C_prev_end, // Ct-1 value not 'ct'. using 'C' for clarity + span_T_iter& C_prev_clipped, const span_T_iter& C_prev_clipped_end, span_T_iter& batched_output, span_T_iter& batched_output_end, const gsl::span& seq_lengths, const int min_sequence_length, diff --git a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h index f1038e63a3..943d3fcd28 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h +++ b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h @@ -47,8 +47,8 @@ inline Direction MakeDirection(const std::string& direction) { if (direction == "bidirectional") { return kBidirectional; } - ORT_THROW("Invalid 'direction' argument of '", direction, - "'. Must be one of 'forward', 'reverse', or 'bidirectional'."); + ORT_THROW("Invalid 'direction' argument of '", direction, + "'. Must be one of 'forward', 'reverse', or 'bidirectional'."); } /** Allocate a unique_ptr using allocator_, and return a span to the allocated memory so usage is safe @@ -230,17 +230,60 @@ void ExecuteLambdaInParallel(const std::string& name, TLambda lambda, int max, i ORT_UNUSED_PARAMETER(name); ORT_UNUSED_PARAMETER(logger); - std::atomic done(0); - for (int i = 0; i < max; i += step) { - ttp.Schedule([lambda, i, &done]() { - lambda(i); - ++done; + // ORT_ENFORCE may and does throw at times from within the tasks that run + // on a thread-pool. Without propagating exceptions the process exits silently + // which will make diagnosing bugs more difficult. + + // \! UGLY + // We have a problem here with the current thread-pool is that it takes std::function + // by value and copies it more than once (even though it is movable). + // + // To report status and exceptions properly it's better to use + // futures and promises but they are not copyable, so we can't come up with a functor + // with a promise member and we are downgrading to C++11 where we can't have captures that moved in. + // + // At the same time promises MUST live in the child thread so if we throw from the main thread + // we don't destroy any promises that are on the main thread stack which children threads may still be using. + // + // The only solution with the current Eigen that comes to mind is to have shared_ptr to with std::promise. + // + const int total_tasks = max / (step > 0 ? step : 1) + (max % step > 0 ? 1 : 0); + std::vector > futures; + futures.reserve(total_tasks); + + for (int i = 0, t = 0; i < max; i += step, ++t) { + auto p_ptr = std::make_shared >(); + futures.push_back(p_ptr->get_future()); + ttp.Schedule([p_ptr, lambda, i]() { + try { + lambda(i); + p_ptr->set_value(); + } catch (...) { + p_ptr->set_exception(std::current_exception()); + } }); } - int totalTasks = max / (step > 0 ? step : 1) + (max % step > 0 ? 1 : 0); - while (done != totalTasks) - ; + // We'd like to wait until all of the tasks have finished + // even though one or more have already thrown. We will store + // the first exception and then will re-throw at the end. + std::exception_ptr pending_exception; + for (auto& fut : futures) { + try { + // get() will re-throw any exceptions + // the running task may throw + fut.get(); + } catch (...) { + if (!pending_exception) { + pending_exception = std::current_exception(); + } + } + } + + if (pending_exception) { + std::rethrow_exception(pending_exception); + } + #endif }