mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
QlinearConv threading adjustments (#11228)
* Reserve the first core for the main thread Currently in "auto affinity" mode the worker threads are affinized to cores 0..(N-1), leaving the very last core for the main thread. This patch preserves core #0 for the main thread, and affinizes the worker threads to cores 1..N. * Avoid unneeded spin_pause in thread pool's worker threads Remove unneeded PAUSE instruction (0.1-0.2 usec latency) after a worker thread finds a task to execute. * MLAS/x86: optimize QLinearConv on hybrid CPUs Existing 4x task granularity for task partitioning on hybrid CPUs is not sufficient to compensate the difference of VNNI instructions throughput between performance and efficient cores. This patch... * Increases granularity for QLinearConv by 2x, to have 2x more tasks with 2x smaller output count * Limits QLinearConv task count from above, to avoid output count per task getting smaller than kernel's capability * Remove hardcoded task count for QLineConv as it limited scaling on 16+ cores CPUs * MLAS/x86: optimize QLinearConv on hybrid CPUs Existing 4x task granularity for task partitioning on hybrid CPUs is not sufficient to compensate the difference of VNNI instructions throughput between performance and efficient cores. This patch... * Increases granularity for QLinearConv by 2x, to have 2x more tasks with 2x smaller output count * Limits QLinearConv task count from above, to avoid output count per task getting smaller than kernel's capability * Remove hardcoded task count for QLineConv as it limited scaling on 16+ cores CP * Addressing comments * combining x86 ARM branches in qlinearconv threaded job partition * revert first core assignment Co-authored-by: Saurabh <saurabh.tangri@intel.com> Co-authored-by: Chen Fu <fuchen@microsoft.com>
This commit is contained in:
parent
80d8c4c7ff
commit
d936751aad
19 changed files with 211 additions and 53 deletions
|
|
@ -1020,6 +1020,13 @@ void EndParallelSection(ThreadPoolParallelSection &ps) override {
|
|||
|
||||
void InitializePreferredWorkers(std::vector<int> &preferred_workers) {
|
||||
static std::atomic<unsigned> next_worker;
|
||||
|
||||
// preferred_workers[0] isn't supposed to be used, so initializng it with -1 to:
|
||||
// a) fault if inapropriately accessed
|
||||
// b) avoid wasting next_worker value
|
||||
if (preferred_workers.size() == 0)
|
||||
preferred_workers.push_back(-1);
|
||||
|
||||
// preferred_workers maps from a par_idx to a q_idx, hence we
|
||||
// initialize slots in the range [0,num_threads_]
|
||||
while (preferred_workers.size() <= num_threads_) {
|
||||
|
|
@ -1465,12 +1472,14 @@ int CurrentThreadId() const final {
|
|||
Task t = q.PopFront();
|
||||
if (!t) {
|
||||
// Spin waiting for work.
|
||||
for (int i = 0; i < spin_count && !t && !done_; i++) {
|
||||
for (int i = 0; i < spin_count && !done_; i++) {
|
||||
if (((i+1)%steal_count == 0)) {
|
||||
t = Steal(StealAttemptKind::TRY_ONE);
|
||||
} else {
|
||||
t = q.PopFront();
|
||||
}
|
||||
if (t) break;
|
||||
|
||||
onnxruntime::concurrency::SpinPause();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -377,6 +377,13 @@ ThreadPool::ThreadPool(Env* env,
|
|||
assert(degree_of_parallelism >= 1);
|
||||
if (degree_of_parallelism >= 2) {
|
||||
int threads_to_create = degree_of_parallelism - 1;
|
||||
|
||||
if (!thread_options_.affinity.empty()) {
|
||||
// Remove first affinity element as designated for the caller thread
|
||||
thread_options_.affinity.erase(thread_options_.affinity.begin());
|
||||
assert(thread_options_.affinity.size() >= size_t(threads_to_create));
|
||||
}
|
||||
|
||||
extended_eigen_threadpool_ =
|
||||
std::make_unique<ThreadPoolTempl<Env> >(name,
|
||||
threads_to_create,
|
||||
|
|
|
|||
|
|
@ -829,6 +829,55 @@ MlasConvSymFixupInputZeroPoint(
|
|||
bool InputIsSigned
|
||||
);
|
||||
|
||||
//
|
||||
// Convolution operators (or maybe others in the future) need to do their
|
||||
// own job partition. Since filters (right hand side B matrix) is usually
|
||||
// small in size, activations are divided horizontally. We need to provide
|
||||
// kernel stride units to facilitate the divide.
|
||||
//
|
||||
|
||||
int32_t
|
||||
MlasConvSymGetKernelOutputCount(
|
||||
bool InputIsSigned
|
||||
);
|
||||
|
||||
int32_t
|
||||
MlasConvSymDepthwiseGetKernelOutputCnt(
|
||||
bool InputIsSigned
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief Returns the stride M of depthwise conv kernel
|
||||
*
|
||||
* Most optimized path is Symmetric conv. See
|
||||
* MlasConvSymDepthwiseGetKernelOutputCnt(bool)
|
||||
*
|
||||
* These kernels are implemented in qdwconv.cpp using
|
||||
* intrincic, all of them with stride val 1. We use
|
||||
* a slightly bigger value to improve cache reuse.
|
||||
*
|
||||
* This needs to be changed if we optimize depthwise
|
||||
* kernels.
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
inline
|
||||
int32_t
|
||||
MlasConvDepthwiseGetKernelOutputCnt()
|
||||
{
|
||||
return 4;
|
||||
}
|
||||
|
||||
int32_t
|
||||
MlasSymmQgemmGetKernelOutputCnt();
|
||||
|
||||
int32_t
|
||||
MlasQgemmGetKernelOutputCnt(
|
||||
bool AIsSigned,
|
||||
bool BIsSigned
|
||||
);
|
||||
|
||||
|
||||
struct MLAS_CONV_SYM_PARAMS {
|
||||
const void* InputDirect;
|
||||
const void* const* InputIndirection;
|
||||
|
|
|
|||
|
|
@ -470,6 +470,24 @@ MlasConvSymFixupInputZeroPoint(
|
|||
return zero_point_value;
|
||||
}
|
||||
|
||||
int32_t
|
||||
MlasConvSymGetKernelOutputCount(
|
||||
bool InputIsSigned
|
||||
)
|
||||
{
|
||||
const MLAS_CONV_SYM_DISPATCH* ConvSymDispatch = GetConvSymDispatch(InputIsSigned);
|
||||
return ConvSymDispatch->KernelOutputCount;
|
||||
}
|
||||
|
||||
int32_t
|
||||
MlasConvSymDepthwiseGetKernelOutputCnt(
|
||||
bool InputIsSigned
|
||||
)
|
||||
{
|
||||
const MLAS_CONV_SYM_DISPATCH* ConvSymDispatch = GetConvSymDispatch(InputIsSigned);
|
||||
return ConvSymDispatch->KernelDepthwiseOutputCount;
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
MlasConvSym(
|
||||
|
|
|
|||
|
|
@ -737,7 +737,6 @@ extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAvx2;
|
|||
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2;
|
||||
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchNeon;
|
||||
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon;
|
||||
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchNeon;
|
||||
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot;
|
||||
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSdot;
|
||||
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd;
|
||||
|
|
|
|||
|
|
@ -1190,4 +1190,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchPOWER10 = {
|
|||
MlasGemmQuantCopyPackB<MLAS_GEMM_QUANT_KERNEL_POWER10>,
|
||||
MLAS_GEMM_QUANT_KERNEL_POWER10::PackedK,
|
||||
MLAS_GEMM_QUANT_KERNEL_POWER10::PackedStrides.K,
|
||||
8 // Kernel M stride
|
||||
};
|
||||
|
|
|
|||
|
|
@ -30,6 +30,10 @@ MlasConvDepthwiseKernel(
|
|||
size_t KernelSize
|
||||
)
|
||||
{
|
||||
//
|
||||
// TODO Modify MlasConvDepthwiseGetKernelOutputCnt() function if this kernel
|
||||
// is further optimized.
|
||||
//
|
||||
#if defined(MLAS_SSE2_INTRINSICS)
|
||||
const __m128i ZeroVector = _mm_setzero_si128();
|
||||
const __m128i InputZeroPointVector = _mm_set1_epi16(InputZeroPoint);
|
||||
|
|
|
|||
|
|
@ -109,6 +109,16 @@ Return Value:
|
|||
}
|
||||
|
||||
|
||||
int32_t
|
||||
MlasQgemmGetKernelOutputCnt(
|
||||
bool AIsSigned,
|
||||
bool BIsSigned
|
||||
)
|
||||
{
|
||||
const auto* dispatch = MlasGemmQuantGetDispatch(AIsSigned, BIsSigned);
|
||||
return int32_t(dispatch->StrideM);
|
||||
}
|
||||
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(push)
|
||||
// VC++ suggests we can attempt to make 'MlasBitsOfFp32' constexpr, but it is not valid.
|
||||
|
|
@ -192,6 +202,15 @@ MlasGemmBatch(
|
|||
});
|
||||
}
|
||||
|
||||
|
||||
int32_t
|
||||
MlasSymmQgemmGetKernelOutputCnt()
|
||||
{
|
||||
const MLAS_SYMM_QGEMM_DISPATCH* dispatch = GetMlasPlatform().SymmQgemmDispatch;
|
||||
return int32_t(dispatch->StrideM);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasSymmQgemmBatch(
|
||||
|
|
|
|||
|
|
@ -799,6 +799,7 @@ struct MLAS_GEMM_QUANT_DISPATCH {
|
|||
MLAS_GEMM_QUANT_COPY_PACKB_ROUTINE* CopyPackBRoutine;
|
||||
size_t PackedK;
|
||||
size_t PackedStrideK;
|
||||
size_t StrideM;
|
||||
};
|
||||
|
||||
struct MLAS_SYMM_QGEMM_DISPATCH {
|
||||
|
|
|
|||
|
|
@ -189,6 +189,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAvx2 = {
|
|||
MlasGemmQuantCopyPackB<MLAS_GEMM_U8S8_KERNEL_AVX2>,
|
||||
MLAS_GEMM_U8S8_KERNEL_AVX2::PackedK,
|
||||
MLAS_GEMM_U8S8_KERNEL_AVX2::PackedStrides.K,
|
||||
6 // assembly kernel M stride
|
||||
};
|
||||
|
||||
struct MLAS_GEMM_U8U8_KERNEL_AVX2
|
||||
|
|
@ -270,4 +271,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2 = {
|
|||
MlasGemmQuantCopyPackB<MLAS_GEMM_U8U8_KERNEL_AVX2>,
|
||||
MLAS_GEMM_U8U8_KERNEL_AVX2::PackedK,
|
||||
MLAS_GEMM_U8U8_KERNEL_AVX2::PackedStrides.K,
|
||||
6 // assembly kernel M stride
|
||||
};
|
||||
|
|
|
|||
|
|
@ -220,4 +220,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmQuantDispatchDefault = {
|
|||
nullptr,
|
||||
MLAS_GEMM_QUANT_KERNEL_DEFAULT::PackedK,
|
||||
0,
|
||||
MLAS_GEMM_QUANT_KERNEL_DEFAULT::Strides.M
|
||||
};
|
||||
|
|
|
|||
|
|
@ -583,6 +583,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchNeon = {
|
|||
MlasGemmQuantCopyPackB<MLAS_GEMM_U8X8_KERNEL_NEON>,
|
||||
MLAS_GEMM_U8X8_KERNEL_NEON::PackedK,
|
||||
MLAS_GEMM_U8X8_KERNEL_NEON::PackedStrides.K,
|
||||
4 // Kernel Stride M
|
||||
};
|
||||
|
||||
#if defined(MLAS_TARGET_ARM64)
|
||||
|
|
@ -1214,6 +1215,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon = {
|
|||
MlasGemmQuantCopyPackB<MLAS_GEMM_X8S8_KERNEL_NEON>,
|
||||
MLAS_GEMM_X8S8_KERNEL_NEON::PackedK,
|
||||
MLAS_GEMM_X8S8_KERNEL_NEON::PackedStrides.K,
|
||||
4 // Kernel Stride M
|
||||
};
|
||||
|
||||
const MLAS_SYMM_QGEMM_DISPATCH MlasSymmQgemmS8DispatchNeon = {
|
||||
|
|
|
|||
|
|
@ -756,6 +756,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSdot = {
|
|||
MlasGemmQuantCopyPackB<MLAS_GEMM_S8S8_KERNEL_SDOT>,
|
||||
MLAS_GEMM_S8S8_KERNEL_SDOT::PackedK,
|
||||
MLAS_GEMM_S8S8_KERNEL_SDOT::PackedStrides.K,
|
||||
8 // Kernel Stride M
|
||||
};
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -492,4 +492,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchSse = {
|
|||
nullptr,
|
||||
MLAS_GEMM_U8X8_KERNEL_SSE::PackedK,
|
||||
0,
|
||||
1 // assembly kernel M stride
|
||||
};
|
||||
|
|
|
|||
|
|
@ -445,4 +445,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchSse41 = {
|
|||
MlasGemmQuantCopyPackB<MLAS_GEMM_U8S8_KERNEL_SSE41>,
|
||||
MLAS_GEMM_U8S8_KERNEL_SSE41::PackedK,
|
||||
MLAS_GEMM_U8S8_KERNEL_SSE41::PackedStrides.K,
|
||||
1 // assembly kernel M stride
|
||||
};
|
||||
|
|
|
|||
|
|
@ -759,4 +759,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot = {
|
|||
MlasGemmQuantCopyPackB<MLAS_GEMM_U8X8_KERNEL_UDOT>,
|
||||
MLAS_GEMM_U8X8_KERNEL_UDOT::PackedK,
|
||||
MLAS_GEMM_U8X8_KERNEL_UDOT::PackedStrides.K,
|
||||
8
|
||||
};
|
||||
|
|
|
|||
|
|
@ -504,4 +504,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd = {
|
|||
nullptr,
|
||||
MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedK,
|
||||
0,
|
||||
4 // multiple of kernel stride M
|
||||
};
|
||||
|
|
|
|||
|
|
@ -206,8 +206,9 @@ class WindowsEnv : public Env {
|
|||
ret.push_back(buffer[i].ProcessorMask);
|
||||
}
|
||||
}
|
||||
if (ret.empty())
|
||||
if (ret.empty()){
|
||||
return generate_vector_of_n(std::thread::hardware_concurrency());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,21 +11,6 @@
|
|||
#include "core/util/qmath.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
|
||||
#if defined(_M_ARM64) || defined(__aarch64__)
|
||||
|
||||
//
|
||||
// TODO!! Hack! need to move this to MLAS
|
||||
//
|
||||
// We use a different job partition with mobile devices
|
||||
// where the model tends to be smaller. When the entire
|
||||
// weight matrix fits in the cache, we process a thin
|
||||
// horizontal slice of the result matrix at a time. The
|
||||
// thickness of this slice depend on micro-kernel M
|
||||
// stride.
|
||||
//
|
||||
#define GEMM_KERNEL_STRIDE_M 4
|
||||
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -123,31 +108,63 @@ class QLinearConv : public OpKernel {
|
|||
return output_scales;
|
||||
}
|
||||
|
||||
static int32_t ComputeTaskCount(int64_t output_image_size, int64_t group_output_channels, int64_t kernel_dim) {
|
||||
// Replicate the logic from MlasGemmU8X8Schedule to control the number of
|
||||
// worker threads used for the convolution.
|
||||
int32_t maximum_thread_count;
|
||||
if (CPUIDInfo::GetCPUIDInfo().IsHybrid()) {
|
||||
maximum_thread_count = 64;
|
||||
} else {
|
||||
maximum_thread_count = 16;
|
||||
}
|
||||
constexpr double thread_complexity = static_cast<double>(64 * 1024);
|
||||
/**
|
||||
* @brief Computes the partition stride of the activation tensor.
|
||||
*
|
||||
* Current threaded job partiton is limited in that we can't
|
||||
* partition the filter tensor (a TODO item). So we can only
|
||||
* horizontally partition the activation tensor into thin
|
||||
* slices. This function decides the thickness of that slice,
|
||||
* which is also number of output pixels each job produces.
|
||||
*
|
||||
* @param degree_of_parallelism Configured thread parallelism for this run
|
||||
* @param output_image_size Number of pixels in the output image
|
||||
* @param group_output_channels Number of filters in this group.
|
||||
* @param kernel_dim Dimension of a filter
|
||||
* @param comp_kernel_stride Best stride to fully utilize hand tuned computing kernel.
|
||||
* @return
|
||||
*/
|
||||
static int32_t ComputeOutputStride(int32_t degree_of_parallelism,
|
||||
int64_t output_image_size,
|
||||
int64_t group_output_channels,
|
||||
int64_t kernel_dim,
|
||||
int64_t comp_kernel_stride) {
|
||||
//
|
||||
// The idea is to simply partition the activation tensor using the computation kernel stride, to ensure
|
||||
// the hand crafted kernel code has maximum throughput in almost all the jobs. Most of the below logic,
|
||||
// however, is to take care of corner cases where we have either too few or too many partitions.
|
||||
//
|
||||
constexpr double MIN_COMPLEXITY = static_cast<double>(64 * 1024);
|
||||
|
||||
const double complexity = static_cast<double>(output_image_size) *
|
||||
static_cast<double>(group_output_channels) *
|
||||
static_cast<double>(kernel_dim);
|
||||
const int64_t weights = group_output_channels * kernel_dim;
|
||||
const int32_t min_stride = static_cast<int32_t>(std::ceil(MIN_COMPLEXITY / static_cast<double>(weights)));
|
||||
|
||||
int32_t task_count = maximum_thread_count;
|
||||
if (complexity < thread_complexity * maximum_thread_count) {
|
||||
task_count = static_cast<int32_t>(complexity / thread_complexity) + 1;
|
||||
}
|
||||
if (task_count > output_image_size) {
|
||||
// Ensure that every thread produces at least one output.
|
||||
task_count = static_cast<int32_t>(output_image_size);
|
||||
int32_t output_stride = static_cast<int32_t>(comp_kernel_stride);
|
||||
|
||||
if (output_stride < min_stride) {
|
||||
output_stride = (min_stride + output_stride - 1) / output_stride * output_stride;
|
||||
}
|
||||
|
||||
return task_count;
|
||||
const auto task_count = (output_image_size + output_stride - 1) / output_stride;
|
||||
#if defined(_M_ARM64) || defined(__aarch64__) || defined(_M_ARM) || defined(__arm__)
|
||||
const auto large_jobs = degree_of_parallelism << 6;
|
||||
#else
|
||||
const auto large_jobs = degree_of_parallelism * 5;
|
||||
#endif
|
||||
if (task_count > large_jobs) {
|
||||
// too many tasks, need a bigger stride
|
||||
output_stride = static_cast<int32_t>(((output_image_size + large_jobs - 1) / large_jobs + comp_kernel_stride - 1) / comp_kernel_stride * comp_kernel_stride);
|
||||
}
|
||||
|
||||
// We need a better partiton when we have a big filter tensor and very small activation tensor
|
||||
// TODO!! we should partition the weight tensor instead
|
||||
constexpr int64_t BIG_WEIGHT = 1024 * 1024;
|
||||
if (weights >= BIG_WEIGHT && task_count < (degree_of_parallelism / 8) ) {
|
||||
int32_t s1 = static_cast<int32_t>((output_image_size + degree_of_parallelism - 1) / degree_of_parallelism);
|
||||
output_stride = std::max(s1, min_stride);
|
||||
}
|
||||
|
||||
return output_stride;
|
||||
}
|
||||
|
||||
bool TryConvSymPrepack(const uint8_t* Wdata,
|
||||
|
|
@ -656,12 +673,41 @@ Status QLinearConv<ActType>::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
|
||||
#if defined(_M_ARM64) || defined(__aarch64__)
|
||||
int32_t task_count = (output_image_size + (GEMM_KERNEL_STRIDE_M - 1)) / GEMM_KERNEL_STRIDE_M;
|
||||
#else
|
||||
int32_t task_count = ComputeTaskCount(output_image_size, group_output_channels, kernel_dim);
|
||||
task_count = std::min(task_count, concurrency::ThreadPool::DegreeOfParallelism(thread_pool));
|
||||
#endif
|
||||
|
||||
/*************************************
|
||||
* Thread partition idea: we are essentially partition a GEMM A[M,K] x B[K,N].
|
||||
* Here B contains the conv filters, which are usually not big, so we assume
|
||||
* it can be in cache entirely. Then we simply partition A horizontally into
|
||||
* thin slices along M dimension. This would ensure that the slice of A fits
|
||||
* into the cache and reduce the chance of kernel waiting for memory.
|
||||
*
|
||||
* The thickness of A slice should be multiple of kernel stride M. Since
|
||||
* we have to choose from many different kernels, the logic of finding
|
||||
* the stride M is hacky.
|
||||
*/
|
||||
|
||||
// The following convoluted branches must match the kernel selection logic
|
||||
// in conv_worker.
|
||||
int64_t compute_stride;
|
||||
if (is_symmetric_conv_) {
|
||||
if (is_depthwise_conv) {
|
||||
compute_stride = MlasConvSymDepthwiseGetKernelOutputCnt(std::is_signed<ActType>::value);
|
||||
} else {
|
||||
compute_stride = MlasConvSymGetKernelOutputCount(std::is_signed<ActType>::value);
|
||||
}
|
||||
} else if (is_depthwise_conv) {
|
||||
compute_stride = MlasConvDepthwiseGetKernelOutputCnt();
|
||||
} else {
|
||||
if (is_symmetric_gemm_) {
|
||||
compute_stride = MlasSymmQgemmGetKernelOutputCnt();
|
||||
} else {
|
||||
compute_stride = MlasQgemmGetKernelOutputCnt(std::is_signed<ActType>::value, is_W_signed);
|
||||
}
|
||||
}
|
||||
|
||||
const int32_t degree_of_par = concurrency::ThreadPool::DegreeOfParallelism(thread_pool);
|
||||
const int32_t stride_m = ComputeOutputStride(degree_of_par, output_image_size, group_output_channels, kernel_dim, compute_stride);
|
||||
const int64_t task_count = (output_image_size + stride_m - 1) / stride_m;
|
||||
|
||||
for (int64_t image_id = 0; image_id < N; ++image_id) {
|
||||
const auto* input_data = Xdata;
|
||||
|
|
@ -699,14 +745,8 @@ Status QLinearConv<ActType>::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
auto conv_worker = [&](ptrdiff_t batch) {
|
||||
#if defined(_M_ARM64) || defined(__aarch64__)
|
||||
int64_t output_start = batch * GEMM_KERNEL_STRIDE_M;
|
||||
int64_t output_count = std::min((int64_t)GEMM_KERNEL_STRIDE_M, output_image_size - output_start);
|
||||
#else
|
||||
auto work = concurrency::ThreadPool::PartitionWork(batch, task_count, static_cast<ptrdiff_t>(output_image_size));
|
||||
int64_t output_start = static_cast<int64_t>(work.start);
|
||||
int64_t output_count = static_cast<int64_t>(work.end) - work.start;
|
||||
#endif
|
||||
int64_t output_start = batch * stride_m;
|
||||
int64_t output_count = std::min((int64_t)stride_m, output_image_size - output_start);
|
||||
|
||||
ActType const** worker_indirection_buffer = nullptr;
|
||||
if (indirection_buffer) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue