diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 8a5117bcf1..ed70ac5971 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -51,6 +51,7 @@ Do not modify directly.*
* com.microsoft.MurmurHash3
* com.microsoft.NGramRepeatBlock
* com.microsoft.NhwcConv
+ * com.microsoft.NhwcFusedConv
* com.microsoft.NhwcMaxPool
* com.microsoft.PackedAttention
* com.microsoft.Pad
@@ -2637,6 +2638,64 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.NhwcFusedConv**
+
+ NhwcFusedConv is a Conv operator with optional activation and add operators fused in.
+ Only has fp16 implementation as of 2023/04/15.
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- activation : string
+
+- activation_params : list of floats
+
+- auto_pad : string
+
+- dilations : list of ints
+
+- group : int
+
+- kernel_shape : list of ints
+
+- pads : list of ints
+
+- strides : list of ints
+
+
+
+#### Inputs (2 - 4)
+
+
+- X : T
+
+- W : T
+
+- B (optional) : T
+
+- Z (optional) : T
+- Tensor to be added to the output, must be the same shape and format as the output tensor.
+
+
+#### Outputs
+
+
+- Y : T
+
+
+
+#### Type Constraints
+
+
+- T : tensor(float16)
+- Constrain input and output types to float tensors
+
+
+
### **com.microsoft.NhwcMaxPool**
#### Version
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index a44cf60f9e..bffe5b9818 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -16,9 +16,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv);
-#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv);
-#endif
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
@@ -81,6 +78,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
// ******** End: Quantization ******************* //
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, NhwcFusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16, MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16, AveragePool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSInternalNHWCDomain, 1, MLFloat16, GlobalAveragePool);
@@ -159,6 +157,7 @@ Status RegisterNchwcKernels(KernelRegistry& kernel_registry) {
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
Status RegisterFp16Kernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -232,9 +231,6 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
-#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
- BuildKernelCreateInfo,
-#endif
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h
index 772f93ab1b..210c0d5aa2 100644
--- a/onnxruntime/core/graph/contrib_ops/ms_opset.h
+++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h
@@ -13,6 +13,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearGlobalAveragePool
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearAveragePool);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearConv);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NhwcConv);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NhwcFusedConv);
// Quantization ops
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DequantizeLinear);
@@ -110,6 +111,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
+ fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
diff --git a/onnxruntime/core/graph/contrib_ops/nhwc_schema_defs.cc b/onnxruntime/core/graph/contrib_ops/nhwc_schema_defs.cc
index 6d59324619..8fe3a4d5f3 100644
--- a/onnxruntime/core/graph/contrib_ops/nhwc_schema_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/nhwc_schema_defs.cc
@@ -383,5 +383,31 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
NhwcConv,
1,
OpSchema().FillUsing(ConvOpSchemaGenerator()));
+
+ONNX_MS_OPERATOR_SET_SCHEMA(NhwcFusedConv, 1,
+ OpSchema()
+ .SetDoc(R"DOC(
+NhwcFusedConv is a Conv operator with optional activation and add operators fused in.
+Only has fp16 implementation as of 2023/04/15.
+)DOC")
+ .Attr("auto_pad", "", AttributeProto::STRING, std::string("NOTSET"))
+ .Attr("kernel_shape", "", AttributeProto::INTS, OPTIONAL_VALUE)
+ .Attr("dilations", "", AttributeProto::INTS, OPTIONAL_VALUE)
+ .Attr("strides", "", AttributeProto::INTS, OPTIONAL_VALUE)
+ .Attr("pads", "", AttributeProto::INTS, OPTIONAL_VALUE)
+ .Attr("group", "", AttributeProto::INT, static_cast(1))
+ .Attr("activation", "", AttributeProto::STRING, OPTIONAL_VALUE)
+ .Attr("activation_params", "", AttributeProto::FLOATS, OPTIONAL_VALUE)
+ .Input(0, "X", "", "T")
+ .Input(1, "W", "", "T")
+ .Input(2, "B", "", "T", OpSchema::Optional)
+ .Input(3, "Z", "Tensor to be added to the output, must be the same shape and format as the output tensor.", "T", OpSchema::Optional)
+ .Output(0, "Y", "", "T")
+ .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output types to float tensors")
+ .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
+ ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
+ convPoolShapeInferenceNhwc(ctx, true, false, 0, 1);
+ }));
+
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h
index a7cffa384e..d44e0ab155 100644
--- a/onnxruntime/core/mlas/inc/mlas.h
+++ b/onnxruntime/core/mlas/inc/mlas.h
@@ -1437,18 +1437,20 @@ public:
};
/**
- * @brief Half precision activation functions
+ * @brief Half precision activation functions, with optional sum tensor.
+ * Supplied sum tensor must be the same layout as the GEMM output tensor.
+ * And the supplied sum tensor will be added to the final result.
*/
-class MLAS_HALF_GEMM_ACTIVATION_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR {
-public:
+class MLAS_HALF_GEMM_ACTIVATION_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR
+{
+ public:
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR(
- const MLAS_ACTIVATION& Activation
- ) :
- Activation_(Activation)
+ const MLAS_ACTIVATION& Activation,
+ const MLAS_FP16* SumBuf = nullptr)
+ : Activation_(Activation), SumBuf_(SumBuf)
{}
- void
- Process(
+ void Process(
MLAS_FP16* C,
size_t StartM,
size_t StartN,
@@ -1457,8 +1459,9 @@ public:
size_t ldc
) const override;
-private:
+ private:
const MLAS_ACTIVATION& Activation_;
+ const MLAS_FP16* SumBuf_;
};
inline
diff --git a/onnxruntime/core/mlas/lib/activate_fp16.cpp b/onnxruntime/core/mlas/lib/activate_fp16.cpp
index d920a5fa47..f62ffe0db9 100644
--- a/onnxruntime/core/mlas/lib/activate_fp16.cpp
+++ b/onnxruntime/core/mlas/lib/activate_fp16.cpp
@@ -25,6 +25,20 @@ Abstract:
template
struct MLAS_HALF_ACTIVATION_FUNCTION;
+template <>
+struct MLAS_HALF_ACTIVATION_FUNCTION {
+ MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation)
+ {
+ MLAS_UNREFERENCED_PARAMETER(Activation);
+ }
+
+ MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) { return Value; }
+
+ MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) { return Value; }
+
+ float Activate(float Value) { return Value; }
+};
+
template<>
struct MLAS_HALF_ACTIVATION_FUNCTION
{
@@ -582,7 +596,7 @@ inline
void
MlasActivationKernel(
const MLAS_ACTIVATION& Activation,
- MLAS_FP16* Buffer,
+ _mlas_fp16_* Buffer,
size_t StartM,
size_t StartN,
size_t CountM,
@@ -592,8 +606,7 @@ MlasActivationKernel(
{
MLAS_HALF_ACTIVATION_FUNCTION ActivationFunction(Activation);
- auto* CRow = reinterpret_cast<_mlas_fp16_*>(Buffer);
- CRow += StartM * ldc + StartN;
+ auto* CRow = Buffer + StartM * ldc + StartN;
while (CountM-- > 0) {
_mlas_fp16_* buffer = CRow;
@@ -629,7 +642,7 @@ inline
void
MlasActivationKernel(
const MLAS_ACTIVATION& Activation,
- MLAS_FP16* Buffer,
+ _mlas_fp16_* Buffer,
size_t StartM,
size_t StartN,
size_t CountM,
@@ -651,6 +664,68 @@ MlasActivationKernel(
}
+template
+MLAS_FORCEINLINE
+void
+MlasActivationKernel(
+ const MLAS_ACTIVATION& Activation,
+ _mlas_fp16_* Buffer,
+ const _mlas_fp16_* Addon,
+ size_t StartM,
+ size_t StartN,
+ size_t CountM,
+ size_t CountN,
+ size_t ldc
+ )
+{
+ MLAS_HALF_ACTIVATION_FUNCTION ActivationFunction(Activation);
+
+ auto* CRow = Buffer + StartM * ldc + StartN;
+ const auto* ARow = Addon + StartM * ldc + StartN;
+
+ while (CountM-- > 0) {
+ auto* buffer = CRow;
+ const auto* addsrc = ARow;
+ size_t n = CountN;
+
+ while (n >= 8) {
+ MLAS_FLOAT16X8 Vector = MlasLoadFloat16x8(buffer);
+ MLAS_FLOAT16X8 AVec = MlasLoadFloat16x8(addsrc);
+ addsrc += 8;
+ Vector = ActivationFunction.Activate(Vector);
+ Vector = MlasAddFloat16x8(Vector, AVec);
+ MlasStoreFloat16x8(buffer, Vector);
+ buffer += 8;
+ n -= 8;
+ }
+
+ if (n >= 4) {
+ MLAS_FLOAT16X4 Vector = MlasLoadFloat16x4(buffer);
+ MLAS_FLOAT16X4 AVec = MlasLoadFloat16x4(addsrc);
+ addsrc += 4;
+ Vector = ActivationFunction.Activate(Vector);
+ Vector = MlasAddFloat16x4(Vector, AVec);
+ MlasStoreFloat16x4(buffer, Vector);
+ buffer += 4;
+ n -= 4;
+ }
+
+ if (n > 0) {
+ MLAS_FLOAT16X4 buf;
+ std::memcpy(&buf, buffer, n * sizeof(_mlas_fp16_));
+ MLAS_FLOAT16X4 addbuf;
+ std::memcpy(&addbuf, addsrc, n * sizeof(_mlas_fp16_));
+ MLAS_FLOAT16X4 res = ActivationFunction.Activate(buf);
+ res = MlasAddFloat16x4(res, addbuf);
+ MlasStorePartialFloat16x4(buffer, res, n);
+ }
+
+ CRow += ldc;
+ ARow += ldc;
+ }
+}
+
+
void
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process(
MLAS_FP16* C,
@@ -661,46 +736,89 @@ MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process(
size_t ldc
) const
{
+ auto* Buffer = reinterpret_cast<_mlas_fp16_*>(C);
switch (Activation_.ActivationKind) {
case MlasIdentityActivation: {
- MlasActivationKernel(Activation_, C, StartM, StartN, CountM,
- CountN, ldc);
+ if (SumBuf_) {
+ MlasActivationKernel(
+ Activation_, Buffer, reinterpret_cast(SumBuf_), StartM,
+ StartN, CountM, CountN, ldc);
+ } else {
+ MlasActivationKernel(Activation_, Buffer, StartM, StartN,
+ CountM, CountN, ldc);
+ }
break;
}
case MlasReluActivation: {
- MlasActivationKernel(Activation_, C, StartM, StartN, CountM, CountN,
- ldc);
+ if (SumBuf_) {
+ MlasActivationKernel(
+ Activation_, Buffer, reinterpret_cast(SumBuf_), StartM,
+ StartN, CountM, CountN, ldc);
+ } else {
+ MlasActivationKernel(Activation_, Buffer, StartM, StartN,
+ CountM, CountN, ldc);
+ }
break;
}
case MlasLeakyReluActivation: {
- MlasActivationKernel(Activation_, C, StartM, StartN, CountM,
- CountN, ldc);
+ if (SumBuf_) {
+ MlasActivationKernel(
+ Activation_, Buffer, reinterpret_cast(SumBuf_), StartM,
+ StartN, CountM, CountN, ldc);
+ } else {
+ MlasActivationKernel(Activation_, Buffer, StartM, StartN,
+ CountM, CountN, ldc);
+ }
break;
}
case MlasTanhActivation: {
- MlasActivationKernel(Activation_, C, StartM, StartN, CountM, CountN,
- ldc);
+ if (SumBuf_) {
+ MlasActivationKernel(
+ Activation_, Buffer, reinterpret_cast(SumBuf_), StartM,
+ StartN, CountM, CountN, ldc);
+ } else {
+ MlasActivationKernel(Activation_, Buffer, StartM, StartN,
+ CountM, CountN, ldc);
+ }
break;
}
case MlasLogisticActivation: {
- MlasActivationKernel(Activation_, C, StartM, StartN, CountM,
- CountN, ldc);
+ if (SumBuf_) {
+ MlasActivationKernel(
+ Activation_, Buffer, reinterpret_cast(SumBuf_), StartM,
+ StartN, CountM, CountN, ldc);
+ } else {
+ MlasActivationKernel(Activation_, Buffer, StartM, StartN,
+ CountM, CountN, ldc);
+ }
break;
}
case MlasClipActivation: {
- MlasActivationKernel(Activation_, C, StartM, StartN, CountM, CountN,
- ldc);
+ if (SumBuf_) {
+ MlasActivationKernel(
+ Activation_, Buffer, reinterpret_cast(SumBuf_), StartM,
+ StartN, CountM, CountN, ldc);
+ } else {
+ MlasActivationKernel(Activation_, Buffer, StartM, StartN,
+ CountM, CountN, ldc);
+ }
break;
}
case MlasHardSigmoidActivation: {
- MlasActivationKernel(Activation_, C, StartM, StartN, CountM,
- CountN, ldc);
+ if (SumBuf_) {
+ MlasActivationKernel(
+ Activation_, Buffer, reinterpret_cast(SumBuf_), StartM,
+ StartN, CountM, CountN, ldc);
+ } else {
+ MlasActivationKernel(Activation_, Buffer, StartM, StartN,
+ CountM, CountN, ldc);
+ }
break;
}
@@ -744,10 +862,20 @@ MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process(
proc.Process(C, StartM, StartN, CountM, CountN, ldc);
_mlas_fp16_* Output = reinterpret_cast<_mlas_fp16_*>(C);
- const auto* CRow = buffer.data();
+ auto* CRow = buffer.data();
+ const _mlas_fp16_* CAdd = nullptr;
+ if (SumBuf_) {
+ CAdd = reinterpret_cast(SumBuf_) + StartM * ldc + StartN;
+ }
Output += StartM * ldc + StartN;
while (CountM-- > 0) {
+ if (CAdd) {
+ for (size_t n = 0; n < CountN; n++) {
+ CRow[n] += MLAS_Half2Float(CAdd[n]);
+ }
+ CAdd += ldc;
+ }
CvtFloat2Half(Output, CRow, CountN);
CRow += CountN;
Output += ldc;
diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc
index 78820bdf63..596f67919d 100644
--- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc
+++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc
@@ -32,18 +32,20 @@ using ConvPadVector = ConvAttributes::ConvPadVector;
* 2. Activation
* It takes an operator attribute 'activation', which supplies the activation info.
*
- * Add is performed BEFORE activation.
+ * Add is performed AFTER activation.
*
- * The implementation runs faster with NHWC. By default, it converts NCHW to NHWC
- * before processing, and convert the result back. It can take NHWC tensors directly.
- * Use operator attribute 'channels_last' to specify that the data layout is NHWC.
+ * The implementation supports both NCHW and NHWC. It runs faster with NHWC.
+ *
+ * Currently this class implement 3 operators: onnx.Conv, ms.FusedConv and ms.NhwcFusedConv
+ * In the constructor, if we see the operator name is NhwcFusedConv, we assume the
+ * input layout to be NHWC, otherwise we assume layout is NCHW.
*
*/
class FusedConvFp16 final : public OpKernel {
public:
FusedConvFp16(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
- channels_last_ = (info.GetAttrOrDefault("channels_last", static_cast(0)) != 0);
+ channels_last_ = (info.GetKernelDef().OpName() == "NhwcFusedConv");
}
Status Compute(OpKernelContext* context) const override;
@@ -55,6 +57,9 @@ class FusedConvFp16 final : public OpKernel {
int input_idx,
/*out*/ bool& used_shared_buffers) override;
+ protected:
+ bool channels_last_{false};
+
private:
/**
@@ -89,7 +94,6 @@ class FusedConvFp16 final : public OpKernel {
MLAS_ACTIVATION activation_;
ConvAttributes conv_attrs_;
- bool channels_last_{false};
TensorShape W_shape_;
BufferUniquePtr packed_W_buffer_;
size_t packed_W_size_{0};
@@ -236,10 +240,8 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
const auto& W_shape = W ? W->Shape() : W_shape_;
const Tensor* B = num_inputs >= 3 ? context->Input(2) : nullptr;
- // TODO!!
- // This tensor should be added to the result before activation is applied
- // We need to augment the post processor to accept an addition operation.
- // const Tensor* Sum = num_inputs >= 4 ? context->Input(3) : nullptr;
+ // This tensor should be added to the result AFTER activation is applied
+ const Tensor* Sum = num_inputs >= 4 ? context->Input(3) : nullptr;
const int64_t N = X->Shape()[0];
const int64_t M = W_shape[0];
@@ -282,6 +284,13 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
if (Y->Shape().Size() == 0) {
return Status::OK();
}
+ if (Sum) {
+ if (Sum->Shape() != Y->Shape()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Z shape does not match output shape.",
+ " Z: ", Sum->Shape().ToString().c_str(),
+ " Output: ", Y->Shape().ToString().c_str());
+ }
+ }
const int64_t input_image_size = input_shape.Size();
const int64_t output_image_size = output_shape.Size();
@@ -332,6 +341,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
const auto* Xdata = X->Data();
const auto* Bdata = B != nullptr ? B->Data() : nullptr;
auto* Ydata = Y->MutableData();
+ const auto* SumData = Sum != nullptr ? Sum->Data() : nullptr;
BufferUniquePtr transpose_input_buffer;
BufferUniquePtr transpose_output_buffer;
@@ -403,6 +413,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
for (int64_t image_id = 0; image_id < N; ++image_id) {
const auto* input_data = Xdata;
auto* output_data = Ydata;
+ const auto* add_src = SumData;
if (!channels_last_) {
// Transpose the input from channels first (CHW) to channels last (HWC).
@@ -413,6 +424,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
static_cast(input_image_size));
input_data = static_cast(transpose_input_buffer.get());
output_data = static_cast(transpose_output_buffer.get());
+ add_src = nullptr;
}
// Threaded implementation of ND convolution is not yet supported, so
@@ -459,10 +471,10 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
}
auto* worker_output = output_data + output_start * M;
+ const auto* worker_addsrc = add_src == nullptr ? nullptr : add_src + output_start * M;
if (is_depthwise_conv) {
- // TODO!! add Sum tensor to activation
- MLAS_HALF_GEMM_ACTIVATION_PROCESSOR act(activation_);
+ MLAS_HALF_GEMM_ACTIVATION_PROCESSOR act(activation_, worker_addsrc);
MlasConvDepthwise(
worker_indirection_buffer,
reordered_W,
@@ -531,8 +543,8 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
lda = static_cast(C);
}
- // TODO!! add Sum tensor to activation
- MLAS_HALF_GEMM_ACTIVATION_PROCESSOR act(activation_);
+ const auto* gemm_add = add_src == nullptr ? nullptr : worker_addsrc + group_id * group_output_channels;
+ MLAS_HALF_GEMM_ACTIVATION_PROCESSOR act(activation_, gemm_add);
MLAS_HALF_GEMM_DATA_PARAMS gemm_params;
gemm_params.A = AData;
gemm_params.lda = lda;
@@ -567,15 +579,29 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
Ydata,
static_cast(output_image_size),
static_cast(M));
+ if (SumData != nullptr) {
+ MLAS_ACTIVATION activation;
+ activation.ActivationKind = MlasIdentityActivation;
+ MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(activation, SumData);
+ proc.Process(Ydata, 0, 0, static_cast(M),
+ static_cast(output_image_size),
+ static_cast(output_image_size));
+ }
}
Xdata += X_offset;
Ydata += Y_offset;
+ if (SumData != nullptr) {
+ SumData += Y_offset;
+ }
}
return Status::OK();
}
+//
+// Operator definitions
+//
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Conv,
@@ -586,14 +612,29 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL(
#ifndef DISABLE_CONTRIB_OPS
+
namespace contrib {
- ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
- FusedConv,
- 1,
- MLFloat16,
- KernelDefBuilder()
- .TypeConstraint("T", DataTypeImpl::GetTensorType()),
- FusedConvFp16);
+
+ONNX_OPERATOR_TYPED_KERNEL_EX(
+ NhwcFusedConv,
+ kMSDomain,
+ 1,
+ MLFloat16,
+ kCpuExecutionProvider,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ FusedConvFp16);
+
+ONNX_OPERATOR_TYPED_KERNEL_EX(
+ FusedConv,
+ kMSDomain,
+ 1,
+ MLFloat16,
+ kCpuExecutionProvider,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ FusedConvFp16);
+
} // namespace contrib
#endif
diff --git a/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp b/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp
index 6523636225..eb09ed8d30 100644
--- a/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp
+++ b/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp
@@ -5,6 +5,20 @@
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
+bool check_equal(float actual, float expected) {
+ if (std::isnan(actual)) {
+ return std::isnan(expected);
+ } else {
+ float diff = std::abs(actual - expected);
+ float top = std::max(std::abs(actual), std::abs(expected));
+ float ratio = 0;
+ if (top > 0.0001) {
+ ratio = diff / top;
+ }
+ return ratio < 0.005;
+ }
+}
+
class MlasFp16ActivationTest : public MlasTestBase {
public:
static const char* GetTestSuiteName() {
@@ -44,14 +58,25 @@ class MlasFp16ActivationTest : public MlasTestBase {
constexpr size_t M = 5;
constexpr size_t N = 23;
+ constexpr float MinimumFillValue = -11.0f;
MatrixGuardBuffer HalfBuffer1;
auto* testData1 = HalfBuffer1.GetBuffer(M * N, true);
MatrixGuardBuffer HalfBuffer2;
auto* testData2 = HalfBuffer2.GetBuffer(M * N, true);
+ MatrixGuardBuffer HalfBuffer3;
+ auto* testData3 = HalfBuffer3.GetBuffer(M * N, true);
+ MatrixGuardBuffer AddonBuffer;
+ auto addonData = AddonBuffer.GetBuffer(M * N, true);
MatrixGuardBuffer FloatBuffer;
auto* fpBuffer = FloatBuffer.GetBuffer(M * N, true);
+ size_t o = 3;
+ for (size_t i = 0; i < M * N; i++) {
+ o = (o + 19) % 23;
+ addonData[i] = (MinimumFillValue + o) / 16.0f;
+ }
+
MLAS_ACTIVATION_KIND acts[] = {
MlasIdentityActivation,
MlasReluActivation,
@@ -62,8 +87,9 @@ class MlasFp16ActivationTest : public MlasTestBase {
MlasHardSigmoidActivation};
MLAS_ACTIVATION Activation;
- MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(Activation);
+ MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(Activation, nullptr);
MLAS_HALF_GEMM_2FLOAT_PROCESSOR converter(Activation, fpBuffer, N);
+ MLAS_HALF_GEMM_ACTIVATION_PROCESSOR addon(Activation, reinterpret_cast(addonData));
for (auto kind : acts) {
Activation.ActivationKind = MLAS_ACTIVATION_KIND(kind);
@@ -84,37 +110,32 @@ class MlasFp16ActivationTest : public MlasTestBase {
for (size_t i = 0; i < _countof(TestData); i++) {
testData1[i] = TestData[i].f;
testData2[i] = TestData[i].f;
+ testData3[i] = TestData[i].f;
}
- constexpr float MinimumFillValue = -11.0f;
size_t offset = 7;
for (size_t i = _countof(TestData); i < M * N; i++) {
offset = (offset + 19) % 23;
testData1[i] = (MinimumFillValue + offset) / 16.0f;
testData2[i] = testData1[i];
+ testData3[i] = testData1[i];
}
proc.Process(reinterpret_cast(testData1), 0, 0, M, N, N);
converter.Process(reinterpret_cast(testData2), 0, 0, M, N, N);
+ addon.Process(reinterpret_cast(testData3), 0, 0, M, N, N);
for (size_t i = 0; i < M*N; i++) {
float actual = testData1[i].ToFloat();
- if (std::isnan(actual)) {
- EXPECT_TRUE(std::isnan(fpBuffer[i]))
+ EXPECT_TRUE(check_equal(actual, fpBuffer[i]))
<< ", Vector Activation Kind:" << (int)kind << ", i=" << i << ", value:"
<< std::setw(8) << std::setfill('0') << std::hex << actual << ", expecting:"
<< std::setw(8) << std::setfill('0') << std::hex << fpBuffer[i];
- } else {
- float diff = std::abs(actual - fpBuffer[i]);
- float top = std::max(std::abs(actual), std::abs(fpBuffer[i]));
- float ratio = 0;
- if (top > 0.0001) {
- ratio = diff / top;
- }
- EXPECT_TRUE(ratio < 0.005)
- << ", Vector Activation Kind:" << (int)kind << ", i=" << i << ", value:"
- << actual << ", expecting:" << fpBuffer[i];
- }
+ float addonActual = testData3[i].ToFloat() - addonData[i].ToFloat();
+ EXPECT_TRUE(check_equal(addonActual, fpBuffer[i]))
+ << ", Vector + Activation Kind:" << (int)kind << ", i=" << i << ", value:"
+ << std::setw(8) << std::setfill('0') << std::hex << actual << ", expecting:"
+ << std::setw(8) << std::setfill('0') << std::hex << fpBuffer[i];
}
}
}
diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc
index 46312292ad..d210d67cb2 100644
--- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc
+++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc
@@ -39,7 +39,7 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes,
int opset = 11) {
std::unique_ptr tester;
if (!attributes.activation.empty()) {
- tester = std::make_unique("FusedConv", 1, onnxruntime::kMSDomain);
+ tester = std::make_unique("NhwcFusedConv", 1, onnxruntime::kMSDomain);
tester->AddAttribute("activation", attributes.activation);
if (!attributes.activation_parameters.empty()) {
@@ -68,12 +68,14 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes,
}
- ORT_ENFORCE(inputs.size() <= 3, "Our name array is only setup to handle 3 inputs");
- const char* szNames[] = {"X", "W", "B"};
+ ORT_ENFORCE(inputs.size() <= 4, "Our name array is only setup to handle 4 inputs");
+ const char* szNames[] = {"X", "W", "B", "Z"};
tester->AddInput(szNames[0], input_shapes[0], inputs[0]);
tester->AddInput(szNames[1], input_shapes[1], inputs[1], weight_is_initializer);
- if (inputs.size() == 3)
+ if (inputs.size() >= 3)
tester->AddInput(szNames[2], input_shapes[2], inputs[2]);
+ if (inputs.size() >= 4)
+ tester->AddInput(szNames[3], input_shapes[3], inputs[3]);
tester->AddOutput("Y", expected_output_shape, expected_output, /*no sort*/ false, 0.002f, 0.0f);
@@ -506,86 +508,6 @@ TEST(ConvFp16Test, Conv2D_AutoPad2) {
TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true);
}
-#ifndef DISABLE_CONTRIB_OPS
-TEST(ConvFp16Test, Conv2D_HardSigmoid) {
- ConvOpAndTestAttributes attrs = {
- "", // auto_pad
- vector{1, 1}, // dilations
- 1, // group
- vector{2, 2}, // kernel_shape
- vector{0, 0, 0, 0}, // pads
- vector{1, 1}, // strides
- {}, // excluded EPs
- "HardSigmoid", // activation
- vector{0.2f, 0.5f} // activation_parameters
- };
-
- vector X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f),
- MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f),
- MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f)};
- vector X_shape = {1, 1, 3, 3};
- vector W = {MLFloat16(0.125f), MLFloat16(0.125f), MLFloat16(0.125f), MLFloat16(0.125f),
- MLFloat16(-0.125f), MLFloat16(-0.125f), MLFloat16(-0.125f), MLFloat16(-0.125f)};
- vector W_shape = {2, 1, 2, 2};
- vector Y_shape = {1, 2, 2, 2};
- auto expected_vals = {MLFloat16(0.8f), MLFloat16(0.9f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(0.2f), MLFloat16(0.1f), MLFloat16(0.0f), MLFloat16(0.0f)};
- TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
-}
-
-TEST(ConvFp16Test, Conv2D_Relu) {
- ConvOpAndTestAttributes attrs = {
- "", // auto_pad
- vector{1, 1}, // dilations
- 1, // group
- vector{2, 2}, // kernel_shape
- vector{0, 0, 0, 0}, // pads
- vector{1, 1}, // strides
- {}, // excluded EPs
- "Relu" // activation
- };
-
- vector X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f),
- MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f),
- MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f)};
- vector X_shape = {1, 1, 3, 3};
- vector W = {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f),
- MLFloat16(-1.0f), MLFloat16(-1.0f), MLFloat16(-1.0f), MLFloat16(-1.0f)};
- vector W_shape = {2, 1, 2, 2};
- vector Y_shape = {1, 2, 2, 2};
- auto expected_vals = {MLFloat16(12.0f), MLFloat16(16.0f), MLFloat16(24.0f), MLFloat16(28.0f),
- MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f)};
- TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
-}
-
-TEST(ConvFp16Test, Conv2D_Bias_Relu) {
- ConvOpAndTestAttributes attrs = {
- "", // auto_pad
- vector{1, 1}, // dilations
- 1, // group
- vector{2, 2}, // kernel_shape
- vector{0, 0, 0, 0}, // pads
- vector{1, 1}, // strides
- {}, // excluded EPs
- "Relu" // activation
- };
-
- vector X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f),
- MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f),
- MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f)};
- vector X_shape = {1, 1, 3, 3};
- vector W = {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f),
- MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f)};
- vector W_shape = {2, 1, 2, 2};
- vector Y_shape = {1, 2, 2, 2};
- vector B = {MLFloat16(1.0f), MLFloat16(-1.0f)};
- vector B_shape = {2};
- auto expected_vals = {MLFloat16(13.0f), MLFloat16(17.0f), MLFloat16(25.0f), MLFloat16(29.0f),
- MLFloat16(11.0f), MLFloat16(15.0f), MLFloat16(23.0f), MLFloat16(27.0f)};
- TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape);
-}
-#endif // CONTRIB_OPS
-
-
TEST(ConvFp16Test, Conv3D_1) {
ConvOpAndTestAttributes attrs = {
"", // auto_pad
@@ -978,31 +900,93 @@ TEST(ConvFp16Test, Pointwise_Relu) {
};
vector X = {
- MLFloat16(-9.f), MLFloat16(1.f), MLFloat16(2.f),
- MLFloat16(-5.f), MLFloat16(3.f), MLFloat16(-2.f),
- MLFloat16(5.f), MLFloat16(-3.f), MLFloat16(1.f),
- MLFloat16(1.f), MLFloat16(8.f), MLFloat16(-4.f),
- MLFloat16(-1.f), MLFloat16(6.f), MLFloat16(7.f),
- MLFloat16(-1.f), MLFloat16(4.f), MLFloat16(-5.f),
- MLFloat16(-9.f), MLFloat16(1.f), MLFloat16(2.f),
- MLFloat16(-5.f), MLFloat16(3.f), MLFloat16(-2.f),
- MLFloat16(5.f), MLFloat16(-3.f), MLFloat16(1.f)};
+ MLFloat16(-9.f), MLFloat16(1.f), MLFloat16(-9.f),
+ MLFloat16(1.f), MLFloat16(8.f), MLFloat16(1.f),
+ MLFloat16(2.f), MLFloat16(-4.f), MLFloat16(2.f),
+ MLFloat16(-5.f), MLFloat16(-1.f), MLFloat16(-5.f),
+ MLFloat16(3.f), MLFloat16(6.f), MLFloat16(3.f),
+ MLFloat16(-2.f), MLFloat16(7.f), MLFloat16(-2.f),
+ MLFloat16(5.f), MLFloat16(-1.f), MLFloat16(5.f),
+ MLFloat16(-3.f), MLFloat16(4.f), MLFloat16(-3.f),
+ MLFloat16(1.f), MLFloat16(-5.f), MLFloat16(1.f)};
vector X_shape = {1, 3, 3, 3};
vector W = {MLFloat16(2.f), MLFloat16(-3.f), MLFloat16(0.5f),
MLFloat16(0.25f), MLFloat16(-2.f), MLFloat16(-0.75f)};
vector W_shape = {2, 3, 1, 1};
- vector Y_shape = {1, 2, 3, 3};
+ vector Y_shape = {1, 3, 3, 2};
auto expected_vals = {
- MLFloat16(0.f), MLFloat16(0.f), MLFloat16(17.f),
- MLFloat16(0.f), MLFloat16(0.f), MLFloat16(0.f),
- MLFloat16(15.5f), MLFloat16(0.f), MLFloat16(17.5f),
- MLFloat16(2.5f), MLFloat16(0.f), MLFloat16(7.f),
- MLFloat16(4.5f), MLFloat16(0.f), MLFloat16(0.f),
- MLFloat16(0.f), MLFloat16(0.f), MLFloat16(9.5f)};
+ MLFloat16(0.f), MLFloat16(2.5f),
+ MLFloat16(0.f), MLFloat16(0.f),
+ MLFloat16(17.f), MLFloat16(7.f),
+ MLFloat16(0.f), MLFloat16(4.5f),
+ MLFloat16(0.f), MLFloat16(0.f),
+ MLFloat16(0.f), MLFloat16(0.f),
+ MLFloat16(15.5f), MLFloat16(0.f),
+ MLFloat16(0.f), MLFloat16(0.f),
+ MLFloat16(17.5f), MLFloat16(9.5f)};
TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
}
+TEST(ConvFp16Test, Conv2D_HardSigmoid) {
+ ConvOpAndTestAttributes attrs = {
+ "", // auto_pad
+ vector{1, 1}, // dilations
+ 1, // group
+ vector{2, 2}, // kernel_shape
+ vector{0, 0, 0, 0}, // pads
+ vector{1, 1}, // strides
+ {}, // excluded EPs
+ "HardSigmoid", // activation
+ vector{0.2f, 0.5f} // activation_parameters
+ };
+
+ vector X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f),
+ MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f),
+ MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f)};
+ vector X_shape = {1, 3, 3, 1};
+ vector W = {MLFloat16(0.125f), MLFloat16(0.125f), MLFloat16(0.125f), MLFloat16(0.125f),
+ MLFloat16(-0.125f), MLFloat16(-0.125f), MLFloat16(-0.125f), MLFloat16(-0.125f)};
+ vector W_shape = {2, 1, 2, 2};
+ vector Y_shape = {1, 2, 2, 2};
+ auto expected_vals = {
+ MLFloat16(0.8f), MLFloat16(0.2f),
+ MLFloat16(0.9f), MLFloat16(0.1f),
+ MLFloat16(1.0f), MLFloat16(0.0f),
+ MLFloat16(1.0f), MLFloat16(0.0f)};
+ TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
+}
+
+
+TEST(ConvFp16Test, Conv2D_Bias_Z_Relu) {
+ ConvOpAndTestAttributes attrs = {
+ "", // auto_pad
+ vector{1, 1}, // dilations
+ 1, // group
+ vector{2, 2}, // kernel_shape
+ vector{0, 0, 0, 0}, // pads
+ vector{1, 1}, // strides
+ {}, // excluded EPs
+ "Relu" // activation
+ };
+
+ vector X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f),
+ MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f),
+ MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f)};
+ vector X_shape = {1, 3, 3, 1};
+ vector W = {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f),
+ MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f)};
+ vector W_shape = {2, 1, 2, 2};
+ vector Y_shape = {1, 2, 2, 2};
+ vector B = {MLFloat16(1.0f), MLFloat16(-1.0f)};
+ vector B_shape = {2};
+ vector Z = {MLFloat16(-1.0f), MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f),
+ MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(1.0f)};
+ vector Z_shape = {1, 2, 2, 2};
+ auto expected_vals = {MLFloat16(12.0f), MLFloat16(11.0f), MLFloat16(17.0f), MLFloat16(15.0f), MLFloat16(25.0f), MLFloat16(23.0f), MLFloat16(29.0f), MLFloat16(28.0f)};
+ TestConvFp16Op(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape);
+}
+
#endif // CONTRIB_OPS