mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Fix topk type handling to accommodate more types. (#2842)
* Fix topk type handling to accommodate more types + add unit test for int64_t. * Fix Linux build
This commit is contained in:
parent
47e27ec9a1
commit
3853ddf9c7
2 changed files with 51 additions and 25 deletions
|
|
@ -30,6 +30,7 @@ namespace onnxruntime {
|
|||
|
||||
template <typename T>
|
||||
struct GreaterValueCmp {
|
||||
using DataType = T;
|
||||
bool operator()(const pair<T, int64_t>& lhs, const pair<T, int64_t>& rhs) {
|
||||
return (lhs.first > rhs.first ||
|
||||
// when values are equal, we want lhs to get higher "priority"
|
||||
|
|
@ -40,6 +41,7 @@ struct GreaterValueCmp {
|
|||
|
||||
template <typename T>
|
||||
struct LesserValueCmp {
|
||||
using DataType = T;
|
||||
bool operator()(const pair<T, int64_t>& lhs, const pair<T, int64_t>& rhs) {
|
||||
return (lhs.first < rhs.first ||
|
||||
// when values are equal, we want lhs to get higher "priority"
|
||||
|
|
@ -52,11 +54,12 @@ struct LesserValueCmp {
|
|||
|
||||
// Selects the top k elements (largest or smallest based on template parameter)
|
||||
template <class Comparator>
|
||||
static vector<pair<float, int64_t>> select_top_k(const ConstEigenMatrixMapRowMajor<float>& raw_data, int64_t row_num, int64_t num_blocks,
|
||||
int64_t block_slice, int64_t inter_block_offset, const unsigned k,
|
||||
bool sort_top_k) {
|
||||
static vector<pair<typename Comparator::DataType, int64_t>> select_top_k(
|
||||
const ConstEigenMatrixMapRowMajor<typename Comparator::DataType>& raw_data, int64_t row_num, int64_t num_blocks,
|
||||
int64_t block_slice, int64_t inter_block_offset, const unsigned k,
|
||||
bool sort_top_k) {
|
||||
// create a data holder and insert elements
|
||||
vector<pair<float, int64_t>> data_holder;
|
||||
vector<pair<typename Comparator::DataType, int64_t>> data_holder;
|
||||
data_holder.reserve(num_blocks);
|
||||
for (int64_t l = 0; l < num_blocks; ++l) {
|
||||
data_holder.push_back({raw_data(row_num, l * block_slice + inter_block_offset), l});
|
||||
|
|
@ -85,11 +88,13 @@ static void extract_top_k_elements(const Tensor* input, const TensorShape& input
|
|||
const int64_t rows = input_shape.SizeToDimension(static_cast<size_t>(axis_parsed));
|
||||
const int64_t cols = input->Shape().Size() / rows;
|
||||
auto input_map =
|
||||
ConstEigenMatrixMapRowMajor<float>(static_cast<const float*>(input->template Data<float>()), rows, cols);
|
||||
ConstEigenMatrixMapRowMajor<typename Comparator::DataType>(
|
||||
static_cast<const typename Comparator::DataType*>(input->template Data<typename Comparator::DataType>()), rows, cols);
|
||||
|
||||
// Use Eigen maps to allow indexing into the 2d tensors like Values_map(i,j)
|
||||
const int64_t reduced_cols = output_shape.SizeFromDimension(static_cast<size_t>(axis_parsed));
|
||||
auto values_map = EigenMatrixMapRowMajor<float>(values->template MutableData<float>(), rows, reduced_cols);
|
||||
auto values_map = EigenMatrixMapRowMajor<typename Comparator::DataType>(
|
||||
values->template MutableData<typename Comparator::DataType>(), rows, reduced_cols);
|
||||
auto indices_map = EigenMatrixMapRowMajor<int64_t>(indices->template MutableData<int64_t>(), rows, reduced_cols);
|
||||
|
||||
// This is basically the number of elements within each of the "k" rows
|
||||
|
|
@ -119,7 +124,7 @@ static void extract_top_k_elements(const Tensor* input, const TensorShape& input
|
|||
// Build a min-heap/max-heap, the heap element is pair of (value, idx)
|
||||
// The top of the heap is the smallest/largest value depending on whether it is a min-heap/max-heap
|
||||
// This is a min-heap if largest == true, this is a max-heap if largest == false
|
||||
priority_queue<pair<float, int64_t>, vector<pair<float, int64_t>>, Comparator> heap;
|
||||
priority_queue<pair<typename Comparator::DataType, int64_t>, vector<pair<typename Comparator::DataType, int64_t>>, Comparator> heap;
|
||||
|
||||
// Maintain the size of heap to be less or equal to k, so the
|
||||
// heap will hold the k largest/smallest values
|
||||
|
|
@ -169,6 +174,7 @@ static void extract_top_k_elements(const Tensor* input, const TensorShape& input
|
|||
}
|
||||
|
||||
// Wrapper over core TopK implementation
|
||||
template <typename T>
|
||||
static Status TopKImpl(OpKernelContext* p_op_kernel_context, const Tensor* input, const int axis, const unsigned k,
|
||||
bool largest = true, bool sorted = true) {
|
||||
const TensorShape& input_shape = input->Shape();
|
||||
|
|
@ -200,20 +206,20 @@ static Status TopKImpl(OpKernelContext* p_op_kernel_context, const Tensor* input
|
|||
|
||||
if (sorted && largest) {
|
||||
// extract sorted largest TopK elements
|
||||
extract_top_k_elements<true, true, GreaterValueCmp<float>>(input, input_shape, values, indices, output_shape, k,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed));
|
||||
extract_top_k_elements<true, true, GreaterValueCmp<T>>(input, input_shape, values, indices, output_shape, k,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed));
|
||||
} else if (sorted && !largest) {
|
||||
// extract sorted smallest TopK elements
|
||||
extract_top_k_elements<false, true, LesserValueCmp<float>>(input, input_shape, values, indices, output_shape, k,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed));
|
||||
extract_top_k_elements<false, true, LesserValueCmp<T>>(input, input_shape, values, indices, output_shape, k,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed));
|
||||
} else if (largest) {
|
||||
// extract unsorted (order undefined) largest TopK elements
|
||||
extract_top_k_elements<true, false, GreaterValueCmp<float>>(input, input_shape, values, indices, output_shape, k,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed));
|
||||
extract_top_k_elements<true, false, GreaterValueCmp<T>>(input, input_shape, values, indices, output_shape, k,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed));
|
||||
} else {
|
||||
// extract unsorted (order undefined) smallest TopK elements
|
||||
extract_top_k_elements<false, false, LesserValueCmp<float>>(input, input_shape, values, indices, output_shape, k,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed));
|
||||
extract_top_k_elements<false, false, LesserValueCmp<T>>(input, input_shape, values, indices, output_shape, k,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
@ -240,7 +246,7 @@ Status TopK<9, float>::Compute(OpKernelContext* p_op_kernel_context) const {
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "input count mismatch, expected 1 input - the tensor to be processed");
|
||||
}
|
||||
|
||||
return TopKImpl(p_op_kernel_context, X, axis_, k_);
|
||||
return TopKImpl<float>(p_op_kernel_context, X, axis_, k_);
|
||||
}
|
||||
|
||||
// Opset ver - 10
|
||||
|
|
@ -272,7 +278,7 @@ Status TopK<10, float>::Compute(OpKernelContext* p_op_kernel_context) const {
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "value of k must not be negative");
|
||||
}
|
||||
|
||||
return TopKImpl(p_op_kernel_context, X, axis_, gsl::narrow_cast<unsigned>(parsed_input_k));
|
||||
return TopKImpl<float>(p_op_kernel_context, X, axis_, gsl::narrow_cast<unsigned>(parsed_input_k));
|
||||
}
|
||||
|
||||
// Opset ver - 11
|
||||
|
|
@ -302,6 +308,7 @@ TopK<11, int64_t>::TopK(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel
|
|||
TopkOpset11ConstructorCommon(op_kernel_info, axis_, largest_, sorted_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Status ComputeImplOpset11(OpKernelContext* p_op_kernel_context, int axis, bool is_largest, bool is_sorted) {
|
||||
const auto* X = p_op_kernel_context->Input<Tensor>(0);
|
||||
const auto* Y = p_op_kernel_context->Input<Tensor>(1);
|
||||
|
|
@ -321,18 +328,18 @@ static Status ComputeImplOpset11(OpKernelContext* p_op_kernel_context, int axis,
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "value of k must not be negative");
|
||||
}
|
||||
|
||||
return TopKImpl(p_op_kernel_context, X, axis, gsl::narrow_cast<unsigned>(parsed_input_k), is_largest, is_sorted);
|
||||
return TopKImpl<T>(p_op_kernel_context, X, axis, gsl::narrow_cast<unsigned>(parsed_input_k), is_largest, is_sorted);
|
||||
}
|
||||
|
||||
// Opset ver - 11
|
||||
template <>
|
||||
Status TopK<11, float>::Compute(OpKernelContext* p_op_kernel_context) const {
|
||||
return ComputeImplOpset11(p_op_kernel_context, axis_, largest_, sorted_);
|
||||
return ComputeImplOpset11<float>(p_op_kernel_context, axis_, largest_, sorted_);
|
||||
}
|
||||
|
||||
template <>
|
||||
Status TopK<11, int64_t>::Compute(OpKernelContext* p_op_kernel_context) const {
|
||||
return ComputeImplOpset11(p_op_kernel_context, axis_, largest_, sorted_);
|
||||
return ComputeImplOpset11<int64_t>(p_op_kernel_context, axis_, largest_, sorted_);
|
||||
}
|
||||
|
||||
// Register necessary kernels
|
||||
|
|
|
|||
|
|
@ -8,11 +8,12 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
template <typename T = float>
|
||||
static void RunTest(int op_set,
|
||||
int64_t k,
|
||||
const std::vector<float>& input_vals,
|
||||
const std::vector<T>& input_vals,
|
||||
const std::vector<int64_t>& input_dimensions,
|
||||
const std::vector<float>& expected_vals,
|
||||
const std::vector<T>& expected_vals,
|
||||
const std::vector<int64_t>& expected_indices,
|
||||
const std::vector<int64_t>& expected_dimensions,
|
||||
bool is_tensorrt_supported = true,
|
||||
|
|
@ -34,16 +35,16 @@ static void RunTest(int op_set,
|
|||
test.AddAttribute("sorted", sorted);
|
||||
|
||||
// Inputs
|
||||
test.AddInput<float>("X", input_dimensions, input_vals);
|
||||
test.AddInput<T>("X", input_dimensions, input_vals);
|
||||
if (op_set >= 10)
|
||||
test.AddInput<int64_t>("K", {1}, {k});
|
||||
|
||||
// Outputs
|
||||
if (sorted == 1) {
|
||||
test.AddOutput<float>("Values", expected_dimensions, expected_vals);
|
||||
test.AddOutput<T>("Values", expected_dimensions, expected_vals);
|
||||
test.AddOutput<int64_t>("Indices", expected_dimensions, expected_indices);
|
||||
} else {
|
||||
test.AddOutput<float>("Values", expected_dimensions, expected_vals, true);
|
||||
test.AddOutput<T>("Values", expected_dimensions, expected_vals, true);
|
||||
test.AddOutput<int64_t>("Indices", expected_dimensions, expected_indices, true);
|
||||
}
|
||||
|
||||
|
|
@ -442,6 +443,24 @@ TEST(TopKOperator, SelectFirstSortNext) {
|
|||
RunTest(11, 5, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis); // largest values
|
||||
}
|
||||
|
||||
TEST(TopKOperator, SelectFirstSortNextInt64) {
|
||||
// in this test, we will select the top 5 elements first then sort the chosen 5 elements
|
||||
// Select + Sort = O(n + k * ln(k)) = 50 + 5 * ln(5) = 58.047
|
||||
// Sorted selection: O(n * ln(k)) = 50 * ln(5) = 80.47
|
||||
// The algorithm used will be Select + Sort
|
||||
std::vector<int64_t> input_vals = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
|
||||
11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
|
||||
21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
|
||||
31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
|
||||
41, 42, 43, 44, 45, 46, 47, 48, 49, 50};
|
||||
std::vector<int64_t> input_dimensions = {50};
|
||||
std::vector<int64_t> expected_vals = {50, 49, 48, 47, 46};
|
||||
std::vector<int64_t> expected_indices = {49, 48, 47, 46, 45};
|
||||
std::vector<int64_t> expected_dimensions = {5};
|
||||
int64_t axis = 0;
|
||||
RunTest(11, 5, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis); // largest values
|
||||
}
|
||||
|
||||
TEST(TopKOperator, SortedSelection) {
|
||||
// in this test, we will use sorted selection (using heap)
|
||||
// Select + Sort = O(n + k * ln(k)) = 10 + 5 * ln(5) = 18.04
|
||||
|
|
|
|||
Loading…
Reference in a new issue