mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-09 00:30:53 +00:00
NCHWc: ReorderInput improvements (#7442)
Implement various improvements related to reordering a tensor for use by NCHWc operations: Relax the requirement that the input channel count must be a multiple of the NCHWc block size (either 8 or 16 depending on ISA). The requirement now is that the channel count must be a multiple of 4. The implementation of MlasReorderInputNchw would need further work to support relaxing this further, but I don't have any models where I've observed this to be necessary yet. Support fusing a Transpose(NHWC->NCHW) into a following ReorderInput. ReorderInput now has a channels_last attribute as was done in the past for ReorderOutput. This helps with models converted from TF where the converter is unable to remove all Transpose operations. Add threading support to ReorderInput to accelerate performance (ReorderOutput will come later).
This commit is contained in:
parent
82108b18e3
commit
d13e5b2fd9
11 changed files with 513 additions and 116 deletions
|
|
@ -2984,6 +2984,13 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
|
||||
This version of the operator has been available since version 1 of the 'com.microsoft.nchwc' operator set.
|
||||
|
||||
#### Attributes
|
||||
|
||||
<dl>
|
||||
<dt><tt>channels_last</tt> : int</dt>
|
||||
<dd></dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs
|
||||
|
||||
<dl>
|
||||
|
|
|
|||
|
|
@ -7,90 +7,128 @@
|
|||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
#define ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(name, ver, type, builder, ...) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMSNchwcDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
ReorderInput,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
ReorderInput);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
ReorderOutput,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
ReorderOutput);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
Conv,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.MayInplace(3, 0)
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
NchwcConv);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
MaxPool,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
NchwcMaxPool);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
GlobalMaxPool,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
NchwcMaxPool);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
AveragePool,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
NchwcAveragePool);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
GlobalAveragePool,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
NchwcAveragePool);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
Upsample,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
NchwcUpsample);
|
||||
|
||||
Status ReorderInput::Compute(OpKernelContext* context) const {
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
const auto& X_shape = X->Shape();
|
||||
ORT_ENFORCE(X_shape.NumDimensions() == 4);
|
||||
ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0);
|
||||
const auto& X_shape = X->Shape().GetDims();
|
||||
const auto X_rank = X_shape.size();
|
||||
ORT_ENFORCE(X_rank == 4);
|
||||
|
||||
auto* Y = context->Output(0, X_shape);
|
||||
MlasReorderInput(X_shape.GetDims().data(), X->template Data<float>(), Y->template MutableData<float>());
|
||||
const int64_t batch_count = X_shape[0];
|
||||
const int64_t channels = X_shape[channels_last_ ? X_rank - 1 : 1];
|
||||
const auto* X_spatial_dims = X_shape.data() + (channels_last_ ? 1 : 2);
|
||||
|
||||
// The current implementation of MlasReorderInputNchw does not work for channels that
|
||||
// are not a multiple of 4.
|
||||
ORT_ENFORCE((channels % 4) == 0);
|
||||
|
||||
const int64_t nchwc_block_size = static_cast<int64_t>(MlasNchwcGetBlockSize());
|
||||
const int64_t nchwc_channels = (channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);
|
||||
|
||||
std::vector<int64_t> Y_shape(X_rank);
|
||||
Y_shape[0] = batch_count;
|
||||
Y_shape[1] = nchwc_channels;
|
||||
int64_t spatial_size = 1;
|
||||
for (size_t i = 0; i < X_rank - 2; i++) {
|
||||
const int64_t spatial_dim = X_spatial_dims[i];
|
||||
spatial_size *= spatial_dim;
|
||||
Y_shape[2 + i] = spatial_dim;
|
||||
}
|
||||
|
||||
auto* Y = context->Output(0, Y_shape);
|
||||
|
||||
// Bail out early if one of the dimensions is zero.
|
||||
if (Y->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Compute the total amount of work depending on NCHW or NHWC format and estimate
|
||||
// a number of workers to use.
|
||||
ptrdiff_t total_work;
|
||||
ptrdiff_t worker_count;
|
||||
|
||||
if (channels_last_) {
|
||||
total_work = static_cast<ptrdiff_t>(batch_count * spatial_size);
|
||||
// Partition the work with the goal of reordering the following number of
|
||||
// elements, so that operations involving a smaller number of channels will
|
||||
// process more rows per worker.
|
||||
constexpr ptrdiff_t worker_goal = 48 * 1024;
|
||||
ptrdiff_t work_per_worker = worker_goal / nchwc_channels;
|
||||
if (work_per_worker == 0) {
|
||||
work_per_worker = 1;
|
||||
}
|
||||
worker_count = total_work / work_per_worker;
|
||||
if (worker_count == 0) {
|
||||
worker_count = 1;
|
||||
}
|
||||
} else {
|
||||
// Each iteration produces one spatial_size chunk of NCHWc blocks.
|
||||
total_work = static_cast<ptrdiff_t>(batch_count * (nchwc_channels / nchwc_block_size));
|
||||
worker_count = total_work;
|
||||
}
|
||||
|
||||
const auto* x_data = X->template Data<float>();
|
||||
auto* y_data = Y->template MutableData<float>();
|
||||
|
||||
auto reorder_worker = [&](ptrdiff_t batch) {
|
||||
auto work = concurrency::ThreadPool::PartitionWork(batch, worker_count, total_work);
|
||||
|
||||
if (channels_last_) {
|
||||
int64_t work_index = static_cast<int64_t>(work.start);
|
||||
int64_t work_remaining = static_cast<int64_t>(work.end - work.start);
|
||||
|
||||
while (work_remaining > 0) {
|
||||
const int64_t batch_index = work_index / spatial_size;
|
||||
const int64_t spatial_index = work_index % spatial_size;
|
||||
const int64_t rows_this_iteration = std::min(work_remaining, spatial_size - spatial_index);
|
||||
|
||||
MlasReorderInputNhwc(
|
||||
x_data + ((batch_index * spatial_size) + spatial_index) * channels,
|
||||
y_data + (batch_index * spatial_size * nchwc_channels) + (spatial_index * nchwc_block_size),
|
||||
static_cast<size_t>(channels),
|
||||
static_cast<size_t>(rows_this_iteration),
|
||||
static_cast<size_t>(spatial_size));
|
||||
|
||||
work_index += rows_this_iteration;
|
||||
work_remaining -= rows_this_iteration;
|
||||
}
|
||||
} else {
|
||||
int64_t work_index = static_cast<int64_t>(work.start) * nchwc_block_size;
|
||||
int64_t work_remaining = static_cast<int64_t>(work.end - work.start) * nchwc_block_size;
|
||||
|
||||
while (work_remaining > 0) {
|
||||
const int64_t batch_index = work_index / nchwc_channels;
|
||||
const int64_t channel_index = work_index % nchwc_channels;
|
||||
const int64_t channels_this_iteration = std::min(work_remaining, channels - channel_index);
|
||||
|
||||
MlasReorderInputNchw(
|
||||
x_data + ((batch_index * channels) + channel_index) * spatial_size,
|
||||
y_data + ((batch_index * nchwc_channels) + channel_index) * spatial_size,
|
||||
static_cast<size_t>(channels_this_iteration),
|
||||
static_cast<size_t>(spatial_size));
|
||||
|
||||
const int64_t nchwc_channels_this_iteration = std::min(work_remaining, nchwc_channels - channel_index);
|
||||
work_index += nchwc_channels_this_iteration;
|
||||
work_remaining -= nchwc_channels_this_iteration;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
|
||||
|
||||
// Handle the work in a single batch if only a single thread is available.
|
||||
if (concurrency::ThreadPool::DegreeOfParallelism(thread_pool) == 1) {
|
||||
worker_count = 1;
|
||||
}
|
||||
|
||||
concurrency::ThreadPool::TrySimpleParallelFor(thread_pool, worker_count, reorder_worker);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReorderOutput::Compute(OpKernelContext* context) const {
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
const auto& X_shape = X->Shape();
|
||||
const auto X_rank = X_shape.NumDimensions();
|
||||
const auto& X_shape = X->Shape().GetDims();
|
||||
const auto X_rank = X_shape.size();
|
||||
ORT_ENFORCE(X_rank == 4);
|
||||
ORT_ENFORCE(channels_ <= X_shape[1]);
|
||||
|
||||
|
|
@ -235,5 +273,73 @@ Status NchwcUpsample::Compute(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#define ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(name, ver, type, builder, ...) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMSNchwcDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
ReorderInput,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
ReorderInput);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
ReorderOutput,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
ReorderOutput);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
Conv,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.MayInplace(3, 0)
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
NchwcConv);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
MaxPool,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
NchwcMaxPool);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
GlobalMaxPool,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
NchwcMaxPool);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
AveragePool,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
NchwcAveragePool);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
GlobalAveragePool,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
NchwcAveragePool);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
Upsample,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
NchwcUpsample);
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -15,9 +15,13 @@ namespace contrib {
|
|||
class ReorderInput : public OpKernel {
|
||||
public:
|
||||
ReorderInput(const OpKernelInfo& info) : OpKernel(info) {
|
||||
ORT_ENFORCE(info.GetAttr<int64_t>("channels_last", &channels_last_).IsOK());
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
int64_t channels_last_;
|
||||
};
|
||||
|
||||
class ReorderOutput : public OpKernel {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/graph/constants.h"
|
||||
#include "core/graph/contrib_ops/contrib_defs.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
|
||||
namespace ONNX_NAMESPACE {
|
||||
void convPoolShapeInference(
|
||||
|
|
@ -58,10 +59,45 @@ void RegisterNchwcSchemas() {
|
|||
.SetDomain(kMSNchwcDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc(R"DOC(For internal use.)DOC")
|
||||
.Attr("channels_last", "", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Input(0, "X", "", "T")
|
||||
.Output(0, "Y", "", "T")
|
||||
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
if (!hasNInputShapes(ctx, 1)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
|
||||
auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
|
||||
|
||||
auto input_rank = input_shape.dim_size();
|
||||
if (input_rank < 2) {
|
||||
fail_shape_inference("tensor rank too small");
|
||||
}
|
||||
|
||||
auto channels_last = getAttribute(ctx, "channels_last", 0);
|
||||
|
||||
// Copy the batch dimension.
|
||||
*output_shape->add_dim() = input_shape.dim(0);
|
||||
|
||||
// Block align the channel dimension.
|
||||
const auto& input_channel_dim = input_shape.dim((channels_last == 0) ? 1 : input_rank - 1);
|
||||
auto* output_channel_dim = output_shape->add_dim();
|
||||
if (input_channel_dim.has_dim_value()) {
|
||||
const int64_t channels = input_channel_dim.dim_value();
|
||||
const int64_t nchwc_block_size = static_cast<int64_t>(MlasNchwcGetBlockSize());
|
||||
int64_t nchwc_channels = (channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);
|
||||
output_channel_dim->set_dim_value(nchwc_channels);
|
||||
}
|
||||
|
||||
// Copy the spatial dimensions.
|
||||
int first_spatial_dim = (channels_last == 0) ? 2 : 1;
|
||||
for (int i = 0; i < input_rank - 2; i++) {
|
||||
*output_shape->add_dim() = input_shape.dim(first_spatial_dim + i);
|
||||
}
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(ReorderOutput)
|
||||
.SetDomain(kMSNchwcDomain)
|
||||
|
|
@ -78,8 +114,8 @@ void RegisterNchwcSchemas() {
|
|||
return;
|
||||
}
|
||||
|
||||
auto input_shape = ctx.getInputType(0)->tensor_type().shape();
|
||||
auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
|
||||
const auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
|
||||
auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
|
||||
|
||||
auto input_rank = input_shape.dim_size();
|
||||
if (input_rank < 2) {
|
||||
|
|
@ -92,7 +128,7 @@ void RegisterNchwcSchemas() {
|
|||
fail_shape_inference("invalid channel count");
|
||||
}
|
||||
|
||||
// Copy batch dimension.
|
||||
// Copy the batch dimension.
|
||||
*output_shape->add_dim() = input_shape.dim(0);
|
||||
|
||||
auto channels_last = getAttribute(ctx, "channels_last", 0);
|
||||
|
|
@ -100,7 +136,7 @@ void RegisterNchwcSchemas() {
|
|||
output_shape->add_dim()->set_dim_value(channels);
|
||||
}
|
||||
|
||||
// Copy spatial dimensions.
|
||||
// Copy the spatial dimensions.
|
||||
for (int i = 0; i < input_rank - 2; i++) {
|
||||
*output_shape->add_dim() = input_shape.dim(2 + i);
|
||||
}
|
||||
|
|
@ -161,8 +197,8 @@ void RegisterNchwcSchemas() {
|
|||
return;
|
||||
}
|
||||
|
||||
auto input_shape = ctx.getInputType(0)->tensor_type().shape();
|
||||
auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
|
||||
const auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
|
||||
auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
|
||||
|
||||
auto input_rank = input_shape.dim_size();
|
||||
if (input_rank < 2) {
|
||||
|
|
|
|||
|
|
@ -552,12 +552,12 @@ MlasGemm(
|
|||
MLAS_THREADPOOL* ThreadPool
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief Batched GEMM, for multiplying multiple pairs of matrices.
|
||||
/**
|
||||
* @brief Batched GEMM, for multiplying multiple pairs of matrices.
|
||||
* Note: We only support uniform batching, so shapes and types of the
|
||||
* input must be same: M, N, K, BIsSigned must be the
|
||||
* same across all parameter blocks.
|
||||
*
|
||||
* same across all parameter blocks.
|
||||
*
|
||||
* @param [IN] Shape A single shape descriptor for all the multiplications
|
||||
* @param [IN] DataParams Array of data descriptors for the matrices.
|
||||
* @param [IN] BatchN Size of the parameters array, also number of multiplications to perform
|
||||
|
|
@ -834,10 +834,21 @@ MlasTranspose(
|
|||
|
||||
void
|
||||
MLASCALL
|
||||
MlasReorderInput(
|
||||
const int64_t* InputShape,
|
||||
MlasReorderInputNchw(
|
||||
const float* S,
|
||||
float* D
|
||||
float* D,
|
||||
size_t InputChannels,
|
||||
size_t InputSize
|
||||
);
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasReorderInputNhwc(
|
||||
const float* S,
|
||||
float* D,
|
||||
size_t InputChannels,
|
||||
size_t RowCount,
|
||||
size_t FullRowCount
|
||||
);
|
||||
|
||||
void
|
||||
|
|
|
|||
|
|
@ -179,10 +179,11 @@ Return Value:
|
|||
|
||||
void
|
||||
MLASCALL
|
||||
MlasReorderInput(
|
||||
const int64_t* InputShape,
|
||||
MlasReorderInputNchw(
|
||||
const float* S,
|
||||
float* D
|
||||
float* D,
|
||||
size_t InputChannels,
|
||||
size_t InputSize
|
||||
)
|
||||
/*++
|
||||
|
||||
|
|
@ -192,12 +193,14 @@ Routine Description:
|
|||
|
||||
Arguments:
|
||||
|
||||
InputShape - Supplies the shape of the input tensor.
|
||||
|
||||
S - Supplies the address of the source tensor.
|
||||
|
||||
D - Supplies the address of the destination tensor.
|
||||
|
||||
InputChannels - Supplies the number of NCHW channels.
|
||||
|
||||
InputSize - Supplies the spatial input size of the tensors.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
|
@ -206,11 +209,12 @@ Return Value:
|
|||
{
|
||||
const size_t BlockSize = MlasNchwcGetBlockSize();
|
||||
|
||||
const size_t InputChannels = size_t(InputShape[0] * InputShape[1]);
|
||||
const size_t InputSize = size_t(InputShape[2]) * size_t(InputShape[3]);
|
||||
|
||||
const MLAS_FLOAT32X4 ZeroFloat32x4 = MlasZeroFloat32x4();
|
||||
|
||||
//
|
||||
// Iterate over BlockSize batches of the input channels.
|
||||
//
|
||||
|
||||
for (size_t i = InputChannels; i > 0;) {
|
||||
|
||||
const size_t InputChannelsThisIteration = std::min(i, BlockSize);
|
||||
|
|
@ -233,7 +237,10 @@ Return Value:
|
|||
}
|
||||
|
||||
for (; bc < BlockSize; bc += 4) {
|
||||
MlasStoreFloat32x4(dd, ZeroFloat32x4);
|
||||
MlasStoreFloat32x4(&dd[BlockSize * 0], ZeroFloat32x4);
|
||||
MlasStoreFloat32x4(&dd[BlockSize * 1], ZeroFloat32x4);
|
||||
MlasStoreFloat32x4(&dd[BlockSize * 2], ZeroFloat32x4);
|
||||
MlasStoreFloat32x4(&dd[BlockSize * 3], ZeroFloat32x4);
|
||||
dd += 4;
|
||||
}
|
||||
|
||||
|
|
@ -267,6 +274,127 @@ Return Value:
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasReorderInputNhwc(
|
||||
const float* S,
|
||||
float* D,
|
||||
size_t InputChannels,
|
||||
size_t RowCount,
|
||||
size_t FullRowCount
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine reorders an input buffer from NHWC to NCHWc format.
|
||||
|
||||
Arguments:
|
||||
|
||||
S - Supplies the address of the source tensor.
|
||||
|
||||
D - Supplies the address of the destination tensor.
|
||||
|
||||
InputChannels - Supplies the number of NHWC channels.
|
||||
|
||||
RowCount - Supplies the number of NHWC rows to process. This number may be
|
||||
less than FullRowCount to support threaded operation.
|
||||
|
||||
FullRowCount - Supplies the total number of NHWC rows per image.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
const size_t BlockSize = MlasNchwcGetBlockSize();
|
||||
|
||||
//
|
||||
// Iterate over batches of the input size to improve locality.
|
||||
//
|
||||
|
||||
for (size_t OuterRowCountRemaining = RowCount; OuterRowCountRemaining > 0; ) {
|
||||
|
||||
constexpr size_t OuterRowCountBatch = 32;
|
||||
|
||||
const size_t OuterRowCountThisIteration = std::min(OuterRowCountRemaining, OuterRowCountBatch);
|
||||
OuterRowCountRemaining -= OuterRowCountThisIteration;
|
||||
|
||||
//
|
||||
// Iterate over BlockSize batches of the input channels.
|
||||
//
|
||||
|
||||
const float* s = S;
|
||||
float* d = D;
|
||||
|
||||
for (size_t i = InputChannels; i > 0;) {
|
||||
|
||||
const size_t InputChannelsThisIteration = std::min(i, BlockSize);
|
||||
i -= InputChannelsThisIteration;
|
||||
|
||||
const float* ss = s;
|
||||
float* dd = d;
|
||||
size_t InnerRowCountRemaining = OuterRowCountThisIteration;
|
||||
|
||||
if (InputChannelsThisIteration == BlockSize) {
|
||||
|
||||
if (BlockSize == 8) {
|
||||
|
||||
while (InnerRowCountRemaining-- > 0) {
|
||||
|
||||
MLAS_FLOAT32X4 v0 = MlasLoadFloat32x4(&ss[0]);
|
||||
MLAS_FLOAT32X4 v1 = MlasLoadFloat32x4(&ss[4]);
|
||||
|
||||
MlasStoreFloat32x4(&dd[0], v0);
|
||||
MlasStoreFloat32x4(&dd[4], v1);
|
||||
|
||||
ss += InputChannels;
|
||||
dd += 8;
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
while (InnerRowCountRemaining-- > 0) {
|
||||
|
||||
MLAS_FLOAT32X4 v0 = MlasLoadFloat32x4(&ss[0]);
|
||||
MLAS_FLOAT32X4 v1 = MlasLoadFloat32x4(&ss[4]);
|
||||
MLAS_FLOAT32X4 v2 = MlasLoadFloat32x4(&ss[8]);
|
||||
MLAS_FLOAT32X4 v3 = MlasLoadFloat32x4(&ss[12]);
|
||||
|
||||
MlasStoreFloat32x4(&dd[0], v0);
|
||||
MlasStoreFloat32x4(&dd[4], v1);
|
||||
MlasStoreFloat32x4(&dd[8], v2);
|
||||
MlasStoreFloat32x4(&dd[12], v3);
|
||||
|
||||
ss += InputChannels;
|
||||
dd += 16;
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
size_t BlockPadding = BlockSize - InputChannelsThisIteration;
|
||||
|
||||
while (InnerRowCountRemaining-- > 0) {
|
||||
|
||||
std::copy_n(ss, InputChannelsThisIteration, dd);
|
||||
std::fill_n(dd + InputChannelsThisIteration, BlockPadding, 0.0f);
|
||||
|
||||
ss += InputChannels;
|
||||
dd += BlockSize;
|
||||
}
|
||||
}
|
||||
|
||||
s += InputChannelsThisIteration;
|
||||
d += BlockSize * FullRowCount;
|
||||
}
|
||||
|
||||
S += InputChannels * OuterRowCountThisIteration;
|
||||
D += BlockSize * OuterRowCountThisIteration;
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasReorderOutputNchw(
|
||||
|
|
|
|||
|
|
@ -120,8 +120,9 @@ class NchwcTransformerImpl {
|
|||
void TransformConcat(Node& node);
|
||||
void TransformActivation(Node& node);
|
||||
void TransformBatchNormalization(Node& node);
|
||||
void TransformTranspose(Node& node);
|
||||
void TransformTransposeToNhwc(Node& node);
|
||||
void TransformResize(Node& node);
|
||||
void TrackTransposeFromNhwc(Node& node);
|
||||
|
||||
Graph& graph_;
|
||||
|
||||
|
|
@ -149,6 +150,11 @@ class NchwcTransformerImpl {
|
|||
// or unsplitting the channels dimension of a tensor.
|
||||
std::unordered_map<int64_t, NodeArg*> reshape_split_;
|
||||
std::unordered_map<int64_t, NodeArg*> reshape_unsplit_;
|
||||
|
||||
// Tracks the last Transpose node and output NodeArg that transposed from
|
||||
// NHWC to NCHW format.
|
||||
Node* transpose_from_nhwc_node_{nullptr};
|
||||
NodeArg* transpose_from_nhwc_output_arg_{nullptr};
|
||||
};
|
||||
|
||||
size_t NchwcTransformerImpl::RemoveOutputEdges(Node& node) {
|
||||
|
|
@ -209,6 +215,20 @@ void NchwcTransformerImpl::InsertReorderInput(Node& node) {
|
|||
kMSNchwcDomain);
|
||||
reorder_input_node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
input_defs[0] = input_nchwc_arg;
|
||||
|
||||
// Attempt to fuse the ReorderInput with a previous Transpose of NHWC->NCHW.
|
||||
// If the last known node to transpose from NHWC is the same as this input
|
||||
// argument, then the Transpose node can be removed and the ReorderInput node
|
||||
// is modified to consume a tensor with NHWC layout order. The transpose was
|
||||
// already determined to have a single use and not be a graph output.
|
||||
if (transpose_from_nhwc_output_arg_ == input_original_arg) {
|
||||
reorder_input_node.MutableInputDefs()[0] = transpose_from_nhwc_node_->MutableInputDefs()[0];
|
||||
reorder_input_node.AddAttribute("channels_last", static_cast<int64_t>(1));
|
||||
graph_utils::RemoveNodeOutputEdges(graph_, *transpose_from_nhwc_node_);
|
||||
removed_nodes_.push_front(transpose_from_nhwc_node_->Index());
|
||||
transpose_from_nhwc_node_ = nullptr;
|
||||
}
|
||||
|
||||
} else {
|
||||
input_defs[0] = it->second;
|
||||
}
|
||||
|
|
@ -324,14 +344,21 @@ void NchwcTransformerImpl::TransformConv(Node& node) {
|
|||
|
||||
bool do_reorder_input = true;
|
||||
bool reorder_filter_OIHWBo = false;
|
||||
int64_t filter_input_channels = input_channels;
|
||||
int64_t nchwc_group_count = group_count;
|
||||
|
||||
// The current implementation of ReorderInput requires the channel count to be
|
||||
// aligned to this value.
|
||||
constexpr int64_t channel_alignment = 4;
|
||||
|
||||
if (group_count > 1) {
|
||||
if ((output_channels % nchwc_block_size) != 0) {
|
||||
if ((output_channels % channel_alignment) != 0) {
|
||||
return;
|
||||
}
|
||||
if (input_channels == 1 && output_channels == group_count) {
|
||||
// Depthwise convolution.
|
||||
reorder_filter_OIHWBo = true;
|
||||
nchwc_group_count = nchwc_output_channels;
|
||||
} else if (((input_channels % nchwc_block_size) != 0) ||
|
||||
((output_channels % group_count) != 0) ||
|
||||
(((output_channels / group_count) % nchwc_block_size) != 0)) {
|
||||
|
|
@ -342,8 +369,11 @@ void NchwcTransformerImpl::TransformConv(Node& node) {
|
|||
// Use NCHW input buffer directly.
|
||||
reorder_filter_OIHWBo = true;
|
||||
do_reorder_input = false;
|
||||
} else if ((input_channels % nchwc_block_size) != 0) {
|
||||
return;
|
||||
} else {
|
||||
if ((input_channels % channel_alignment) != 0) {
|
||||
return;
|
||||
}
|
||||
filter_input_channels = (input_channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -374,15 +404,19 @@ void NchwcTransformerImpl::TransformConv(Node& node) {
|
|||
nchwc_conv_W_arg = filters_it->second;
|
||||
} else {
|
||||
Initializer conv_W{*conv_W_tensor_proto, graph_.ModelPath()};
|
||||
const auto& conv_W_dims = conv_W.dims();
|
||||
|
||||
int64_t reordered_filter_vec_size = conv_W.size() / output_channels * nchwc_output_channels;
|
||||
std::vector<float> reordered_filter(gsl::narrow<size_t>(reordered_filter_vec_size));
|
||||
int64_t reordered_filter_size = nchwc_output_channels * filter_input_channels;
|
||||
for (size_t i = 2; i < 4; i++) {
|
||||
reordered_filter_size *= conv_W_dims[i];
|
||||
}
|
||||
std::vector<float> reordered_filter(gsl::narrow<size_t>(reordered_filter_size));
|
||||
|
||||
// Reorder the weights tensor statically.
|
||||
if (reorder_filter_OIHWBo) {
|
||||
MlasReorderFilterOIHWBo(conv_W.dims().data(), conv_W.data<float>(), reordered_filter.data());
|
||||
MlasReorderFilterOIHWBo(conv_W_dims.data(), conv_W.data<float>(), reordered_filter.data());
|
||||
} else {
|
||||
MlasReorderFilterOIHWBiBo(conv_W.dims().data(), conv_W.data<float>(), reordered_filter.data());
|
||||
MlasReorderFilterOIHWBiBo(conv_W_dims.data(), conv_W.data<float>(), reordered_filter.data());
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::TensorProto nchwc_conv_W_tensor_proto;
|
||||
|
|
@ -392,8 +426,9 @@ void NchwcTransformerImpl::TransformConv(Node& node) {
|
|||
nchwc_conv_W_tensor_proto.set_raw_data(reordered_filter.data(), reordered_filter.size() * sizeof(float));
|
||||
|
||||
nchwc_conv_W_tensor_proto.add_dims(nchwc_output_channels);
|
||||
for (size_t i = 1; i < 4; i++) {
|
||||
nchwc_conv_W_tensor_proto.add_dims(conv_W.dims()[i]);
|
||||
nchwc_conv_W_tensor_proto.add_dims(filter_input_channels);
|
||||
for (size_t i = 2; i < 4; i++) {
|
||||
nchwc_conv_W_tensor_proto.add_dims(conv_W_dims[i]);
|
||||
}
|
||||
|
||||
nchwc_conv_W_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_W_tensor_proto);
|
||||
|
|
@ -436,6 +471,9 @@ void NchwcTransformerImpl::TransformConv(Node& node) {
|
|||
&node.GetAttributes(),
|
||||
kMSNchwcDomain);
|
||||
nchwc_node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
if (nchwc_group_count != group_count) {
|
||||
nchwc_node.AddAttribute("group", nchwc_group_count);
|
||||
}
|
||||
|
||||
nchwc_node.MutableInputDefs()[1] = nchwc_conv_W_arg;
|
||||
|
||||
|
|
@ -892,7 +930,7 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) {
|
|||
removed_nodes_.push_front(node.Index());
|
||||
}
|
||||
|
||||
void NchwcTransformerImpl::TransformTranspose(Node& node) {
|
||||
void NchwcTransformerImpl::TransformTransposeToNhwc(Node& node) {
|
||||
auto& input_defs = node.MutableInputDefs();
|
||||
auto& output_defs = node.MutableOutputDefs();
|
||||
|
||||
|
|
@ -1070,7 +1108,33 @@ void NchwcTransformerImpl::TransformResize(Node& node) {
|
|||
removed_nodes_.push_front(node.Index());
|
||||
}
|
||||
|
||||
void NchwcTransformerImpl::TrackTransposeFromNhwc(Node& node) {
|
||||
const auto* perm_attr = graph_utils::GetNodeAttribute(node, "perm");
|
||||
if (perm_attr == nullptr || perm_attr->ints_size() != 4) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Test if this transposes from NHWC to NCHW layout order.
|
||||
const int64_t* perm_data = perm_attr->ints().data();
|
||||
if (perm_data[0] != 0 || perm_data[1] != 3 || perm_data[2] != 1 || perm_data[3] != 2) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Verify that the node does not produce a graph output and produces output
|
||||
// for a single node.
|
||||
if (!graph_.GetNodeOutputsInGraphOutputs(node).empty() || node.GetOutputEdgesCount() != 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
transpose_from_nhwc_node_ = &node;
|
||||
transpose_from_nhwc_output_arg_ = node.MutableOutputDefs()[0];
|
||||
}
|
||||
|
||||
void NchwcTransformerImpl::Transform(Node& node) {
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13})) {
|
||||
TrackTransposeFromNhwc(node);
|
||||
}
|
||||
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "FusedConv", {1}, kMSDomain)) {
|
||||
TransformConv(node);
|
||||
|
|
@ -1097,7 +1161,7 @@ void NchwcTransformerImpl::Transform(Node& node) {
|
|||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9})) {
|
||||
TransformBatchNormalization(node);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13})) {
|
||||
TransformTranspose(node);
|
||||
TransformTransposeToNhwc(node);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Upsample", {9, 13}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Resize", {10, 11, 13})) {
|
||||
TransformResize(node);
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ class MlasNchwcConv2DTest : public MlasConv2DTest<Threaded> {
|
|||
if (DoReorderInput) {
|
||||
size_t NchwcInputElements = BatchCount * NchwcInputChannels * InputHeight * InputWidth;
|
||||
float* NchwcInput = BufferNchwcInput.GetBuffer(NchwcInputElements);
|
||||
MlasReorderInput(InputShape, Input, NchwcInput);
|
||||
ReorderInputNchw(InputShape, Input, NchwcInput);
|
||||
Input = NchwcInput;
|
||||
InputShape[1] = NchwcInputChannels;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class MlasNchwcPool2DTest : public MlasPool2DTest<PoolingKind, Threaded> {
|
|||
size_t NchwcOutputElements = size_t(NchwcOutputShape[0]) * size_t(NchwcOutputShape[1]) * size_t(NchwcOutputShape[2]) * size_t(NchwcOutputShape[3]);
|
||||
float* NchwcOutput = BufferNchwcOutput.GetBuffer(NchwcOutputElements);
|
||||
|
||||
MlasReorderInput(InputShape, Input, NchwcInput);
|
||||
ReorderInputNchw(InputShape, Input, NchwcInput);
|
||||
|
||||
MlasNchwcPool(PoolingKind,
|
||||
NchwcInputShape,
|
||||
|
|
|
|||
|
|
@ -211,7 +211,7 @@ class MlasLongExecuteTests : public MlasTestFixture<TMlasTester> {
|
|||
}
|
||||
};
|
||||
|
||||
// Some short Execute may not need to distinguish each parameters,
|
||||
// Some short Execute may not need to distinguish each parameters,
|
||||
// because they finish quickly, and may disturb others by inject too many small tests.
|
||||
// Register it as whole using following helper.
|
||||
template <typename TMlasTester>
|
||||
|
|
@ -237,3 +237,16 @@ class MlasDirectShortExecuteTests : public MlasTestFixture<TMlasTester> {
|
|||
}
|
||||
};
|
||||
|
||||
inline
|
||||
void ReorderInputNchw(const int64_t* input_shape, const float* S, float* D) {
|
||||
const int64_t nchwc_block_size = static_cast<int64_t>(MlasNchwcGetBlockSize());
|
||||
int64_t batch_count = input_shape[0];
|
||||
int64_t channel_count = input_shape[1];
|
||||
int64_t nchwc_channel_count = (channel_count + nchwc_block_size - 1) & ~(nchwc_block_size - 1);
|
||||
int64_t spatial_count = input_shape[2] * input_shape[3];
|
||||
for (int64_t n = 0; n < batch_count; n++) {
|
||||
MlasReorderInputNchw(S, D, static_cast<size_t>(channel_count), static_cast<size_t>(spatial_count));
|
||||
S += spatial_count * channel_count;
|
||||
D += spatial_count * nchwc_channel_count;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -970,10 +970,10 @@ TEST(NchwcOptimizerTests, MixedOutputUsage) {
|
|||
|
||||
TEST(NchwcOptimizerTests, TensorAlignment) {
|
||||
auto build_test_case = [&](NchwcTestHelper& helper) {
|
||||
// Input channel count must currently be a multiple of the NCHWc block size.
|
||||
auto* input1_arg = helper.MakeInput<float>({1, 60, 28, 42});
|
||||
// Input channel count must currently be a multiple of 4.
|
||||
auto* input1_arg = helper.MakeInput<float>({1, 62, 28, 42});
|
||||
auto* output1_arg = helper.MakeOutput();
|
||||
helper.AddConvNode(input1_arg, output1_arg, {128, 60, 1, 1});
|
||||
helper.AddConvNode(input1_arg, output1_arg, {128, 62, 1, 1});
|
||||
|
||||
// Grouped input channel count must be a multiple of the NCHWc block size.
|
||||
auto* input2_arg = helper.MakeInput<float>({1, 48, 28, 42});
|
||||
|
|
@ -1116,6 +1116,34 @@ TEST(NchwcOptimizerTests, BatchNormalization) {
|
|||
test_case(true);
|
||||
}
|
||||
|
||||
TEST(NchwcOptimizerTests, ConvReorderInputNhwc) {
|
||||
auto test_case = [&](int64_t channels) {
|
||||
auto build_test_case = [&](NchwcTestHelper& helper) {
|
||||
auto* input_arg = helper.MakeInput<float>({5, 27, 29, channels});
|
||||
auto* transpose_output_arg = helper.MakeIntermediate();
|
||||
auto* output_arg = helper.MakeOutput();
|
||||
|
||||
helper.AddTransposeToNchwNode(input_arg, transpose_output_arg);
|
||||
helper.AddConvNode(transpose_output_arg, output_arg, {34, channels, 1, 1});
|
||||
};
|
||||
|
||||
auto check_nchwc_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 1);
|
||||
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1);
|
||||
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1);
|
||||
EXPECT_EQ(op_to_count["Transpose"], 0);
|
||||
};
|
||||
|
||||
// Verify that a NHWC->NCHW transpose is fused into ReorderInput.
|
||||
NchwcOptimizerTester(build_test_case, check_nchwc_graph);
|
||||
};
|
||||
|
||||
for (int64_t channels = 16; channels <= 32; channels += 4) {
|
||||
test_case(channels);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NchwcOptimizerTests, ConvReorderOutputNhwc) {
|
||||
auto build_test_case = [&](NchwcTestHelper& helper) {
|
||||
auto* input_arg = helper.MakeInput<float>({1, 64, 28, 32});
|
||||
|
|
|
|||
Loading…
Reference in a new issue