mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
refactor threading (#110)
This commit is contained in:
parent
6d80253502
commit
3c7c1068e7
9 changed files with 256 additions and 326 deletions
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
set(mlas_common_srcs
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/platform.cpp
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/threading.cpp
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/sgemm.cpp
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/convolve.cpp
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/pooling.cpp
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ struct MLAS_CONV_PARAMETERS {
|
|||
} u;
|
||||
};
|
||||
|
||||
bool
|
||||
void
|
||||
MLASCALL
|
||||
MlasConvPrepare(
|
||||
MLAS_CONV_PARAMETERS* Parameters,
|
||||
|
|
@ -115,6 +115,7 @@ MlasConvPrepare(
|
|||
const int64_t* DilationShape,
|
||||
const int64_t* Padding,
|
||||
const int64_t* StrideShape,
|
||||
const int64_t* OutputShape,
|
||||
size_t FilterCount,
|
||||
size_t* WorkingBufferSize
|
||||
);
|
||||
|
|
|
|||
|
|
@ -29,20 +29,17 @@ Abstract:
|
|||
//
|
||||
|
||||
struct MLAS_CONV_WORK_BLOCK {
|
||||
#if defined(MLAS_USE_WIN32_THREADPOOL)
|
||||
volatile LONG Counter;
|
||||
const MLAS_CONV_PARAMETERS* Parameters;
|
||||
const float* Input;
|
||||
const float* Filter;
|
||||
const float* Bias;
|
||||
float* WorkingBuffer;
|
||||
float* Output;
|
||||
#endif
|
||||
struct SEGMENT {
|
||||
size_t StartN;
|
||||
size_t CountN;
|
||||
} Segments[MLAS_MAXIMUM_THREAD_COUNT];
|
||||
uint32_t TargetThreadCount;
|
||||
int32_t TargetThreadCount;
|
||||
};
|
||||
|
||||
void
|
||||
|
|
@ -610,14 +607,10 @@ Return Value:
|
|||
}
|
||||
}
|
||||
|
||||
#if defined(MLAS_USE_WIN32_THREADPOOL)
|
||||
|
||||
void
|
||||
CALLBACK
|
||||
MlasConvWorkCallback(
|
||||
PTP_CALLBACK_INSTANCE Instance,
|
||||
MlasConvOperationThreaded(
|
||||
void* Context,
|
||||
PTP_WORK WorkObject
|
||||
int32_t Index
|
||||
)
|
||||
/*++
|
||||
|
||||
|
|
@ -628,11 +621,9 @@ Routine Description:
|
|||
|
||||
Arguments:
|
||||
|
||||
Instance - Supplies the callback instance object.
|
||||
Context - Supplies the pointer to the context for the threaded operation.
|
||||
|
||||
Context - Supplies the pointer to the parameters for the SGEMM operation.
|
||||
|
||||
WorkObject - Supplies the threadpool work object.
|
||||
Index - Supplies the current index of the threaded operation.
|
||||
|
||||
Return Value:
|
||||
|
||||
|
|
@ -640,13 +631,8 @@ Return Value:
|
|||
|
||||
--*/
|
||||
{
|
||||
MLAS_UNREFERENCED_PARAMETER(Instance);
|
||||
MLAS_UNREFERENCED_PARAMETER(WorkObject);
|
||||
|
||||
MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context;
|
||||
|
||||
LONG Index = InterlockedIncrement(&WorkBlock->Counter) - 1;
|
||||
|
||||
MLAS_CONV_WORK_BLOCK::SEGMENT* Segment = &WorkBlock->Segments[Index];
|
||||
|
||||
float* ColumnBuffer =
|
||||
|
|
@ -658,11 +644,9 @@ Return Value:
|
|||
}
|
||||
|
||||
void
|
||||
CALLBACK
|
||||
MlasConvGemmDirectWorkCallback(
|
||||
PTP_CALLBACK_INSTANCE Instance,
|
||||
MlasConvGemmDirectThreaded(
|
||||
void* Context,
|
||||
PTP_WORK WorkObject
|
||||
int32_t Index
|
||||
)
|
||||
/*++
|
||||
|
||||
|
|
@ -673,11 +657,9 @@ Routine Description:
|
|||
|
||||
Arguments:
|
||||
|
||||
Instance - Supplies the callback instance object.
|
||||
Context - Supplies the pointer to the context for the threaded operation.
|
||||
|
||||
Context - Supplies the pointer to the parameters for the SGEMM operation.
|
||||
|
||||
WorkObject - Supplies the threadpool work object.
|
||||
Index - Supplies the current index of the threaded operation.
|
||||
|
||||
Return Value:
|
||||
|
||||
|
|
@ -685,13 +667,8 @@ Return Value:
|
|||
|
||||
--*/
|
||||
{
|
||||
MLAS_UNREFERENCED_PARAMETER(Instance);
|
||||
MLAS_UNREFERENCED_PARAMETER(WorkObject);
|
||||
|
||||
MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context;
|
||||
|
||||
LONG Index = InterlockedIncrement(&WorkBlock->Counter) - 1;
|
||||
|
||||
const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters;
|
||||
|
||||
//
|
||||
|
|
@ -755,8 +732,6 @@ Return Value:
|
|||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
inline
|
||||
bool
|
||||
MlasConvTryMultithread(
|
||||
|
|
@ -798,7 +773,7 @@ Return Value:
|
|||
--*/
|
||||
{
|
||||
|
||||
#if defined(MLAS_USE_WIN32_THREADPOOL) || defined(MLAS_USE_OPENMP)
|
||||
#if defined(MLAS_HAS_THREADING_SUPPORT)
|
||||
|
||||
MLAS_CONV_WORK_BLOCK WorkBlock;
|
||||
|
||||
|
|
@ -809,23 +784,10 @@ Return Value:
|
|||
return false;
|
||||
}
|
||||
|
||||
#if defined(MLAS_USE_WIN32_THREADPOOL)
|
||||
|
||||
//
|
||||
// Create an object to submit work to the threadpool.
|
||||
//
|
||||
|
||||
PTP_WORK WorkObject = CreateThreadpoolWork(MlasConvWorkCallback, &WorkBlock, nullptr);
|
||||
|
||||
if (WorkObject == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
//
|
||||
// Initialize the common fields of the work block.
|
||||
//
|
||||
|
||||
WorkBlock.Counter = 0;
|
||||
WorkBlock.Parameters = Parameters;
|
||||
WorkBlock.Input = Input;
|
||||
WorkBlock.Filter = Filter;
|
||||
|
|
@ -833,13 +795,11 @@ Return Value:
|
|||
WorkBlock.WorkingBuffer = WorkingBuffer;
|
||||
WorkBlock.Output = Output;
|
||||
|
||||
#endif
|
||||
|
||||
//
|
||||
// Segment the operation across multiple threads.
|
||||
//
|
||||
|
||||
uint32_t Index = 0;
|
||||
int32_t Index = 0;
|
||||
size_t SegmentCountN;
|
||||
|
||||
for (size_t SegmentStartN = 0; SegmentStartN < OutputSize; SegmentStartN += SegmentCountN) {
|
||||
|
|
@ -853,51 +813,10 @@ Return Value:
|
|||
WorkBlock.Segments[Index].StartN = SegmentStartN;
|
||||
WorkBlock.Segments[Index].CountN = SegmentCountN;
|
||||
|
||||
#if defined(MLAS_USE_WIN32_THREADPOOL)
|
||||
|
||||
//
|
||||
// Execute one of the segments on a worker thread.
|
||||
//
|
||||
|
||||
if (Index > 0) {
|
||||
SubmitThreadpoolWork(WorkObject);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
Index++;
|
||||
}
|
||||
|
||||
#if defined(MLAS_USE_OPENMP)
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int32_t tid = 0; tid < int32_t(Index); tid++) {
|
||||
|
||||
MLAS_CONV_WORK_BLOCK::SEGMENT* Segment = &WorkBlock.Segments[tid];
|
||||
|
||||
float* ColumnBuffer =
|
||||
WorkingBuffer + tid * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD;
|
||||
|
||||
MlasConvOperation(Parameters, Input, Filter, Bias, ColumnBuffer, Output,
|
||||
Segment->StartN, Segment->CountN);
|
||||
}
|
||||
|
||||
#elif defined(MLAS_USE_WIN32_THREADPOOL)
|
||||
|
||||
//
|
||||
// Execute the remaining segment on this thread.
|
||||
//
|
||||
|
||||
MlasConvWorkCallback(nullptr, &WorkBlock, WorkObject);
|
||||
|
||||
//
|
||||
// Wait for the worker threads to complete.
|
||||
//
|
||||
|
||||
WaitForThreadpoolWorkCallbacks(WorkObject, FALSE);
|
||||
CloseThreadpoolWork(WorkObject);
|
||||
|
||||
#endif
|
||||
MlasExecuteThreaded(MlasConvOperationThreaded, &WorkBlock, Index);
|
||||
|
||||
return true;
|
||||
|
||||
|
|
@ -971,7 +890,7 @@ Return Value:
|
|||
|
||||
const MLAS_CONV_ALGORITHM Algorithm = Parameters->Algorithm;
|
||||
|
||||
#if defined(MLAS_USE_WIN32_THREADPOOL) || defined(MLAS_USE_OPENMP)
|
||||
#if defined(MLAS_HAS_THREADING_SUPPORT)
|
||||
|
||||
//
|
||||
// Schedule batches of GEMMs across multiple threads.
|
||||
|
|
@ -981,76 +900,25 @@ Return Value:
|
|||
|
||||
const size_t BatchGroupCount = BatchCount * GroupCount;
|
||||
|
||||
#if defined(MLAS_USE_WIN32_THREADPOOL)
|
||||
|
||||
uint32_t TargetThreadCount = MlasPlatform.GetMaximumThreadCount();
|
||||
int32_t TargetThreadCount = MlasPlatform.GetMaximumThreadCount();
|
||||
|
||||
if (TargetThreadCount >= BatchGroupCount) {
|
||||
TargetThreadCount = uint32_t(BatchGroupCount);
|
||||
TargetThreadCount = int32_t(BatchGroupCount);
|
||||
}
|
||||
|
||||
if (TargetThreadCount > 1) {
|
||||
MLAS_CONV_WORK_BLOCK WorkBlock;
|
||||
|
||||
MLAS_CONV_WORK_BLOCK WorkBlock;
|
||||
WorkBlock.Parameters = Parameters;
|
||||
WorkBlock.Input = Input;
|
||||
WorkBlock.Filter = Filter;
|
||||
WorkBlock.Bias = Bias;
|
||||
WorkBlock.WorkingBuffer = nullptr;
|
||||
WorkBlock.Output = Output;
|
||||
WorkBlock.TargetThreadCount = TargetThreadCount;
|
||||
|
||||
PTP_WORK WorkObject = CreateThreadpoolWork(MlasConvGemmDirectWorkCallback, &WorkBlock, nullptr);
|
||||
|
||||
if (WorkObject != nullptr) {
|
||||
|
||||
WorkBlock.Counter = 0;
|
||||
WorkBlock.Parameters = Parameters;
|
||||
WorkBlock.Input = Input;
|
||||
WorkBlock.Filter = Filter;
|
||||
WorkBlock.Bias = Bias;
|
||||
WorkBlock.WorkingBuffer = nullptr;
|
||||
WorkBlock.Output = Output;
|
||||
WorkBlock.TargetThreadCount = TargetThreadCount;
|
||||
|
||||
for (uint32_t Index = 1; Index < TargetThreadCount; Index++) {
|
||||
SubmitThreadpoolWork(WorkObject);
|
||||
}
|
||||
|
||||
MlasConvGemmDirectWorkCallback(nullptr, &WorkBlock, WorkObject);
|
||||
|
||||
WaitForThreadpoolWorkCallbacks(WorkObject, FALSE);
|
||||
CloseThreadpoolWork(WorkObject);
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t bg = 0; bg < int64_t(BatchGroupCount); bg++) {
|
||||
|
||||
size_t group = size_t(bg % GroupCount);
|
||||
|
||||
const float* input = Input + bg * InputGroupSize;
|
||||
const float* filter = Filter + group * FilterGroupSize;
|
||||
float* output = Output + bg * OutputGroupSize;
|
||||
|
||||
//
|
||||
// Invoke the non-threaded GEMM directly with the input tensor.
|
||||
//
|
||||
|
||||
MlasSgemmOperation(CblasNoTrans, Parameters->u.GemmDirect.TransB, FilterCount,
|
||||
OutputSize, K, 1.0f, filter, K, input, Parameters->u.GemmDirect.ldb, 0.0f,
|
||||
output, OutputSize);
|
||||
|
||||
//
|
||||
// Add the optional bias vector.
|
||||
//
|
||||
|
||||
if (Bias != nullptr) {
|
||||
MlasBiasAdd(Bias + group * FilterCount, FilterCount, output, OutputSize, OutputSize);
|
||||
}
|
||||
}
|
||||
MlasExecuteThreaded(MlasConvGemmDirectThreaded, &WorkBlock, TargetThreadCount);
|
||||
|
||||
return;
|
||||
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
@ -1152,7 +1020,7 @@ Return Value:
|
|||
}
|
||||
}
|
||||
|
||||
bool
|
||||
void
|
||||
MLASCALL
|
||||
MlasConvPrepare(
|
||||
MLAS_CONV_PARAMETERS* Parameters,
|
||||
|
|
@ -1165,6 +1033,7 @@ MlasConvPrepare(
|
|||
const int64_t* DilationShape,
|
||||
const int64_t* Padding,
|
||||
const int64_t* StrideShape,
|
||||
const int64_t* OutputShape,
|
||||
size_t FilterCount,
|
||||
size_t* WorkingBufferSize
|
||||
)
|
||||
|
|
@ -1200,6 +1069,8 @@ Arguments:
|
|||
|
||||
StrideShape - Supplies the shape of the stride.
|
||||
|
||||
OutputShape - Supplies the shape of the output tensor.
|
||||
|
||||
FilterCount - Supplies the number of rows of the filter matrix per group.
|
||||
|
||||
WorkingBufferSize - Receives the number of elements to allocate for the
|
||||
|
|
@ -1207,20 +1078,12 @@ Arguments:
|
|||
|
||||
Return Value:
|
||||
|
||||
Returns true if implementation can support this operation.
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
//
|
||||
// Support only 2D or 3D convolutions.
|
||||
//
|
||||
|
||||
if (Dimensions != 2 && Dimensions != 3) {
|
||||
return false;
|
||||
}
|
||||
|
||||
//
|
||||
// Build the convolution parameters.
|
||||
// Save the convolution parameters.
|
||||
//
|
||||
|
||||
Parameters->Dimensions = Dimensions;
|
||||
|
|
@ -1240,25 +1103,13 @@ Return Value:
|
|||
for (size_t dim = 0; dim < Dimensions; dim++) {
|
||||
|
||||
Parameters->InputShape[dim] = size_t(InputShape[dim]);
|
||||
Parameters->OutputShape[dim] = size_t(OutputShape[dim]);
|
||||
Parameters->KernelShape[dim] = size_t(KernelShape[dim]);
|
||||
Parameters->DilationShape[dim] = size_t(DilationShape[dim]);
|
||||
Parameters->Padding[dim] = size_t(Padding[dim]);
|
||||
Parameters->Padding[dim + Dimensions] = size_t(Padding[dim + Dimensions]);
|
||||
Parameters->StrideShape[dim] = size_t(StrideShape[dim]);
|
||||
|
||||
//
|
||||
// Compute the output shape.
|
||||
//
|
||||
|
||||
int64_t OutputShape = (InputShape[dim] + Padding[dim] + Padding[dim + Dimensions] -
|
||||
(DilationShape[dim] * (KernelShape[dim] - 1) + 1)) / StrideShape[dim] + 1;
|
||||
|
||||
if (OutputShape <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Parameters->OutputShape[dim] = size_t(OutputShape);
|
||||
|
||||
InputSize *= Parameters->InputShape[dim];
|
||||
OutputSize *= Parameters->OutputShape[dim];
|
||||
K *= Parameters->KernelShape[dim];
|
||||
|
|
@ -1290,7 +1141,7 @@ Return Value:
|
|||
Parameters->u.GemmDirect.TransB = CblasNoTrans;
|
||||
Parameters->u.GemmDirect.ldb = OutputSize;
|
||||
|
||||
return true;
|
||||
return;
|
||||
}
|
||||
|
||||
if (Dimensions == 2 && AllDilationsAreOne && InputChannels == 1) {
|
||||
|
|
@ -1306,7 +1157,7 @@ Return Value:
|
|||
Parameters->u.GemmDirect.TransB = CblasTrans;
|
||||
Parameters->u.GemmDirect.ldb = Parameters->InputShape[1];
|
||||
|
||||
return true;
|
||||
return;
|
||||
}
|
||||
|
||||
if (Parameters->KernelShape[0] == Parameters->InputShape[0] &&
|
||||
|
|
@ -1316,7 +1167,7 @@ Return Value:
|
|||
Parameters->u.GemmDirect.TransB = CblasNoTrans;
|
||||
Parameters->u.GemmDirect.ldb = Parameters->InputShape[1];
|
||||
|
||||
return true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1343,16 +1194,16 @@ Return Value:
|
|||
// threaded path.
|
||||
//
|
||||
|
||||
uint32_t TargetThreadCount;
|
||||
int32_t TargetThreadCount;
|
||||
double Complexity = double(FilterCount) * double(OutputSize) * double(K);
|
||||
|
||||
if (Complexity < double(MLAS_SGEMM_THREAD_COMPLEXITY * MLAS_MAXIMUM_THREAD_COUNT)) {
|
||||
TargetThreadCount = uint32_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1;
|
||||
TargetThreadCount = int32_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1;
|
||||
} else {
|
||||
TargetThreadCount = MLAS_MAXIMUM_THREAD_COUNT;
|
||||
}
|
||||
|
||||
uint32_t MaximumThreadCount = MlasPlatform.GetMaximumThreadCount();
|
||||
int32_t MaximumThreadCount = MlasPlatform.GetMaximumThreadCount();
|
||||
|
||||
if (TargetThreadCount >= MaximumThreadCount) {
|
||||
TargetThreadCount = MaximumThreadCount;
|
||||
|
|
@ -1384,6 +1235,4 @@ Return Value:
|
|||
|
||||
*WorkingBufferSize = TargetThreadCount * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -82,22 +82,24 @@ Abstract:
|
|||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#define MLAS_USE_OPENMP
|
||||
#define MLAS_HAS_THREADING_SUPPORT
|
||||
#elif defined(_WIN32)
|
||||
#define MLAS_USE_WIN32_THREADPOOL
|
||||
#define MLAS_HAS_THREADING_SUPPORT
|
||||
#endif
|
||||
|
||||
//
|
||||
// Define the maximum number of threads supported by this implementation.
|
||||
//
|
||||
|
||||
#define MLAS_MAXIMUM_THREAD_COUNT 16
|
||||
#define MLAS_MAXIMUM_THREAD_COUNT 16
|
||||
|
||||
//
|
||||
// Define the default strides to step through slices of the input matrices.
|
||||
//
|
||||
|
||||
#define MLAS_SGEMM_STRIDEN 128
|
||||
#define MLAS_SGEMM_STRIDEK 128
|
||||
#define MLAS_SGEMM_STRIDEN 128
|
||||
#define MLAS_SGEMM_STRIDEK 128
|
||||
|
||||
//
|
||||
// Define the alignment for segmenting a SGEMM operation across multiple
|
||||
|
|
@ -108,7 +110,7 @@ Abstract:
|
|||
// the effort at this time.
|
||||
//
|
||||
|
||||
#define MLAS_SGEMM_STRIDEN_THREAD_ALIGN 16
|
||||
#define MLAS_SGEMM_STRIDEN_THREAD_ALIGN 16
|
||||
|
||||
//
|
||||
// Define the prototypes of the SGEMM platform optimized routines.
|
||||
|
|
@ -193,12 +195,12 @@ extern "C" {
|
|||
//
|
||||
|
||||
#if defined(MLAS_USE_OPENMP)
|
||||
#define MLAS_SGEMM_THREAD_COMPLEXITY (64 * 1024)
|
||||
#define MLAS_SGEMM_THREAD_COMPLEXITY (64 * 1024)
|
||||
#else
|
||||
#if defined(MLAS_TARGET_AMD64)
|
||||
#define MLAS_SGEMM_THREAD_COMPLEXITY (2 * 1024 * 1024)
|
||||
#define MLAS_SGEMM_THREAD_COMPLEXITY (2 * 1024 * 1024)
|
||||
#else
|
||||
#define MLAS_SGEMM_THREAD_COMPLEXITY (1 * 1024 * 1024)
|
||||
#define MLAS_SGEMM_THREAD_COMPLEXITY (1 * 1024 * 1024)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
|
@ -243,10 +245,10 @@ struct MLAS_PLATFORM {
|
|||
#endif
|
||||
|
||||
#if defined(MLAS_USE_WIN32_THREADPOOL)
|
||||
uint32_t MaximumThreadCount;
|
||||
int32_t MaximumThreadCount;
|
||||
#endif
|
||||
|
||||
uint32_t
|
||||
int32_t
|
||||
GetMaximumThreadCount(
|
||||
void
|
||||
)
|
||||
|
|
@ -263,6 +265,26 @@ struct MLAS_PLATFORM {
|
|||
|
||||
extern MLAS_PLATFORM MlasPlatform;
|
||||
|
||||
//
|
||||
// Threading support.
|
||||
//
|
||||
|
||||
typedef
|
||||
void
|
||||
(MLAS_THREADED_ROUTINE)(
|
||||
void* Context,
|
||||
int32_t Index
|
||||
);
|
||||
|
||||
typedef MLAS_THREADED_ROUTINE* PMLAS_THREADED_ROUTINE;
|
||||
|
||||
void
|
||||
MlasExecuteThreaded(
|
||||
PMLAS_THREADED_ROUTINE ThreadedRoutine,
|
||||
void* Context,
|
||||
int32_t Iterations
|
||||
);
|
||||
|
||||
//
|
||||
// Define the missing ARM64 NEON intrinsic macros from arm64_neon.h that enable
|
||||
// cross-compiler support.
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ Return Value:
|
|||
GetSystemInfo(&SystemInfo);
|
||||
|
||||
if (SystemInfo.dwNumberOfProcessors <= MLAS_MAXIMUM_THREAD_COUNT) {
|
||||
this->MaximumThreadCount = SystemInfo.dwNumberOfProcessors;
|
||||
this->MaximumThreadCount = int32_t(SystemInfo.dwNumberOfProcessors);
|
||||
} else {
|
||||
this->MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,8 +32,6 @@ Abstract:
|
|||
//
|
||||
|
||||
struct MLAS_SGEMM_WORK_BLOCK {
|
||||
#if defined(MLAS_USE_WIN32_THREADPOOL)
|
||||
volatile LONG Counter;
|
||||
CBLAS_TRANSPOSE TransA;
|
||||
CBLAS_TRANSPOSE TransB;
|
||||
size_t K;
|
||||
|
|
@ -42,7 +40,6 @@ struct MLAS_SGEMM_WORK_BLOCK {
|
|||
size_t ldc;
|
||||
float alpha;
|
||||
float beta;
|
||||
#endif
|
||||
struct SEGMENT {
|
||||
size_t M;
|
||||
size_t N;
|
||||
|
|
@ -1034,14 +1031,10 @@ Return Value:
|
|||
}
|
||||
}
|
||||
|
||||
#if defined(MLAS_USE_WIN32_THREADPOOL)
|
||||
|
||||
void
|
||||
CALLBACK
|
||||
MlasSgemmWorkCallback(
|
||||
PTP_CALLBACK_INSTANCE Instance,
|
||||
MlasSgemmOperationThreaded(
|
||||
void* Context,
|
||||
PTP_WORK WorkObject
|
||||
int32_t Index
|
||||
)
|
||||
/*++
|
||||
|
||||
|
|
@ -1052,11 +1045,9 @@ Routine Description:
|
|||
|
||||
Arguments:
|
||||
|
||||
Instance - Supplies the callback instance object.
|
||||
Context - Supplies the pointer to the context for the threaded operation.
|
||||
|
||||
Context - Supplies the pointer to the parameters for the SGEMM operation.
|
||||
|
||||
WorkObject - Supplies the threadpool work object.
|
||||
Index - Supplies the current index of the threaded operation.
|
||||
|
||||
Return Value:
|
||||
|
||||
|
|
@ -1064,13 +1055,8 @@ Return Value:
|
|||
|
||||
--*/
|
||||
{
|
||||
MLAS_UNREFERENCED_PARAMETER(Instance);
|
||||
MLAS_UNREFERENCED_PARAMETER(WorkObject);
|
||||
|
||||
MLAS_SGEMM_WORK_BLOCK* WorkBlock = (MLAS_SGEMM_WORK_BLOCK*)Context;
|
||||
|
||||
LONG Index = InterlockedIncrement(&WorkBlock->Counter) - 1;
|
||||
|
||||
MLAS_SGEMM_WORK_BLOCK::SEGMENT* Segment = &WorkBlock->Segments[Index];
|
||||
|
||||
MlasSgemmOperation(WorkBlock->TransA, WorkBlock->TransB, Segment->M,
|
||||
|
|
@ -1079,8 +1065,6 @@ Return Value:
|
|||
WorkBlock->ldc);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
inline
|
||||
bool
|
||||
MlasSgemmTryMultithread(
|
||||
|
|
@ -1142,10 +1126,10 @@ Return Value:
|
|||
--*/
|
||||
{
|
||||
|
||||
#if defined(MLAS_USE_WIN32_THREADPOOL) || defined(MLAS_USE_OPENMP)
|
||||
#if defined(MLAS_HAS_THREADING_SUPPORT)
|
||||
|
||||
MLAS_SGEMM_WORK_BLOCK WorkBlock;
|
||||
uint32_t TargetThreadCount;
|
||||
int32_t TargetThreadCount;
|
||||
|
||||
//
|
||||
// Compute the number of target threads given the complexity of the SGEMM
|
||||
|
|
@ -1155,12 +1139,12 @@ Return Value:
|
|||
double Complexity = double(M) * double(N) * double(K);
|
||||
|
||||
if (Complexity < double(MLAS_SGEMM_THREAD_COMPLEXITY * MLAS_MAXIMUM_THREAD_COUNT)) {
|
||||
TargetThreadCount = uint32_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1;
|
||||
TargetThreadCount = int32_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1;
|
||||
} else {
|
||||
TargetThreadCount = MLAS_MAXIMUM_THREAD_COUNT;
|
||||
}
|
||||
|
||||
uint32_t MaximumThreadCount = MlasPlatform.GetMaximumThreadCount();
|
||||
int32_t MaximumThreadCount = MlasPlatform.GetMaximumThreadCount();
|
||||
|
||||
if (TargetThreadCount >= MaximumThreadCount) {
|
||||
TargetThreadCount = MaximumThreadCount;
|
||||
|
|
@ -1170,23 +1154,10 @@ Return Value:
|
|||
return false;
|
||||
}
|
||||
|
||||
#if defined(MLAS_USE_WIN32_THREADPOOL)
|
||||
|
||||
//
|
||||
// Create an object to submit work to the threadpool.
|
||||
//
|
||||
|
||||
PTP_WORK WorkObject = CreateThreadpoolWork(MlasSgemmWorkCallback, &WorkBlock, nullptr);
|
||||
|
||||
if (WorkObject == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
//
|
||||
// Initialize the common fields of the work block.
|
||||
//
|
||||
|
||||
WorkBlock.Counter = 0;
|
||||
WorkBlock.TransA = TransA;
|
||||
WorkBlock.TransB = TransB;
|
||||
WorkBlock.K = K;
|
||||
|
|
@ -1196,13 +1167,11 @@ Return Value:
|
|||
WorkBlock.alpha = alpha;
|
||||
WorkBlock.beta = beta;
|
||||
|
||||
#endif
|
||||
|
||||
//
|
||||
// Segment the operation across multiple threads.
|
||||
//
|
||||
|
||||
uint32_t Index = 0;
|
||||
int32_t Index = 0;
|
||||
|
||||
if (N > M) {
|
||||
|
||||
|
|
@ -1231,18 +1200,6 @@ Return Value:
|
|||
WorkBlock.Segments[Index].B = B + n * pldb;
|
||||
WorkBlock.Segments[Index].C = C + n;
|
||||
|
||||
#if defined(MLAS_USE_WIN32_THREADPOOL)
|
||||
|
||||
//
|
||||
// Execute one of the segments on a worker thread.
|
||||
//
|
||||
|
||||
if (Index > 0) {
|
||||
SubmitThreadpoolWork(WorkObject);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
Index++;
|
||||
}
|
||||
|
||||
|
|
@ -1270,49 +1227,11 @@ Return Value:
|
|||
WorkBlock.Segments[Index].B = B;
|
||||
WorkBlock.Segments[Index].C = C + m * ldc;
|
||||
|
||||
#if defined(MLAS_USE_WIN32_THREADPOOL)
|
||||
|
||||
//
|
||||
// Execute one of the segments on a worker thread.
|
||||
//
|
||||
|
||||
if (Index > 0) {
|
||||
SubmitThreadpoolWork(WorkObject);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
Index++;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(MLAS_USE_OPENMP)
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int32_t tid = 0; tid < int32_t(Index); tid++) {
|
||||
|
||||
MLAS_SGEMM_WORK_BLOCK::SEGMENT* Segment = &WorkBlock.Segments[tid];
|
||||
|
||||
MlasSgemmOperation(TransA, TransB, Segment->M, Segment->N, K, alpha,
|
||||
Segment->A, lda, Segment->B, ldb, beta, Segment->C, ldc);
|
||||
}
|
||||
|
||||
#elif defined(MLAS_USE_WIN32_THREADPOOL)
|
||||
|
||||
//
|
||||
// Execute the remaining segment on this thread.
|
||||
//
|
||||
|
||||
MlasSgemmWorkCallback(nullptr, &WorkBlock, WorkObject);
|
||||
|
||||
//
|
||||
// Wait for the worker threads to complete.
|
||||
//
|
||||
|
||||
WaitForThreadpoolWorkCallbacks(WorkObject, FALSE);
|
||||
CloseThreadpoolWork(WorkObject);
|
||||
|
||||
#endif
|
||||
MlasExecuteThreaded(MlasSgemmOperationThreaded, &WorkBlock, Index);
|
||||
|
||||
return true;
|
||||
|
||||
|
|
|
|||
135
onnxruntime/core/mlas/lib/threading.cpp
Normal file
135
onnxruntime/core/mlas/lib/threading.cpp
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
threading.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements platform specific threading support.
|
||||
|
||||
--*/
|
||||
|
||||
#include "mlasi.h"
|
||||
|
||||
#if defined(MLAS_USE_WIN32_THREADPOOL)
|
||||
|
||||
//
|
||||
// Define the parameters to execute threaded work using the Windows thread pool
|
||||
// library.
|
||||
//
|
||||
|
||||
struct MLAS_THREADED_WORK_BLOCK {
|
||||
volatile LONG Counter;
|
||||
PMLAS_THREADED_ROUTINE ThreadedRoutine;
|
||||
void* Context;
|
||||
};
|
||||
|
||||
void
|
||||
CALLBACK
|
||||
MlasThreadedWorkCallback(
|
||||
PTP_CALLBACK_INSTANCE Instance,
|
||||
void* Context,
|
||||
PTP_WORK WorkObject
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine is invoked from a worker thread to execute one iteration of a
|
||||
batch of threaded work.
|
||||
|
||||
Arguments:
|
||||
|
||||
Instance - Supplies the callback instance object.
|
||||
|
||||
Context - Supplies the pointer to the parameters for the operation.
|
||||
|
||||
WorkObject - Supplies the threadpool work object.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
MLAS_THREADED_WORK_BLOCK* WorkBlock = (MLAS_THREADED_WORK_BLOCK*)Context;
|
||||
|
||||
LONG Index = InterlockedIncrement(&WorkBlock->Counter) - 1;
|
||||
|
||||
WorkBlock->ThreadedRoutine(WorkBlock->Context, Index);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
void
|
||||
MlasExecuteThreaded(
|
||||
MLAS_THREADED_ROUTINE ThreadedRoutine,
|
||||
void* Context,
|
||||
int32_t Iterations
|
||||
)
|
||||
{
|
||||
//
|
||||
// Execute the routine directly if only one iteration is specified.
|
||||
//
|
||||
|
||||
if (Iterations == 1) {
|
||||
ThreadedRoutine(Context, 0);
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(MLAS_USE_WIN32_THREADPOOL)
|
||||
|
||||
//
|
||||
// Schedule the threaded iterations using a work object.
|
||||
//
|
||||
|
||||
MLAS_THREADED_WORK_BLOCK WorkBlock;
|
||||
|
||||
PTP_WORK WorkObject = CreateThreadpoolWork(MlasThreadedWorkCallback, &WorkBlock, nullptr);
|
||||
|
||||
if (WorkObject != nullptr) {
|
||||
|
||||
WorkBlock.Counter = 0;
|
||||
WorkBlock.ThreadedRoutine = ThreadedRoutine;
|
||||
WorkBlock.Context = Context;
|
||||
|
||||
for (int32_t tid = 1; tid < Iterations; tid++) {
|
||||
SubmitThreadpoolWork(WorkObject);
|
||||
}
|
||||
|
||||
//
|
||||
// Execute the remaining iteration on this thread.
|
||||
//
|
||||
|
||||
ThreadedRoutine(Context, Iterations - 1);
|
||||
|
||||
//
|
||||
// Wait for the work object callbacks to complete.
|
||||
//
|
||||
|
||||
WaitForThreadpoolWorkCallbacks(WorkObject, FALSE);
|
||||
CloseThreadpoolWork(WorkObject);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
//
|
||||
// Fallback to a serialized implementation.
|
||||
//
|
||||
|
||||
#endif
|
||||
|
||||
//
|
||||
// Execute the routine for the specified number of iterations.
|
||||
//
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int32_t tid = 0; tid < Iterations; tid++) {
|
||||
ThreadedRoutine(Context, tid);
|
||||
}
|
||||
}
|
||||
|
|
@ -206,30 +206,31 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
Tensor* Y = context->Output(0, TensorShape(Y_dims));
|
||||
TensorShape output_shape = Y->Shape().Slice(2);
|
||||
|
||||
const int64_t input_image_size = input_shape.Size();
|
||||
const int64_t output_image_size = output_shape.Size();
|
||||
const int64_t kernel_size = TensorShape(kernel_shape).Size();
|
||||
|
||||
AllocatorPtr alloc;
|
||||
ONNXRUNTIME_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
|
||||
|
||||
const float* Xdata = X->template Data<float>();
|
||||
float* Ydata = Y->template MutableData<float>();
|
||||
|
||||
MLAS_CONV_PARAMETERS Parameters;
|
||||
size_t WorkingBufferSize;
|
||||
if (MlasConvPrepare(&Parameters,
|
||||
kernel_shape.size(),
|
||||
static_cast<size_t>(N),
|
||||
static_cast<size_t>(group_),
|
||||
static_cast<size_t>(C / group_),
|
||||
input_shape.GetDims().data(),
|
||||
kernel_shape.data(),
|
||||
dilations.data(),
|
||||
pads.data(),
|
||||
strides.data(),
|
||||
static_cast<size_t>(M / group_),
|
||||
&WorkingBufferSize)) {
|
||||
const size_t kernel_rank = kernel_shape.size();
|
||||
|
||||
if (kernel_rank == 2 || kernel_rank == 3) {
|
||||
MLAS_CONV_PARAMETERS Parameters;
|
||||
size_t WorkingBufferSize;
|
||||
MlasConvPrepare(&Parameters,
|
||||
kernel_rank,
|
||||
static_cast<size_t>(N),
|
||||
static_cast<size_t>(group_),
|
||||
static_cast<size_t>(C / group_),
|
||||
input_shape.GetDims().data(),
|
||||
kernel_shape.data(),
|
||||
dilations.data(),
|
||||
pads.data(),
|
||||
strides.data(),
|
||||
output_shape.GetDims().data(),
|
||||
static_cast<size_t>(M / group_),
|
||||
&WorkingBufferSize);
|
||||
|
||||
auto working_data = WorkingBufferSize > 0 ? alloc->Alloc(sizeof(float) * WorkingBufferSize) : nullptr;
|
||||
BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc));
|
||||
|
||||
|
|
@ -240,6 +241,9 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
static_cast<float*>(working_buffer.get()),
|
||||
Ydata);
|
||||
} else {
|
||||
const int64_t input_image_size = input_shape.Size();
|
||||
const int64_t output_image_size = output_shape.Size();
|
||||
const int64_t kernel_size = TensorShape(kernel_shape).Size();
|
||||
const int64_t X_offset = C / group_ * input_image_size;
|
||||
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_;
|
||||
const int64_t W_offset = W->Shape().Size() / group_;
|
||||
|
|
|
|||
|
|
@ -495,25 +495,24 @@ TrialConv2D(
|
|||
int64_t DilationShape[] = { int64_t(DilationHeight), int64_t(DilationWidth) };
|
||||
int64_t Padding[] = { int64_t(PaddingLeftHeight), int64_t(PaddingLeftWidth), int64_t(PaddingRightHeight), int64_t(PaddingRightWidth) };
|
||||
int64_t StrideShape[] = { int64_t(StrideHeight), int64_t(StrideWidth) };
|
||||
int64_t OutputShape[] = { OutputHeight64, OutputWidth64 };
|
||||
|
||||
MLAS_CONV_PARAMETERS Parameters;
|
||||
size_t WorkingBufferSize;
|
||||
|
||||
if (!MlasConvPrepare(&Parameters,
|
||||
2,
|
||||
BatchCount,
|
||||
GroupCount,
|
||||
InputChannels,
|
||||
InputShape,
|
||||
KernelShape,
|
||||
DilationShape,
|
||||
Padding,
|
||||
StrideShape,
|
||||
FilterCount,
|
||||
&WorkingBufferSize)) {
|
||||
printf("MlasConvPrepare failed!!!\n");
|
||||
return;
|
||||
}
|
||||
MlasConvPrepare(&Parameters,
|
||||
2,
|
||||
BatchCount,
|
||||
GroupCount,
|
||||
InputChannels,
|
||||
InputShape,
|
||||
KernelShape,
|
||||
DilationShape,
|
||||
Padding,
|
||||
StrideShape,
|
||||
OutputShape,
|
||||
FilterCount,
|
||||
&WorkingBufferSize);
|
||||
|
||||
size_t OutputHeight = size_t(OutputHeight64);
|
||||
size_t OutputWidth = size_t(OutputWidth64);
|
||||
|
|
@ -1236,9 +1235,9 @@ main(
|
|||
)
|
||||
{
|
||||
// ExecuteSgemmTests();
|
||||
// ExecuteConvTests();
|
||||
ExecutePool2DTests();
|
||||
ExecutePool3DTests();
|
||||
ExecuteConvTests();
|
||||
// ExecutePool2DTests();
|
||||
// ExecutePool3DTests();
|
||||
// EvaluateThreadingPerformance();
|
||||
|
||||
return 0;
|
||||
|
|
|
|||
Loading…
Reference in a new issue