Use MlasConv for 1D convolutions (#3425)

Use the existing 2D convolution code in MlasConv to also handle 1D convolutions.
This commit is contained in:
Tracy Sharpe 2020-04-04 09:43:10 -07:00 committed by GitHub
parent 5835349614
commit d4d19a75ba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 3 deletions

View file

@ -1037,7 +1037,7 @@ Arguments:
Parameters - Supplies the structure that stores the provided and computed
parameters for the convolution operation.
Dimensions - Supplies the number of dimensions (must be 2 or 3).
Dimensions - Supplies the number of dimensions (must be between 1 and 3).
BatchCount - Supplies the number of batches to the processed.
@ -1080,7 +1080,6 @@ Return Value:
//
Parameters->Activation = Activation;
Parameters->Dimensions = Dimensions;
Parameters->BatchCount = BatchCount;
Parameters->GroupCount = GroupCount;
Parameters->InputChannels = InputChannels;
@ -1117,6 +1116,32 @@ Return Value:
Parameters->OutputSize = OutputSize;
Parameters->K = K;
//
// Promote 1D convolutions to 2D convolutions.
//
if (Dimensions == 1) {
Parameters->InputShape[1] = Parameters->InputShape[0];
Parameters->InputShape[0] = 1;
Parameters->OutputShape[1] = Parameters->OutputShape[0];
Parameters->OutputShape[0] = 1;
Parameters->KernelShape[1] = Parameters->KernelShape[0];
Parameters->KernelShape[0] = 1;
Parameters->DilationShape[1] = Parameters->DilationShape[0];
Parameters->DilationShape[0] = 1;
Parameters->Padding[3] = Parameters->Padding[1];
Parameters->Padding[2] = 0;
Parameters->Padding[1] = Parameters->Padding[0];
Parameters->Padding[0] = 0;
Parameters->StrideShape[1] = Parameters->StrideShape[0];
Parameters->StrideShape[0] = 1;
Dimensions = 2;
}
Parameters->Dimensions = Dimensions;
//
// Evaluate how the convolution will be performed.
//

View file

@ -207,7 +207,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
const size_t kernel_rank = kernel_shape.size();
concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
if (kernel_rank == 2 || kernel_rank == 3) {
if (kernel_rank >= 1 && kernel_rank <= 3) {
MLAS_CONV_PARAMETERS Parameters;
size_t WorkingBufferSize;
MlasConvPrepare(&Parameters,