diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 4727498238..905bd37ed2 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2984,6 +2984,13 @@ This version of the operator has been available since version 1 of the 'com.micr
This version of the operator has been available since version 1 of the 'com.microsoft.nchwc' operator set.
+#### Attributes
+
+
+- channels_last : int
+
+
+
#### Inputs
diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc
index e81a04d587..fd9c1f3a40 100644
--- a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc
+++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc
@@ -7,90 +7,128 @@
namespace onnxruntime {
namespace contrib {
-#define ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(name, ver, type, builder, ...) \
- ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMSNchwcDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
-
-ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
- ReorderInput,
- 1,
- float,
- KernelDefBuilder()
- .TypeConstraint("T", DataTypeImpl::GetTensorType()),
- ReorderInput);
-
-ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
- ReorderOutput,
- 1,
- float,
- KernelDefBuilder()
- .TypeConstraint("T", DataTypeImpl::GetTensorType()),
- ReorderOutput);
-
-ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
- Conv,
- 1,
- float,
- KernelDefBuilder()
- .MayInplace(3, 0)
- .TypeConstraint("T", DataTypeImpl::GetTensorType()),
- NchwcConv);
-
-ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
- MaxPool,
- 1,
- float,
- KernelDefBuilder()
- .TypeConstraint("T", DataTypeImpl::GetTensorType()),
- NchwcMaxPool);
-
-ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
- GlobalMaxPool,
- 1,
- float,
- KernelDefBuilder()
- .TypeConstraint("T", DataTypeImpl::GetTensorType()),
- NchwcMaxPool);
-
-ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
- AveragePool,
- 1,
- float,
- KernelDefBuilder()
- .TypeConstraint("T", DataTypeImpl::GetTensorType()),
- NchwcAveragePool);
-
-ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
- GlobalAveragePool,
- 1,
- float,
- KernelDefBuilder()
- .TypeConstraint("T", DataTypeImpl::GetTensorType()),
- NchwcAveragePool);
-
-ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
- Upsample,
- 1,
- float,
- KernelDefBuilder()
- .TypeConstraint("T", DataTypeImpl::GetTensorType()),
- NchwcUpsample);
-
Status ReorderInput::Compute(OpKernelContext* context) const {
const auto* X = context->Input(0);
- const auto& X_shape = X->Shape();
- ORT_ENFORCE(X_shape.NumDimensions() == 4);
- ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0);
+ const auto& X_shape = X->Shape().GetDims();
+ const auto X_rank = X_shape.size();
+ ORT_ENFORCE(X_rank == 4);
- auto* Y = context->Output(0, X_shape);
- MlasReorderInput(X_shape.GetDims().data(), X->template Data(), Y->template MutableData());
+ const int64_t batch_count = X_shape[0];
+ const int64_t channels = X_shape[channels_last_ ? X_rank - 1 : 1];
+ const auto* X_spatial_dims = X_shape.data() + (channels_last_ ? 1 : 2);
+
+ // The current implementation of MlasReorderInputNchw does not work for channels that
+ // are not a multiple of 4.
+ ORT_ENFORCE((channels % 4) == 0);
+
+ const int64_t nchwc_block_size = static_cast(MlasNchwcGetBlockSize());
+ const int64_t nchwc_channels = (channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);
+
+ std::vector Y_shape(X_rank);
+ Y_shape[0] = batch_count;
+ Y_shape[1] = nchwc_channels;
+ int64_t spatial_size = 1;
+ for (size_t i = 0; i < X_rank - 2; i++) {
+ const int64_t spatial_dim = X_spatial_dims[i];
+ spatial_size *= spatial_dim;
+ Y_shape[2 + i] = spatial_dim;
+ }
+
+ auto* Y = context->Output(0, Y_shape);
+
+ // Bail out early if one of the dimensions is zero.
+ if (Y->Shape().Size() == 0) {
+ return Status::OK();
+ }
+
+ // Compute the total amount of work depending on NCHW or NHWC format and estimate
+ // a number of workers to use.
+ ptrdiff_t total_work;
+ ptrdiff_t worker_count;
+
+ if (channels_last_) {
+ total_work = static_cast(batch_count * spatial_size);
+ // Partition the work with the goal of reordering the following number of
+ // elements, so that operations involving a smaller number of channels will
+ // process more rows per worker.
+ constexpr ptrdiff_t worker_goal = 48 * 1024;
+ ptrdiff_t work_per_worker = worker_goal / nchwc_channels;
+ if (work_per_worker == 0) {
+ work_per_worker = 1;
+ }
+ worker_count = total_work / work_per_worker;
+ if (worker_count == 0) {
+ worker_count = 1;
+ }
+ } else {
+ // Each iteration produces one spatial_size chunk of NCHWc blocks.
+ total_work = static_cast(batch_count * (nchwc_channels / nchwc_block_size));
+ worker_count = total_work;
+ }
+
+ const auto* x_data = X->template Data();
+ auto* y_data = Y->template MutableData();
+
+ auto reorder_worker = [&](ptrdiff_t batch) {
+ auto work = concurrency::ThreadPool::PartitionWork(batch, worker_count, total_work);
+
+ if (channels_last_) {
+ int64_t work_index = static_cast(work.start);
+ int64_t work_remaining = static_cast(work.end - work.start);
+
+ while (work_remaining > 0) {
+ const int64_t batch_index = work_index / spatial_size;
+ const int64_t spatial_index = work_index % spatial_size;
+ const int64_t rows_this_iteration = std::min(work_remaining, spatial_size - spatial_index);
+
+ MlasReorderInputNhwc(
+ x_data + ((batch_index * spatial_size) + spatial_index) * channels,
+ y_data + (batch_index * spatial_size * nchwc_channels) + (spatial_index * nchwc_block_size),
+ static_cast(channels),
+ static_cast(rows_this_iteration),
+ static_cast(spatial_size));
+
+ work_index += rows_this_iteration;
+ work_remaining -= rows_this_iteration;
+ }
+ } else {
+ int64_t work_index = static_cast(work.start) * nchwc_block_size;
+ int64_t work_remaining = static_cast(work.end - work.start) * nchwc_block_size;
+
+ while (work_remaining > 0) {
+ const int64_t batch_index = work_index / nchwc_channels;
+ const int64_t channel_index = work_index % nchwc_channels;
+ const int64_t channels_this_iteration = std::min(work_remaining, channels - channel_index);
+
+ MlasReorderInputNchw(
+ x_data + ((batch_index * channels) + channel_index) * spatial_size,
+ y_data + ((batch_index * nchwc_channels) + channel_index) * spatial_size,
+ static_cast(channels_this_iteration),
+ static_cast(spatial_size));
+
+ const int64_t nchwc_channels_this_iteration = std::min(work_remaining, nchwc_channels - channel_index);
+ work_index += nchwc_channels_this_iteration;
+ work_remaining -= nchwc_channels_this_iteration;
+ }
+ }
+ };
+
+ concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
+
+ // Handle the work in a single batch if only a single thread is available.
+ if (concurrency::ThreadPool::DegreeOfParallelism(thread_pool) == 1) {
+ worker_count = 1;
+ }
+
+ concurrency::ThreadPool::TrySimpleParallelFor(thread_pool, worker_count, reorder_worker);
return Status::OK();
}
Status ReorderOutput::Compute(OpKernelContext* context) const {
const auto* X = context->Input(0);
- const auto& X_shape = X->Shape();
- const auto X_rank = X_shape.NumDimensions();
+ const auto& X_shape = X->Shape().GetDims();
+ const auto X_rank = X_shape.size();
ORT_ENFORCE(X_rank == 4);
ORT_ENFORCE(channels_ <= X_shape[1]);
@@ -235,5 +273,73 @@ Status NchwcUpsample::Compute(OpKernelContext* context) const {
return Status::OK();
}
+#define ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(name, ver, type, builder, ...) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMSNchwcDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
+
+ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
+ ReorderInput,
+ 1,
+ float,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ ReorderInput);
+
+ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
+ ReorderOutput,
+ 1,
+ float,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ ReorderOutput);
+
+ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
+ Conv,
+ 1,
+ float,
+ KernelDefBuilder()
+ .MayInplace(3, 0)
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ NchwcConv);
+
+ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
+ MaxPool,
+ 1,
+ float,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ NchwcMaxPool);
+
+ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
+ GlobalMaxPool,
+ 1,
+ float,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ NchwcMaxPool);
+
+ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
+ AveragePool,
+ 1,
+ float,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ NchwcAveragePool);
+
+ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
+ GlobalAveragePool,
+ 1,
+ float,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ NchwcAveragePool);
+
+ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
+ Upsample,
+ 1,
+ float,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ NchwcUpsample);
+
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.h b/onnxruntime/contrib_ops/cpu/nchwc_ops.h
index 402f029640..f6639d03f4 100644
--- a/onnxruntime/contrib_ops/cpu/nchwc_ops.h
+++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.h
@@ -15,9 +15,13 @@ namespace contrib {
class ReorderInput : public OpKernel {
public:
ReorderInput(const OpKernelInfo& info) : OpKernel(info) {
+ ORT_ENFORCE(info.GetAttr("channels_last", &channels_last_).IsOK());
}
Status Compute(OpKernelContext* context) const override;
+
+ private:
+ int64_t channels_last_;
};
class ReorderOutput : public OpKernel {
diff --git a/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc b/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc
index dcb4d34541..39d7621eec 100644
--- a/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc
@@ -4,6 +4,7 @@
#include "core/framework/tensorprotoutils.h"
#include "core/graph/constants.h"
#include "core/graph/contrib_ops/contrib_defs.h"
+#include "core/mlas/inc/mlas.h"
namespace ONNX_NAMESPACE {
void convPoolShapeInference(
@@ -58,10 +59,45 @@ void RegisterNchwcSchemas() {
.SetDomain(kMSNchwcDomain)
.SinceVersion(1)
.SetDoc(R"DOC(For internal use.)DOC")
+ .Attr("channels_last", "", AttributeProto::INT, static_cast(0))
.Input(0, "X", "", "T")
.Output(0, "Y", "", "T")
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors")
- .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
+ .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
+ if (!hasNInputShapes(ctx, 1)) {
+ return;
+ }
+
+ const auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
+ auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
+
+ auto input_rank = input_shape.dim_size();
+ if (input_rank < 2) {
+ fail_shape_inference("tensor rank too small");
+ }
+
+ auto channels_last = getAttribute(ctx, "channels_last", 0);
+
+ // Copy the batch dimension.
+ *output_shape->add_dim() = input_shape.dim(0);
+
+ // Block align the channel dimension.
+ const auto& input_channel_dim = input_shape.dim((channels_last == 0) ? 1 : input_rank - 1);
+ auto* output_channel_dim = output_shape->add_dim();
+ if (input_channel_dim.has_dim_value()) {
+ const int64_t channels = input_channel_dim.dim_value();
+ const int64_t nchwc_block_size = static_cast(MlasNchwcGetBlockSize());
+ int64_t nchwc_channels = (channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);
+ output_channel_dim->set_dim_value(nchwc_channels);
+ }
+
+ // Copy the spatial dimensions.
+ int first_spatial_dim = (channels_last == 0) ? 2 : 1;
+ for (int i = 0; i < input_rank - 2; i++) {
+ *output_shape->add_dim() = input_shape.dim(first_spatial_dim + i);
+ }
+ });
ONNX_CONTRIB_OPERATOR_SCHEMA(ReorderOutput)
.SetDomain(kMSNchwcDomain)
@@ -78,8 +114,8 @@ void RegisterNchwcSchemas() {
return;
}
- auto input_shape = ctx.getInputType(0)->tensor_type().shape();
- auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
+ const auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
+ auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
auto input_rank = input_shape.dim_size();
if (input_rank < 2) {
@@ -92,7 +128,7 @@ void RegisterNchwcSchemas() {
fail_shape_inference("invalid channel count");
}
- // Copy batch dimension.
+ // Copy the batch dimension.
*output_shape->add_dim() = input_shape.dim(0);
auto channels_last = getAttribute(ctx, "channels_last", 0);
@@ -100,7 +136,7 @@ void RegisterNchwcSchemas() {
output_shape->add_dim()->set_dim_value(channels);
}
- // Copy spatial dimensions.
+ // Copy the spatial dimensions.
for (int i = 0; i < input_rank - 2; i++) {
*output_shape->add_dim() = input_shape.dim(2 + i);
}
@@ -161,8 +197,8 @@ void RegisterNchwcSchemas() {
return;
}
- auto input_shape = ctx.getInputType(0)->tensor_type().shape();
- auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
+ const auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
+ auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
auto input_rank = input_shape.dim_size();
if (input_rank < 2) {
diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h
index c417bdb17d..b2ce940482 100644
--- a/onnxruntime/core/mlas/inc/mlas.h
+++ b/onnxruntime/core/mlas/inc/mlas.h
@@ -552,12 +552,12 @@ MlasGemm(
MLAS_THREADPOOL* ThreadPool
);
-/**
- * @brief Batched GEMM, for multiplying multiple pairs of matrices.
+/**
+ * @brief Batched GEMM, for multiplying multiple pairs of matrices.
* Note: We only support uniform batching, so shapes and types of the
* input must be same: M, N, K, BIsSigned must be the
- * same across all parameter blocks.
- *
+ * same across all parameter blocks.
+ *
* @param [IN] Shape A single shape descriptor for all the multiplications
* @param [IN] DataParams Array of data descriptors for the matrices.
* @param [IN] BatchN Size of the parameters array, also number of multiplications to perform
@@ -834,10 +834,21 @@ MlasTranspose(
void
MLASCALL
-MlasReorderInput(
- const int64_t* InputShape,
+MlasReorderInputNchw(
const float* S,
- float* D
+ float* D,
+ size_t InputChannels,
+ size_t InputSize
+ );
+
+void
+MLASCALL
+MlasReorderInputNhwc(
+ const float* S,
+ float* D,
+ size_t InputChannels,
+ size_t RowCount,
+ size_t FullRowCount
);
void
diff --git a/onnxruntime/core/mlas/lib/reorder.cpp b/onnxruntime/core/mlas/lib/reorder.cpp
index 821bb0c9a2..0d7fbd97a4 100644
--- a/onnxruntime/core/mlas/lib/reorder.cpp
+++ b/onnxruntime/core/mlas/lib/reorder.cpp
@@ -179,10 +179,11 @@ Return Value:
void
MLASCALL
-MlasReorderInput(
- const int64_t* InputShape,
+MlasReorderInputNchw(
const float* S,
- float* D
+ float* D,
+ size_t InputChannels,
+ size_t InputSize
)
/*++
@@ -192,12 +193,14 @@ Routine Description:
Arguments:
- InputShape - Supplies the shape of the input tensor.
-
S - Supplies the address of the source tensor.
D - Supplies the address of the destination tensor.
+ InputChannels - Supplies the number of NCHW channels.
+
+ InputSize - Supplies the spatial input size of the tensors.
+
Return Value:
None.
@@ -206,11 +209,12 @@ Return Value:
{
const size_t BlockSize = MlasNchwcGetBlockSize();
- const size_t InputChannels = size_t(InputShape[0] * InputShape[1]);
- const size_t InputSize = size_t(InputShape[2]) * size_t(InputShape[3]);
-
const MLAS_FLOAT32X4 ZeroFloat32x4 = MlasZeroFloat32x4();
+ //
+ // Iterate over BlockSize batches of the input channels.
+ //
+
for (size_t i = InputChannels; i > 0;) {
const size_t InputChannelsThisIteration = std::min(i, BlockSize);
@@ -233,7 +237,10 @@ Return Value:
}
for (; bc < BlockSize; bc += 4) {
- MlasStoreFloat32x4(dd, ZeroFloat32x4);
+ MlasStoreFloat32x4(&dd[BlockSize * 0], ZeroFloat32x4);
+ MlasStoreFloat32x4(&dd[BlockSize * 1], ZeroFloat32x4);
+ MlasStoreFloat32x4(&dd[BlockSize * 2], ZeroFloat32x4);
+ MlasStoreFloat32x4(&dd[BlockSize * 3], ZeroFloat32x4);
dd += 4;
}
@@ -267,6 +274,127 @@ Return Value:
}
}
+void
+MLASCALL
+MlasReorderInputNhwc(
+ const float* S,
+ float* D,
+ size_t InputChannels,
+ size_t RowCount,
+ size_t FullRowCount
+ )
+/*++
+
+Routine Description:
+
+ This routine reorders an input buffer from NHWC to NCHWc format.
+
+Arguments:
+
+ S - Supplies the address of the source tensor.
+
+ D - Supplies the address of the destination tensor.
+
+ InputChannels - Supplies the number of NHWC channels.
+
+ RowCount - Supplies the number of NHWC rows to process. This number may be
+ less than FullRowCount to support threaded operation.
+
+ FullRowCount - Supplies the total number of NHWC rows per image.
+
+Return Value:
+
+ None.
+
+--*/
+{
+ const size_t BlockSize = MlasNchwcGetBlockSize();
+
+ //
+ // Iterate over batches of the input size to improve locality.
+ //
+
+ for (size_t OuterRowCountRemaining = RowCount; OuterRowCountRemaining > 0; ) {
+
+ constexpr size_t OuterRowCountBatch = 32;
+
+ const size_t OuterRowCountThisIteration = std::min(OuterRowCountRemaining, OuterRowCountBatch);
+ OuterRowCountRemaining -= OuterRowCountThisIteration;
+
+ //
+ // Iterate over BlockSize batches of the input channels.
+ //
+
+ const float* s = S;
+ float* d = D;
+
+ for (size_t i = InputChannels; i > 0;) {
+
+ const size_t InputChannelsThisIteration = std::min(i, BlockSize);
+ i -= InputChannelsThisIteration;
+
+ const float* ss = s;
+ float* dd = d;
+ size_t InnerRowCountRemaining = OuterRowCountThisIteration;
+
+ if (InputChannelsThisIteration == BlockSize) {
+
+ if (BlockSize == 8) {
+
+ while (InnerRowCountRemaining-- > 0) {
+
+ MLAS_FLOAT32X4 v0 = MlasLoadFloat32x4(&ss[0]);
+ MLAS_FLOAT32X4 v1 = MlasLoadFloat32x4(&ss[4]);
+
+ MlasStoreFloat32x4(&dd[0], v0);
+ MlasStoreFloat32x4(&dd[4], v1);
+
+ ss += InputChannels;
+ dd += 8;
+ }
+
+ } else {
+
+ while (InnerRowCountRemaining-- > 0) {
+
+ MLAS_FLOAT32X4 v0 = MlasLoadFloat32x4(&ss[0]);
+ MLAS_FLOAT32X4 v1 = MlasLoadFloat32x4(&ss[4]);
+ MLAS_FLOAT32X4 v2 = MlasLoadFloat32x4(&ss[8]);
+ MLAS_FLOAT32X4 v3 = MlasLoadFloat32x4(&ss[12]);
+
+ MlasStoreFloat32x4(&dd[0], v0);
+ MlasStoreFloat32x4(&dd[4], v1);
+ MlasStoreFloat32x4(&dd[8], v2);
+ MlasStoreFloat32x4(&dd[12], v3);
+
+ ss += InputChannels;
+ dd += 16;
+ }
+ }
+
+ } else {
+
+ size_t BlockPadding = BlockSize - InputChannelsThisIteration;
+
+ while (InnerRowCountRemaining-- > 0) {
+
+ std::copy_n(ss, InputChannelsThisIteration, dd);
+ std::fill_n(dd + InputChannelsThisIteration, BlockPadding, 0.0f);
+
+ ss += InputChannels;
+ dd += BlockSize;
+ }
+ }
+
+ s += InputChannelsThisIteration;
+ d += BlockSize * FullRowCount;
+ }
+
+ S += InputChannels * OuterRowCountThisIteration;
+ D += BlockSize * OuterRowCountThisIteration;
+ }
+}
+
void
MLASCALL
MlasReorderOutputNchw(
diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc
index 87d931576d..623202b3cb 100644
--- a/onnxruntime/core/optimizer/nchwc_transformer.cc
+++ b/onnxruntime/core/optimizer/nchwc_transformer.cc
@@ -120,8 +120,9 @@ class NchwcTransformerImpl {
void TransformConcat(Node& node);
void TransformActivation(Node& node);
void TransformBatchNormalization(Node& node);
- void TransformTranspose(Node& node);
+ void TransformTransposeToNhwc(Node& node);
void TransformResize(Node& node);
+ void TrackTransposeFromNhwc(Node& node);
Graph& graph_;
@@ -149,6 +150,11 @@ class NchwcTransformerImpl {
// or unsplitting the channels dimension of a tensor.
std::unordered_map reshape_split_;
std::unordered_map reshape_unsplit_;
+
+ // Tracks the last Transpose node and output NodeArg that transposed from
+ // NHWC to NCHW format.
+ Node* transpose_from_nhwc_node_{nullptr};
+ NodeArg* transpose_from_nhwc_output_arg_{nullptr};
};
size_t NchwcTransformerImpl::RemoveOutputEdges(Node& node) {
@@ -209,6 +215,20 @@ void NchwcTransformerImpl::InsertReorderInput(Node& node) {
kMSNchwcDomain);
reorder_input_node.SetExecutionProviderType(kCpuExecutionProvider);
input_defs[0] = input_nchwc_arg;
+
+ // Attempt to fuse the ReorderInput with a previous Transpose of NHWC->NCHW.
+ // If the last known node to transpose from NHWC is the same as this input
+ // argument, then the Transpose node can be removed and the ReorderInput node
+ // is modified to consume a tensor with NHWC layout order. The transpose was
+ // already determined to have a single use and not be a graph output.
+ if (transpose_from_nhwc_output_arg_ == input_original_arg) {
+ reorder_input_node.MutableInputDefs()[0] = transpose_from_nhwc_node_->MutableInputDefs()[0];
+ reorder_input_node.AddAttribute("channels_last", static_cast(1));
+ graph_utils::RemoveNodeOutputEdges(graph_, *transpose_from_nhwc_node_);
+ removed_nodes_.push_front(transpose_from_nhwc_node_->Index());
+ transpose_from_nhwc_node_ = nullptr;
+ }
+
} else {
input_defs[0] = it->second;
}
@@ -324,14 +344,21 @@ void NchwcTransformerImpl::TransformConv(Node& node) {
bool do_reorder_input = true;
bool reorder_filter_OIHWBo = false;
+ int64_t filter_input_channels = input_channels;
+ int64_t nchwc_group_count = group_count;
+
+ // The current implementation of ReorderInput requires the channel count to be
+ // aligned to this value.
+ constexpr int64_t channel_alignment = 4;
if (group_count > 1) {
- if ((output_channels % nchwc_block_size) != 0) {
+ if ((output_channels % channel_alignment) != 0) {
return;
}
if (input_channels == 1 && output_channels == group_count) {
// Depthwise convolution.
reorder_filter_OIHWBo = true;
+ nchwc_group_count = nchwc_output_channels;
} else if (((input_channels % nchwc_block_size) != 0) ||
((output_channels % group_count) != 0) ||
(((output_channels / group_count) % nchwc_block_size) != 0)) {
@@ -342,8 +369,11 @@ void NchwcTransformerImpl::TransformConv(Node& node) {
// Use NCHW input buffer directly.
reorder_filter_OIHWBo = true;
do_reorder_input = false;
- } else if ((input_channels % nchwc_block_size) != 0) {
- return;
+ } else {
+ if ((input_channels % channel_alignment) != 0) {
+ return;
+ }
+ filter_input_channels = (input_channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);
}
}
@@ -374,15 +404,19 @@ void NchwcTransformerImpl::TransformConv(Node& node) {
nchwc_conv_W_arg = filters_it->second;
} else {
Initializer conv_W{*conv_W_tensor_proto, graph_.ModelPath()};
+ const auto& conv_W_dims = conv_W.dims();
- int64_t reordered_filter_vec_size = conv_W.size() / output_channels * nchwc_output_channels;
- std::vector reordered_filter(gsl::narrow(reordered_filter_vec_size));
+ int64_t reordered_filter_size = nchwc_output_channels * filter_input_channels;
+ for (size_t i = 2; i < 4; i++) {
+ reordered_filter_size *= conv_W_dims[i];
+ }
+ std::vector reordered_filter(gsl::narrow(reordered_filter_size));
// Reorder the weights tensor statically.
if (reorder_filter_OIHWBo) {
- MlasReorderFilterOIHWBo(conv_W.dims().data(), conv_W.data(), reordered_filter.data());
+ MlasReorderFilterOIHWBo(conv_W_dims.data(), conv_W.data(), reordered_filter.data());
} else {
- MlasReorderFilterOIHWBiBo(conv_W.dims().data(), conv_W.data(), reordered_filter.data());
+ MlasReorderFilterOIHWBiBo(conv_W_dims.data(), conv_W.data(), reordered_filter.data());
}
ONNX_NAMESPACE::TensorProto nchwc_conv_W_tensor_proto;
@@ -392,8 +426,9 @@ void NchwcTransformerImpl::TransformConv(Node& node) {
nchwc_conv_W_tensor_proto.set_raw_data(reordered_filter.data(), reordered_filter.size() * sizeof(float));
nchwc_conv_W_tensor_proto.add_dims(nchwc_output_channels);
- for (size_t i = 1; i < 4; i++) {
- nchwc_conv_W_tensor_proto.add_dims(conv_W.dims()[i]);
+ nchwc_conv_W_tensor_proto.add_dims(filter_input_channels);
+ for (size_t i = 2; i < 4; i++) {
+ nchwc_conv_W_tensor_proto.add_dims(conv_W_dims[i]);
}
nchwc_conv_W_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_W_tensor_proto);
@@ -436,6 +471,9 @@ void NchwcTransformerImpl::TransformConv(Node& node) {
&node.GetAttributes(),
kMSNchwcDomain);
nchwc_node.SetExecutionProviderType(kCpuExecutionProvider);
+ if (nchwc_group_count != group_count) {
+ nchwc_node.AddAttribute("group", nchwc_group_count);
+ }
nchwc_node.MutableInputDefs()[1] = nchwc_conv_W_arg;
@@ -892,7 +930,7 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) {
removed_nodes_.push_front(node.Index());
}
-void NchwcTransformerImpl::TransformTranspose(Node& node) {
+void NchwcTransformerImpl::TransformTransposeToNhwc(Node& node) {
auto& input_defs = node.MutableInputDefs();
auto& output_defs = node.MutableOutputDefs();
@@ -1070,7 +1108,33 @@ void NchwcTransformerImpl::TransformResize(Node& node) {
removed_nodes_.push_front(node.Index());
}
+void NchwcTransformerImpl::TrackTransposeFromNhwc(Node& node) {
+ const auto* perm_attr = graph_utils::GetNodeAttribute(node, "perm");
+ if (perm_attr == nullptr || perm_attr->ints_size() != 4) {
+ return;
+ }
+
+ // Test if this transposes from NHWC to NCHW layout order.
+ const int64_t* perm_data = perm_attr->ints().data();
+ if (perm_data[0] != 0 || perm_data[1] != 3 || perm_data[2] != 1 || perm_data[3] != 2) {
+ return;
+ }
+
+ // Verify that the node does not produce a graph output and produces output
+ // for a single node.
+ if (!graph_.GetNodeOutputsInGraphOutputs(node).empty() || node.GetOutputEdgesCount() != 1) {
+ return;
+ }
+
+ transpose_from_nhwc_node_ = &node;
+ transpose_from_nhwc_output_arg_ = node.MutableOutputDefs()[0];
+}
+
void NchwcTransformerImpl::Transform(Node& node) {
+ if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13})) {
+ TrackTransposeFromNhwc(node);
+ }
+
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "FusedConv", {1}, kMSDomain)) {
TransformConv(node);
@@ -1097,7 +1161,7 @@ void NchwcTransformerImpl::Transform(Node& node) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9})) {
TransformBatchNormalization(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13})) {
- TransformTranspose(node);
+ TransformTransposeToNhwc(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Upsample", {9, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Resize", {10, 11, 13})) {
TransformResize(node);
diff --git a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h
index 5741be2d2a..439ebb027e 100644
--- a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h
+++ b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h
@@ -105,7 +105,7 @@ class MlasNchwcConv2DTest : public MlasConv2DTest {
if (DoReorderInput) {
size_t NchwcInputElements = BatchCount * NchwcInputChannels * InputHeight * InputWidth;
float* NchwcInput = BufferNchwcInput.GetBuffer(NchwcInputElements);
- MlasReorderInput(InputShape, Input, NchwcInput);
+ ReorderInputNchw(InputShape, Input, NchwcInput);
Input = NchwcInput;
InputShape[1] = NchwcInputChannels;
}
diff --git a/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.h b/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.h
index 38efcd7deb..10e3f7f927 100644
--- a/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.h
+++ b/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.h
@@ -36,7 +36,7 @@ class MlasNchwcPool2DTest : public MlasPool2DTest {
size_t NchwcOutputElements = size_t(NchwcOutputShape[0]) * size_t(NchwcOutputShape[1]) * size_t(NchwcOutputShape[2]) * size_t(NchwcOutputShape[3]);
float* NchwcOutput = BufferNchwcOutput.GetBuffer(NchwcOutputElements);
- MlasReorderInput(InputShape, Input, NchwcInput);
+ ReorderInputNchw(InputShape, Input, NchwcInput);
MlasNchwcPool(PoolingKind,
NchwcInputShape,
diff --git a/onnxruntime/test/mlas/unittest/test_util.h b/onnxruntime/test/mlas/unittest/test_util.h
index 9c72409e36..45f8c33f66 100644
--- a/onnxruntime/test/mlas/unittest/test_util.h
+++ b/onnxruntime/test/mlas/unittest/test_util.h
@@ -211,7 +211,7 @@ class MlasLongExecuteTests : public MlasTestFixture {
}
};
-// Some short Execute may not need to distinguish each parameters,
+// Some short Execute may not need to distinguish each parameters,
// because they finish quickly, and may disturb others by inject too many small tests.
// Register it as whole using following helper.
template
@@ -237,3 +237,16 @@ class MlasDirectShortExecuteTests : public MlasTestFixture {
}
};
+inline
+void ReorderInputNchw(const int64_t* input_shape, const float* S, float* D) {
+ const int64_t nchwc_block_size = static_cast(MlasNchwcGetBlockSize());
+ int64_t batch_count = input_shape[0];
+ int64_t channel_count = input_shape[1];
+ int64_t nchwc_channel_count = (channel_count + nchwc_block_size - 1) & ~(nchwc_block_size - 1);
+ int64_t spatial_count = input_shape[2] * input_shape[3];
+ for (int64_t n = 0; n < batch_count; n++) {
+ MlasReorderInputNchw(S, D, static_cast(channel_count), static_cast(spatial_count));
+ S += spatial_count * channel_count;
+ D += spatial_count * nchwc_channel_count;
+ }
+}
diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc
index d1b495e536..9f357d57bd 100644
--- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc
+++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc
@@ -970,10 +970,10 @@ TEST(NchwcOptimizerTests, MixedOutputUsage) {
TEST(NchwcOptimizerTests, TensorAlignment) {
auto build_test_case = [&](NchwcTestHelper& helper) {
- // Input channel count must currently be a multiple of the NCHWc block size.
- auto* input1_arg = helper.MakeInput({1, 60, 28, 42});
+ // Input channel count must currently be a multiple of 4.
+ auto* input1_arg = helper.MakeInput({1, 62, 28, 42});
auto* output1_arg = helper.MakeOutput();
- helper.AddConvNode(input1_arg, output1_arg, {128, 60, 1, 1});
+ helper.AddConvNode(input1_arg, output1_arg, {128, 62, 1, 1});
// Grouped input channel count must be a multiple of the NCHWc block size.
auto* input2_arg = helper.MakeInput({1, 48, 28, 42});
@@ -1116,6 +1116,34 @@ TEST(NchwcOptimizerTests, BatchNormalization) {
test_case(true);
}
+TEST(NchwcOptimizerTests, ConvReorderInputNhwc) {
+ auto test_case = [&](int64_t channels) {
+ auto build_test_case = [&](NchwcTestHelper& helper) {
+ auto* input_arg = helper.MakeInput({5, 27, 29, channels});
+ auto* transpose_output_arg = helper.MakeIntermediate();
+ auto* output_arg = helper.MakeOutput();
+
+ helper.AddTransposeToNchwNode(input_arg, transpose_output_arg);
+ helper.AddConvNode(transpose_output_arg, output_arg, {34, channels, 1, 1});
+ };
+
+ auto check_nchwc_graph = [&](InferenceSessionWrapper& session) {
+ auto op_to_count = CountOpsInGraph(session.GetGraph());
+ EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 1);
+ EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1);
+ EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1);
+ EXPECT_EQ(op_to_count["Transpose"], 0);
+ };
+
+ // Verify that a NHWC->NCHW transpose is fused into ReorderInput.
+ NchwcOptimizerTester(build_test_case, check_nchwc_graph);
+ };
+
+ for (int64_t channels = 16; channels <= 32; channels += 4) {
+ test_case(channels);
+ }
+}
+
TEST(NchwcOptimizerTests, ConvReorderOutputNhwc) {
auto build_test_case = [&](NchwcTestHelper& helper) {
auto* input_arg = helper.MakeInput({1, 64, 28, 32});