Improve TopK performance. (#3612)

* Update TopK implementation.
  - add faster heap
  - special case k=1
  - update selector for when to use heap and when to use nth_element based on performance testing
  - parallelize if enough work to do
  - reduce templatized code
  - add some extra unit tests.

Perf tested vs. master. Average speedup is 3.75x using this combination of input sizes:

```
    batches = [10, 25, 50]
    batch_size = [8, 16, 32, 64, 128, 256, 512, 1024, 2048]
    k = [1, 2, 4, 6, 8, 16, 24, 32, 48, 64, 128]
```

For larger batches (e.g. 50x2048) the speedup is over 20x.
This commit is contained in:
Scott McKay 2020-04-22 10:05:13 +10:00 committed by GitHub
parent 9636da3951
commit b4508dbdc6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 327 additions and 154 deletions

View file

@ -20,6 +20,7 @@
#include "core/common/exceptions.h"
#include "core/framework/op_kernel.h"
#include "core/framework/tensor.h"
#include "core/platform/threadpool.h"
#include "core/util/math_cpuonly.h"
#include <queue>
#include <algorithm>
@ -31,145 +32,301 @@ namespace onnxruntime {
template <typename T>
struct GreaterValueCmp {
using DataType = T;
bool operator()(const pair<T, int64_t>& lhs, const pair<T, int64_t>& rhs) {
return (lhs.first > rhs.first ||
GreaterValueCmp(const T* data = nullptr) : data_(data) {
}
bool operator()(const int64_t lhs_idx, const int64_t rhs_idx) const {
return (data_[lhs_idx] > data_[rhs_idx] ||
// when values are equal, we want lhs to get higher "priority"
// if its corresponding index comes first (i.e.) is lower
(lhs.first == rhs.first && lhs.second < rhs.second));
(data_[lhs_idx] == data_[rhs_idx] && lhs_idx < rhs_idx));
}
bool CompareValueOnly(const T& lhs, const T& rhs) const {
return lhs > rhs;
}
private:
const T* data_;
};
template <typename T>
struct LesserValueCmp {
using DataType = T;
bool operator()(const pair<T, int64_t>& lhs, const pair<T, int64_t>& rhs) {
return (lhs.first < rhs.first ||
LesserValueCmp(const T* data = nullptr) : data_(data) {
}
bool operator()(const int64_t lhs_idx, const int64_t rhs_idx) const {
return (data_[lhs_idx] < data_[rhs_idx] ||
// when values are equal, we want lhs to get higher "priority"
// if its corresponding index comes first (i.e.) is lower
(lhs.first == rhs.first && lhs.second < rhs.second));
(data_[lhs_idx] == data_[rhs_idx] && lhs_idx < rhs_idx));
}
bool CompareValueOnly(const T& lhs, const T& rhs) const {
return lhs < rhs;
}
private:
const T* data_;
};
/*
Maintain a binary heap where HeapComp of the parent with either child is false.
e.g. if the comparison is 'greater than', the parent is smaller than both children.
There is no ordering within a level.
NOTE: The comparison is backwards compared to std::priority_queue as we use the same comparator for this as for
nth_element in SelectTopK. As such for a heap selecting the largest values the comparator is 'greater than'.
*/
template <class HeapCmp>
static void HeapifyIthPosition(int64_t* heap, size_t i, size_t k, const HeapCmp& heap_cmp) {
while (true) {
size_t left = 2 * i + 1;
size_t right = left + 1;
if (right < k) {
// need to check both left and right children as either could be replaced
// check if we should move child up. check left node as well as whether left is preferred over right.
// if 'i' can replace left, check whether right would replace left (if so, i replaces left as it's the weakest)
bool i_replaces_left = heap_cmp(heap[i], heap[left]);
if (i_replaces_left && heap_cmp(heap[right], heap[left])) {
// left is going to be pushed up as both i and right beat it
// NOTE: std::swap is slower as it uses std::move
auto tmp = heap[i];
heap[i] = heap[left];
heap[left] = tmp;
i = left;
} else if (i_replaces_left || heap_cmp(heap[i], heap[right])) {
// i_replaces_left implies left replaces right due to 'if' so replace right with i as right is the weakest.
// also check if i only beats right
auto tmp = heap[i];
heap[i] = heap[right];
heap[right] = tmp;
i = right;
} else
break;
} else if ((left < k) && heap_cmp(heap[i], heap[left])) {
auto tmp = heap[i];
heap[i] = heap[left];
heap[left] = tmp;
i = left;
} else
break;
}
}
// Static helpers that implement the core logic for each of the 'TopK' operator flavor
// Selects the top k elements (largest or smallest based on template parameter)
template <class Comparator>
static vector<pair<typename Comparator::DataType, int64_t>> select_top_k(
const ConstEigenMatrixMapRowMajor<typename Comparator::DataType>& raw_data, int64_t row_num, int64_t num_blocks,
int64_t block_slice, int64_t inter_block_offset, const unsigned k,
bool sort_top_k) {
// create a data holder and insert elements
vector<pair<typename Comparator::DataType, int64_t>> data_holder;
data_holder.reserve(num_blocks);
static void SelectTopK(const Comparator& comparer,
int64_t row_offset, int64_t num_blocks, int64_t block_slice, int64_t inter_block_offset,
const unsigned k, bool sort_top_k, vector<int64_t>& data_holder) {
for (int64_t l = 0; l < num_blocks; ++l) {
data_holder.push_back({raw_data(row_num, l * block_slice + inter_block_offset), l});
data_holder[l] = (row_offset + (l * block_slice + inter_block_offset));
}
// find the top k (largest or smallest) elements in the data holder - O(n)
nth_element(data_holder.begin(), data_holder.begin() + (k - 1), data_holder.end(), Comparator());
// find the top k (largest or smallest) elements in the data holder - O(n) average. O(n*n) worst case.
// See https://en.wikipedia.org/wiki/Quickselect
nth_element(data_holder.begin(), data_holder.begin() + (k - 1), data_holder.end(), comparer);
// sort the top k elements if needed - O (k log k)
if (sort_top_k) {
std::sort(data_holder.begin(), data_holder.begin() + k, Comparator());
std::sort(data_holder.begin(), data_holder.begin() + k, comparer);
}
// the data_holder now contains the top k elements in the first k indices
return data_holder;
// the data_holder now contains the indices of the top k elements in the first k elements
}
// Given an input tensor 'input' and metadata values - 'k' and 'axis_parsed',
// this method will extract the sorted top k largest/smallest elements and place them in the output tensor 'values'
// along with the metadata output 'indices'
template <bool largest, bool sorted, class Comparator>
static void extract_top_k_elements(const Tensor* input, const TensorShape& input_shape, Tensor* values,
Tensor* indices, const TensorShape& output_shape, const unsigned k,
const unsigned axis_parsed) {
template <class Comparator>
static void FindTopKElements(const Tensor* input, const TensorShape& input_shape, Tensor* values,
Tensor* indices, const TensorShape& output_shape, const unsigned k, bool sorted,
const unsigned axis_parsed, concurrency::ThreadPool* threadpool) {
// Cache some values that will be used in the implementation below
const int64_t rows = input_shape.SizeToDimension(static_cast<size_t>(axis_parsed));
const int64_t cols = input->Shape().Size() / rows;
auto input_map =
ConstEigenMatrixMapRowMajor<typename Comparator::DataType>(
static_cast<const typename Comparator::DataType*>(input->template Data<typename Comparator::DataType>()), rows, cols);
const auto* input_data = input->template Data<typename Comparator::DataType>();
// Use Eigen maps to allow indexing into the 2d tensors like Values_map(i,j)
// Use Eigen maps for convenient indexing into the 2d tensors like Values_map(i,j)
const int64_t reduced_cols = output_shape.SizeFromDimension(static_cast<size_t>(axis_parsed));
auto values_map = EigenMatrixMapRowMajor<typename Comparator::DataType>(
values->template MutableData<typename Comparator::DataType>(), rows, reduced_cols);
auto indices_map = EigenMatrixMapRowMajor<int64_t>(indices->template MutableData<int64_t>(), rows, reduced_cols);
auto* values_data = values->template MutableData<typename Comparator::DataType>();
auto* indices_data = indices->template MutableData<int64_t>();
auto values_map = EigenMatrixMapRowMajor<typename Comparator::DataType>(values_data, rows, reduced_cols);
auto indices_map = EigenMatrixMapRowMajor<int64_t>(indices_data, rows, reduced_cols);
// This is basically the number of elements within each of the "k" rows
const int64_t block_slice = reduced_cols / k;
const int64_t num_blocks = input_shape[axis_parsed];
const int64_t block_slice = reduced_cols / k;
for (int64_t i = 0; i < rows; ++i) {
for (int64_t j = 0; j < block_slice; ++j) {
// Since sorted == true, we will use a Heap to hold the top K values in sorted fashion
if (sorted) { // The optimizer will clean-up the redundant condition based on the template parameter 'sorted'
auto n_casted = static_cast<double>(num_blocks);
auto k_casted = static_cast<double>(k);
if ((n_casted + k_casted * log(k_casted)) < (n_casted * log(k_casted))) {
// Select first - O(n), then sort O(k * ln(k))
// Overall complexity = O (n + k * ln(k))
const auto& data_holder = select_top_k<Comparator>(input_map, i, num_blocks, block_slice, j, k, true);
for (int64_t l = 0; l < k; ++l) {
const auto& elem = data_holder[l];
auto col_index = l * block_slice + j;
values_map(i, col_index) = elem.first;
indices_map(i, col_index) = elem.second;
}
} else {
// Perform sorted selection by passing 'n' elements over a heap of size 'k'
// overall complexity = O (n * ln(k))
int64_t tp_threads = threadpool != nullptr ? threadpool->NumThreads() : 1;
int64_t num_threads = std::min(tp_threads, rows); // split on rows so can't have more threads than rows
// Build a min-heap/max-heap, the heap element is pair of (value, idx)
// The top of the heap is the smallest/largest value depending on whether it is a min-heap/max-heap
// This is a min-heap if largest == true, this is a max-heap if largest == false
priority_queue<pair<typename Comparator::DataType, int64_t>, vector<pair<typename Comparator::DataType, int64_t>>, Comparator> heap;
// rough attempt to make sure there's enough work for each thread. if there's insufficient work the usage of
// too many threads degrades performance.
// TODO: May want a different calculation for each branch below instead.
int64_t threads_needed = static_cast<int64_t>(std::floor(input_shape.Size() * k / (128 * 1024)));
num_threads = std::max(std::min(threads_needed, num_threads), static_cast<int64_t>(1));
// Maintain the size of heap to be less or equal to k, so the
// heap will hold the k largest/smallest values
for (int64_t l = 0; l < num_blocks; ++l) {
const auto value = input_map(i, l * block_slice + j);
// largest == true: insert into the min-heap if the size is < k or if the new
// element is greater than the min element in the min-heap
// from testing various batch sizes relative to k, the following appears to work well as a selector.
// tested with following combinations
// batch_size = [ 8, 16, 32, 64, 128, 256, 512, 1024, 2048 ]
// k = [ 1, 2, 4, 6, 8, 16, 24, 32, 48, 64, 128 ]
bool use_priority_queue = k != 1 && (k < 4 || (std::log2(k) / std::log2(num_blocks)) < 0.725);
// largest == false: insert into the min-heap if the size is < k or if the new
// element is lesser than the max element in the max-heap
if ((heap.size() < k) || (largest && value > heap.top().first) ||
(!largest && value < heap.top().first)) { // the optimizer will clean-up the redundant condition based
// on the template parameter 'largest'
heap.push({value, l});
}
if (heap.size() > k) {
heap.pop();
std::function<void(std::ptrdiff_t batch)> find_top_k;
if (k == 1) {
// just need to compare values and not indexes as the first instance of the best value is always selected
find_top_k =
[num_threads, rows, block_slice, num_blocks, input_data, cols, &values_map, &indices_map](std::ptrdiff_t batch) {
int64_t start_row = static_cast<int64_t>(batch * rows / num_threads);
int64_t end_row = static_cast<int64_t>((batch + 1) * rows / num_threads);
Comparator comparer(input_data);
for (int64_t i = start_row; i < end_row; ++i) {
auto row_offset = i * cols;
for (int64_t j = 0; j < block_slice; ++j) {
int64_t cur_idx = row_offset + j;
const auto* cur_value = input_data + cur_idx; // using pointer to data is faster than input_data[cur_idx]
auto best = *cur_value; // save best value so we only have one load in the CompareValueOnly call
int64_t top_idx = cur_idx;
for (int64_t l = 1; l < num_blocks; ++l) {
cur_value += block_slice;
if (comparer.CompareValueOnly(*cur_value, best)) {
best = *cur_value;
top_idx = cur_value - input_data;
}
}
values_map(i, j) = best;
// convert overall index to result index
// avoid '/' if possible for perf reasons
indices_map(i, j) = block_slice == 1 ? (top_idx - row_offset - j)
: (top_idx - row_offset - j) / block_slice;
}
}
// Extract these k elements and place them in the results placeholder
for (int64_t l = 0; l < k; ++l) {
const auto& elem = heap.top();
auto col_index = (k - l - 1) * block_slice + j;
values_map(i, col_index) = elem.first;
indices_map(i, col_index) = elem.second;
heap.pop();
};
} else if (use_priority_queue) {
find_top_k =
[num_threads, rows, block_slice, num_blocks, k, sorted,
input_data, cols, &values_map, &indices_map](std::ptrdiff_t batch) {
int64_t start_row = static_cast<int64_t>(batch * rows / num_threads);
int64_t end_row = static_cast<int64_t>((batch + 1) * rows / num_threads);
Comparator comparer(input_data);
// the heap is stored in indices_data. each iteration overwrites the old data when it adds the
// initial k values, so we don't need to clear it.
std::vector<int64_t> indices_data(k);
int64_t* indices = indices_data.data(); // raw pointer is slightly faster for HeapifyIthPosition
for (int64_t i = start_row; i < end_row; ++i) {
const auto row_offset = i * cols;
for (int64_t j = 0; j < block_slice; ++j) {
int64_t l = 0;
auto cur_idx = row_offset + j;
// add first k items starting from the bottom up
for (; l < k; ++l) {
indices[k - l - 1] = cur_idx;
HeapifyIthPosition(indices, k - l - 1, k, comparer);
cur_idx += block_slice;
}
// insert remainder if the next value would replace the top of the heap (current worst top k value)
// save top so we only have one load in the CompareValueOnly call
auto top = input_data[indices[0]];
for (; l < num_blocks; ++l) {
// we can compare value only. if the current value is equal to the top of the heap it won't
// replace it as the index will be higher.
if (comparer.CompareValueOnly(input_data[cur_idx], top)) {
indices[0] = cur_idx;
HeapifyIthPosition(indices, 0, k, comparer);
top = input_data[indices[0]];
}
cur_idx += block_slice;
}
if (sorted) {
// Extract these k elements and place them in the results placeholder
for (l = 0; l < k; ++l) {
auto idx = indices[0];
auto col_index = (k - l - 1) * block_slice + j;
values_map(i, col_index) = input_data[idx];
// convert overall index to result index. avoid '/' if possible for perf reasons
indices_map(i, col_index) = block_slice == 1 ? (idx - row_offset - j)
: (idx - row_offset - j) / block_slice;
// put the last value at the top of the heap to replace the removed one, and push it into
// place in a heap one smaller.
indices[0] = indices[k - l - 1];
HeapifyIthPosition(indices, 0, k - l - 1, comparer);
}
} else {
for (l = 0; l < k; ++l) {
int64_t idx = indices[l];
auto col_index = l * block_slice + j;
values_map(i, col_index) = input_data[idx];
// convert overall index to result index. avoid '/' if possible for perf reasons
indices_map(i, col_index) = block_slice == 1 ? (idx - row_offset - j)
: (idx - row_offset - j) / block_slice;
}
}
}
}
}
} else { // sorted == false
// The optimizer will clean-up the redundant condition based on the template parameter 'sorted'
};
} else {
find_top_k =
[num_threads, rows, block_slice, num_blocks, k, sorted,
input_data, cols,
&values_map, &indices_map](std::ptrdiff_t batch) {
int64_t start_row = static_cast<int64_t>(batch * rows / num_threads);
int64_t end_row = static_cast<int64_t>((batch + 1) * rows / num_threads);
// If the top K values are not required to be sorted, we use a more optimal selection algorithm
// Average - O(n). Worst - O(n * ln(n)) or O(n^2) depending on the implementation, where 'n' is the number of input
Comparator comparer(input_data);
const auto& data_holder = select_top_k<Comparator>(input_map, i, num_blocks, block_slice, j, k, false);
// we re-use a single data_holder for performance. avoids allocating memory on each iteration.
// the call to SelectTopK overwrites any existing data so we don't need to clear on each iteration.
std::vector<int64_t> data_holder(num_blocks);
// Insert the top 'k' (largest or smallest) elements into the final output buffers
for (int64_t l = 0; l < k; ++l) {
const auto& elem = data_holder[l];
auto col_index = l * block_slice + j;
values_map(i, col_index) = elem.first;
indices_map(i, col_index) = elem.second;
}
}
}
for (int64_t i = start_row; i < end_row; ++i) {
auto row_offset = i * cols;
for (int64_t j = 0; j < block_slice; ++j) {
SelectTopK<Comparator>(comparer, row_offset, num_blocks, block_slice, j, k, sorted, data_holder);
// Insert the top 'k' (largest or smallest) elements into the final output buffers
for (int64_t l = 0; l < k; ++l) {
int64_t idx = data_holder[l];
auto col_index = l * block_slice + j;
values_map(i, col_index) = input_data[idx];
// convert overall index to result index. avoid the cost of the '/' is possible
indices_map(i, col_index) = block_slice == 1 ? (idx - row_offset - j)
: (idx - row_offset - j) / block_slice;
}
}
}
};
}
if (num_threads <= 1) {
find_top_k(0);
} else {
// we want to re-use the storage variables in each lambda as much as possible to minimize allocations
// on each iteration, so the lambda does multiple rows. e.g. the data_holder and indices_data vectors.
// the alternative would be to use TryBatchParallelFor with the lambda doing one row.
threadpool->SimpleParallelFor(num_threads, find_top_k);
}
}
@ -204,22 +361,14 @@ static Status TopKImpl(OpKernelContext* p_op_kernel_context, const Tensor* input
return Status::OK();
}
if (sorted && largest) {
// extract sorted largest TopK elements
extract_top_k_elements<true, true, GreaterValueCmp<T>>(input, input_shape, values, indices, output_shape, k,
gsl::narrow_cast<unsigned>(axis_parsed));
} else if (sorted && !largest) {
// extract sorted smallest TopK elements
extract_top_k_elements<false, true, LesserValueCmp<T>>(input, input_shape, values, indices, output_shape, k,
gsl::narrow_cast<unsigned>(axis_parsed));
} else if (largest) {
// extract unsorted (order undefined) largest TopK elements
extract_top_k_elements<true, false, GreaterValueCmp<T>>(input, input_shape, values, indices, output_shape, k,
gsl::narrow_cast<unsigned>(axis_parsed));
auto* threadpool = p_op_kernel_context->GetOperatorThreadPool();
if (largest) {
FindTopKElements<GreaterValueCmp<T>>(input, input_shape, values, indices, output_shape, k, sorted,
gsl::narrow_cast<unsigned>(axis_parsed), threadpool);
} else {
// extract unsorted (order undefined) smallest TopK elements
extract_top_k_elements<false, false, LesserValueCmp<T>>(input, input_shape, values, indices, output_shape, k,
gsl::narrow_cast<unsigned>(axis_parsed));
FindTopKElements<LesserValueCmp<T>>(input, input_shape, values, indices, output_shape, k, sorted,
gsl::narrow_cast<unsigned>(axis_parsed), threadpool);
}
return Status::OK();

View file

@ -425,54 +425,17 @@ TEST(TopKOperator, Top1ExplicitAxisMultiDInputSmallestElements) {
top_1_explicit_axis_MultiD_input_smallest(11, 0); //unsorted
}
TEST(TopKOperator, SelectFirstSortNext) {
// in this test, we will select the top 5 elements first then sort the chosen 5 elements
// Select + Sort = O(n + k * ln(k)) = 50 + 5 * ln(5) = 58.047
// Sorted selection: O(n * ln(k)) = 50 * ln(5) = 80.47
// The algorithm used will be Select + Sort
std::vector<float> input_vals = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0,
11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0,
21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0,
31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0,
41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0};
std::vector<int64_t> input_dimensions = {50};
std::vector<float> expected_vals = {50.0f, 49.0f, 48.0f, 47.0f, 46.0f};
std::vector<int64_t> expected_indices = {49, 48, 47, 46, 45};
std::vector<int64_t> expected_dimensions = {5};
int64_t axis = 0;
RunTest(11, 5, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis); // largest values
}
TEST(TopKOperator, SelectFirstSortNextInt64) {
// in this test, we will select the top 5 elements first then sort the chosen 5 elements
// Select + Sort = O(n + k * ln(k)) = 50 + 5 * ln(5) = 58.047
// Sorted selection: O(n * ln(k)) = 50 * ln(5) = 80.47
// The algorithm used will be Select + Sort
std::vector<int64_t> input_vals = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
41, 42, 43, 44, 45, 46, 47, 48, 49, 50};
std::vector<int64_t> input_dimensions = {50};
std::vector<int64_t> expected_vals = {50, 49, 48, 47, 46};
std::vector<int64_t> expected_indices = {49, 48, 47, 46, 45};
std::vector<int64_t> expected_dimensions = {5};
int64_t axis = 0;
RunTest(11, 5, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis); // largest values
}
TEST(TopKOperator, SortedSelection) {
// in this test, we will use sorted selection (using heap)
// Select + Sort = O(n + k * ln(k)) = 10 + 5 * ln(5) = 18.04
// Sorted selection: O(n * ln(k)) = 10 * ln(5) = 16.09
// The algorithm used will be Sorted selection
std::vector<float> input_vals = {10.0f, 8.0f, 7.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 9.0f, 3.0};
std::vector<int64_t> input_dimensions = {10};
std::vector<float> expected_vals = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
std::vector<int64_t> expected_indices = {6, 7, 9, 3, 4};
std::vector<int64_t> expected_dimensions = {5};
int64_t axis = 0;
RunTest(11, 5, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis, 0); // smallest values
// test path where SelectTopK is used (select using std::nth_element)
// we use a custom path for n=1, and priority queue based implementation if
// bool use_priority_queue = k != 1 && (k < 4 || (std::log2(k) / std::log2(n)) < 0.725);
// so easiest way to test is for k to be 4 and n to be a little larger
TEST(TopKOperator, NthElement) {
std::vector<float> input_vals = {10.0f, 8.0f, 7.0f, 4.0f, 5.0f, 6.0f};
std::vector<int64_t> input_dimensions = {6};
std::vector<float> expected_vals = {10.0f, 8.0f, 7.0f, 6.0f};
std::vector<int64_t> expected_indices = {0, 1, 2, 5};
std::vector<int64_t> expected_dimensions = {4};
RunTest(11, 4, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
}
// test dimension in range (GridDim::maxThreadsPerBlock, GridDim::maxThreadsPerBlock * 2], ie. [257, 512]
@ -532,5 +495,66 @@ TEST(TopKOperator, BigArrayBigTopKSorted) {
RunTest(11, 9000, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, 0, 1, 1);
}
static void top_3_all_same(int opset_version, int64_t largest = 1) {
// whether it's largest or smallest we should pick the first instance/s of a number if there are multiple
std::vector<float> input_vals = {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f};
std::vector<int64_t> input_dimensions = {2, 4};
std::vector<float> expected_vals = {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f};
std::vector<int64_t> expected_indices = {0, 1, 2, 0, 1, 2};
std::vector<int64_t> expected_dimensions = {2, 3};
RunTest(opset_version, 3, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, -1, largest);
}
TEST(TopKOperator, Top3AllSame) {
top_3_all_same(10);
top_3_all_same(11);
top_3_all_same(10, 0); // smallest
top_3_explicit_axis(11, 0);
}
static void TestThreaded(int64_t k, int64_t n, int64_t batch_size) {
std::vector<float> input_vals(n * batch_size, 0.0f);
std::iota(input_vals.begin(), input_vals.end(), 0.0f);
std::vector<int64_t> input_dimensions = {n, batch_size};
std::vector<float> expected_vals(n * k, 0.0f);
std::vector<int64_t> expected_indices(n * k, 0);
std::vector<int64_t> expected_dimensions = {n, k};
for (int64_t i = 0; i < n; ++i) {
auto begin_batch_output = expected_vals.begin() + i * k;
std::iota(begin_batch_output, begin_batch_output + k, static_cast<float>(((i + 1) * batch_size) - k));
std::reverse(begin_batch_output, begin_batch_output + k);
// indices are within the axis so don't need adjusting by the batch number
auto begin_indices_output = expected_indices.begin() + i * k;
std::iota(begin_indices_output, begin_indices_output + k, batch_size - k);
std::reverse(begin_indices_output, begin_indices_output + k);
}
RunTest(11, k, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
}
// create input of 2x1000 and select 200 so 2 threads are needed based on there being 2 rows
// and sufficient items to process given this calculation:
// int64_t threads_needed = static_cast<int64_t>(std::floor(input_shape.Size() * k / (128 * 1024)));
TEST(TopKOperator, PriorityQueueThreaded) {
const int64_t k = 200;
const int64_t n = 2;
const int64_t batch_size = 1000;
TestThreaded(k, n, batch_size);
}
// create input of 2x500 and select 400 so 2 threads are needed based on there being 2 rows
// and sufficient items to process given this calculation:
// int64_t threads_needed = static_cast<int64_t>(std::floor(input_shape.Size() * k / (128 * 1024)));
TEST(TopKOperator, SelectTopKThreaded) {
const int64_t k = 400;
const int64_t n = 2;
const int64_t batch_size = 500;
TestThreaded(k, n, batch_size);
}
} // namespace test
} // namespace onnxruntime