refactor threading (#110)

This commit is contained in:
Tracy Sharpe 2018-12-06 09:20:32 -08:00 committed by GitHub
parent 6d80253502
commit 3c7c1068e7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 256 additions and 326 deletions

View file

@ -3,6 +3,7 @@
set(mlas_common_srcs
${ONNXRUNTIME_ROOT}/core/mlas/lib/platform.cpp
${ONNXRUNTIME_ROOT}/core/mlas/lib/threading.cpp
${ONNXRUNTIME_ROOT}/core/mlas/lib/sgemm.cpp
${ONNXRUNTIME_ROOT}/core/mlas/lib/convolve.cpp
${ONNXRUNTIME_ROOT}/core/mlas/lib/pooling.cpp

View file

@ -102,7 +102,7 @@ struct MLAS_CONV_PARAMETERS {
} u;
};
bool
void
MLASCALL
MlasConvPrepare(
MLAS_CONV_PARAMETERS* Parameters,
@ -115,6 +115,7 @@ MlasConvPrepare(
const int64_t* DilationShape,
const int64_t* Padding,
const int64_t* StrideShape,
const int64_t* OutputShape,
size_t FilterCount,
size_t* WorkingBufferSize
);

View file

@ -29,20 +29,17 @@ Abstract:
//
struct MLAS_CONV_WORK_BLOCK {
#if defined(MLAS_USE_WIN32_THREADPOOL)
volatile LONG Counter;
const MLAS_CONV_PARAMETERS* Parameters;
const float* Input;
const float* Filter;
const float* Bias;
float* WorkingBuffer;
float* Output;
#endif
struct SEGMENT {
size_t StartN;
size_t CountN;
} Segments[MLAS_MAXIMUM_THREAD_COUNT];
uint32_t TargetThreadCount;
int32_t TargetThreadCount;
};
void
@ -610,14 +607,10 @@ Return Value:
}
}
#if defined(MLAS_USE_WIN32_THREADPOOL)
void
CALLBACK
MlasConvWorkCallback(
PTP_CALLBACK_INSTANCE Instance,
MlasConvOperationThreaded(
void* Context,
PTP_WORK WorkObject
int32_t Index
)
/*++
@ -628,11 +621,9 @@ Routine Description:
Arguments:
Instance - Supplies the callback instance object.
Context - Supplies the pointer to the context for the threaded operation.
Context - Supplies the pointer to the parameters for the SGEMM operation.
WorkObject - Supplies the threadpool work object.
Index - Supplies the current index of the threaded operation.
Return Value:
@ -640,13 +631,8 @@ Return Value:
--*/
{
MLAS_UNREFERENCED_PARAMETER(Instance);
MLAS_UNREFERENCED_PARAMETER(WorkObject);
MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context;
LONG Index = InterlockedIncrement(&WorkBlock->Counter) - 1;
MLAS_CONV_WORK_BLOCK::SEGMENT* Segment = &WorkBlock->Segments[Index];
float* ColumnBuffer =
@ -658,11 +644,9 @@ Return Value:
}
void
CALLBACK
MlasConvGemmDirectWorkCallback(
PTP_CALLBACK_INSTANCE Instance,
MlasConvGemmDirectThreaded(
void* Context,
PTP_WORK WorkObject
int32_t Index
)
/*++
@ -673,11 +657,9 @@ Routine Description:
Arguments:
Instance - Supplies the callback instance object.
Context - Supplies the pointer to the context for the threaded operation.
Context - Supplies the pointer to the parameters for the SGEMM operation.
WorkObject - Supplies the threadpool work object.
Index - Supplies the current index of the threaded operation.
Return Value:
@ -685,13 +667,8 @@ Return Value:
--*/
{
MLAS_UNREFERENCED_PARAMETER(Instance);
MLAS_UNREFERENCED_PARAMETER(WorkObject);
MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context;
LONG Index = InterlockedIncrement(&WorkBlock->Counter) - 1;
const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters;
//
@ -755,8 +732,6 @@ Return Value:
}
}
#endif
inline
bool
MlasConvTryMultithread(
@ -798,7 +773,7 @@ Return Value:
--*/
{
#if defined(MLAS_USE_WIN32_THREADPOOL) || defined(MLAS_USE_OPENMP)
#if defined(MLAS_HAS_THREADING_SUPPORT)
MLAS_CONV_WORK_BLOCK WorkBlock;
@ -809,23 +784,10 @@ Return Value:
return false;
}
#if defined(MLAS_USE_WIN32_THREADPOOL)
//
// Create an object to submit work to the threadpool.
//
PTP_WORK WorkObject = CreateThreadpoolWork(MlasConvWorkCallback, &WorkBlock, nullptr);
if (WorkObject == nullptr) {
return false;
}
//
// Initialize the common fields of the work block.
//
WorkBlock.Counter = 0;
WorkBlock.Parameters = Parameters;
WorkBlock.Input = Input;
WorkBlock.Filter = Filter;
@ -833,13 +795,11 @@ Return Value:
WorkBlock.WorkingBuffer = WorkingBuffer;
WorkBlock.Output = Output;
#endif
//
// Segment the operation across multiple threads.
//
uint32_t Index = 0;
int32_t Index = 0;
size_t SegmentCountN;
for (size_t SegmentStartN = 0; SegmentStartN < OutputSize; SegmentStartN += SegmentCountN) {
@ -853,51 +813,10 @@ Return Value:
WorkBlock.Segments[Index].StartN = SegmentStartN;
WorkBlock.Segments[Index].CountN = SegmentCountN;
#if defined(MLAS_USE_WIN32_THREADPOOL)
//
// Execute one of the segments on a worker thread.
//
if (Index > 0) {
SubmitThreadpoolWork(WorkObject);
}
#endif
Index++;
}
#if defined(MLAS_USE_OPENMP)
#pragma omp parallel for
for (int32_t tid = 0; tid < int32_t(Index); tid++) {
MLAS_CONV_WORK_BLOCK::SEGMENT* Segment = &WorkBlock.Segments[tid];
float* ColumnBuffer =
WorkingBuffer + tid * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD;
MlasConvOperation(Parameters, Input, Filter, Bias, ColumnBuffer, Output,
Segment->StartN, Segment->CountN);
}
#elif defined(MLAS_USE_WIN32_THREADPOOL)
//
// Execute the remaining segment on this thread.
//
MlasConvWorkCallback(nullptr, &WorkBlock, WorkObject);
//
// Wait for the worker threads to complete.
//
WaitForThreadpoolWorkCallbacks(WorkObject, FALSE);
CloseThreadpoolWork(WorkObject);
#endif
MlasExecuteThreaded(MlasConvOperationThreaded, &WorkBlock, Index);
return true;
@ -971,7 +890,7 @@ Return Value:
const MLAS_CONV_ALGORITHM Algorithm = Parameters->Algorithm;
#if defined(MLAS_USE_WIN32_THREADPOOL) || defined(MLAS_USE_OPENMP)
#if defined(MLAS_HAS_THREADING_SUPPORT)
//
// Schedule batches of GEMMs across multiple threads.
@ -981,76 +900,25 @@ Return Value:
const size_t BatchGroupCount = BatchCount * GroupCount;
#if defined(MLAS_USE_WIN32_THREADPOOL)
uint32_t TargetThreadCount = MlasPlatform.GetMaximumThreadCount();
int32_t TargetThreadCount = MlasPlatform.GetMaximumThreadCount();
if (TargetThreadCount >= BatchGroupCount) {
TargetThreadCount = uint32_t(BatchGroupCount);
TargetThreadCount = int32_t(BatchGroupCount);
}
if (TargetThreadCount > 1) {
MLAS_CONV_WORK_BLOCK WorkBlock;
MLAS_CONV_WORK_BLOCK WorkBlock;
WorkBlock.Parameters = Parameters;
WorkBlock.Input = Input;
WorkBlock.Filter = Filter;
WorkBlock.Bias = Bias;
WorkBlock.WorkingBuffer = nullptr;
WorkBlock.Output = Output;
WorkBlock.TargetThreadCount = TargetThreadCount;
PTP_WORK WorkObject = CreateThreadpoolWork(MlasConvGemmDirectWorkCallback, &WorkBlock, nullptr);
if (WorkObject != nullptr) {
WorkBlock.Counter = 0;
WorkBlock.Parameters = Parameters;
WorkBlock.Input = Input;
WorkBlock.Filter = Filter;
WorkBlock.Bias = Bias;
WorkBlock.WorkingBuffer = nullptr;
WorkBlock.Output = Output;
WorkBlock.TargetThreadCount = TargetThreadCount;
for (uint32_t Index = 1; Index < TargetThreadCount; Index++) {
SubmitThreadpoolWork(WorkObject);
}
MlasConvGemmDirectWorkCallback(nullptr, &WorkBlock, WorkObject);
WaitForThreadpoolWorkCallbacks(WorkObject, FALSE);
CloseThreadpoolWork(WorkObject);
return;
}
}
#else
#pragma omp parallel for
for (int64_t bg = 0; bg < int64_t(BatchGroupCount); bg++) {
size_t group = size_t(bg % GroupCount);
const float* input = Input + bg * InputGroupSize;
const float* filter = Filter + group * FilterGroupSize;
float* output = Output + bg * OutputGroupSize;
//
// Invoke the non-threaded GEMM directly with the input tensor.
//
MlasSgemmOperation(CblasNoTrans, Parameters->u.GemmDirect.TransB, FilterCount,
OutputSize, K, 1.0f, filter, K, input, Parameters->u.GemmDirect.ldb, 0.0f,
output, OutputSize);
//
// Add the optional bias vector.
//
if (Bias != nullptr) {
MlasBiasAdd(Bias + group * FilterCount, FilterCount, output, OutputSize, OutputSize);
}
}
MlasExecuteThreaded(MlasConvGemmDirectThreaded, &WorkBlock, TargetThreadCount);
return;
#endif
}
#endif
@ -1152,7 +1020,7 @@ Return Value:
}
}
bool
void
MLASCALL
MlasConvPrepare(
MLAS_CONV_PARAMETERS* Parameters,
@ -1165,6 +1033,7 @@ MlasConvPrepare(
const int64_t* DilationShape,
const int64_t* Padding,
const int64_t* StrideShape,
const int64_t* OutputShape,
size_t FilterCount,
size_t* WorkingBufferSize
)
@ -1200,6 +1069,8 @@ Arguments:
StrideShape - Supplies the shape of the stride.
OutputShape - Supplies the shape of the output tensor.
FilterCount - Supplies the number of rows of the filter matrix per group.
WorkingBufferSize - Receives the number of elements to allocate for the
@ -1207,20 +1078,12 @@ Arguments:
Return Value:
Returns true if implementation can support this operation.
None.
--*/
{
//
// Support only 2D or 3D convolutions.
//
if (Dimensions != 2 && Dimensions != 3) {
return false;
}
//
// Build the convolution parameters.
// Save the convolution parameters.
//
Parameters->Dimensions = Dimensions;
@ -1240,25 +1103,13 @@ Return Value:
for (size_t dim = 0; dim < Dimensions; dim++) {
Parameters->InputShape[dim] = size_t(InputShape[dim]);
Parameters->OutputShape[dim] = size_t(OutputShape[dim]);
Parameters->KernelShape[dim] = size_t(KernelShape[dim]);
Parameters->DilationShape[dim] = size_t(DilationShape[dim]);
Parameters->Padding[dim] = size_t(Padding[dim]);
Parameters->Padding[dim + Dimensions] = size_t(Padding[dim + Dimensions]);
Parameters->StrideShape[dim] = size_t(StrideShape[dim]);
//
// Compute the output shape.
//
int64_t OutputShape = (InputShape[dim] + Padding[dim] + Padding[dim + Dimensions] -
(DilationShape[dim] * (KernelShape[dim] - 1) + 1)) / StrideShape[dim] + 1;
if (OutputShape <= 0) {
return false;
}
Parameters->OutputShape[dim] = size_t(OutputShape);
InputSize *= Parameters->InputShape[dim];
OutputSize *= Parameters->OutputShape[dim];
K *= Parameters->KernelShape[dim];
@ -1290,7 +1141,7 @@ Return Value:
Parameters->u.GemmDirect.TransB = CblasNoTrans;
Parameters->u.GemmDirect.ldb = OutputSize;
return true;
return;
}
if (Dimensions == 2 && AllDilationsAreOne && InputChannels == 1) {
@ -1306,7 +1157,7 @@ Return Value:
Parameters->u.GemmDirect.TransB = CblasTrans;
Parameters->u.GemmDirect.ldb = Parameters->InputShape[1];
return true;
return;
}
if (Parameters->KernelShape[0] == Parameters->InputShape[0] &&
@ -1316,7 +1167,7 @@ Return Value:
Parameters->u.GemmDirect.TransB = CblasNoTrans;
Parameters->u.GemmDirect.ldb = Parameters->InputShape[1];
return true;
return;
}
}
}
@ -1343,16 +1194,16 @@ Return Value:
// threaded path.
//
uint32_t TargetThreadCount;
int32_t TargetThreadCount;
double Complexity = double(FilterCount) * double(OutputSize) * double(K);
if (Complexity < double(MLAS_SGEMM_THREAD_COMPLEXITY * MLAS_MAXIMUM_THREAD_COUNT)) {
TargetThreadCount = uint32_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1;
TargetThreadCount = int32_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1;
} else {
TargetThreadCount = MLAS_MAXIMUM_THREAD_COUNT;
}
uint32_t MaximumThreadCount = MlasPlatform.GetMaximumThreadCount();
int32_t MaximumThreadCount = MlasPlatform.GetMaximumThreadCount();
if (TargetThreadCount >= MaximumThreadCount) {
TargetThreadCount = MaximumThreadCount;
@ -1384,6 +1235,4 @@ Return Value:
*WorkingBufferSize = TargetThreadCount * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD;
}
return true;
}

