From d936751aada2d1d66c902a8c37af413aba7d18a2 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Tue, 14 Jun 2022 23:42:12 +0200 Subject: [PATCH] QlinearConv threading adjustments (#11228) * Reserve the first core for the main thread Currently in "auto affinity" mode the worker threads are affinized to cores 0..(N-1), leaving the very last core for the main thread. This patch preserves core #0 for the main thread, and affinizes the worker threads to cores 1..N. * Avoid unneeded spin_pause in thread pool's worker threads Remove unneeded PAUSE instruction (0.1-0.2 usec latency) after a worker thread finds a task to execute. * MLAS/x86: optimize QLinearConv on hybrid CPUs Existing 4x task granularity for task partitioning on hybrid CPUs is not sufficient to compensate the difference of VNNI instructions throughput between performance and efficient cores. This patch... * Increases granularity for QLinearConv by 2x, to have 2x more tasks with 2x smaller output count * Limits QLinearConv task count from above, to avoid output count per task getting smaller than kernel's capability * Remove hardcoded task count for QLineConv as it limited scaling on 16+ cores CPUs * MLAS/x86: optimize QLinearConv on hybrid CPUs Existing 4x task granularity for task partitioning on hybrid CPUs is not sufficient to compensate the difference of VNNI instructions throughput between performance and efficient cores. This patch... * Increases granularity for QLinearConv by 2x, to have 2x more tasks with 2x smaller output count * Limits QLinearConv task count from above, to avoid output count per task getting smaller than kernel's capability * Remove hardcoded task count for QLineConv as it limited scaling on 16+ cores CP * Addressing comments * combining x86 ARM branches in qlinearconv threaded job partition * revert first core assignment Co-authored-by: Saurabh Co-authored-by: Chen Fu --- .../platform/EigenNonBlockingThreadPool.h | 11 +- onnxruntime/core/common/threadpool.cc | 7 + onnxruntime/core/mlas/inc/mlas.h | 49 ++++++ onnxruntime/core/mlas/lib/convsym.cpp | 18 +++ onnxruntime/core/mlas/lib/mlasi.h | 1 - .../mlas/lib/power/qgemm_kernel_power10.cpp | 1 + onnxruntime/core/mlas/lib/qdwconv.cpp | 4 + onnxruntime/core/mlas/lib/qgemm.cpp | 19 +++ onnxruntime/core/mlas/lib/qgemm.h | 1 + .../core/mlas/lib/qgemm_kernel_avx2.cpp | 2 + .../core/mlas/lib/qgemm_kernel_default.cpp | 1 + .../core/mlas/lib/qgemm_kernel_neon.cpp | 2 + .../core/mlas/lib/qgemm_kernel_sdot.cpp | 1 + .../core/mlas/lib/qgemm_kernel_sse.cpp | 1 + .../core/mlas/lib/qgemm_kernel_sse41.cpp | 1 + .../core/mlas/lib/qgemm_kernel_udot.cpp | 1 + .../core/mlas/lib/qgemm_kernel_wasmsimd.cpp | 1 + onnxruntime/core/platform/windows/env.cc | 3 +- .../providers/cpu/quantization/qlinearconv.cc | 140 +++++++++++------- 19 files changed, 211 insertions(+), 53 deletions(-) diff --git a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h index 75eb86c9a0..90c4329202 100644 --- a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h +++ b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h @@ -1020,6 +1020,13 @@ void EndParallelSection(ThreadPoolParallelSection &ps) override { void InitializePreferredWorkers(std::vector &preferred_workers) { static std::atomic next_worker; + + // preferred_workers[0] isn't supposed to be used, so initializng it with -1 to: + // a) fault if inapropriately accessed + // b) avoid wasting next_worker value + if (preferred_workers.size() == 0) + preferred_workers.push_back(-1); + // preferred_workers maps from a par_idx to a q_idx, hence we // initialize slots in the range [0,num_threads_] while (preferred_workers.size() <= num_threads_) { @@ -1465,12 +1472,14 @@ int CurrentThreadId() const final { Task t = q.PopFront(); if (!t) { // Spin waiting for work. - for (int i = 0; i < spin_count && !t && !done_; i++) { + for (int i = 0; i < spin_count && !done_; i++) { if (((i+1)%steal_count == 0)) { t = Steal(StealAttemptKind::TRY_ONE); } else { t = q.PopFront(); } + if (t) break; + onnxruntime::concurrency::SpinPause(); } diff --git a/onnxruntime/core/common/threadpool.cc b/onnxruntime/core/common/threadpool.cc index f3714dbca6..ad12198091 100644 --- a/onnxruntime/core/common/threadpool.cc +++ b/onnxruntime/core/common/threadpool.cc @@ -377,6 +377,13 @@ ThreadPool::ThreadPool(Env* env, assert(degree_of_parallelism >= 1); if (degree_of_parallelism >= 2) { int threads_to_create = degree_of_parallelism - 1; + + if (!thread_options_.affinity.empty()) { + // Remove first affinity element as designated for the caller thread + thread_options_.affinity.erase(thread_options_.affinity.begin()); + assert(thread_options_.affinity.size() >= size_t(threads_to_create)); + } + extended_eigen_threadpool_ = std::make_unique >(name, threads_to_create, diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 9d8cc96153..c713b5547e 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -829,6 +829,55 @@ MlasConvSymFixupInputZeroPoint( bool InputIsSigned ); +// +// Convolution operators (or maybe others in the future) need to do their +// own job partition. Since filters (right hand side B matrix) is usually +// small in size, activations are divided horizontally. We need to provide +// kernel stride units to facilitate the divide. +// + +int32_t +MlasConvSymGetKernelOutputCount( + bool InputIsSigned + ); + +int32_t +MlasConvSymDepthwiseGetKernelOutputCnt( + bool InputIsSigned + ); + +/** + * @brief Returns the stride M of depthwise conv kernel + * + * Most optimized path is Symmetric conv. See + * MlasConvSymDepthwiseGetKernelOutputCnt(bool) + * + * These kernels are implemented in qdwconv.cpp using + * intrincic, all of them with stride val 1. We use + * a slightly bigger value to improve cache reuse. + * + * This needs to be changed if we optimize depthwise + * kernels. + * + * @return +*/ +inline +int32_t +MlasConvDepthwiseGetKernelOutputCnt() +{ + return 4; +} + +int32_t +MlasSymmQgemmGetKernelOutputCnt(); + +int32_t +MlasQgemmGetKernelOutputCnt( + bool AIsSigned, + bool BIsSigned + ); + + struct MLAS_CONV_SYM_PARAMS { const void* InputDirect; const void* const* InputIndirection; diff --git a/onnxruntime/core/mlas/lib/convsym.cpp b/onnxruntime/core/mlas/lib/convsym.cpp index a667fe4048..5f8be3580b 100644 --- a/onnxruntime/core/mlas/lib/convsym.cpp +++ b/onnxruntime/core/mlas/lib/convsym.cpp @@ -470,6 +470,24 @@ MlasConvSymFixupInputZeroPoint( return zero_point_value; } +int32_t +MlasConvSymGetKernelOutputCount( + bool InputIsSigned + ) +{ + const MLAS_CONV_SYM_DISPATCH* ConvSymDispatch = GetConvSymDispatch(InputIsSigned); + return ConvSymDispatch->KernelOutputCount; +} + +int32_t +MlasConvSymDepthwiseGetKernelOutputCnt( + bool InputIsSigned + ) +{ + const MLAS_CONV_SYM_DISPATCH* ConvSymDispatch = GetConvSymDispatch(InputIsSigned); + return ConvSymDispatch->KernelDepthwiseOutputCount; +} + void MlasConvSym( diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index df00121c36..da157458ed 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -737,7 +737,6 @@ extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAvx2; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSdot; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd; diff --git a/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp b/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp index 7b6f2db620..633349e800 100644 --- a/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp +++ b/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp @@ -1190,4 +1190,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchPOWER10 = { MlasGemmQuantCopyPackB, MLAS_GEMM_QUANT_KERNEL_POWER10::PackedK, MLAS_GEMM_QUANT_KERNEL_POWER10::PackedStrides.K, + 8 // Kernel M stride }; diff --git a/onnxruntime/core/mlas/lib/qdwconv.cpp b/onnxruntime/core/mlas/lib/qdwconv.cpp index 921addab2c..924009ab5c 100644 --- a/onnxruntime/core/mlas/lib/qdwconv.cpp +++ b/onnxruntime/core/mlas/lib/qdwconv.cpp @@ -30,6 +30,10 @@ MlasConvDepthwiseKernel( size_t KernelSize ) { + // + // TODO Modify MlasConvDepthwiseGetKernelOutputCnt() function if this kernel + // is further optimized. + // #if defined(MLAS_SSE2_INTRINSICS) const __m128i ZeroVector = _mm_setzero_si128(); const __m128i InputZeroPointVector = _mm_set1_epi16(InputZeroPoint); diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index 8cc9d695f9..859fcd049a 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -109,6 +109,16 @@ Return Value: } +int32_t +MlasQgemmGetKernelOutputCnt( + bool AIsSigned, + bool BIsSigned + ) +{ + const auto* dispatch = MlasGemmQuantGetDispatch(AIsSigned, BIsSigned); + return int32_t(dispatch->StrideM); +} + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) // VC++ suggests we can attempt to make 'MlasBitsOfFp32' constexpr, but it is not valid. @@ -192,6 +202,15 @@ MlasGemmBatch( }); } + +int32_t +MlasSymmQgemmGetKernelOutputCnt() +{ + const MLAS_SYMM_QGEMM_DISPATCH* dispatch = GetMlasPlatform().SymmQgemmDispatch; + return int32_t(dispatch->StrideM); +} + + void MLASCALL MlasSymmQgemmBatch( diff --git a/onnxruntime/core/mlas/lib/qgemm.h b/onnxruntime/core/mlas/lib/qgemm.h index fabd791010..8a1856eb71 100644 --- a/onnxruntime/core/mlas/lib/qgemm.h +++ b/onnxruntime/core/mlas/lib/qgemm.h @@ -799,6 +799,7 @@ struct MLAS_GEMM_QUANT_DISPATCH { MLAS_GEMM_QUANT_COPY_PACKB_ROUTINE* CopyPackBRoutine; size_t PackedK; size_t PackedStrideK; + size_t StrideM; }; struct MLAS_SYMM_QGEMM_DISPATCH { diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp index 810a25bf8d..deec324d01 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp @@ -189,6 +189,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAvx2 = { MlasGemmQuantCopyPackB, MLAS_GEMM_U8S8_KERNEL_AVX2::PackedK, MLAS_GEMM_U8S8_KERNEL_AVX2::PackedStrides.K, + 6 // assembly kernel M stride }; struct MLAS_GEMM_U8U8_KERNEL_AVX2 @@ -270,4 +271,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2 = { MlasGemmQuantCopyPackB, MLAS_GEMM_U8U8_KERNEL_AVX2::PackedK, MLAS_GEMM_U8U8_KERNEL_AVX2::PackedStrides.K, + 6 // assembly kernel M stride }; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_default.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_default.cpp index eb8ef4e274..8f4baaa0ff 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_default.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_default.cpp @@ -220,4 +220,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmQuantDispatchDefault = { nullptr, MLAS_GEMM_QUANT_KERNEL_DEFAULT::PackedK, 0, + MLAS_GEMM_QUANT_KERNEL_DEFAULT::Strides.M }; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp index 0b747bc7cc..50e23a0251 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp @@ -583,6 +583,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchNeon = { MlasGemmQuantCopyPackB, MLAS_GEMM_U8X8_KERNEL_NEON::PackedK, MLAS_GEMM_U8X8_KERNEL_NEON::PackedStrides.K, + 4 // Kernel Stride M }; #if defined(MLAS_TARGET_ARM64) @@ -1214,6 +1215,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon = { MlasGemmQuantCopyPackB, MLAS_GEMM_X8S8_KERNEL_NEON::PackedK, MLAS_GEMM_X8S8_KERNEL_NEON::PackedStrides.K, + 4 // Kernel Stride M }; const MLAS_SYMM_QGEMM_DISPATCH MlasSymmQgemmS8DispatchNeon = { diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp index 411c44ad2b..5370b859bc 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp @@ -756,6 +756,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSdot = { MlasGemmQuantCopyPackB, MLAS_GEMM_S8S8_KERNEL_SDOT::PackedK, MLAS_GEMM_S8S8_KERNEL_SDOT::PackedStrides.K, + 8 // Kernel Stride M }; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_sse.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_sse.cpp index a2abb4aa7e..06f936f50c 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_sse.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_sse.cpp @@ -492,4 +492,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchSse = { nullptr, MLAS_GEMM_U8X8_KERNEL_SSE::PackedK, 0, + 1 // assembly kernel M stride }; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_sse41.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_sse41.cpp index 4ba4b74791..68931c53ee 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_sse41.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_sse41.cpp @@ -445,4 +445,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchSse41 = { MlasGemmQuantCopyPackB, MLAS_GEMM_U8S8_KERNEL_SSE41::PackedK, MLAS_GEMM_U8S8_KERNEL_SSE41::PackedStrides.K, + 1 // assembly kernel M stride }; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_udot.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_udot.cpp index 45e2558ef5..5cec72542d 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_udot.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_udot.cpp @@ -759,4 +759,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot = { MlasGemmQuantCopyPackB, MLAS_GEMM_U8X8_KERNEL_UDOT::PackedK, MLAS_GEMM_U8X8_KERNEL_UDOT::PackedStrides.K, + 8 }; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp index 92c05fa39b..f85fc929f7 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp @@ -504,4 +504,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd = { nullptr, MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedK, 0, + 4 // multiple of kernel stride M }; diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 1a1f30fdf9..4358ddc7c3 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -206,8 +206,9 @@ class WindowsEnv : public Env { ret.push_back(buffer[i].ProcessorMask); } } - if (ret.empty()) + if (ret.empty()){ return generate_vector_of_n(std::thread::hardware_concurrency()); + } return ret; } diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc index 36f2f76cfa..9e5872f921 100644 --- a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc @@ -11,21 +11,6 @@ #include "core/util/qmath.h" #include "core/mlas/inc/mlas.h" -#if defined(_M_ARM64) || defined(__aarch64__) - -// -// TODO!! Hack! need to move this to MLAS -// -// We use a different job partition with mobile devices -// where the model tends to be smaller. When the entire -// weight matrix fits in the cache, we process a thin -// horizontal slice of the result matrix at a time. The -// thickness of this slice depend on micro-kernel M -// stride. -// -#define GEMM_KERNEL_STRIDE_M 4 - -#endif namespace onnxruntime { @@ -123,31 +108,63 @@ class QLinearConv : public OpKernel { return output_scales; } - static int32_t ComputeTaskCount(int64_t output_image_size, int64_t group_output_channels, int64_t kernel_dim) { - // Replicate the logic from MlasGemmU8X8Schedule to control the number of - // worker threads used for the convolution. - int32_t maximum_thread_count; - if (CPUIDInfo::GetCPUIDInfo().IsHybrid()) { - maximum_thread_count = 64; - } else { - maximum_thread_count = 16; - } - constexpr double thread_complexity = static_cast(64 * 1024); + /** + * @brief Computes the partition stride of the activation tensor. + * + * Current threaded job partiton is limited in that we can't + * partition the filter tensor (a TODO item). So we can only + * horizontally partition the activation tensor into thin + * slices. This function decides the thickness of that slice, + * which is also number of output pixels each job produces. + * + * @param degree_of_parallelism Configured thread parallelism for this run + * @param output_image_size Number of pixels in the output image + * @param group_output_channels Number of filters in this group. + * @param kernel_dim Dimension of a filter + * @param comp_kernel_stride Best stride to fully utilize hand tuned computing kernel. + * @return + */ + static int32_t ComputeOutputStride(int32_t degree_of_parallelism, + int64_t output_image_size, + int64_t group_output_channels, + int64_t kernel_dim, + int64_t comp_kernel_stride) { + // + // The idea is to simply partition the activation tensor using the computation kernel stride, to ensure + // the hand crafted kernel code has maximum throughput in almost all the jobs. Most of the below logic, + // however, is to take care of corner cases where we have either too few or too many partitions. + // + constexpr double MIN_COMPLEXITY = static_cast(64 * 1024); - const double complexity = static_cast(output_image_size) * - static_cast(group_output_channels) * - static_cast(kernel_dim); + const int64_t weights = group_output_channels * kernel_dim; + const int32_t min_stride = static_cast(std::ceil(MIN_COMPLEXITY / static_cast(weights))); - int32_t task_count = maximum_thread_count; - if (complexity < thread_complexity * maximum_thread_count) { - task_count = static_cast(complexity / thread_complexity) + 1; - } - if (task_count > output_image_size) { - // Ensure that every thread produces at least one output. - task_count = static_cast(output_image_size); + int32_t output_stride = static_cast(comp_kernel_stride); + + if (output_stride < min_stride) { + output_stride = (min_stride + output_stride - 1) / output_stride * output_stride; } - return task_count; + const auto task_count = (output_image_size + output_stride - 1) / output_stride; +#if defined(_M_ARM64) || defined(__aarch64__) || defined(_M_ARM) || defined(__arm__) + const auto large_jobs = degree_of_parallelism << 6; +#else + const auto large_jobs = degree_of_parallelism * 5; +#endif + if (task_count > large_jobs) { + // too many tasks, need a bigger stride + output_stride = static_cast(((output_image_size + large_jobs - 1) / large_jobs + comp_kernel_stride - 1) / comp_kernel_stride * comp_kernel_stride); + } + + // We need a better partiton when we have a big filter tensor and very small activation tensor + // TODO!! we should partition the weight tensor instead + constexpr int64_t BIG_WEIGHT = 1024 * 1024; + if (weights >= BIG_WEIGHT && task_count < (degree_of_parallelism / 8) ) { + int32_t s1 = static_cast((output_image_size + degree_of_parallelism - 1) / degree_of_parallelism); + output_stride = std::max(s1, min_stride); + } + + return output_stride; } bool TryConvSymPrepack(const uint8_t* Wdata, @@ -656,12 +673,41 @@ Status QLinearConv::Compute(OpKernelContext* context) const { } concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); -#if defined(_M_ARM64) || defined(__aarch64__) - int32_t task_count = (output_image_size + (GEMM_KERNEL_STRIDE_M - 1)) / GEMM_KERNEL_STRIDE_M; -#else - int32_t task_count = ComputeTaskCount(output_image_size, group_output_channels, kernel_dim); - task_count = std::min(task_count, concurrency::ThreadPool::DegreeOfParallelism(thread_pool)); -#endif + + /************************************* + * Thread partition idea: we are essentially partition a GEMM A[M,K] x B[K,N]. + * Here B contains the conv filters, which are usually not big, so we assume + * it can be in cache entirely. Then we simply partition A horizontally into + * thin slices along M dimension. This would ensure that the slice of A fits + * into the cache and reduce the chance of kernel waiting for memory. + * + * The thickness of A slice should be multiple of kernel stride M. Since + * we have to choose from many different kernels, the logic of finding + * the stride M is hacky. + */ + + // The following convoluted branches must match the kernel selection logic + // in conv_worker. + int64_t compute_stride; + if (is_symmetric_conv_) { + if (is_depthwise_conv) { + compute_stride = MlasConvSymDepthwiseGetKernelOutputCnt(std::is_signed::value); + } else { + compute_stride = MlasConvSymGetKernelOutputCount(std::is_signed::value); + } + } else if (is_depthwise_conv) { + compute_stride = MlasConvDepthwiseGetKernelOutputCnt(); + } else { + if (is_symmetric_gemm_) { + compute_stride = MlasSymmQgemmGetKernelOutputCnt(); + } else { + compute_stride = MlasQgemmGetKernelOutputCnt(std::is_signed::value, is_W_signed); + } + } + + const int32_t degree_of_par = concurrency::ThreadPool::DegreeOfParallelism(thread_pool); + const int32_t stride_m = ComputeOutputStride(degree_of_par, output_image_size, group_output_channels, kernel_dim, compute_stride); + const int64_t task_count = (output_image_size + stride_m - 1) / stride_m; for (int64_t image_id = 0; image_id < N; ++image_id) { const auto* input_data = Xdata; @@ -699,14 +745,8 @@ Status QLinearConv::Compute(OpKernelContext* context) const { } auto conv_worker = [&](ptrdiff_t batch) { -#if defined(_M_ARM64) || defined(__aarch64__) - int64_t output_start = batch * GEMM_KERNEL_STRIDE_M; - int64_t output_count = std::min((int64_t)GEMM_KERNEL_STRIDE_M, output_image_size - output_start); -#else - auto work = concurrency::ThreadPool::PartitionWork(batch, task_count, static_cast(output_image_size)); - int64_t output_start = static_cast(work.start); - int64_t output_count = static_cast(work.end) - work.start; -#endif + int64_t output_start = batch * stride_m; + int64_t output_count = std::min((int64_t)stride_m, output_image_size - output_start); ActType const** worker_indirection_buffer = nullptr; if (indirection_buffer) {