From 3c7c1068e73bb2bfab47475aa2ebbf6630b9ea06 Mon Sep 17 00:00:00 2001 From: Tracy Sharpe <42477615+tracysh@users.noreply.github.com> Date: Thu, 6 Dec 2018 09:20:32 -0800 Subject: [PATCH] refactor threading (#110) --- cmake/onnxruntime_mlas.cmake | 1 + onnxruntime/core/mlas/inc/mlas.h | 3 +- onnxruntime/core/mlas/lib/convolve.cpp | 225 +++--------------- onnxruntime/core/mlas/lib/mlasi.h | 40 +++- onnxruntime/core/mlas/lib/platform.cpp | 2 +- onnxruntime/core/mlas/lib/sgemm.cpp | 101 +------- onnxruntime/core/mlas/lib/threading.cpp | 135 +++++++++++ onnxruntime/core/providers/cpu/nn/conv_impl.h | 40 ++-- onnxruntime/test/mlas/unittest.cpp | 35 ++- 9 files changed, 256 insertions(+), 326 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/threading.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 8285ef7952..c15e794673 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -3,6 +3,7 @@ set(mlas_common_srcs ${ONNXRUNTIME_ROOT}/core/mlas/lib/platform.cpp + ${ONNXRUNTIME_ROOT}/core/mlas/lib/threading.cpp ${ONNXRUNTIME_ROOT}/core/mlas/lib/sgemm.cpp ${ONNXRUNTIME_ROOT}/core/mlas/lib/convolve.cpp ${ONNXRUNTIME_ROOT}/core/mlas/lib/pooling.cpp diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 975620efb5..738777183e 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -102,7 +102,7 @@ struct MLAS_CONV_PARAMETERS { } u; }; -bool +void MLASCALL MlasConvPrepare( MLAS_CONV_PARAMETERS* Parameters, @@ -115,6 +115,7 @@ MlasConvPrepare( const int64_t* DilationShape, const int64_t* Padding, const int64_t* StrideShape, + const int64_t* OutputShape, size_t FilterCount, size_t* WorkingBufferSize ); diff --git a/onnxruntime/core/mlas/lib/convolve.cpp b/onnxruntime/core/mlas/lib/convolve.cpp index f56ca76f2f..b933032011 100644 --- a/onnxruntime/core/mlas/lib/convolve.cpp +++ b/onnxruntime/core/mlas/lib/convolve.cpp @@ -29,20 +29,17 @@ Abstract: // struct MLAS_CONV_WORK_BLOCK { -#if defined(MLAS_USE_WIN32_THREADPOOL) - volatile LONG Counter; const MLAS_CONV_PARAMETERS* Parameters; const float* Input; const float* Filter; const float* Bias; float* WorkingBuffer; float* Output; -#endif struct SEGMENT { size_t StartN; size_t CountN; } Segments[MLAS_MAXIMUM_THREAD_COUNT]; - uint32_t TargetThreadCount; + int32_t TargetThreadCount; }; void @@ -610,14 +607,10 @@ Return Value: } } -#if defined(MLAS_USE_WIN32_THREADPOOL) - void -CALLBACK -MlasConvWorkCallback( - PTP_CALLBACK_INSTANCE Instance, +MlasConvOperationThreaded( void* Context, - PTP_WORK WorkObject + int32_t Index ) /*++ @@ -628,11 +621,9 @@ Routine Description: Arguments: - Instance - Supplies the callback instance object. + Context - Supplies the pointer to the context for the threaded operation. - Context - Supplies the pointer to the parameters for the SGEMM operation. - - WorkObject - Supplies the threadpool work object. + Index - Supplies the current index of the threaded operation. Return Value: @@ -640,13 +631,8 @@ Return Value: --*/ { - MLAS_UNREFERENCED_PARAMETER(Instance); - MLAS_UNREFERENCED_PARAMETER(WorkObject); - MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context; - LONG Index = InterlockedIncrement(&WorkBlock->Counter) - 1; - MLAS_CONV_WORK_BLOCK::SEGMENT* Segment = &WorkBlock->Segments[Index]; float* ColumnBuffer = @@ -658,11 +644,9 @@ Return Value: } void -CALLBACK -MlasConvGemmDirectWorkCallback( - PTP_CALLBACK_INSTANCE Instance, +MlasConvGemmDirectThreaded( void* Context, - PTP_WORK WorkObject + int32_t Index ) /*++ @@ -673,11 +657,9 @@ Routine Description: Arguments: - Instance - Supplies the callback instance object. + Context - Supplies the pointer to the context for the threaded operation. - Context - Supplies the pointer to the parameters for the SGEMM operation. - - WorkObject - Supplies the threadpool work object. + Index - Supplies the current index of the threaded operation. Return Value: @@ -685,13 +667,8 @@ Return Value: --*/ { - MLAS_UNREFERENCED_PARAMETER(Instance); - MLAS_UNREFERENCED_PARAMETER(WorkObject); - MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context; - LONG Index = InterlockedIncrement(&WorkBlock->Counter) - 1; - const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters; // @@ -755,8 +732,6 @@ Return Value: } } -#endif - inline bool MlasConvTryMultithread( @@ -798,7 +773,7 @@ Return Value: --*/ { -#if defined(MLAS_USE_WIN32_THREADPOOL) || defined(MLAS_USE_OPENMP) +#if defined(MLAS_HAS_THREADING_SUPPORT) MLAS_CONV_WORK_BLOCK WorkBlock; @@ -809,23 +784,10 @@ Return Value: return false; } -#if defined(MLAS_USE_WIN32_THREADPOOL) - - // - // Create an object to submit work to the threadpool. - // - - PTP_WORK WorkObject = CreateThreadpoolWork(MlasConvWorkCallback, &WorkBlock, nullptr); - - if (WorkObject == nullptr) { - return false; - } - // // Initialize the common fields of the work block. // - WorkBlock.Counter = 0; WorkBlock.Parameters = Parameters; WorkBlock.Input = Input; WorkBlock.Filter = Filter; @@ -833,13 +795,11 @@ Return Value: WorkBlock.WorkingBuffer = WorkingBuffer; WorkBlock.Output = Output; -#endif - // // Segment the operation across multiple threads. // - uint32_t Index = 0; + int32_t Index = 0; size_t SegmentCountN; for (size_t SegmentStartN = 0; SegmentStartN < OutputSize; SegmentStartN += SegmentCountN) { @@ -853,51 +813,10 @@ Return Value: WorkBlock.Segments[Index].StartN = SegmentStartN; WorkBlock.Segments[Index].CountN = SegmentCountN; -#if defined(MLAS_USE_WIN32_THREADPOOL) - - // - // Execute one of the segments on a worker thread. - // - - if (Index > 0) { - SubmitThreadpoolWork(WorkObject); - } - -#endif - Index++; } -#if defined(MLAS_USE_OPENMP) - - #pragma omp parallel for - for (int32_t tid = 0; tid < int32_t(Index); tid++) { - - MLAS_CONV_WORK_BLOCK::SEGMENT* Segment = &WorkBlock.Segments[tid]; - - float* ColumnBuffer = - WorkingBuffer + tid * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD; - - MlasConvOperation(Parameters, Input, Filter, Bias, ColumnBuffer, Output, - Segment->StartN, Segment->CountN); - } - -#elif defined(MLAS_USE_WIN32_THREADPOOL) - - // - // Execute the remaining segment on this thread. - // - - MlasConvWorkCallback(nullptr, &WorkBlock, WorkObject); - - // - // Wait for the worker threads to complete. - // - - WaitForThreadpoolWorkCallbacks(WorkObject, FALSE); - CloseThreadpoolWork(WorkObject); - -#endif + MlasExecuteThreaded(MlasConvOperationThreaded, &WorkBlock, Index); return true; @@ -971,7 +890,7 @@ Return Value: const MLAS_CONV_ALGORITHM Algorithm = Parameters->Algorithm; -#if defined(MLAS_USE_WIN32_THREADPOOL) || defined(MLAS_USE_OPENMP) +#if defined(MLAS_HAS_THREADING_SUPPORT) // // Schedule batches of GEMMs across multiple threads. @@ -981,76 +900,25 @@ Return Value: const size_t BatchGroupCount = BatchCount * GroupCount; -#if defined(MLAS_USE_WIN32_THREADPOOL) - - uint32_t TargetThreadCount = MlasPlatform.GetMaximumThreadCount(); + int32_t TargetThreadCount = MlasPlatform.GetMaximumThreadCount(); if (TargetThreadCount >= BatchGroupCount) { - TargetThreadCount = uint32_t(BatchGroupCount); + TargetThreadCount = int32_t(BatchGroupCount); } - if (TargetThreadCount > 1) { + MLAS_CONV_WORK_BLOCK WorkBlock; - MLAS_CONV_WORK_BLOCK WorkBlock; + WorkBlock.Parameters = Parameters; + WorkBlock.Input = Input; + WorkBlock.Filter = Filter; + WorkBlock.Bias = Bias; + WorkBlock.WorkingBuffer = nullptr; + WorkBlock.Output = Output; + WorkBlock.TargetThreadCount = TargetThreadCount; - PTP_WORK WorkObject = CreateThreadpoolWork(MlasConvGemmDirectWorkCallback, &WorkBlock, nullptr); - - if (WorkObject != nullptr) { - - WorkBlock.Counter = 0; - WorkBlock.Parameters = Parameters; - WorkBlock.Input = Input; - WorkBlock.Filter = Filter; - WorkBlock.Bias = Bias; - WorkBlock.WorkingBuffer = nullptr; - WorkBlock.Output = Output; - WorkBlock.TargetThreadCount = TargetThreadCount; - - for (uint32_t Index = 1; Index < TargetThreadCount; Index++) { - SubmitThreadpoolWork(WorkObject); - } - - MlasConvGemmDirectWorkCallback(nullptr, &WorkBlock, WorkObject); - - WaitForThreadpoolWorkCallbacks(WorkObject, FALSE); - CloseThreadpoolWork(WorkObject); - - return; - } - } - -#else - - #pragma omp parallel for - for (int64_t bg = 0; bg < int64_t(BatchGroupCount); bg++) { - - size_t group = size_t(bg % GroupCount); - - const float* input = Input + bg * InputGroupSize; - const float* filter = Filter + group * FilterGroupSize; - float* output = Output + bg * OutputGroupSize; - - // - // Invoke the non-threaded GEMM directly with the input tensor. - // - - MlasSgemmOperation(CblasNoTrans, Parameters->u.GemmDirect.TransB, FilterCount, - OutputSize, K, 1.0f, filter, K, input, Parameters->u.GemmDirect.ldb, 0.0f, - output, OutputSize); - - // - // Add the optional bias vector. - // - - if (Bias != nullptr) { - MlasBiasAdd(Bias + group * FilterCount, FilterCount, output, OutputSize, OutputSize); - } - } + MlasExecuteThreaded(MlasConvGemmDirectThreaded, &WorkBlock, TargetThreadCount); return; - -#endif - } #endif @@ -1152,7 +1020,7 @@ Return Value: } } -bool +void MLASCALL MlasConvPrepare( MLAS_CONV_PARAMETERS* Parameters, @@ -1165,6 +1033,7 @@ MlasConvPrepare( const int64_t* DilationShape, const int64_t* Padding, const int64_t* StrideShape, + const int64_t* OutputShape, size_t FilterCount, size_t* WorkingBufferSize ) @@ -1200,6 +1069,8 @@ Arguments: StrideShape - Supplies the shape of the stride. + OutputShape - Supplies the shape of the output tensor. + FilterCount - Supplies the number of rows of the filter matrix per group. WorkingBufferSize - Receives the number of elements to allocate for the @@ -1207,20 +1078,12 @@ Arguments: Return Value: - Returns true if implementation can support this operation. + None. --*/ { // - // Support only 2D or 3D convolutions. - // - - if (Dimensions != 2 && Dimensions != 3) { - return false; - } - - // - // Build the convolution parameters. + // Save the convolution parameters. // Parameters->Dimensions = Dimensions; @@ -1240,25 +1103,13 @@ Return Value: for (size_t dim = 0; dim < Dimensions; dim++) { Parameters->InputShape[dim] = size_t(InputShape[dim]); + Parameters->OutputShape[dim] = size_t(OutputShape[dim]); Parameters->KernelShape[dim] = size_t(KernelShape[dim]); Parameters->DilationShape[dim] = size_t(DilationShape[dim]); Parameters->Padding[dim] = size_t(Padding[dim]); Parameters->Padding[dim + Dimensions] = size_t(Padding[dim + Dimensions]); Parameters->StrideShape[dim] = size_t(StrideShape[dim]); - // - // Compute the output shape. - // - - int64_t OutputShape = (InputShape[dim] + Padding[dim] + Padding[dim + Dimensions] - - (DilationShape[dim] * (KernelShape[dim] - 1) + 1)) / StrideShape[dim] + 1; - - if (OutputShape <= 0) { - return false; - } - - Parameters->OutputShape[dim] = size_t(OutputShape); - InputSize *= Parameters->InputShape[dim]; OutputSize *= Parameters->OutputShape[dim]; K *= Parameters->KernelShape[dim]; @@ -1290,7 +1141,7 @@ Return Value: Parameters->u.GemmDirect.TransB = CblasNoTrans; Parameters->u.GemmDirect.ldb = OutputSize; - return true; + return; } if (Dimensions == 2 && AllDilationsAreOne && InputChannels == 1) { @@ -1306,7 +1157,7 @@ Return Value: Parameters->u.GemmDirect.TransB = CblasTrans; Parameters->u.GemmDirect.ldb = Parameters->InputShape[1]; - return true; + return; } if (Parameters->KernelShape[0] == Parameters->InputShape[0] && @@ -1316,7 +1167,7 @@ Return Value: Parameters->u.GemmDirect.TransB = CblasNoTrans; Parameters->u.GemmDirect.ldb = Parameters->InputShape[1]; - return true; + return; } } } @@ -1343,16 +1194,16 @@ Return Value: // threaded path. // - uint32_t TargetThreadCount; + int32_t TargetThreadCount; double Complexity = double(FilterCount) * double(OutputSize) * double(K); if (Complexity < double(MLAS_SGEMM_THREAD_COMPLEXITY * MLAS_MAXIMUM_THREAD_COUNT)) { - TargetThreadCount = uint32_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; + TargetThreadCount = int32_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; } else { TargetThreadCount = MLAS_MAXIMUM_THREAD_COUNT; } - uint32_t MaximumThreadCount = MlasPlatform.GetMaximumThreadCount(); + int32_t MaximumThreadCount = MlasPlatform.GetMaximumThreadCount(); if (TargetThreadCount >= MaximumThreadCount) { TargetThreadCount = MaximumThreadCount; @@ -1384,6 +1235,4 @@ Return Value: *WorkingBufferSize = TargetThreadCount * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD; } - - return true; } diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index bdefbf46ad..52f8ea3b85 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -82,22 +82,24 @@ Abstract: #if defined(_OPENMP) #include #define MLAS_USE_OPENMP +#define MLAS_HAS_THREADING_SUPPORT #elif defined(_WIN32) #define MLAS_USE_WIN32_THREADPOOL +#define MLAS_HAS_THREADING_SUPPORT #endif // // Define the maximum number of threads supported by this implementation. // -#define MLAS_MAXIMUM_THREAD_COUNT 16 +#define MLAS_MAXIMUM_THREAD_COUNT 16 // // Define the default strides to step through slices of the input matrices. // -#define MLAS_SGEMM_STRIDEN 128 -#define MLAS_SGEMM_STRIDEK 128 +#define MLAS_SGEMM_STRIDEN 128 +#define MLAS_SGEMM_STRIDEK 128 // // Define the alignment for segmenting a SGEMM operation across multiple @@ -108,7 +110,7 @@ Abstract: // the effort at this time. // -#define MLAS_SGEMM_STRIDEN_THREAD_ALIGN 16 +#define MLAS_SGEMM_STRIDEN_THREAD_ALIGN 16 // // Define the prototypes of the SGEMM platform optimized routines. @@ -193,12 +195,12 @@ extern "C" { // #if defined(MLAS_USE_OPENMP) -#define MLAS_SGEMM_THREAD_COMPLEXITY (64 * 1024) +#define MLAS_SGEMM_THREAD_COMPLEXITY (64 * 1024) #else #if defined(MLAS_TARGET_AMD64) -#define MLAS_SGEMM_THREAD_COMPLEXITY (2 * 1024 * 1024) +#define MLAS_SGEMM_THREAD_COMPLEXITY (2 * 1024 * 1024) #else -#define MLAS_SGEMM_THREAD_COMPLEXITY (1 * 1024 * 1024) +#define MLAS_SGEMM_THREAD_COMPLEXITY (1 * 1024 * 1024) #endif #endif @@ -243,10 +245,10 @@ struct MLAS_PLATFORM { #endif #if defined(MLAS_USE_WIN32_THREADPOOL) - uint32_t MaximumThreadCount; + int32_t MaximumThreadCount; #endif - uint32_t + int32_t GetMaximumThreadCount( void ) @@ -263,6 +265,26 @@ struct MLAS_PLATFORM { extern MLAS_PLATFORM MlasPlatform; +// +// Threading support. +// + +typedef +void +(MLAS_THREADED_ROUTINE)( + void* Context, + int32_t Index + ); + +typedef MLAS_THREADED_ROUTINE* PMLAS_THREADED_ROUTINE; + +void +MlasExecuteThreaded( + PMLAS_THREADED_ROUTINE ThreadedRoutine, + void* Context, + int32_t Iterations + ); + // // Define the missing ARM64 NEON intrinsic macros from arm64_neon.h that enable // cross-compiler support. diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 79ae776554..136517d72a 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -137,7 +137,7 @@ Return Value: GetSystemInfo(&SystemInfo); if (SystemInfo.dwNumberOfProcessors <= MLAS_MAXIMUM_THREAD_COUNT) { - this->MaximumThreadCount = SystemInfo.dwNumberOfProcessors; + this->MaximumThreadCount = int32_t(SystemInfo.dwNumberOfProcessors); } else { this->MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT; } diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp index 9d69ba6b57..08fd2582fd 100644 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ b/onnxruntime/core/mlas/lib/sgemm.cpp @@ -32,8 +32,6 @@ Abstract: // struct MLAS_SGEMM_WORK_BLOCK { -#if defined(MLAS_USE_WIN32_THREADPOOL) - volatile LONG Counter; CBLAS_TRANSPOSE TransA; CBLAS_TRANSPOSE TransB; size_t K; @@ -42,7 +40,6 @@ struct MLAS_SGEMM_WORK_BLOCK { size_t ldc; float alpha; float beta; -#endif struct SEGMENT { size_t M; size_t N; @@ -1034,14 +1031,10 @@ Return Value: } } -#if defined(MLAS_USE_WIN32_THREADPOOL) - void -CALLBACK -MlasSgemmWorkCallback( - PTP_CALLBACK_INSTANCE Instance, +MlasSgemmOperationThreaded( void* Context, - PTP_WORK WorkObject + int32_t Index ) /*++ @@ -1052,11 +1045,9 @@ Routine Description: Arguments: - Instance - Supplies the callback instance object. + Context - Supplies the pointer to the context for the threaded operation. - Context - Supplies the pointer to the parameters for the SGEMM operation. - - WorkObject - Supplies the threadpool work object. + Index - Supplies the current index of the threaded operation. Return Value: @@ -1064,13 +1055,8 @@ Return Value: --*/ { - MLAS_UNREFERENCED_PARAMETER(Instance); - MLAS_UNREFERENCED_PARAMETER(WorkObject); - MLAS_SGEMM_WORK_BLOCK* WorkBlock = (MLAS_SGEMM_WORK_BLOCK*)Context; - LONG Index = InterlockedIncrement(&WorkBlock->Counter) - 1; - MLAS_SGEMM_WORK_BLOCK::SEGMENT* Segment = &WorkBlock->Segments[Index]; MlasSgemmOperation(WorkBlock->TransA, WorkBlock->TransB, Segment->M, @@ -1079,8 +1065,6 @@ Return Value: WorkBlock->ldc); } -#endif - inline bool MlasSgemmTryMultithread( @@ -1142,10 +1126,10 @@ Return Value: --*/ { -#if defined(MLAS_USE_WIN32_THREADPOOL) || defined(MLAS_USE_OPENMP) +#if defined(MLAS_HAS_THREADING_SUPPORT) MLAS_SGEMM_WORK_BLOCK WorkBlock; - uint32_t TargetThreadCount; + int32_t TargetThreadCount; // // Compute the number of target threads given the complexity of the SGEMM @@ -1155,12 +1139,12 @@ Return Value: double Complexity = double(M) * double(N) * double(K); if (Complexity < double(MLAS_SGEMM_THREAD_COMPLEXITY * MLAS_MAXIMUM_THREAD_COUNT)) { - TargetThreadCount = uint32_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; + TargetThreadCount = int32_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; } else { TargetThreadCount = MLAS_MAXIMUM_THREAD_COUNT; } - uint32_t MaximumThreadCount = MlasPlatform.GetMaximumThreadCount(); + int32_t MaximumThreadCount = MlasPlatform.GetMaximumThreadCount(); if (TargetThreadCount >= MaximumThreadCount) { TargetThreadCount = MaximumThreadCount; @@ -1170,23 +1154,10 @@ Return Value: return false; } -#if defined(MLAS_USE_WIN32_THREADPOOL) - - // - // Create an object to submit work to the threadpool. - // - - PTP_WORK WorkObject = CreateThreadpoolWork(MlasSgemmWorkCallback, &WorkBlock, nullptr); - - if (WorkObject == nullptr) { - return false; - } - // // Initialize the common fields of the work block. // - WorkBlock.Counter = 0; WorkBlock.TransA = TransA; WorkBlock.TransB = TransB; WorkBlock.K = K; @@ -1196,13 +1167,11 @@ Return Value: WorkBlock.alpha = alpha; WorkBlock.beta = beta; -#endif - // // Segment the operation across multiple threads. // - uint32_t Index = 0; + int32_t Index = 0; if (N > M) { @@ -1231,18 +1200,6 @@ Return Value: WorkBlock.Segments[Index].B = B + n * pldb; WorkBlock.Segments[Index].C = C + n; -#if defined(MLAS_USE_WIN32_THREADPOOL) - - // - // Execute one of the segments on a worker thread. - // - - if (Index > 0) { - SubmitThreadpoolWork(WorkObject); - } - -#endif - Index++; } @@ -1270,49 +1227,11 @@ Return Value: WorkBlock.Segments[Index].B = B; WorkBlock.Segments[Index].C = C + m * ldc; -#if defined(MLAS_USE_WIN32_THREADPOOL) - - // - // Execute one of the segments on a worker thread. - // - - if (Index > 0) { - SubmitThreadpoolWork(WorkObject); - } - -#endif - Index++; } } -#if defined(MLAS_USE_OPENMP) - - #pragma omp parallel for - for (int32_t tid = 0; tid < int32_t(Index); tid++) { - - MLAS_SGEMM_WORK_BLOCK::SEGMENT* Segment = &WorkBlock.Segments[tid]; - - MlasSgemmOperation(TransA, TransB, Segment->M, Segment->N, K, alpha, - Segment->A, lda, Segment->B, ldb, beta, Segment->C, ldc); - } - -#elif defined(MLAS_USE_WIN32_THREADPOOL) - - // - // Execute the remaining segment on this thread. - // - - MlasSgemmWorkCallback(nullptr, &WorkBlock, WorkObject); - - // - // Wait for the worker threads to complete. - // - - WaitForThreadpoolWorkCallbacks(WorkObject, FALSE); - CloseThreadpoolWork(WorkObject); - -#endif + MlasExecuteThreaded(MlasSgemmOperationThreaded, &WorkBlock, Index); return true; diff --git a/onnxruntime/core/mlas/lib/threading.cpp b/onnxruntime/core/mlas/lib/threading.cpp new file mode 100644 index 0000000000..b2c27ce697 --- /dev/null +++ b/onnxruntime/core/mlas/lib/threading.cpp @@ -0,0 +1,135 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + threading.cpp + +Abstract: + + This module implements platform specific threading support. + +--*/ + +#include "mlasi.h" + +#if defined(MLAS_USE_WIN32_THREADPOOL) + +// +// Define the parameters to execute threaded work using the Windows thread pool +// library. +// + +struct MLAS_THREADED_WORK_BLOCK { + volatile LONG Counter; + PMLAS_THREADED_ROUTINE ThreadedRoutine; + void* Context; +}; + +void +CALLBACK +MlasThreadedWorkCallback( + PTP_CALLBACK_INSTANCE Instance, + void* Context, + PTP_WORK WorkObject + ) +/*++ + +Routine Description: + + This routine is invoked from a worker thread to execute one iteration of a + batch of threaded work. + +Arguments: + + Instance - Supplies the callback instance object. + + Context - Supplies the pointer to the parameters for the operation. + + WorkObject - Supplies the threadpool work object. + +Return Value: + + None. + +--*/ +{ + MLAS_THREADED_WORK_BLOCK* WorkBlock = (MLAS_THREADED_WORK_BLOCK*)Context; + + LONG Index = InterlockedIncrement(&WorkBlock->Counter) - 1; + + WorkBlock->ThreadedRoutine(WorkBlock->Context, Index); +} + +#endif + +void +MlasExecuteThreaded( + MLAS_THREADED_ROUTINE ThreadedRoutine, + void* Context, + int32_t Iterations + ) +{ + // + // Execute the routine directly if only one iteration is specified. + // + + if (Iterations == 1) { + ThreadedRoutine(Context, 0); + return; + } + +#if defined(MLAS_USE_WIN32_THREADPOOL) + + // + // Schedule the threaded iterations using a work object. + // + + MLAS_THREADED_WORK_BLOCK WorkBlock; + + PTP_WORK WorkObject = CreateThreadpoolWork(MlasThreadedWorkCallback, &WorkBlock, nullptr); + + if (WorkObject != nullptr) { + + WorkBlock.Counter = 0; + WorkBlock.ThreadedRoutine = ThreadedRoutine; + WorkBlock.Context = Context; + + for (int32_t tid = 1; tid < Iterations; tid++) { + SubmitThreadpoolWork(WorkObject); + } + + // + // Execute the remaining iteration on this thread. + // + + ThreadedRoutine(Context, Iterations - 1); + + // + // Wait for the work object callbacks to complete. + // + + WaitForThreadpoolWorkCallbacks(WorkObject, FALSE); + CloseThreadpoolWork(WorkObject); + + return; + } + + // + // Fallback to a serialized implementation. + // + +#endif + + // + // Execute the routine for the specified number of iterations. + // + + #pragma omp parallel for + for (int32_t tid = 0; tid < Iterations; tid++) { + ThreadedRoutine(Context, tid); + } +} diff --git a/onnxruntime/core/providers/cpu/nn/conv_impl.h b/onnxruntime/core/providers/cpu/nn/conv_impl.h index 44fb16c353..e23ab9d715 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_impl.h +++ b/onnxruntime/core/providers/cpu/nn/conv_impl.h @@ -206,30 +206,31 @@ Status Conv::Compute(OpKernelContext* context) const { Tensor* Y = context->Output(0, TensorShape(Y_dims)); TensorShape output_shape = Y->Shape().Slice(2); - const int64_t input_image_size = input_shape.Size(); - const int64_t output_image_size = output_shape.Size(); - const int64_t kernel_size = TensorShape(kernel_shape).Size(); - AllocatorPtr alloc; ONNXRUNTIME_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); const float* Xdata = X->template Data(); float* Ydata = Y->template MutableData(); - MLAS_CONV_PARAMETERS Parameters; - size_t WorkingBufferSize; - if (MlasConvPrepare(&Parameters, - kernel_shape.size(), - static_cast(N), - static_cast(group_), - static_cast(C / group_), - input_shape.GetDims().data(), - kernel_shape.data(), - dilations.data(), - pads.data(), - strides.data(), - static_cast(M / group_), - &WorkingBufferSize)) { + const size_t kernel_rank = kernel_shape.size(); + + if (kernel_rank == 2 || kernel_rank == 3) { + MLAS_CONV_PARAMETERS Parameters; + size_t WorkingBufferSize; + MlasConvPrepare(&Parameters, + kernel_rank, + static_cast(N), + static_cast(group_), + static_cast(C / group_), + input_shape.GetDims().data(), + kernel_shape.data(), + dilations.data(), + pads.data(), + strides.data(), + output_shape.GetDims().data(), + static_cast(M / group_), + &WorkingBufferSize); + auto working_data = WorkingBufferSize > 0 ? alloc->Alloc(sizeof(float) * WorkingBufferSize) : nullptr; BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc)); @@ -240,6 +241,9 @@ Status Conv::Compute(OpKernelContext* context) const { static_cast(working_buffer.get()), Ydata); } else { + const int64_t input_image_size = input_shape.Size(); + const int64_t output_image_size = output_shape.Size(); + const int64_t kernel_size = TensorShape(kernel_shape).Size(); const int64_t X_offset = C / group_ * input_image_size; const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_; const int64_t W_offset = W->Shape().Size() / group_; diff --git a/onnxruntime/test/mlas/unittest.cpp b/onnxruntime/test/mlas/unittest.cpp index 2fd079f090..13902f460e 100644 --- a/onnxruntime/test/mlas/unittest.cpp +++ b/onnxruntime/test/mlas/unittest.cpp @@ -495,25 +495,24 @@ TrialConv2D( int64_t DilationShape[] = { int64_t(DilationHeight), int64_t(DilationWidth) }; int64_t Padding[] = { int64_t(PaddingLeftHeight), int64_t(PaddingLeftWidth), int64_t(PaddingRightHeight), int64_t(PaddingRightWidth) }; int64_t StrideShape[] = { int64_t(StrideHeight), int64_t(StrideWidth) }; + int64_t OutputShape[] = { OutputHeight64, OutputWidth64 }; MLAS_CONV_PARAMETERS Parameters; size_t WorkingBufferSize; - if (!MlasConvPrepare(&Parameters, - 2, - BatchCount, - GroupCount, - InputChannels, - InputShape, - KernelShape, - DilationShape, - Padding, - StrideShape, - FilterCount, - &WorkingBufferSize)) { - printf("MlasConvPrepare failed!!!\n"); - return; - } + MlasConvPrepare(&Parameters, + 2, + BatchCount, + GroupCount, + InputChannels, + InputShape, + KernelShape, + DilationShape, + Padding, + StrideShape, + OutputShape, + FilterCount, + &WorkingBufferSize); size_t OutputHeight = size_t(OutputHeight64); size_t OutputWidth = size_t(OutputWidth64); @@ -1236,9 +1235,9 @@ main( ) { // ExecuteSgemmTests(); -// ExecuteConvTests(); - ExecutePool2DTests(); - ExecutePool3DTests(); + ExecuteConvTests(); +// ExecutePool2DTests(); +// ExecutePool3DTests(); // EvaluateThreadingPerformance(); return 0;