View file

@ -82,22 +82,24 @@ Abstract:
#if defined(_OPENMP)
#include <omp.h>
#define MLAS_USE_OPENMP
#define MLAS_HAS_THREADING_SUPPORT
#elif defined(_WIN32)
#define MLAS_USE_WIN32_THREADPOOL
#define MLAS_HAS_THREADING_SUPPORT
#endif
//
// Define the maximum number of threads supported by this implementation.
//
#define MLAS_MAXIMUM_THREAD_COUNT 16
#define MLAS_MAXIMUM_THREAD_COUNT 16
//
// Define the default strides to step through slices of the input matrices.
//
#define MLAS_SGEMM_STRIDEN 128
#define MLAS_SGEMM_STRIDEK 128
#define MLAS_SGEMM_STRIDEN 128
#define MLAS_SGEMM_STRIDEK 128
//
// Define the alignment for segmenting a SGEMM operation across multiple
@ -108,7 +110,7 @@ Abstract:
// the effort at this time.
//
#define MLAS_SGEMM_STRIDEN_THREAD_ALIGN 16
#define MLAS_SGEMM_STRIDEN_THREAD_ALIGN 16
//
// Define the prototypes of the SGEMM platform optimized routines.
@ -193,12 +195,12 @@ extern "C" {
//
#if defined(MLAS_USE_OPENMP)
#define MLAS_SGEMM_THREAD_COMPLEXITY (64 * 1024)
#define MLAS_SGEMM_THREAD_COMPLEXITY (64 * 1024)
#else
#if defined(MLAS_TARGET_AMD64)
#define MLAS_SGEMM_THREAD_COMPLEXITY (2 * 1024 * 1024)
#define MLAS_SGEMM_THREAD_COMPLEXITY (2 * 1024 * 1024)
#else
#define MLAS_SGEMM_THREAD_COMPLEXITY (1 * 1024 * 1024)
#define MLAS_SGEMM_THREAD_COMPLEXITY (1 * 1024 * 1024)
#endif
#endif
@ -243,10 +245,10 @@ struct MLAS_PLATFORM {
#endif
#if defined(MLAS_USE_WIN32_THREADPOOL)
uint32_t MaximumThreadCount;
int32_t MaximumThreadCount;
#endif
uint32_t
int32_t
GetMaximumThreadCount(
void
)
@ -263,6 +265,26 @@ struct MLAS_PLATFORM {
extern MLAS_PLATFORM MlasPlatform;
//
// Threading support.
//
typedef
void
(MLAS_THREADED_ROUTINE)(
void* Context,
int32_t Index
);
typedef MLAS_THREADED_ROUTINE* PMLAS_THREADED_ROUTINE;
void
MlasExecuteThreaded(
PMLAS_THREADED_ROUTINE ThreadedRoutine,
void* Context,
int32_t Iterations
);
//
// Define the missing ARM64 NEON intrinsic macros from arm64_neon.h that enable
// cross-compiler support.

View file

@ -137,7 +137,7 @@ Return Value:
GetSystemInfo(&SystemInfo);
if (SystemInfo.dwNumberOfProcessors <= MLAS_MAXIMUM_THREAD_COUNT) {
this->MaximumThreadCount = SystemInfo.dwNumberOfProcessors;
this->MaximumThreadCount = int32_t(SystemInfo.dwNumberOfProcessors);
} else {
this->MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT;
}

View file

@ -32,8 +32,6 @@ Abstract:
//
struct MLAS_SGEMM_WORK_BLOCK {
#if defined(MLAS_USE_WIN32_THREADPOOL)
volatile LONG Counter;
CBLAS_TRANSPOSE TransA;
CBLAS_TRANSPOSE TransB;
size_t K;
@ -42,7 +40,6 @@ struct MLAS_SGEMM_WORK_BLOCK {
size_t ldc;
float alpha;
float beta;
#endif
struct SEGMENT {
size_t M;
size_t N;
@ -1034,14 +1031,10 @@ Return Value:
}
}
#if defined(MLAS_USE_WIN32_THREADPOOL)
void
CALLBACK
MlasSgemmWorkCallback(
PTP_CALLBACK_INSTANCE Instance,
MlasSgemmOperationThreaded(
void* Context,
PTP_WORK WorkObject
int32_t Index
)
/*++
@ -1052,11 +1045,9 @@ Routine Description:
Arguments:
Instance - Supplies the callback instance object.
Context - Supplies the pointer to the context for the threaded operation.
Context - Supplies the pointer to the parameters for the SGEMM operation.
WorkObject - Supplies the threadpool work object.
Index - Supplies the current index of the threaded operation.
Return Value:
@ -1064,13 +1055,8 @@ Return Value:
--*/
{
MLAS_UNREFERENCED_PARAMETER(Instance);
MLAS_UNREFERENCED_PARAMETER(WorkObject);
MLAS_SGEMM_WORK_BLOCK* WorkBlock = (MLAS_SGEMM_WORK_BLOCK*)Context;
LONG Index = InterlockedIncrement(&WorkBlock->Counter) - 1;
MLAS_SGEMM_WORK_BLOCK::SEGMENT* Segment = &WorkBlock->Segments[Index];
MlasSgemmOperation(WorkBlock->TransA, WorkBlock->TransB, Segment->M,
@ -1079,8 +1065,6 @@ Return Value:
WorkBlock->ldc);
}
#endif
inline
bool
MlasSgemmTryMultithread(
@ -1142,10 +1126,10 @@ Return Value:
--*/
{
#if defined(MLAS_USE_WIN32_THREADPOOL) || defined(MLAS_USE_OPENMP)
#if defined(MLAS_HAS_THREADING_SUPPORT)
MLAS_SGEMM_WORK_BLOCK WorkBlock;
uint32_t TargetThreadCount;
int32_t TargetThreadCount;
//
// Compute the number of target threads given the complexity of the SGEMM
@ -1155,12 +1139,12 @@ Return Value:
double Complexity = double(M) * double(N) * double(K);
if (Complexity < double(MLAS_SGEMM_THREAD_COMPLEXITY * MLAS_MAXIMUM_THREAD_COUNT)) {
TargetThreadCount = uint32_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1;
TargetThreadCount = int32_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1;
} else {
TargetThreadCount = MLAS_MAXIMUM_THREAD_COUNT;
}
uint32_t MaximumThreadCount = MlasPlatform.GetMaximumThreadCount();
int32_t MaximumThreadCount = MlasPlatform.GetMaximumThreadCount();
if (TargetThreadCount >= MaximumThreadCount) {
TargetThreadCount = MaximumThreadCount;
@ -1170,23 +1154,10 @@ Return Value:
return false;
}
#if defined(MLAS_USE_WIN32_THREADPOOL)
//
// Create an object to submit work to the threadpool.
//
PTP_WORK WorkObject = CreateThreadpoolWork(MlasSgemmWorkCallback, &WorkBlock, nullptr);
if (WorkObject == nullptr) {
return false;
}
//
// Initialize the common fields of the work block.
//
WorkBlock.Counter = 0;
WorkBlock.TransA = TransA;
WorkBlock.TransB = TransB;
WorkBlock.K = K;
@ -1196,13 +1167,11 @@ Return Value:
WorkBlock.alpha = alpha;
WorkBlock.beta = beta;
#endif
//
// Segment the operation across multiple threads.
//
uint32_t Index = 0;
int32_t Index = 0;
if (N > M) {
@ -1231,18 +1200,6 @@ Return Value:
WorkBlock.Segments[Index].B = B + n * pldb;
WorkBlock.Segments[Index].C = C + n;
#if defined(MLAS_USE_WIN32_THREADPOOL)
//
// Execute one of the segments on a worker thread.
//
if (Index > 0) {
SubmitThreadpoolWork(WorkObject);
}
#endif
Index++;
}
@ -1270,49 +1227,11 @@ Return Value:
WorkBlock.Segments[Index].B = B;
WorkBlock.Segments[Index].C = C + m * ldc;
#if defined(MLAS_USE_WIN32_THREADPOOL)
//
// Execute one of the segments on a worker thread.
//
if (Index > 0) {
SubmitThreadpoolWork(WorkObject);
}
#endif
Index++;
}
}
#if defined(MLAS_USE_OPENMP)
#pragma omp parallel for
for (int32_t tid = 0; tid < int32_t(Index); tid++) {
MLAS_SGEMM_WORK_BLOCK::SEGMENT* Segment = &WorkBlock.Segments[tid];
MlasSgemmOperation(TransA, TransB, Segment->M, Segment->N, K, alpha,
Segment->A, lda, Segment->B, ldb, beta, Segment->C, ldc);
}
#elif defined(MLAS_USE_WIN32_THREADPOOL)
//
// Execute the remaining segment on this thread.
//
MlasSgemmWorkCallback(nullptr, &WorkBlock, WorkObject);
//
// Wait for the worker threads to complete.
//
WaitForThreadpoolWorkCallbacks(WorkObject, FALSE);
CloseThreadpoolWork(WorkObject);
#endif
MlasExecuteThreaded(MlasSgemmOperationThreaded, &WorkBlock, Index);
return true;

View file

@ -0,0 +1,135 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
threading.cpp
Abstract:
This module implements platform specific threading support.
--*/
#include "mlasi.h"
#if defined(MLAS_USE_WIN32_THREADPOOL)
//
// Define the parameters to execute threaded work using the Windows thread pool
// library.
//
struct MLAS_THREADED_WORK_BLOCK {
volatile LONG Counter;
PMLAS_THREADED_ROUTINE ThreadedRoutine;
void* Context;
};
void
CALLBACK
MlasThreadedWorkCallback(
PTP_CALLBACK_INSTANCE Instance,
void* Context,
PTP_WORK WorkObject
)
/*++
Routine Description:
This routine is invoked from a worker thread to execute one iteration of a
batch of threaded work.
Arguments:
Instance - Supplies the callback instance object.
Context - Supplies the pointer to the parameters for the operation.
WorkObject - Supplies the threadpool work object.
Return Value:
None.
--*/
{
MLAS_THREADED_WORK_BLOCK* WorkBlock = (MLAS_THREADED_WORK_BLOCK*)Context;
LONG Index = InterlockedIncrement(&WorkBlock->Counter) - 1;
WorkBlock->ThreadedRoutine(WorkBlock->Context, Index);
}
#endif
void
MlasExecuteThreaded(
MLAS_THREADED_ROUTINE ThreadedRoutine,
void* Context,
int32_t Iterations
)
{
//
// Execute the routine directly if only one iteration is specified.
//
if (Iterations == 1) {
ThreadedRoutine(Context, 0);
return;
}
#if defined(MLAS_USE_WIN32_THREADPOOL)
//
// Schedule the threaded iterations using a work object.
//
MLAS_THREADED_WORK_BLOCK WorkBlock;
PTP_WORK WorkObject = CreateThreadpoolWork(MlasThreadedWorkCallback, &WorkBlock, nullptr);
if (WorkObject != nullptr) {
WorkBlock.Counter = 0;
WorkBlock.ThreadedRoutine = ThreadedRoutine;
WorkBlock.Context = Context;
for (int32_t tid = 1; tid < Iterations; tid++) {
SubmitThreadpoolWork(WorkObject);
}
//
// Execute the remaining iteration on this thread.
//
ThreadedRoutine(Context, Iterations - 1);
//
// Wait for the work object callbacks to complete.
//
WaitForThreadpoolWorkCallbacks(WorkObject, FALSE);
CloseThreadpoolWork(WorkObject);
return;
}
//
// Fallback to a serialized implementation.
//
#endif
//
// Execute the routine for the specified number of iterations.
//
#pragma omp parallel for
for (int32_t tid = 0; tid < Iterations; tid++) {
ThreadedRoutine(Context, tid);
}
}

View file

@ -206,30 +206,31 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
Tensor* Y = context->Output(0, TensorShape(Y_dims));
TensorShape output_shape = Y->Shape().Slice(2);
const int64_t input_image_size = input_shape.Size();
const int64_t output_image_size = output_shape.Size();
const int64_t kernel_size = TensorShape(kernel_shape).Size();
AllocatorPtr alloc;
ONNXRUNTIME_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
const float* Xdata = X->template Data<float>();
float* Ydata = Y->template MutableData<float>();
MLAS_CONV_PARAMETERS Parameters;
size_t WorkingBufferSize;
if (MlasConvPrepare(&Parameters,
kernel_shape.size(),
static_cast<size_t>(N),
static_cast<size_t>(group_),
static_cast<size_t>(C / group_),
input_shape.GetDims().data(),
kernel_shape.data(),
dilations.data(),
pads.data(),
strides.data(),
static_cast<size_t>(M / group_),
&WorkingBufferSize)) {
const size_t kernel_rank = kernel_shape.size();
if (kernel_rank == 2 || kernel_rank == 3) {
MLAS_CONV_PARAMETERS Parameters;
size_t WorkingBufferSize;
MlasConvPrepare(&Parameters,
kernel_rank,
static_cast<size_t>(N),
static_cast<size_t>(group_),
static_cast<size_t>(C / group_),
input_shape.GetDims().data(),
kernel_shape.data(),
dilations.data(),
pads.data(),
strides.data(),
output_shape.GetDims().data(),
static_cast<size_t>(M / group_),
&WorkingBufferSize);
auto working_data = WorkingBufferSize > 0 ? alloc->Alloc(sizeof(float) * WorkingBufferSize) : nullptr;
BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc));
@ -240,6 +241,9 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
static_cast<float*>(working_buffer.get()),
Ydata);
} else {
const int64_t input_image_size = input_shape.Size();
const int64_t output_image_size = output_shape.Size();
const int64_t kernel_size = TensorShape(kernel_shape).Size();
const int64_t X_offset = C / group_ * input_image_size;
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_;
const int64_t W_offset = W->Shape().Size() / group_;

View file

@ -495,25 +495,24 @@ TrialConv2D(
int64_t DilationShape[] = { int64_t(DilationHeight), int64_t(DilationWidth) };
int64_t Padding[] = { int64_t(PaddingLeftHeight), int64_t(PaddingLeftWidth), int64_t(PaddingRightHeight), int64_t(PaddingRightWidth) };
int64_t StrideShape[] = { int64_t(StrideHeight), int64_t(StrideWidth) };
int64_t OutputShape[] = { OutputHeight64, OutputWidth64 };
MLAS_CONV_PARAMETERS Parameters;
size_t WorkingBufferSize;
if (!MlasConvPrepare(&Parameters,
2,
BatchCount,
GroupCount,
InputChannels,
InputShape,
KernelShape,
DilationShape,
Padding,
StrideShape,
FilterCount,
&WorkingBufferSize)) {
printf("MlasConvPrepare failed!!!\n");
return;
}
MlasConvPrepare(&Parameters,
2,
BatchCount,
GroupCount,
InputChannels,
InputShape,
KernelShape,
DilationShape,
Padding,
StrideShape,
OutputShape,
FilterCount,
&WorkingBufferSize);
size_t OutputHeight = size_t(OutputHeight64);
size_t OutputWidth = size_t(OutputWidth64);
@ -1236,9 +1235,9 @@ main(
)
{
// ExecuteSgemmTests();
// ExecuteConvTests();
ExecutePool2DTests();
ExecutePool3DTests();
ExecuteConvTests();
// ExecutePool2DTests();
// ExecutePool3DTests();
// EvaluateThreadingPerformance();
return 0;