mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Use MlasConv for 1D convolutions (#3425)
Use the existing 2D convolution code in MlasConv to also handle 1D convolutions.
This commit is contained in:
parent
5835349614
commit
d4d19a75ba
2 changed files with 28 additions and 3 deletions
|
|
@ -1037,7 +1037,7 @@ Arguments:
|
|||
Parameters - Supplies the structure that stores the provided and computed
|
||||
parameters for the convolution operation.
|
||||
|
||||
Dimensions - Supplies the number of dimensions (must be 2 or 3).
|
||||
Dimensions - Supplies the number of dimensions (must be between 1 and 3).
|
||||
|
||||
BatchCount - Supplies the number of batches to the processed.
|
||||
|
||||
|
|
@ -1080,7 +1080,6 @@ Return Value:
|
|||
//
|
||||
|
||||
Parameters->Activation = Activation;
|
||||
Parameters->Dimensions = Dimensions;
|
||||
Parameters->BatchCount = BatchCount;
|
||||
Parameters->GroupCount = GroupCount;
|
||||
Parameters->InputChannels = InputChannels;
|
||||
|
|
@ -1117,6 +1116,32 @@ Return Value:
|
|||
Parameters->OutputSize = OutputSize;
|
||||
Parameters->K = K;
|
||||
|
||||
//
|
||||
// Promote 1D convolutions to 2D convolutions.
|
||||
//
|
||||
|
||||
if (Dimensions == 1) {
|
||||
|
||||
Parameters->InputShape[1] = Parameters->InputShape[0];
|
||||
Parameters->InputShape[0] = 1;
|
||||
Parameters->OutputShape[1] = Parameters->OutputShape[0];
|
||||
Parameters->OutputShape[0] = 1;
|
||||
Parameters->KernelShape[1] = Parameters->KernelShape[0];
|
||||
Parameters->KernelShape[0] = 1;
|
||||
Parameters->DilationShape[1] = Parameters->DilationShape[0];
|
||||
Parameters->DilationShape[0] = 1;
|
||||
Parameters->Padding[3] = Parameters->Padding[1];
|
||||
Parameters->Padding[2] = 0;
|
||||
Parameters->Padding[1] = Parameters->Padding[0];
|
||||
Parameters->Padding[0] = 0;
|
||||
Parameters->StrideShape[1] = Parameters->StrideShape[0];
|
||||
Parameters->StrideShape[0] = 1;
|
||||
|
||||
Dimensions = 2;
|
||||
}
|
||||
|
||||
Parameters->Dimensions = Dimensions;
|
||||
|
||||
//
|
||||
// Evaluate how the convolution will be performed.
|
||||
//
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
const size_t kernel_rank = kernel_shape.size();
|
||||
concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
|
||||
|
||||
if (kernel_rank == 2 || kernel_rank == 3) {
|
||||
if (kernel_rank >= 1 && kernel_rank <= 3) {
|
||||
MLAS_CONV_PARAMETERS Parameters;
|
||||
size_t WorkingBufferSize;
|
||||
MlasConvPrepare(&Parameters,
|
||||
|
|
|
|||
Loading…
Reference in a new issue