diff --git a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h index 75eb86c9a0..90c4329202 100644 --- a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h +++ b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h @@ -1020,6 +1020,13 @@ void EndParallelSection(ThreadPoolParallelSection &ps) override { void InitializePreferredWorkers(std::vector &preferred_workers) { static std::atomic next_worker; + + // preferred_workers[0] isn't supposed to be used, so initializng it with -1 to: + // a) fault if inapropriately accessed + // b) avoid wasting next_worker value + if (preferred_workers.size() == 0) + preferred_workers.push_back(-1); + // preferred_workers maps from a par_idx to a q_idx, hence we // initialize slots in the range [0,num_threads_] while (preferred_workers.size() <= num_threads_) { @@ -1465,12 +1472,14 @@ int CurrentThreadId() const final { Task t = q.PopFront(); if (!t) { // Spin waiting for work. - for (int i = 0; i < spin_count && !t && !done_; i++) { + for (int i = 0; i < spin_count && !done_; i++) { if (((i+1)%steal_count == 0)) { t = Steal(StealAttemptKind::TRY_ONE); } else { t = q.PopFront(); } + if (t) break; + onnxruntime::concurrency::SpinPause(); } diff --git a/onnxruntime/core/common/threadpool.cc b/onnxruntime/core/common/threadpool.cc index f3714dbca6..ad12198091 100644 --- a/onnxruntime/core/common/threadpool.cc +++ b/onnxruntime/core/common/threadpool.cc @@ -377,6 +377,13 @@ ThreadPool::ThreadPool(Env* env, assert(degree_of_parallelism >= 1); if (degree_of_parallelism >= 2) { int threads_to_create = degree_of_parallelism - 1; + + if (!thread_options_.affinity.empty()) { + // Remove first affinity element as designated for the caller thread + thread_options_.affinity.erase(thread_options_.affinity.begin()); + assert(thread_options_.affinity.size() >= size_t(threads_to_create)); + } + extended_eigen_threadpool_ = std::make_unique >(name, threads_to_create, diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 9d8cc96153..c713b5547e 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -829,6 +829,55 @@ MlasConvSymFixupInputZeroPoint( bool InputIsSigned ); +// +// Convolution operators (or maybe others in the future) need to do their +// own job partition. Since filters (right hand side B matrix) is usually +// small in size, activations are divided horizontally. We need to provide +// kernel stride units to facilitate the divide. +// + +int32_t +MlasConvSymGetKernelOutputCount( + bool InputIsSigned + ); + +int32_t +MlasConvSymDepthwiseGetKernelOutputCnt( + bool InputIsSigned + ); + +/** + * @brief Returns the stride M of depthwise conv kernel + * + * Most optimized path is Symmetric conv. See + * MlasConvSymDepthwiseGetKernelOutputCnt(bool) + * + * These kernels are implemented in qdwconv.cpp using + * intrincic, all of them with stride val 1. We use + * a slightly bigger value to improve cache reuse. + * + * This needs to be changed if we optimize depthwise + * kernels. + * + * @return +*/ +inline +int32_t +MlasConvDepthwiseGetKernelOutputCnt() +{ + return 4; +} + +int32_t +MlasSymmQgemmGetKernelOutputCnt(); + +int32_t +MlasQgemmGetKernelOutputCnt( + bool AIsSigned, + bool BIsSigned + ); + + struct MLAS_CONV_SYM_PARAMS { const void* InputDirect; const void* const* InputIndirection; diff --git a/onnxruntime/core/mlas/lib/convsym.cpp b/onnxruntime/core/mlas/lib/convsym.cpp index a667fe4048..5f8be3580b 100644 --- a/onnxruntime/core/mlas/lib/convsym.cpp +++ b/onnxruntime/core/mlas/lib/convsym.cpp @@ -470,6 +470,24 @@ MlasConvSymFixupInputZeroPoint( return zero_point_value; } +int32_t +MlasConvSymGetKernelOutputCount( + bool InputIsSigned + ) +{ + const MLAS_CONV_SYM_DISPATCH* ConvSymDispatch = GetConvSymDispatch(InputIsSigned); + return ConvSymDispatch->KernelOutputCount; +} + +int32_t +MlasConvSymDepthwiseGetKernelOutputCnt( + bool InputIsSigned + ) +{ + const MLAS_CONV_SYM_DISPATCH* ConvSymDispatch = GetConvSymDispatch(InputIsSigned); + return ConvSymDispatch->KernelDepthwiseOutputCount; +} + void MlasConvSym( diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index df00121c36..da157458ed 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -737,7 +737,6 @@ extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAvx2; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSdot; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd; diff --git a/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp b/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp index 7b6f2db620..633349e800 100644 --- a/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp +++ b/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp @@ -1190,4 +1190,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchPOWER10 = { MlasGemmQuantCopyPackB, MLAS_GEMM_QUANT_KERNEL_POWER10::PackedK, MLAS_GEMM_QUANT_KERNEL_POWER10::PackedStrides.K, + 8 // Kernel M stride }; diff --git a/onnxruntime/core/mlas/lib/qdwconv.cpp b/onnxruntime/core/mlas/lib/qdwconv.cpp index 921addab2c..924009ab5c 100644 --- a/onnxruntime/core/mlas/lib/qdwconv.cpp +++ b/onnxruntime/core/mlas/lib/qdwconv.cpp @@ -30,6 +30,10 @@ MlasConvDepthwiseKernel( size_t KernelSize ) { + // + // TODO Modify MlasConvDepthwiseGetKernelOutputCnt() function if this kernel + // is further optimized. + // #if defined(MLAS_SSE2_INTRINSICS) const __m128i ZeroVector = _mm_setzero_si128(); const __m128i InputZeroPointVector = _mm_set1_epi16(InputZeroPoint); diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index 8cc9d695f9..859fcd049a 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -109,6 +109,16 @@ Return Value: } +int32_t +MlasQgemmGetKernelOutputCnt( + bool AIsSigned, + bool BIsSigned + ) +{ + const auto* dispatch = MlasGemmQuantGetDispatch(AIsSigned, BIsSigned); + return int32_t(dispatch->StrideM); +} + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) // VC++ suggests we can attempt to make 'MlasBitsOfFp32' constexpr, but it is not valid. @@ -192,6 +202,15 @@ MlasGemmBatch( }); } + +int32_t +MlasSymmQgemmGetKernelOutputCnt() +{ + const MLAS_SYMM_QGEMM_DISPATCH* dispatch = GetMlasPlatform().SymmQgemmDispatch; + return int32_t(dispatch->StrideM); +} + + void MLASCALL MlasSymmQgemmBatch( diff --git a/onnxruntime/core/mlas/lib/qgemm.h b/onnxruntime/core/mlas/lib/qgemm.h index fabd791010..8a1856eb71 100644 --- a/onnxruntime/core/mlas/lib/qgemm.h +++ b/onnxruntime/core/mlas/lib/qgemm.h @@ -799,6 +799,7 @@ struct MLAS_GEMM_QUANT_DISPATCH { MLAS_GEMM_QUANT_COPY_PACKB_ROUTINE* CopyPackBRoutine; size_t PackedK; size_t PackedStrideK; + size_t StrideM; }; struct MLAS_SYMM_QGEMM_DISPATCH { diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp index 810a25bf8d..deec324d01 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp @@ -189,6 +189,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAvx2 = { MlasGemmQuantCopyPackB, MLAS_GEMM_U8S8_KERNEL_AVX2::PackedK, MLAS_GEMM_U8S8_KERNEL_AVX2::PackedStrides.K, + 6 // assembly kernel M stride }; struct MLAS_GEMM_U8U8_KERNEL_AVX2 @@ -270,4 +271,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2 = { MlasGemmQuantCopyPackB, MLAS_GEMM_U8U8_KERNEL_AVX2::PackedK, MLAS_GEMM_U8U8_KERNEL_AVX2::PackedStrides.K, + 6 // assembly kernel M stride }; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_default.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_default.cpp index eb8ef4e274..8f4baaa0ff 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_default.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_default.cpp @@ -220,4 +220,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmQuantDispatchDefault = { nullptr, MLAS_GEMM_QUANT_KERNEL_DEFAULT::PackedK, 0, + MLAS_GEMM_QUANT_KERNEL_DEFAULT::Strides.M }; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp index 0b747bc7cc..50e23a0251 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp @@ -583,6 +583,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchNeon = { MlasGemmQuantCopyPackB, MLAS_GEMM_U8X8_KERNEL_NEON::PackedK, MLAS_GEMM_U8X8_KERNEL_NEON::PackedStrides.K, + 4 // Kernel Stride M }; #if defined(MLAS_TARGET_ARM64) @@ -1214,6 +1215,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon = { MlasGemmQuantCopyPackB, MLAS_GEMM_X8S8_KERNEL_NEON::PackedK, MLAS_GEMM_X8S8_KERNEL_NEON::PackedStrides.K, + 4 // Kernel Stride M }; const MLAS_SYMM_QGEMM_DISPATCH MlasSymmQgemmS8DispatchNeon = { diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp index 411c44ad2b..5370b859bc 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp @@ -756,6 +756,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSdot = { MlasGemmQuantCopyPackB, MLAS_GEMM_S8S8_KERNEL_SDOT::PackedK, MLAS_GEMM_S8S8_KERNEL_SDOT::PackedStrides.K, + 8 // Kernel Stride M }; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_sse.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_sse.cpp index a2abb4aa7e..06f936f50c 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_sse.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_sse.cpp @@ -492,4 +492,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchSse = { nullptr, MLAS_GEMM_U8X8_KERNEL_SSE::PackedK, 0, + 1 // assembly kernel M stride }; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_sse41.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_sse41.cpp index 4ba4b74791..68931c53ee 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_sse41.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_sse41.cpp @@ -445,4 +445,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchSse41 = { MlasGemmQuantCopyPackB, MLAS_GEMM_U8S8_KERNEL_SSE41::PackedK, MLAS_GEMM_U8S8_KERNEL_SSE41::PackedStrides.K, + 1 // assembly kernel M stride }; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_udot.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_udot.cpp index 45e2558ef5..5cec72542d 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_udot.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_udot.cpp @@ -759,4 +759,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot = { MlasGemmQuantCopyPackB, MLAS_GEMM_U8X8_KERNEL_UDOT::PackedK, MLAS_GEMM_U8X8_KERNEL_UDOT::PackedStrides.K, + 8 }; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp index 92c05fa39b..f85fc929f7 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp @@ -504,4 +504,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd = { nullptr, MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedK, 0, + 4 // multiple of kernel stride M }; diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 1a1f30fdf9..4358ddc7c3 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -206,8 +206,9 @@ class WindowsEnv : public Env { ret.push_back(buffer[i].ProcessorMask); } } - if (ret.empty()) + if (ret.empty()){ return generate_vector_of_n(std::thread::hardware_concurrency()); + } return ret; } diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc index 36f2f76cfa..9e5872f921 100644 --- a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc @@ -11,21 +11,6 @@ #include "core/util/qmath.h" #include "core/mlas/inc/mlas.h" -#if defined(_M_ARM64) || defined(__aarch64__) - -// -// TODO!! Hack! need to move this to MLAS -// -// We use a different job partition with mobile devices -// where the model tends to be smaller. When the entire -// weight matrix fits in the cache, we process a thin -// horizontal slice of the result matrix at a time. The -// thickness of this slice depend on micro-kernel M -// stride. -// -#define GEMM_KERNEL_STRIDE_M 4 - -#endif namespace onnxruntime { @@ -123,31 +108,63 @@ class QLinearConv : public OpKernel { return output_scales; } - static int32_t ComputeTaskCount(int64_t output_image_size, int64_t group_output_channels, int64_t kernel_dim) { - // Replicate the logic from MlasGemmU8X8Schedule to control the number of - // worker threads used for the convolution. - int32_t maximum_thread_count; - if (CPUIDInfo::GetCPUIDInfo().IsHybrid()) { - maximum_thread_count = 64; - } else { - maximum_thread_count = 16; - } - constexpr double thread_complexity = static_cast(64 * 1024); + /** + * @brief Computes the partition stride of the activation tensor. + * + * Current threaded job partiton is limited in that we can't + * partition the filter tensor (a TODO item). So we can only + * horizontally partition the activation tensor into thin + * slices. This function decides the thickness of that slice, + * which is also number of output pixels each job produces. + * + * @param degree_of_parallelism Configured thread parallelism for this run + * @param output_image_size Number of pixels in the output image + * @param group_output_channels Number of filters in this group. + * @param kernel_dim Dimension of a filter + * @param comp_kernel_stride Best stride to fully utilize hand tuned computing kernel. + * @return + */ + static int32_t ComputeOutputStride(int32_t degree_of_parallelism, + int64_t output_image_size, + int64_t group_output_channels, + int64_t kernel_dim, + int64_t comp_kernel_stride) { + // + // The idea is to simply partition the activation tensor using the computation kernel stride, to ensure + // the hand crafted kernel code has maximum throughput in almost all the jobs. Most of the below logic, + // however, is to take care of corner cases where we have either too few or too many partitions. + // + constexpr double MIN_COMPLEXITY = static_cast(64 * 1024); - const double complexity = static_cast(output_image_size) * - static_cast(group_output_channels) * - static_cast(kernel_dim); + const int64_t weights = group_output_channels * kernel_dim; + const int32_t min_stride = static_cast(std::ceil(MIN_COMPLEXITY / static_cast(weights))); - int32_t task_count = maximum_thread_count; - if (complexity < thread_complexity * maximum_thread_count) { - task_count = static_cast(complexity / thread_complexity) + 1; - } - if (task_count > output_image_size) { - // Ensure that every thread produces at least one output. - task_count = static_cast(output_image_size); + int32_t output_stride = static_cast(comp_kernel_stride); + + if (output_stride < min_stride) { + output_stride = (min_stride + output_stride - 1) / output_stride * output_stride; } - return task_count; + const auto task_count = (output_image_size + output_stride - 1) / output_stride; +#if defined(_M_ARM64) || defined(__aarch64__) || defined(_M_ARM) || defined(__arm__) + const auto large_jobs = degree_of_parallelism << 6; +#else + const auto large_jobs = degree_of_parallelism * 5; +#endif + if (task_count > large_jobs) { + // too many tasks, need a bigger stride + output_stride = static_cast(((output_image_size + large_jobs - 1) / large_jobs + comp_kernel_stride - 1) / comp_kernel_stride * comp_kernel_stride); + } + + // We need a better partiton when we have a big filter tensor and very small activation tensor + // TODO!! we should partition the weight tensor instead + constexpr int64_t BIG_WEIGHT = 1024 * 1024; + if (weights >= BIG_WEIGHT && task_count < (degree_of_parallelism / 8) ) { + int32_t s1 = static_cast((output_image_size + degree_of_parallelism - 1) / degree_of_parallelism); + output_stride = std::max(s1, min_stride); + } + + return output_stride; } bool TryConvSymPrepack(const uint8_t* Wdata, @@ -656,12 +673,41 @@ Status QLinearConv::Compute(OpKernelContext* context) const { } concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); -#if defined(_M_ARM64) || defined(__aarch64__) - int32_t task_count = (output_image_size + (GEMM_KERNEL_STRIDE_M - 1)) / GEMM_KERNEL_STRIDE_M; -#else - int32_t task_count = ComputeTaskCount(output_image_size, group_output_channels, kernel_dim); - task_count = std::min(task_count, concurrency::ThreadPool::DegreeOfParallelism(thread_pool)); -#endif + + /************************************* + * Thread partition idea: we are essentially partition a GEMM A[M,K] x B[K,N]. + * Here B contains the conv filters, which are usually not big, so we assume + * it can be in cache entirely. Then we simply partition A horizontally into + * thin slices along M dimension. This would ensure that the slice of A fits + * into the cache and reduce the chance of kernel waiting for memory. + * + * The thickness of A slice should be multiple of kernel stride M. Since + * we have to choose from many different kernels, the logic of finding + * the stride M is hacky. + */ + + // The following convoluted branches must match the kernel selection logic + // in conv_worker. + int64_t compute_stride; + if (is_symmetric_conv_) { + if (is_depthwise_conv) { + compute_stride = MlasConvSymDepthwiseGetKernelOutputCnt(std::is_signed::value); + } else { + compute_stride = MlasConvSymGetKernelOutputCount(std::is_signed::value); + } + } else if (is_depthwise_conv) { + compute_stride = MlasConvDepthwiseGetKernelOutputCnt(); + } else { + if (is_symmetric_gemm_) { + compute_stride = MlasSymmQgemmGetKernelOutputCnt(); + } else { + compute_stride = MlasQgemmGetKernelOutputCnt(std::is_signed::value, is_W_signed); + } + } + + const int32_t degree_of_par = concurrency::ThreadPool::DegreeOfParallelism(thread_pool); + const int32_t stride_m = ComputeOutputStride(degree_of_par, output_image_size, group_output_channels, kernel_dim, compute_stride); + const int64_t task_count = (output_image_size + stride_m - 1) / stride_m; for (int64_t image_id = 0; image_id < N; ++image_id) { const auto* input_data = Xdata; @@ -699,14 +745,8 @@ Status QLinearConv::Compute(OpKernelContext* context) const { } auto conv_worker = [&](ptrdiff_t batch) { -#if defined(_M_ARM64) || defined(__aarch64__) - int64_t output_start = batch * GEMM_KERNEL_STRIDE_M; - int64_t output_count = std::min((int64_t)GEMM_KERNEL_STRIDE_M, output_image_size - output_start); -#else - auto work = concurrency::ThreadPool::PartitionWork(batch, task_count, static_cast(output_image_size)); - int64_t output_start = static_cast(work.start); - int64_t output_count = static_cast(work.end) - work.start; -#endif + int64_t output_start = batch * stride_m; + int64_t output_count = std::min((int64_t)stride_m, output_image_size - output_start); ActType const** worker_indirection_buffer = nullptr; if (indirection_buffer) {