diff --git a/onnxruntime/core/providers/cpu/math/top_k.cc b/onnxruntime/core/providers/cpu/math/top_k.cc index 4eec811d85..422c5b7616 100644 --- a/onnxruntime/core/providers/cpu/math/top_k.cc +++ b/onnxruntime/core/providers/cpu/math/top_k.cc @@ -30,6 +30,7 @@ namespace onnxruntime { template struct GreaterValueCmp { + using DataType = T; bool operator()(const pair& lhs, const pair& rhs) { return (lhs.first > rhs.first || // when values are equal, we want lhs to get higher "priority" @@ -40,6 +41,7 @@ struct GreaterValueCmp { template struct LesserValueCmp { + using DataType = T; bool operator()(const pair& lhs, const pair& rhs) { return (lhs.first < rhs.first || // when values are equal, we want lhs to get higher "priority" @@ -52,11 +54,12 @@ struct LesserValueCmp { // Selects the top k elements (largest or smallest based on template parameter) template -static vector> select_top_k(const ConstEigenMatrixMapRowMajor& raw_data, int64_t row_num, int64_t num_blocks, - int64_t block_slice, int64_t inter_block_offset, const unsigned k, - bool sort_top_k) { +static vector> select_top_k( + const ConstEigenMatrixMapRowMajor& raw_data, int64_t row_num, int64_t num_blocks, + int64_t block_slice, int64_t inter_block_offset, const unsigned k, + bool sort_top_k) { // create a data holder and insert elements - vector> data_holder; + vector> data_holder; data_holder.reserve(num_blocks); for (int64_t l = 0; l < num_blocks; ++l) { data_holder.push_back({raw_data(row_num, l * block_slice + inter_block_offset), l}); @@ -85,11 +88,13 @@ static void extract_top_k_elements(const Tensor* input, const TensorShape& input const int64_t rows = input_shape.SizeToDimension(static_cast(axis_parsed)); const int64_t cols = input->Shape().Size() / rows; auto input_map = - ConstEigenMatrixMapRowMajor(static_cast(input->template Data()), rows, cols); + ConstEigenMatrixMapRowMajor( + static_cast(input->template Data()), rows, cols); // Use Eigen maps to allow indexing into the 2d tensors like Values_map(i,j) const int64_t reduced_cols = output_shape.SizeFromDimension(static_cast(axis_parsed)); - auto values_map = EigenMatrixMapRowMajor(values->template MutableData(), rows, reduced_cols); + auto values_map = EigenMatrixMapRowMajor( + values->template MutableData(), rows, reduced_cols); auto indices_map = EigenMatrixMapRowMajor(indices->template MutableData(), rows, reduced_cols); // This is basically the number of elements within each of the "k" rows @@ -119,7 +124,7 @@ static void extract_top_k_elements(const Tensor* input, const TensorShape& input // Build a min-heap/max-heap, the heap element is pair of (value, idx) // The top of the heap is the smallest/largest value depending on whether it is a min-heap/max-heap // This is a min-heap if largest == true, this is a max-heap if largest == false - priority_queue, vector>, Comparator> heap; + priority_queue, vector>, Comparator> heap; // Maintain the size of heap to be less or equal to k, so the // heap will hold the k largest/smallest values @@ -169,6 +174,7 @@ static void extract_top_k_elements(const Tensor* input, const TensorShape& input } // Wrapper over core TopK implementation +template static Status TopKImpl(OpKernelContext* p_op_kernel_context, const Tensor* input, const int axis, const unsigned k, bool largest = true, bool sorted = true) { const TensorShape& input_shape = input->Shape(); @@ -200,20 +206,20 @@ static Status TopKImpl(OpKernelContext* p_op_kernel_context, const Tensor* input if (sorted && largest) { // extract sorted largest TopK elements - extract_top_k_elements>(input, input_shape, values, indices, output_shape, k, - gsl::narrow_cast(axis_parsed)); + extract_top_k_elements>(input, input_shape, values, indices, output_shape, k, + gsl::narrow_cast(axis_parsed)); } else if (sorted && !largest) { // extract sorted smallest TopK elements - extract_top_k_elements>(input, input_shape, values, indices, output_shape, k, - gsl::narrow_cast(axis_parsed)); + extract_top_k_elements>(input, input_shape, values, indices, output_shape, k, + gsl::narrow_cast(axis_parsed)); } else if (largest) { // extract unsorted (order undefined) largest TopK elements - extract_top_k_elements>(input, input_shape, values, indices, output_shape, k, - gsl::narrow_cast(axis_parsed)); + extract_top_k_elements>(input, input_shape, values, indices, output_shape, k, + gsl::narrow_cast(axis_parsed)); } else { // extract unsorted (order undefined) smallest TopK elements - extract_top_k_elements>(input, input_shape, values, indices, output_shape, k, - gsl::narrow_cast(axis_parsed)); + extract_top_k_elements>(input, input_shape, values, indices, output_shape, k, + gsl::narrow_cast(axis_parsed)); } return Status::OK(); @@ -240,7 +246,7 @@ Status TopK<9, float>::Compute(OpKernelContext* p_op_kernel_context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "input count mismatch, expected 1 input - the tensor to be processed"); } - return TopKImpl(p_op_kernel_context, X, axis_, k_); + return TopKImpl(p_op_kernel_context, X, axis_, k_); } // Opset ver - 10 @@ -272,7 +278,7 @@ Status TopK<10, float>::Compute(OpKernelContext* p_op_kernel_context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "value of k must not be negative"); } - return TopKImpl(p_op_kernel_context, X, axis_, gsl::narrow_cast(parsed_input_k)); + return TopKImpl(p_op_kernel_context, X, axis_, gsl::narrow_cast(parsed_input_k)); } // Opset ver - 11 @@ -302,6 +308,7 @@ TopK<11, int64_t>::TopK(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel TopkOpset11ConstructorCommon(op_kernel_info, axis_, largest_, sorted_); } +template static Status ComputeImplOpset11(OpKernelContext* p_op_kernel_context, int axis, bool is_largest, bool is_sorted) { const auto* X = p_op_kernel_context->Input(0); const auto* Y = p_op_kernel_context->Input(1); @@ -321,18 +328,18 @@ static Status ComputeImplOpset11(OpKernelContext* p_op_kernel_context, int axis, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "value of k must not be negative"); } - return TopKImpl(p_op_kernel_context, X, axis, gsl::narrow_cast(parsed_input_k), is_largest, is_sorted); + return TopKImpl(p_op_kernel_context, X, axis, gsl::narrow_cast(parsed_input_k), is_largest, is_sorted); } // Opset ver - 11 template <> Status TopK<11, float>::Compute(OpKernelContext* p_op_kernel_context) const { - return ComputeImplOpset11(p_op_kernel_context, axis_, largest_, sorted_); + return ComputeImplOpset11(p_op_kernel_context, axis_, largest_, sorted_); } template <> Status TopK<11, int64_t>::Compute(OpKernelContext* p_op_kernel_context) const { - return ComputeImplOpset11(p_op_kernel_context, axis_, largest_, sorted_); + return ComputeImplOpset11(p_op_kernel_context, axis_, largest_, sorted_); } // Register necessary kernels diff --git a/onnxruntime/test/providers/cpu/math/topk_op_test.cc b/onnxruntime/test/providers/cpu/math/topk_op_test.cc index 26e49dab6c..32d1f24f24 100644 --- a/onnxruntime/test/providers/cpu/math/topk_op_test.cc +++ b/onnxruntime/test/providers/cpu/math/topk_op_test.cc @@ -8,11 +8,12 @@ namespace onnxruntime { namespace test { +template static void RunTest(int op_set, int64_t k, - const std::vector& input_vals, + const std::vector& input_vals, const std::vector& input_dimensions, - const std::vector& expected_vals, + const std::vector& expected_vals, const std::vector& expected_indices, const std::vector& expected_dimensions, bool is_tensorrt_supported = true, @@ -34,16 +35,16 @@ static void RunTest(int op_set, test.AddAttribute("sorted", sorted); // Inputs - test.AddInput("X", input_dimensions, input_vals); + test.AddInput("X", input_dimensions, input_vals); if (op_set >= 10) test.AddInput("K", {1}, {k}); // Outputs if (sorted == 1) { - test.AddOutput("Values", expected_dimensions, expected_vals); + test.AddOutput("Values", expected_dimensions, expected_vals); test.AddOutput("Indices", expected_dimensions, expected_indices); } else { - test.AddOutput("Values", expected_dimensions, expected_vals, true); + test.AddOutput("Values", expected_dimensions, expected_vals, true); test.AddOutput("Indices", expected_dimensions, expected_indices, true); } @@ -442,6 +443,24 @@ TEST(TopKOperator, SelectFirstSortNext) { RunTest(11, 5, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis); // largest values } +TEST(TopKOperator, SelectFirstSortNextInt64) { + // in this test, we will select the top 5 elements first then sort the chosen 5 elements + // Select + Sort = O(n + k * ln(k)) = 50 + 5 * ln(5) = 58.047 + // Sorted selection: O(n * ln(k)) = 50 * ln(5) = 80.47 + // The algorithm used will be Select + Sort + std::vector input_vals = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + 41, 42, 43, 44, 45, 46, 47, 48, 49, 50}; + std::vector input_dimensions = {50}; + std::vector expected_vals = {50, 49, 48, 47, 46}; + std::vector expected_indices = {49, 48, 47, 46, 45}; + std::vector expected_dimensions = {5}; + int64_t axis = 0; + RunTest(11, 5, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis); // largest values +} + TEST(TopKOperator, SortedSelection) { // in this test, we will use sorted selection (using heap) // Select + Sort = O(n + k * ln(k)) = 10 + 5 * ln(5) = 18.04