diff --git a/onnxruntime/core/providers/cpu/math/top_k.cc b/onnxruntime/core/providers/cpu/math/top_k.cc index 422c5b7616..26995bcfc8 100644 --- a/onnxruntime/core/providers/cpu/math/top_k.cc +++ b/onnxruntime/core/providers/cpu/math/top_k.cc @@ -20,6 +20,7 @@ #include "core/common/exceptions.h" #include "core/framework/op_kernel.h" #include "core/framework/tensor.h" +#include "core/platform/threadpool.h" #include "core/util/math_cpuonly.h" #include #include @@ -31,145 +32,301 @@ namespace onnxruntime { template struct GreaterValueCmp { using DataType = T; - bool operator()(const pair& lhs, const pair& rhs) { - return (lhs.first > rhs.first || + GreaterValueCmp(const T* data = nullptr) : data_(data) { + } + + bool operator()(const int64_t lhs_idx, const int64_t rhs_idx) const { + return (data_[lhs_idx] > data_[rhs_idx] || // when values are equal, we want lhs to get higher "priority" // if its corresponding index comes first (i.e.) is lower - (lhs.first == rhs.first && lhs.second < rhs.second)); + (data_[lhs_idx] == data_[rhs_idx] && lhs_idx < rhs_idx)); } + + bool CompareValueOnly(const T& lhs, const T& rhs) const { + return lhs > rhs; + } + + private: + const T* data_; }; template struct LesserValueCmp { using DataType = T; - bool operator()(const pair& lhs, const pair& rhs) { - return (lhs.first < rhs.first || + + LesserValueCmp(const T* data = nullptr) : data_(data) { + } + + bool operator()(const int64_t lhs_idx, const int64_t rhs_idx) const { + return (data_[lhs_idx] < data_[rhs_idx] || // when values are equal, we want lhs to get higher "priority" // if its corresponding index comes first (i.e.) is lower - (lhs.first == rhs.first && lhs.second < rhs.second)); + (data_[lhs_idx] == data_[rhs_idx] && lhs_idx < rhs_idx)); } + + bool CompareValueOnly(const T& lhs, const T& rhs) const { + return lhs < rhs; + } + + private: + const T* data_; }; +/* +Maintain a binary heap where HeapComp of the parent with either child is false. + e.g. if the comparison is 'greater than', the parent is smaller than both children. +There is no ordering within a level. + +NOTE: The comparison is backwards compared to std::priority_queue as we use the same comparator for this as for + nth_element in SelectTopK. As such for a heap selecting the largest values the comparator is 'greater than'. +*/ +template +static void HeapifyIthPosition(int64_t* heap, size_t i, size_t k, const HeapCmp& heap_cmp) { + while (true) { + size_t left = 2 * i + 1; + size_t right = left + 1; + if (right < k) { + // need to check both left and right children as either could be replaced + + // check if we should move child up. check left node as well as whether left is preferred over right. + // if 'i' can replace left, check whether right would replace left (if so, i replaces left as it's the weakest) + bool i_replaces_left = heap_cmp(heap[i], heap[left]); + if (i_replaces_left && heap_cmp(heap[right], heap[left])) { + // left is going to be pushed up as both i and right beat it + // NOTE: std::swap is slower as it uses std::move + auto tmp = heap[i]; + heap[i] = heap[left]; + heap[left] = tmp; + i = left; + } else if (i_replaces_left || heap_cmp(heap[i], heap[right])) { + // i_replaces_left implies left replaces right due to 'if' so replace right with i as right is the weakest. + // also check if i only beats right + auto tmp = heap[i]; + heap[i] = heap[right]; + heap[right] = tmp; + i = right; + } else + break; + } else if ((left < k) && heap_cmp(heap[i], heap[left])) { + auto tmp = heap[i]; + heap[i] = heap[left]; + heap[left] = tmp; + i = left; + } else + break; + } +} + // Static helpers that implement the core logic for each of the 'TopK' operator flavor // Selects the top k elements (largest or smallest based on template parameter) template -static vector> select_top_k( - const ConstEigenMatrixMapRowMajor& raw_data, int64_t row_num, int64_t num_blocks, - int64_t block_slice, int64_t inter_block_offset, const unsigned k, - bool sort_top_k) { - // create a data holder and insert elements - vector> data_holder; - data_holder.reserve(num_blocks); +static void SelectTopK(const Comparator& comparer, + int64_t row_offset, int64_t num_blocks, int64_t block_slice, int64_t inter_block_offset, + const unsigned k, bool sort_top_k, vector& data_holder) { for (int64_t l = 0; l < num_blocks; ++l) { - data_holder.push_back({raw_data(row_num, l * block_slice + inter_block_offset), l}); + data_holder[l] = (row_offset + (l * block_slice + inter_block_offset)); } - // find the top k (largest or smallest) elements in the data holder - O(n) - nth_element(data_holder.begin(), data_holder.begin() + (k - 1), data_holder.end(), Comparator()); + // find the top k (largest or smallest) elements in the data holder - O(n) average. O(n*n) worst case. + // See https://en.wikipedia.org/wiki/Quickselect + nth_element(data_holder.begin(), data_holder.begin() + (k - 1), data_holder.end(), comparer); // sort the top k elements if needed - O (k log k) if (sort_top_k) { - std::sort(data_holder.begin(), data_holder.begin() + k, Comparator()); + std::sort(data_holder.begin(), data_holder.begin() + k, comparer); } - // the data_holder now contains the top k elements in the first k indices - return data_holder; + // the data_holder now contains the indices of the top k elements in the first k elements } // Given an input tensor 'input' and metadata values - 'k' and 'axis_parsed', // this method will extract the sorted top k largest/smallest elements and place them in the output tensor 'values' // along with the metadata output 'indices' -template -static void extract_top_k_elements(const Tensor* input, const TensorShape& input_shape, Tensor* values, - Tensor* indices, const TensorShape& output_shape, const unsigned k, - const unsigned axis_parsed) { +template +static void FindTopKElements(const Tensor* input, const TensorShape& input_shape, Tensor* values, + Tensor* indices, const TensorShape& output_shape, const unsigned k, bool sorted, + const unsigned axis_parsed, concurrency::ThreadPool* threadpool) { // Cache some values that will be used in the implementation below const int64_t rows = input_shape.SizeToDimension(static_cast(axis_parsed)); const int64_t cols = input->Shape().Size() / rows; - auto input_map = - ConstEigenMatrixMapRowMajor( - static_cast(input->template Data()), rows, cols); + const auto* input_data = input->template Data(); - // Use Eigen maps to allow indexing into the 2d tensors like Values_map(i,j) + // Use Eigen maps for convenient indexing into the 2d tensors like Values_map(i,j) const int64_t reduced_cols = output_shape.SizeFromDimension(static_cast(axis_parsed)); - auto values_map = EigenMatrixMapRowMajor( - values->template MutableData(), rows, reduced_cols); - auto indices_map = EigenMatrixMapRowMajor(indices->template MutableData(), rows, reduced_cols); + + auto* values_data = values->template MutableData(); + auto* indices_data = indices->template MutableData(); + auto values_map = EigenMatrixMapRowMajor(values_data, rows, reduced_cols); + auto indices_map = EigenMatrixMapRowMajor(indices_data, rows, reduced_cols); // This is basically the number of elements within each of the "k" rows - const int64_t block_slice = reduced_cols / k; const int64_t num_blocks = input_shape[axis_parsed]; + const int64_t block_slice = reduced_cols / k; - for (int64_t i = 0; i < rows; ++i) { - for (int64_t j = 0; j < block_slice; ++j) { - // Since sorted == true, we will use a Heap to hold the top K values in sorted fashion - if (sorted) { // The optimizer will clean-up the redundant condition based on the template parameter 'sorted' - auto n_casted = static_cast(num_blocks); - auto k_casted = static_cast(k); - if ((n_casted + k_casted * log(k_casted)) < (n_casted * log(k_casted))) { - // Select first - O(n), then sort O(k * ln(k)) - // Overall complexity = O (n + k * ln(k)) - const auto& data_holder = select_top_k(input_map, i, num_blocks, block_slice, j, k, true); - for (int64_t l = 0; l < k; ++l) { - const auto& elem = data_holder[l]; - auto col_index = l * block_slice + j; - values_map(i, col_index) = elem.first; - indices_map(i, col_index) = elem.second; - } - } else { - // Perform sorted selection by passing 'n' elements over a heap of size 'k' - // overall complexity = O (n * ln(k)) + int64_t tp_threads = threadpool != nullptr ? threadpool->NumThreads() : 1; + int64_t num_threads = std::min(tp_threads, rows); // split on rows so can't have more threads than rows - // Build a min-heap/max-heap, the heap element is pair of (value, idx) - // The top of the heap is the smallest/largest value depending on whether it is a min-heap/max-heap - // This is a min-heap if largest == true, this is a max-heap if largest == false - priority_queue, vector>, Comparator> heap; + // rough attempt to make sure there's enough work for each thread. if there's insufficient work the usage of + // too many threads degrades performance. + // TODO: May want a different calculation for each branch below instead. + int64_t threads_needed = static_cast(std::floor(input_shape.Size() * k / (128 * 1024))); + num_threads = std::max(std::min(threads_needed, num_threads), static_cast(1)); - // Maintain the size of heap to be less or equal to k, so the - // heap will hold the k largest/smallest values - for (int64_t l = 0; l < num_blocks; ++l) { - const auto value = input_map(i, l * block_slice + j); - // largest == true: insert into the min-heap if the size is < k or if the new - // element is greater than the min element in the min-heap + // from testing various batch sizes relative to k, the following appears to work well as a selector. + // tested with following combinations + // batch_size = [ 8, 16, 32, 64, 128, 256, 512, 1024, 2048 ] + // k = [ 1, 2, 4, 6, 8, 16, 24, 32, 48, 64, 128 ] + bool use_priority_queue = k != 1 && (k < 4 || (std::log2(k) / std::log2(num_blocks)) < 0.725); - // largest == false: insert into the min-heap if the size is < k or if the new - // element is lesser than the max element in the max-heap - if ((heap.size() < k) || (largest && value > heap.top().first) || - (!largest && value < heap.top().first)) { // the optimizer will clean-up the redundant condition based - // on the template parameter 'largest' - heap.push({value, l}); - } - if (heap.size() > k) { - heap.pop(); + std::function find_top_k; + + if (k == 1) { + // just need to compare values and not indexes as the first instance of the best value is always selected + find_top_k = + [num_threads, rows, block_slice, num_blocks, input_data, cols, &values_map, &indices_map](std::ptrdiff_t batch) { + int64_t start_row = static_cast(batch * rows / num_threads); + int64_t end_row = static_cast((batch + 1) * rows / num_threads); + + Comparator comparer(input_data); + + for (int64_t i = start_row; i < end_row; ++i) { + auto row_offset = i * cols; + for (int64_t j = 0; j < block_slice; ++j) { + int64_t cur_idx = row_offset + j; + + const auto* cur_value = input_data + cur_idx; // using pointer to data is faster than input_data[cur_idx] + auto best = *cur_value; // save best value so we only have one load in the CompareValueOnly call + int64_t top_idx = cur_idx; + + for (int64_t l = 1; l < num_blocks; ++l) { + cur_value += block_slice; + if (comparer.CompareValueOnly(*cur_value, best)) { + best = *cur_value; + top_idx = cur_value - input_data; + } + } + + values_map(i, j) = best; + // convert overall index to result index + // avoid '/' if possible for perf reasons + indices_map(i, j) = block_slice == 1 ? (top_idx - row_offset - j) + : (top_idx - row_offset - j) / block_slice; } } - // Extract these k elements and place them in the results placeholder - for (int64_t l = 0; l < k; ++l) { - const auto& elem = heap.top(); - auto col_index = (k - l - 1) * block_slice + j; - values_map(i, col_index) = elem.first; - indices_map(i, col_index) = elem.second; - heap.pop(); + }; + } else if (use_priority_queue) { + find_top_k = + [num_threads, rows, block_slice, num_blocks, k, sorted, + input_data, cols, &values_map, &indices_map](std::ptrdiff_t batch) { + int64_t start_row = static_cast(batch * rows / num_threads); + int64_t end_row = static_cast((batch + 1) * rows / num_threads); + + Comparator comparer(input_data); + + // the heap is stored in indices_data. each iteration overwrites the old data when it adds the + // initial k values, so we don't need to clear it. + std::vector indices_data(k); + int64_t* indices = indices_data.data(); // raw pointer is slightly faster for HeapifyIthPosition + + for (int64_t i = start_row; i < end_row; ++i) { + const auto row_offset = i * cols; + + for (int64_t j = 0; j < block_slice; ++j) { + int64_t l = 0; + auto cur_idx = row_offset + j; + + // add first k items starting from the bottom up + for (; l < k; ++l) { + indices[k - l - 1] = cur_idx; + HeapifyIthPosition(indices, k - l - 1, k, comparer); + + cur_idx += block_slice; + } + + // insert remainder if the next value would replace the top of the heap (current worst top k value) + // save top so we only have one load in the CompareValueOnly call + auto top = input_data[indices[0]]; + for (; l < num_blocks; ++l) { + // we can compare value only. if the current value is equal to the top of the heap it won't + // replace it as the index will be higher. + if (comparer.CompareValueOnly(input_data[cur_idx], top)) { + indices[0] = cur_idx; + HeapifyIthPosition(indices, 0, k, comparer); + top = input_data[indices[0]]; + } + + cur_idx += block_slice; + } + + if (sorted) { + // Extract these k elements and place them in the results placeholder + for (l = 0; l < k; ++l) { + auto idx = indices[0]; + auto col_index = (k - l - 1) * block_slice + j; + values_map(i, col_index) = input_data[idx]; + // convert overall index to result index. avoid '/' if possible for perf reasons + indices_map(i, col_index) = block_slice == 1 ? (idx - row_offset - j) + : (idx - row_offset - j) / block_slice; + + // put the last value at the top of the heap to replace the removed one, and push it into + // place in a heap one smaller. + indices[0] = indices[k - l - 1]; + HeapifyIthPosition(indices, 0, k - l - 1, comparer); + } + } else { + for (l = 0; l < k; ++l) { + int64_t idx = indices[l]; + auto col_index = l * block_slice + j; + values_map(i, col_index) = input_data[idx]; + // convert overall index to result index. avoid '/' if possible for perf reasons + indices_map(i, col_index) = block_slice == 1 ? (idx - row_offset - j) + : (idx - row_offset - j) / block_slice; + } + } + } } - } - } else { // sorted == false - // The optimizer will clean-up the redundant condition based on the template parameter 'sorted' + }; + } else { + find_top_k = + [num_threads, rows, block_slice, num_blocks, k, sorted, + input_data, cols, + &values_map, &indices_map](std::ptrdiff_t batch) { + int64_t start_row = static_cast(batch * rows / num_threads); + int64_t end_row = static_cast((batch + 1) * rows / num_threads); - // If the top K values are not required to be sorted, we use a more optimal selection algorithm - // Average - O(n). Worst - O(n * ln(n)) or O(n^2) depending on the implementation, where 'n' is the number of input + Comparator comparer(input_data); - const auto& data_holder = select_top_k(input_map, i, num_blocks, block_slice, j, k, false); + // we re-use a single data_holder for performance. avoids allocating memory on each iteration. + // the call to SelectTopK overwrites any existing data so we don't need to clear on each iteration. + std::vector data_holder(num_blocks); - // Insert the top 'k' (largest or smallest) elements into the final output buffers - for (int64_t l = 0; l < k; ++l) { - const auto& elem = data_holder[l]; - auto col_index = l * block_slice + j; - values_map(i, col_index) = elem.first; - indices_map(i, col_index) = elem.second; - } - } - } + for (int64_t i = start_row; i < end_row; ++i) { + auto row_offset = i * cols; + for (int64_t j = 0; j < block_slice; ++j) { + SelectTopK(comparer, row_offset, num_blocks, block_slice, j, k, sorted, data_holder); + + // Insert the top 'k' (largest or smallest) elements into the final output buffers + for (int64_t l = 0; l < k; ++l) { + int64_t idx = data_holder[l]; + auto col_index = l * block_slice + j; + values_map(i, col_index) = input_data[idx]; + // convert overall index to result index. avoid the cost of the '/' is possible + indices_map(i, col_index) = block_slice == 1 ? (idx - row_offset - j) + : (idx - row_offset - j) / block_slice; + } + } + } + }; + } + + if (num_threads <= 1) { + find_top_k(0); + } else { + // we want to re-use the storage variables in each lambda as much as possible to minimize allocations + // on each iteration, so the lambda does multiple rows. e.g. the data_holder and indices_data vectors. + // the alternative would be to use TryBatchParallelFor with the lambda doing one row. + threadpool->SimpleParallelFor(num_threads, find_top_k); } } @@ -204,22 +361,14 @@ static Status TopKImpl(OpKernelContext* p_op_kernel_context, const Tensor* input return Status::OK(); } - if (sorted && largest) { - // extract sorted largest TopK elements - extract_top_k_elements>(input, input_shape, values, indices, output_shape, k, - gsl::narrow_cast(axis_parsed)); - } else if (sorted && !largest) { - // extract sorted smallest TopK elements - extract_top_k_elements>(input, input_shape, values, indices, output_shape, k, - gsl::narrow_cast(axis_parsed)); - } else if (largest) { - // extract unsorted (order undefined) largest TopK elements - extract_top_k_elements>(input, input_shape, values, indices, output_shape, k, - gsl::narrow_cast(axis_parsed)); + auto* threadpool = p_op_kernel_context->GetOperatorThreadPool(); + + if (largest) { + FindTopKElements>(input, input_shape, values, indices, output_shape, k, sorted, + gsl::narrow_cast(axis_parsed), threadpool); } else { - // extract unsorted (order undefined) smallest TopK elements - extract_top_k_elements>(input, input_shape, values, indices, output_shape, k, - gsl::narrow_cast(axis_parsed)); + FindTopKElements>(input, input_shape, values, indices, output_shape, k, sorted, + gsl::narrow_cast(axis_parsed), threadpool); } return Status::OK(); diff --git a/onnxruntime/test/providers/cpu/math/topk_op_test.cc b/onnxruntime/test/providers/cpu/math/topk_op_test.cc index d3a01de6e3..eb8f775c67 100644 --- a/onnxruntime/test/providers/cpu/math/topk_op_test.cc +++ b/onnxruntime/test/providers/cpu/math/topk_op_test.cc @@ -425,54 +425,17 @@ TEST(TopKOperator, Top1ExplicitAxisMultiDInputSmallestElements) { top_1_explicit_axis_MultiD_input_smallest(11, 0); //unsorted } -TEST(TopKOperator, SelectFirstSortNext) { - // in this test, we will select the top 5 elements first then sort the chosen 5 elements - // Select + Sort = O(n + k * ln(k)) = 50 + 5 * ln(5) = 58.047 - // Sorted selection: O(n * ln(k)) = 50 * ln(5) = 80.47 - // The algorithm used will be Select + Sort - std::vector input_vals = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0, - 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0, - 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0, - 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0, - 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0}; - std::vector input_dimensions = {50}; - std::vector expected_vals = {50.0f, 49.0f, 48.0f, 47.0f, 46.0f}; - std::vector expected_indices = {49, 48, 47, 46, 45}; - std::vector expected_dimensions = {5}; - int64_t axis = 0; - RunTest(11, 5, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis); // largest values -} - -TEST(TopKOperator, SelectFirstSortNextInt64) { - // in this test, we will select the top 5 elements first then sort the chosen 5 elements - // Select + Sort = O(n + k * ln(k)) = 50 + 5 * ln(5) = 58.047 - // Sorted selection: O(n * ln(k)) = 50 * ln(5) = 80.47 - // The algorithm used will be Select + Sort - std::vector input_vals = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, - 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, - 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, - 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, - 41, 42, 43, 44, 45, 46, 47, 48, 49, 50}; - std::vector input_dimensions = {50}; - std::vector expected_vals = {50, 49, 48, 47, 46}; - std::vector expected_indices = {49, 48, 47, 46, 45}; - std::vector expected_dimensions = {5}; - int64_t axis = 0; - RunTest(11, 5, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis); // largest values -} - -TEST(TopKOperator, SortedSelection) { - // in this test, we will use sorted selection (using heap) - // Select + Sort = O(n + k * ln(k)) = 10 + 5 * ln(5) = 18.04 - // Sorted selection: O(n * ln(k)) = 10 * ln(5) = 16.09 - // The algorithm used will be Sorted selection - std::vector input_vals = {10.0f, 8.0f, 7.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 9.0f, 3.0}; - std::vector input_dimensions = {10}; - std::vector expected_vals = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; - std::vector expected_indices = {6, 7, 9, 3, 4}; - std::vector expected_dimensions = {5}; - int64_t axis = 0; - RunTest(11, 5, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis, 0); // smallest values +// test path where SelectTopK is used (select using std::nth_element) +// we use a custom path for n=1, and priority queue based implementation if +// bool use_priority_queue = k != 1 && (k < 4 || (std::log2(k) / std::log2(n)) < 0.725); +// so easiest way to test is for k to be 4 and n to be a little larger +TEST(TopKOperator, NthElement) { + std::vector input_vals = {10.0f, 8.0f, 7.0f, 4.0f, 5.0f, 6.0f}; + std::vector input_dimensions = {6}; + std::vector expected_vals = {10.0f, 8.0f, 7.0f, 6.0f}; + std::vector expected_indices = {0, 1, 2, 5}; + std::vector expected_dimensions = {4}; + RunTest(11, 4, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false); } // test dimension in range (GridDim::maxThreadsPerBlock, GridDim::maxThreadsPerBlock * 2], ie. [257, 512] @@ -532,5 +495,66 @@ TEST(TopKOperator, BigArrayBigTopKSorted) { RunTest(11, 9000, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, 0, 1, 1); } +static void top_3_all_same(int opset_version, int64_t largest = 1) { + // whether it's largest or smallest we should pick the first instance/s of a number if there are multiple + std::vector input_vals = {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f}; + std::vector input_dimensions = {2, 4}; + std::vector expected_vals = {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f}; + std::vector expected_indices = {0, 1, 2, 0, 1, 2}; + std::vector expected_dimensions = {2, 3}; + RunTest(opset_version, 3, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, -1, largest); +} + +TEST(TopKOperator, Top3AllSame) { + top_3_all_same(10); + top_3_all_same(11); + top_3_all_same(10, 0); // smallest + top_3_explicit_axis(11, 0); +} + +static void TestThreaded(int64_t k, int64_t n, int64_t batch_size) { + std::vector input_vals(n * batch_size, 0.0f); + std::iota(input_vals.begin(), input_vals.end(), 0.0f); + + std::vector input_dimensions = {n, batch_size}; + + std::vector expected_vals(n * k, 0.0f); + std::vector expected_indices(n * k, 0); + std::vector expected_dimensions = {n, k}; + + for (int64_t i = 0; i < n; ++i) { + auto begin_batch_output = expected_vals.begin() + i * k; + std::iota(begin_batch_output, begin_batch_output + k, static_cast(((i + 1) * batch_size) - k)); + std::reverse(begin_batch_output, begin_batch_output + k); + + // indices are within the axis so don't need adjusting by the batch number + auto begin_indices_output = expected_indices.begin() + i * k; + std::iota(begin_indices_output, begin_indices_output + k, batch_size - k); + std::reverse(begin_indices_output, begin_indices_output + k); + } + + RunTest(11, k, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false); +} + +// create input of 2x1000 and select 200 so 2 threads are needed based on there being 2 rows +// and sufficient items to process given this calculation: +// int64_t threads_needed = static_cast(std::floor(input_shape.Size() * k / (128 * 1024))); +TEST(TopKOperator, PriorityQueueThreaded) { + const int64_t k = 200; + const int64_t n = 2; + const int64_t batch_size = 1000; + TestThreaded(k, n, batch_size); +} + +// create input of 2x500 and select 400 so 2 threads are needed based on there being 2 rows +// and sufficient items to process given this calculation: +// int64_t threads_needed = static_cast(std::floor(input_shape.Size() * k / (128 * 1024))); +TEST(TopKOperator, SelectTopKThreaded) { + const int64_t k = 400; + const int64_t n = 2; + const int64_t batch_size = 500; + TestThreaded(k, n, batch_size); +} + } // namespace test } // namespace onnxruntime