mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-09 00:30:53 +00:00
Fuse 'Add' operator into FP16 Conv (#15213)
### Description Adding 'Add' functionality to FP16 Conv operator. It takes a tensor that has the same shape of the output tensor, and add it to the result tensor. ### Motivation and Context Needed to run Resnet 50
This commit is contained in:
parent
bb21031cbb
commit
8dce83a818
9 changed files with 430 additions and 170 deletions
|
|
@ -51,6 +51,7 @@ Do not modify directly.*
|
|||
* <a href="#com.microsoft.MurmurHash3">com.microsoft.MurmurHash3</a>
|
||||
* <a href="#com.microsoft.NGramRepeatBlock">com.microsoft.NGramRepeatBlock</a>
|
||||
* <a href="#com.microsoft.NhwcConv">com.microsoft.NhwcConv</a>
|
||||
* <a href="#com.microsoft.NhwcFusedConv">com.microsoft.NhwcFusedConv</a>
|
||||
* <a href="#com.microsoft.NhwcMaxPool">com.microsoft.NhwcMaxPool</a>
|
||||
* <a href="#com.microsoft.PackedAttention">com.microsoft.PackedAttention</a>
|
||||
* <a href="#com.microsoft.Pad">com.microsoft.Pad</a>
|
||||
|
|
@ -2637,6 +2638,64 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.NhwcFusedConv"></a><a name="com.microsoft.nhwcfusedconv">**com.microsoft.NhwcFusedConv**</a>
|
||||
|
||||
NhwcFusedConv is a Conv operator with optional activation and add operators fused in.
|
||||
Only has fp16 implementation as of 2023/04/15.
|
||||
|
||||
#### Version
|
||||
|
||||
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
|
||||
|
||||
#### Attributes
|
||||
|
||||
<dl>
|
||||
<dt><tt>activation</tt> : string</dt>
|
||||
<dd></dd>
|
||||
<dt><tt>activation_params</tt> : list of floats</dt>
|
||||
<dd></dd>
|
||||
<dt><tt>auto_pad</tt> : string</dt>
|
||||
<dd></dd>
|
||||
<dt><tt>dilations</tt> : list of ints</dt>
|
||||
<dd></dd>
|
||||
<dt><tt>group</tt> : int</dt>
|
||||
<dd></dd>
|
||||
<dt><tt>kernel_shape</tt> : list of ints</dt>
|
||||
<dd></dd>
|
||||
<dt><tt>pads</tt> : list of ints</dt>
|
||||
<dd></dd>
|
||||
<dt><tt>strides</tt> : list of ints</dt>
|
||||
<dd></dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs (2 - 4)
|
||||
|
||||
<dl>
|
||||
<dt><tt>X</tt> : T</dt>
|
||||
<dd></dd>
|
||||
<dt><tt>W</tt> : T</dt>
|
||||
<dd></dd>
|
||||
<dt><tt>B</tt> (optional) : T</dt>
|
||||
<dd></dd>
|
||||
<dt><tt>Z</tt> (optional) : T</dt>
|
||||
<dd>Tensor to be added to the output, must be the same shape and format as the output tensor.</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>Y</tt> : T</dt>
|
||||
<dd></dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
||||
<dl>
|
||||
<dt><tt>T</tt> : tensor(float16)</dt>
|
||||
<dd>Constrain input and output types to float tensors</dd>
|
||||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.NhwcMaxPool"></a><a name="com.microsoft.nhwcmaxpool">**com.microsoft.NhwcMaxPool**</a>
|
||||
|
||||
#### Version
|
||||
|
|
|
|||
|
|
@ -16,9 +16,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv);
|
||||
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv);
|
||||
#endif
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
|
||||
|
|
@ -81,6 +78,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
|
|||
// ******** End: Quantization ******************* //
|
||||
|
||||
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, NhwcFusedConv);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16, MaxPool);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16, AveragePool);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSInternalNHWCDomain, 1, MLFloat16, GlobalAveragePool);
|
||||
|
|
@ -159,6 +157,7 @@ Status RegisterNchwcKernels(KernelRegistry& kernel_registry) {
|
|||
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
|
||||
Status RegisterFp16Kernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, NhwcFusedConv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSInternalNHWCDomain, 1, MLFloat16, GlobalAveragePool)>,
|
||||
|
|
@ -232,9 +231,6 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv)>,
|
||||
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv)>,
|
||||
#endif
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling)>,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearGlobalAveragePool
|
|||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearAveragePool);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearConv);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NhwcConv);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NhwcFusedConv);
|
||||
|
||||
// Quantization ops
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DequantizeLinear);
|
||||
|
|
@ -110,6 +111,7 @@ class OpSet_Microsoft_ver1 {
|
|||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearAveragePool)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearConv)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NhwcConv)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NhwcFusedConv)>());
|
||||
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DequantizeLinear)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DequantizeBFP)>());
|
||||
|
|
|
|||
|
|
@ -383,5 +383,31 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
NhwcConv,
|
||||
1,
|
||||
OpSchema().FillUsing(ConvOpSchemaGenerator()));
|
||||
|
||||
ONNX_MS_OPERATOR_SET_SCHEMA(NhwcFusedConv, 1,
|
||||
OpSchema()
|
||||
.SetDoc(R"DOC(
|
||||
NhwcFusedConv is a Conv operator with optional activation and add operators fused in.
|
||||
Only has fp16 implementation as of 2023/04/15.
|
||||
)DOC")
|
||||
.Attr("auto_pad", "", AttributeProto::STRING, std::string("NOTSET"))
|
||||
.Attr("kernel_shape", "", AttributeProto::INTS, OPTIONAL_VALUE)
|
||||
.Attr("dilations", "", AttributeProto::INTS, OPTIONAL_VALUE)
|
||||
.Attr("strides", "", AttributeProto::INTS, OPTIONAL_VALUE)
|
||||
.Attr("pads", "", AttributeProto::INTS, OPTIONAL_VALUE)
|
||||
.Attr("group", "", AttributeProto::INT, static_cast<int64_t>(1))
|
||||
.Attr("activation", "", AttributeProto::STRING, OPTIONAL_VALUE)
|
||||
.Attr("activation_params", "", AttributeProto::FLOATS, OPTIONAL_VALUE)
|
||||
.Input(0, "X", "", "T")
|
||||
.Input(1, "W", "", "T")
|
||||
.Input(2, "B", "", "T", OpSchema::Optional)
|
||||
.Input(3, "Z", "Tensor to be added to the output, must be the same shape and format as the output tensor.", "T", OpSchema::Optional)
|
||||
.Output(0, "Y", "", "T")
|
||||
.TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output types to float tensors")
|
||||
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
convPoolShapeInferenceNhwc(ctx, true, false, 0, 1);
|
||||
}));
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1437,18 +1437,20 @@ public:
|
|||
};
|
||||
|
||||
/**
|
||||
* @brief Half precision activation functions
|
||||
* @brief Half precision activation functions, with optional sum tensor.
|
||||
* Supplied sum tensor must be the same layout as the GEMM output tensor.
|
||||
* And the supplied sum tensor will be added to the final result.
|
||||
*/
|
||||
class MLAS_HALF_GEMM_ACTIVATION_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR {
|
||||
public:
|
||||
class MLAS_HALF_GEMM_ACTIVATION_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR
|
||||
{
|
||||
public:
|
||||
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR(
|
||||
const MLAS_ACTIVATION& Activation
|
||||
) :
|
||||
Activation_(Activation)
|
||||
const MLAS_ACTIVATION& Activation,
|
||||
const MLAS_FP16* SumBuf = nullptr)
|
||||
: Activation_(Activation), SumBuf_(SumBuf)
|
||||
{}
|
||||
|
||||
void
|
||||
Process(
|
||||
void Process(
|
||||
MLAS_FP16* C,
|
||||
size_t StartM,
|
||||
size_t StartN,
|
||||
|
|
@ -1457,8 +1459,9 @@ public:
|
|||
size_t ldc
|
||||
) const override;
|
||||
|
||||
private:
|
||||
private:
|
||||
const MLAS_ACTIVATION& Activation_;
|
||||
const MLAS_FP16* SumBuf_;
|
||||
};
|
||||
|
||||
inline
|
||||
|
|
|
|||
|
|
@ -25,6 +25,20 @@ Abstract:
|
|||
template<MLAS_ACTIVATION_KIND ActivationKind>
|
||||
struct MLAS_HALF_ACTIVATION_FUNCTION;
|
||||
|
||||
template <>
|
||||
struct MLAS_HALF_ACTIVATION_FUNCTION<MlasIdentityActivation> {
|
||||
MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation)
|
||||
{
|
||||
MLAS_UNREFERENCED_PARAMETER(Activation);
|
||||
}
|
||||
|
||||
MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) { return Value; }
|
||||
|
||||
MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) { return Value; }
|
||||
|
||||
float Activate(float Value) { return Value; }
|
||||
};
|
||||
|
||||
template<>
|
||||
struct MLAS_HALF_ACTIVATION_FUNCTION<MlasReluActivation>
|
||||
{
|
||||
|
|
@ -582,7 +596,7 @@ inline
|
|||
void
|
||||
MlasActivationKernel(
|
||||
const MLAS_ACTIVATION& Activation,
|
||||
MLAS_FP16* Buffer,
|
||||
_mlas_fp16_* Buffer,
|
||||
size_t StartM,
|
||||
size_t StartN,
|
||||
size_t CountM,
|
||||
|
|
@ -592,8 +606,7 @@ MlasActivationKernel(
|
|||
{
|
||||
MLAS_HALF_ACTIVATION_FUNCTION<ActivationKind> ActivationFunction(Activation);
|
||||
|
||||
auto* CRow = reinterpret_cast<_mlas_fp16_*>(Buffer);
|
||||
CRow += StartM * ldc + StartN;
|
||||
auto* CRow = Buffer + StartM * ldc + StartN;
|
||||
|
||||
while (CountM-- > 0) {
|
||||
_mlas_fp16_* buffer = CRow;
|
||||
|
|
@ -629,7 +642,7 @@ inline
|
|||
void
|
||||
MlasActivationKernel<MlasIdentityActivation>(
|
||||
const MLAS_ACTIVATION& Activation,
|
||||
MLAS_FP16* Buffer,
|
||||
_mlas_fp16_* Buffer,
|
||||
size_t StartM,
|
||||
size_t StartN,
|
||||
size_t CountM,
|
||||
|
|
@ -651,6 +664,68 @@ MlasActivationKernel<MlasIdentityActivation>(
|
|||
}
|
||||
|
||||
|
||||
template<MLAS_ACTIVATION_KIND ActivationKind>
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasActivationKernel(
|
||||
const MLAS_ACTIVATION& Activation,
|
||||
_mlas_fp16_* Buffer,
|
||||
const _mlas_fp16_* Addon,
|
||||
size_t StartM,
|
||||
size_t StartN,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t ldc
|
||||
)
|
||||
{
|
||||
MLAS_HALF_ACTIVATION_FUNCTION<ActivationKind> ActivationFunction(Activation);
|
||||
|
||||
auto* CRow = Buffer + StartM * ldc + StartN;
|
||||
const auto* ARow = Addon + StartM * ldc + StartN;
|
||||
|
||||
while (CountM-- > 0) {
|
||||
auto* buffer = CRow;
|
||||
const auto* addsrc = ARow;
|
||||
size_t n = CountN;
|
||||
|
||||
while (n >= 8) {
|
||||
MLAS_FLOAT16X8 Vector = MlasLoadFloat16x8(buffer);
|
||||
MLAS_FLOAT16X8 AVec = MlasLoadFloat16x8(addsrc);
|
||||
addsrc += 8;
|
||||
Vector = ActivationFunction.Activate(Vector);
|
||||
Vector = MlasAddFloat16x8(Vector, AVec);
|
||||
MlasStoreFloat16x8(buffer, Vector);
|
||||
buffer += 8;
|
||||
n -= 8;
|
||||
}
|
||||
|
||||
if (n >= 4) {
|
||||
MLAS_FLOAT16X4 Vector = MlasLoadFloat16x4(buffer);
|
||||
MLAS_FLOAT16X4 AVec = MlasLoadFloat16x4(addsrc);
|
||||
addsrc += 4;
|
||||
Vector = ActivationFunction.Activate(Vector);
|
||||
Vector = MlasAddFloat16x4(Vector, AVec);
|
||||
MlasStoreFloat16x4(buffer, Vector);
|
||||
buffer += 4;
|
||||
n -= 4;
|
||||
}
|
||||
|
||||
if (n > 0) {
|
||||
MLAS_FLOAT16X4 buf;
|
||||
std::memcpy(&buf, buffer, n * sizeof(_mlas_fp16_));
|
||||
MLAS_FLOAT16X4 addbuf;
|
||||
std::memcpy(&addbuf, addsrc, n * sizeof(_mlas_fp16_));
|
||||
MLAS_FLOAT16X4 res = ActivationFunction.Activate(buf);
|
||||
res = MlasAddFloat16x4(res, addbuf);
|
||||
MlasStorePartialFloat16x4(buffer, res, n);
|
||||
}
|
||||
|
||||
CRow += ldc;
|
||||
ARow += ldc;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process(
|
||||
MLAS_FP16* C,
|
||||
|
|
@ -661,46 +736,89 @@ MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process(
|
|||
size_t ldc
|
||||
) const
|
||||
{
|
||||
auto* Buffer = reinterpret_cast<_mlas_fp16_*>(C);
|
||||
switch (Activation_.ActivationKind) {
|
||||
case MlasIdentityActivation: {
|
||||
MlasActivationKernel<MlasIdentityActivation>(Activation_, C, StartM, StartN, CountM,
|
||||
CountN, ldc);
|
||||
if (SumBuf_) {
|
||||
MlasActivationKernel<MlasIdentityActivation>(
|
||||
Activation_, Buffer, reinterpret_cast<const _mlas_fp16_*>(SumBuf_), StartM,
|
||||
StartN, CountM, CountN, ldc);
|
||||
} else {
|
||||
MlasActivationKernel<MlasIdentityActivation>(Activation_, Buffer, StartM, StartN,
|
||||
CountM, CountN, ldc);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case MlasReluActivation: {
|
||||
MlasActivationKernel<MlasReluActivation>(Activation_, C, StartM, StartN, CountM, CountN,
|
||||
ldc);
|
||||
if (SumBuf_) {
|
||||
MlasActivationKernel<MlasReluActivation>(
|
||||
Activation_, Buffer, reinterpret_cast<const _mlas_fp16_*>(SumBuf_), StartM,
|
||||
StartN, CountM, CountN, ldc);
|
||||
} else {
|
||||
MlasActivationKernel<MlasReluActivation>(Activation_, Buffer, StartM, StartN,
|
||||
CountM, CountN, ldc);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case MlasLeakyReluActivation: {
|
||||
MlasActivationKernel<MlasLeakyReluActivation>(Activation_, C, StartM, StartN, CountM,
|
||||
CountN, ldc);
|
||||
if (SumBuf_) {
|
||||
MlasActivationKernel<MlasLeakyReluActivation>(
|
||||
Activation_, Buffer, reinterpret_cast<const _mlas_fp16_*>(SumBuf_), StartM,
|
||||
StartN, CountM, CountN, ldc);
|
||||
} else {
|
||||
MlasActivationKernel<MlasLeakyReluActivation>(Activation_, Buffer, StartM, StartN,
|
||||
CountM, CountN, ldc);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case MlasTanhActivation: {
|
||||
MlasActivationKernel<MlasTanhActivation>(Activation_, C, StartM, StartN, CountM, CountN,
|
||||
ldc);
|
||||
if (SumBuf_) {
|
||||
MlasActivationKernel<MlasTanhActivation>(
|
||||
Activation_, Buffer, reinterpret_cast<const _mlas_fp16_*>(SumBuf_), StartM,
|
||||
StartN, CountM, CountN, ldc);
|
||||
} else {
|
||||
MlasActivationKernel<MlasTanhActivation>(Activation_, Buffer, StartM, StartN,
|
||||
CountM, CountN, ldc);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case MlasLogisticActivation: {
|
||||
MlasActivationKernel<MlasLogisticActivation>(Activation_, C, StartM, StartN, CountM,
|
||||
CountN, ldc);
|
||||
if (SumBuf_) {
|
||||
MlasActivationKernel<MlasLogisticActivation>(
|
||||
Activation_, Buffer, reinterpret_cast<const _mlas_fp16_*>(SumBuf_), StartM,
|
||||
StartN, CountM, CountN, ldc);
|
||||
} else {
|
||||
MlasActivationKernel<MlasLogisticActivation>(Activation_, Buffer, StartM, StartN,
|
||||
CountM, CountN, ldc);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case MlasClipActivation: {
|
||||
MlasActivationKernel<MlasClipActivation>(Activation_, C, StartM, StartN, CountM, CountN,
|
||||
ldc);
|
||||
if (SumBuf_) {
|
||||
MlasActivationKernel<MlasClipActivation>(
|
||||
Activation_, Buffer, reinterpret_cast<const _mlas_fp16_*>(SumBuf_), StartM,
|
||||
StartN, CountM, CountN, ldc);
|
||||
} else {
|
||||
MlasActivationKernel<MlasClipActivation>(Activation_, Buffer, StartM, StartN,
|
||||
CountM, CountN, ldc);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case MlasHardSigmoidActivation: {
|
||||
MlasActivationKernel<MlasHardSigmoidActivation>(Activation_, C, StartM, StartN, CountM,
|
||||
CountN, ldc);
|
||||
if (SumBuf_) {
|
||||
MlasActivationKernel<MlasHardSigmoidActivation>(
|
||||
Activation_, Buffer, reinterpret_cast<const _mlas_fp16_*>(SumBuf_), StartM,
|
||||
StartN, CountM, CountN, ldc);
|
||||
} else {
|
||||
MlasActivationKernel<MlasHardSigmoidActivation>(Activation_, Buffer, StartM, StartN,
|
||||
CountM, CountN, ldc);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -744,10 +862,20 @@ MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process(
|
|||
proc.Process(C, StartM, StartN, CountM, CountN, ldc);
|
||||
|
||||
_mlas_fp16_* Output = reinterpret_cast<_mlas_fp16_*>(C);
|
||||
const auto* CRow = buffer.data();
|
||||
auto* CRow = buffer.data();
|
||||
const _mlas_fp16_* CAdd = nullptr;
|
||||
if (SumBuf_) {
|
||||
CAdd = reinterpret_cast<const _mlas_fp16_*>(SumBuf_) + StartM * ldc + StartN;
|
||||
}
|
||||
Output += StartM * ldc + StartN;
|
||||
|
||||
while (CountM-- > 0) {
|
||||
if (CAdd) {
|
||||
for (size_t n = 0; n < CountN; n++) {
|
||||
CRow[n] += MLAS_Half2Float(CAdd[n]);
|
||||
}
|
||||
CAdd += ldc;
|
||||
}
|
||||
CvtFloat2Half(Output, CRow, CountN);
|
||||
CRow += CountN;
|
||||
Output += ldc;
|
||||
|
|
|
|||
|
|
@ -32,18 +32,20 @@ using ConvPadVector = ConvAttributes::ConvPadVector;
|
|||
* 2. Activation
|
||||
* It takes an operator attribute 'activation', which supplies the activation info.
|
||||
*
|
||||
* Add is performed BEFORE activation.
|
||||
* Add is performed AFTER activation.
|
||||
*
|
||||
* The implementation runs faster with NHWC. By default, it converts NCHW to NHWC
|
||||
* before processing, and convert the result back. It can take NHWC tensors directly.
|
||||
* Use operator attribute 'channels_last' to specify that the data layout is NHWC.
|
||||
* The implementation supports both NCHW and NHWC. It runs faster with NHWC.
|
||||
*
|
||||
* Currently this class implement 3 operators: onnx.Conv, ms.FusedConv and ms.NhwcFusedConv
|
||||
* In the constructor, if we see the operator name is NhwcFusedConv, we assume the
|
||||
* input layout to be NHWC, otherwise we assume layout is NCHW.
|
||||
*
|
||||
*/
|
||||
class FusedConvFp16 final : public OpKernel {
|
||||
public:
|
||||
FusedConvFp16(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {
|
||||
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
|
||||
channels_last_ = (info.GetAttrOrDefault<int64_t>("channels_last", static_cast<int64_t>(0)) != 0);
|
||||
channels_last_ = (info.GetKernelDef().OpName() == "NhwcFusedConv");
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
|
@ -55,6 +57,9 @@ class FusedConvFp16 final : public OpKernel {
|
|||
int input_idx,
|
||||
/*out*/ bool& used_shared_buffers) override;
|
||||
|
||||
protected:
|
||||
bool channels_last_{false};
|
||||
|
||||
private:
|
||||
|
||||
/**
|
||||
|
|
@ -89,7 +94,6 @@ class FusedConvFp16 final : public OpKernel {
|
|||
|
||||
MLAS_ACTIVATION activation_;
|
||||
ConvAttributes conv_attrs_;
|
||||
bool channels_last_{false};
|
||||
TensorShape W_shape_;
|
||||
BufferUniquePtr packed_W_buffer_;
|
||||
size_t packed_W_size_{0};
|
||||
|
|
@ -236,10 +240,8 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
|
|||
const auto& W_shape = W ? W->Shape() : W_shape_;
|
||||
const Tensor* B = num_inputs >= 3 ? context->Input<Tensor>(2) : nullptr;
|
||||
|
||||
// TODO!!
|
||||
// This tensor should be added to the result before activation is applied
|
||||
// We need to augment the post processor to accept an addition operation.
|
||||
// const Tensor* Sum = num_inputs >= 4 ? context->Input<Tensor>(3) : nullptr;
|
||||
// This tensor should be added to the result AFTER activation is applied
|
||||
const Tensor* Sum = num_inputs >= 4 ? context->Input<Tensor>(3) : nullptr;
|
||||
|
||||
const int64_t N = X->Shape()[0];
|
||||
const int64_t M = W_shape[0];
|
||||
|
|
@ -282,6 +284,13 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
|
|||
if (Y->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (Sum) {
|
||||
if (Sum->Shape() != Y->Shape()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Z shape does not match output shape.",
|
||||
" Z: ", Sum->Shape().ToString().c_str(),
|
||||
" Output: ", Y->Shape().ToString().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t input_image_size = input_shape.Size();
|
||||
const int64_t output_image_size = output_shape.Size();
|
||||
|
|
@ -332,6 +341,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
|
|||
const auto* Xdata = X->Data<MLFloat16>();
|
||||
const auto* Bdata = B != nullptr ? B->Data<MLFloat16>() : nullptr;
|
||||
auto* Ydata = Y->MutableData<MLFloat16>();
|
||||
const auto* SumData = Sum != nullptr ? Sum->Data<MLFloat16>() : nullptr;
|
||||
|
||||
BufferUniquePtr transpose_input_buffer;
|
||||
BufferUniquePtr transpose_output_buffer;
|
||||
|
|
@ -403,6 +413,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
|
|||
for (int64_t image_id = 0; image_id < N; ++image_id) {
|
||||
const auto* input_data = Xdata;
|
||||
auto* output_data = Ydata;
|
||||
const auto* add_src = SumData;
|
||||
|
||||
if (!channels_last_) {
|
||||
// Transpose the input from channels first (CHW) to channels last (HWC).
|
||||
|
|
@ -413,6 +424,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
|
|||
static_cast<size_t>(input_image_size));
|
||||
input_data = static_cast<MLFloat16*>(transpose_input_buffer.get());
|
||||
output_data = static_cast<MLFloat16*>(transpose_output_buffer.get());
|
||||
add_src = nullptr;
|
||||
}
|
||||
|
||||
// Threaded implementation of ND convolution is not yet supported, so
|
||||
|
|
@ -459,10 +471,10 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
auto* worker_output = output_data + output_start * M;
|
||||
const auto* worker_addsrc = add_src == nullptr ? nullptr : add_src + output_start * M;
|
||||
|
||||
if (is_depthwise_conv) {
|
||||
// TODO!! add Sum tensor to activation
|
||||
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR act(activation_);
|
||||
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR act(activation_, worker_addsrc);
|
||||
MlasConvDepthwise(
|
||||
worker_indirection_buffer,
|
||||
reordered_W,
|
||||
|
|
@ -531,8 +543,8 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
|
|||
lda = static_cast<size_t>(C);
|
||||
}
|
||||
|
||||
// TODO!! add Sum tensor to activation
|
||||
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR act(activation_);
|
||||
const auto* gemm_add = add_src == nullptr ? nullptr : worker_addsrc + group_id * group_output_channels;
|
||||
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR act(activation_, gemm_add);
|
||||
MLAS_HALF_GEMM_DATA_PARAMS gemm_params;
|
||||
gemm_params.A = AData;
|
||||
gemm_params.lda = lda;
|
||||
|
|
@ -567,15 +579,29 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
|
|||
Ydata,
|
||||
static_cast<size_t>(output_image_size),
|
||||
static_cast<size_t>(M));
|
||||
if (SumData != nullptr) {
|
||||
MLAS_ACTIVATION activation;
|
||||
activation.ActivationKind = MlasIdentityActivation;
|
||||
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(activation, SumData);
|
||||
proc.Process(Ydata, 0, 0, static_cast<size_t>(M),
|
||||
static_cast<size_t>(output_image_size),
|
||||
static_cast<size_t>(output_image_size));
|
||||
}
|
||||
}
|
||||
|
||||
Xdata += X_offset;
|
||||
Ydata += Y_offset;
|
||||
if (SumData != nullptr) {
|
||||
SumData += Y_offset;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//
|
||||
// Operator definitions
|
||||
//
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Conv,
|
||||
|
|
@ -586,14 +612,29 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
|||
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
|
||||
namespace contrib {
|
||||
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
|
||||
FusedConv,
|
||||
1,
|
||||
MLFloat16,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
|
||||
FusedConvFp16);
|
||||
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX(
|
||||
NhwcFusedConv,
|
||||
kMSDomain,
|
||||
1,
|
||||
MLFloat16,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
|
||||
FusedConvFp16);
|
||||
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX(
|
||||
FusedConv,
|
||||
kMSDomain,
|
||||
1,
|
||||
MLFloat16,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
|
||||
FusedConvFp16);
|
||||
|
||||
} // namespace contrib
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,20 @@
|
|||
|
||||
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
|
||||
|
||||
bool check_equal(float actual, float expected) {
|
||||
if (std::isnan(actual)) {
|
||||
return std::isnan(expected);
|
||||
} else {
|
||||
float diff = std::abs(actual - expected);
|
||||
float top = std::max(std::abs(actual), std::abs(expected));
|
||||
float ratio = 0;
|
||||
if (top > 0.0001) {
|
||||
ratio = diff / top;
|
||||
}
|
||||
return ratio < 0.005;
|
||||
}
|
||||
}
|
||||
|
||||
class MlasFp16ActivationTest : public MlasTestBase {
|
||||
public:
|
||||
static const char* GetTestSuiteName() {
|
||||
|
|
@ -44,14 +58,25 @@ class MlasFp16ActivationTest : public MlasTestBase {
|
|||
|
||||
constexpr size_t M = 5;
|
||||
constexpr size_t N = 23;
|
||||
constexpr float MinimumFillValue = -11.0f;
|
||||
|
||||
MatrixGuardBuffer<MLFp16> HalfBuffer1;
|
||||
auto* testData1 = HalfBuffer1.GetBuffer(M * N, true);
|
||||
MatrixGuardBuffer<MLFp16> HalfBuffer2;
|
||||
auto* testData2 = HalfBuffer2.GetBuffer(M * N, true);
|
||||
MatrixGuardBuffer<MLFp16> HalfBuffer3;
|
||||
auto* testData3 = HalfBuffer3.GetBuffer(M * N, true);
|
||||
MatrixGuardBuffer<MLFp16> AddonBuffer;
|
||||
auto addonData = AddonBuffer.GetBuffer(M * N, true);
|
||||
MatrixGuardBuffer<float> FloatBuffer;
|
||||
auto* fpBuffer = FloatBuffer.GetBuffer(M * N, true);
|
||||
|
||||
size_t o = 3;
|
||||
for (size_t i = 0; i < M * N; i++) {
|
||||
o = (o + 19) % 23;
|
||||
addonData[i] = (MinimumFillValue + o) / 16.0f;
|
||||
}
|
||||
|
||||
MLAS_ACTIVATION_KIND acts[] = {
|
||||
MlasIdentityActivation,
|
||||
MlasReluActivation,
|
||||
|
|
@ -62,8 +87,9 @@ class MlasFp16ActivationTest : public MlasTestBase {
|
|||
MlasHardSigmoidActivation};
|
||||
|
||||
MLAS_ACTIVATION Activation;
|
||||
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(Activation);
|
||||
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(Activation, nullptr);
|
||||
MLAS_HALF_GEMM_2FLOAT_PROCESSOR converter(Activation, fpBuffer, N);
|
||||
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR addon(Activation, reinterpret_cast<const MLAS_FP16*>(addonData));
|
||||
for (auto kind : acts) {
|
||||
Activation.ActivationKind = MLAS_ACTIVATION_KIND(kind);
|
||||
|
||||
|
|
@ -84,37 +110,32 @@ class MlasFp16ActivationTest : public MlasTestBase {
|
|||
for (size_t i = 0; i < _countof(TestData); i++) {
|
||||
testData1[i] = TestData[i].f;
|
||||
testData2[i] = TestData[i].f;
|
||||
testData3[i] = TestData[i].f;
|
||||
}
|
||||
constexpr float MinimumFillValue = -11.0f;
|
||||
size_t offset = 7;
|
||||
for (size_t i = _countof(TestData); i < M * N; i++) {
|
||||
offset = (offset + 19) % 23;
|
||||
testData1[i] = (MinimumFillValue + offset) / 16.0f;
|
||||
testData2[i] = testData1[i];
|
||||
testData3[i] = testData1[i];
|
||||
}
|
||||
|
||||
proc.Process(reinterpret_cast<MLAS_FP16*>(testData1), 0, 0, M, N, N);
|
||||
converter.Process(reinterpret_cast<MLAS_FP16*>(testData2), 0, 0, M, N, N);
|
||||
addon.Process(reinterpret_cast<MLAS_FP16*>(testData3), 0, 0, M, N, N);
|
||||
|
||||
for (size_t i = 0; i < M*N; i++) {
|
||||
float actual = testData1[i].ToFloat();
|
||||
if (std::isnan(actual)) {
|
||||
EXPECT_TRUE(std::isnan(fpBuffer[i]))
|
||||
EXPECT_TRUE(check_equal(actual, fpBuffer[i]))
|
||||
<< ", Vector Activation Kind:" << (int)kind << ", i=" << i << ", value:"
|
||||
<< std::setw(8) << std::setfill('0') << std::hex << actual << ", expecting:"
|
||||
<< std::setw(8) << std::setfill('0') << std::hex << fpBuffer[i];
|
||||
|
||||
} else {
|
||||
float diff = std::abs(actual - fpBuffer[i]);
|
||||
float top = std::max(std::abs(actual), std::abs(fpBuffer[i]));
|
||||
float ratio = 0;
|
||||
if (top > 0.0001) {
|
||||
ratio = diff / top;
|
||||
}
|
||||
EXPECT_TRUE(ratio < 0.005)
|
||||
<< ", Vector Activation Kind:" << (int)kind << ", i=" << i << ", value:"
|
||||
<< actual << ", expecting:" << fpBuffer[i];
|
||||
}
|
||||
float addonActual = testData3[i].ToFloat() - addonData[i].ToFloat();
|
||||
EXPECT_TRUE(check_equal(addonActual, fpBuffer[i]))
|
||||
<< ", Vector + Activation Kind:" << (int)kind << ", i=" << i << ", value:"
|
||||
<< std::setw(8) << std::setfill('0') << std::hex << actual << ", expecting:"
|
||||
<< std::setw(8) << std::setfill('0') << std::hex << fpBuffer[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes,
|
|||
int opset = 11) {
|
||||
std::unique_ptr<OpTester> tester;
|
||||
if (!attributes.activation.empty()) {
|
||||
tester = std::make_unique<OpTester>("FusedConv", 1, onnxruntime::kMSDomain);
|
||||
tester = std::make_unique<OpTester>("NhwcFusedConv", 1, onnxruntime::kMSDomain);
|
||||
tester->AddAttribute("activation", attributes.activation);
|
||||
|
||||
if (!attributes.activation_parameters.empty()) {
|
||||
|
|
@ -68,12 +68,14 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes,
|
|||
}
|
||||
|
||||
|
||||
ORT_ENFORCE(inputs.size() <= 3, "Our name array is only setup to handle 3 inputs");
|
||||
const char* szNames[] = {"X", "W", "B"};
|
||||
ORT_ENFORCE(inputs.size() <= 4, "Our name array is only setup to handle 4 inputs");
|
||||
const char* szNames[] = {"X", "W", "B", "Z"};
|
||||
tester->AddInput<MLFloat16>(szNames[0], input_shapes[0], inputs[0]);
|
||||
tester->AddInput<MLFloat16>(szNames[1], input_shapes[1], inputs[1], weight_is_initializer);
|
||||
if (inputs.size() == 3)
|
||||
if (inputs.size() >= 3)
|
||||
tester->AddInput<MLFloat16>(szNames[2], input_shapes[2], inputs[2]);
|
||||
if (inputs.size() >= 4)
|
||||
tester->AddInput<MLFloat16>(szNames[3], input_shapes[3], inputs[3]);
|
||||
|
||||
tester->AddOutput<MLFloat16>("Y", expected_output_shape, expected_output, /*no sort*/ false, 0.002f, 0.0f);
|
||||
|
||||
|
|
@ -506,86 +508,6 @@ TEST(ConvFp16Test, Conv2D_AutoPad2) {
|
|||
TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true);
|
||||
}
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
TEST(ConvFp16Test, Conv2D_HardSigmoid) {
|
||||
ConvOpAndTestAttributes attrs = {
|
||||
"", // auto_pad
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
1, // group
|
||||
vector<int64_t>{2, 2}, // kernel_shape
|
||||
vector<int64_t>{0, 0, 0, 0}, // pads
|
||||
vector<int64_t>{1, 1}, // strides
|
||||
{}, // excluded EPs
|
||||
"HardSigmoid", // activation
|
||||
vector<float>{0.2f, 0.5f} // activation_parameters
|
||||
};
|
||||
|
||||
vector<MLFloat16> X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f),
|
||||
MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f),
|
||||
MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f)};
|
||||
vector<int64_t> X_shape = {1, 1, 3, 3};
|
||||
vector<MLFloat16> W = {MLFloat16(0.125f), MLFloat16(0.125f), MLFloat16(0.125f), MLFloat16(0.125f),
|
||||
MLFloat16(-0.125f), MLFloat16(-0.125f), MLFloat16(-0.125f), MLFloat16(-0.125f)};
|
||||
vector<int64_t> W_shape = {2, 1, 2, 2};
|
||||
vector<int64_t> Y_shape = {1, 2, 2, 2};
|
||||
auto expected_vals = {MLFloat16(0.8f), MLFloat16(0.9f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(0.2f), MLFloat16(0.1f), MLFloat16(0.0f), MLFloat16(0.0f)};
|
||||
TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
TEST(ConvFp16Test, Conv2D_Relu) {
|
||||
ConvOpAndTestAttributes attrs = {
|
||||
"", // auto_pad
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
1, // group
|
||||
vector<int64_t>{2, 2}, // kernel_shape
|
||||
vector<int64_t>{0, 0, 0, 0}, // pads
|
||||
vector<int64_t>{1, 1}, // strides
|
||||
{}, // excluded EPs
|
||||
"Relu" // activation
|
||||
};
|
||||
|
||||
vector<MLFloat16> X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f),
|
||||
MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f),
|
||||
MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f)};
|
||||
vector<int64_t> X_shape = {1, 1, 3, 3};
|
||||
vector<MLFloat16> W = {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f),
|
||||
MLFloat16(-1.0f), MLFloat16(-1.0f), MLFloat16(-1.0f), MLFloat16(-1.0f)};
|
||||
vector<int64_t> W_shape = {2, 1, 2, 2};
|
||||
vector<int64_t> Y_shape = {1, 2, 2, 2};
|
||||
auto expected_vals = {MLFloat16(12.0f), MLFloat16(16.0f), MLFloat16(24.0f), MLFloat16(28.0f),
|
||||
MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f)};
|
||||
TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
TEST(ConvFp16Test, Conv2D_Bias_Relu) {
|
||||
ConvOpAndTestAttributes attrs = {
|
||||
"", // auto_pad
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
1, // group
|
||||
vector<int64_t>{2, 2}, // kernel_shape
|
||||
vector<int64_t>{0, 0, 0, 0}, // pads
|
||||
vector<int64_t>{1, 1}, // strides
|
||||
{}, // excluded EPs
|
||||
"Relu" // activation
|
||||
};
|
||||
|
||||
vector<MLFloat16> X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f),
|
||||
MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f),
|
||||
MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f)};
|
||||
vector<int64_t> X_shape = {1, 1, 3, 3};
|
||||
vector<MLFloat16> W = {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f),
|
||||
MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f)};
|
||||
vector<int64_t> W_shape = {2, 1, 2, 2};
|
||||
vector<int64_t> Y_shape = {1, 2, 2, 2};
|
||||
vector<MLFloat16> B = {MLFloat16(1.0f), MLFloat16(-1.0f)};
|
||||
vector<int64_t> B_shape = {2};
|
||||
auto expected_vals = {MLFloat16(13.0f), MLFloat16(17.0f), MLFloat16(25.0f), MLFloat16(29.0f),
|
||||
MLFloat16(11.0f), MLFloat16(15.0f), MLFloat16(23.0f), MLFloat16(27.0f)};
|
||||
TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
#endif // CONTRIB_OPS
|
||||
|
||||
|
||||
TEST(ConvFp16Test, Conv3D_1) {
|
||||
ConvOpAndTestAttributes attrs = {
|
||||
"", // auto_pad
|
||||
|
|
@ -978,31 +900,93 @@ TEST(ConvFp16Test, Pointwise_Relu) {
|
|||
};
|
||||
|
||||
vector<MLFloat16> X = {
|
||||
MLFloat16(-9.f), MLFloat16(1.f), MLFloat16(2.f),
|
||||
MLFloat16(-5.f), MLFloat16(3.f), MLFloat16(-2.f),
|
||||
MLFloat16(5.f), MLFloat16(-3.f), MLFloat16(1.f),
|
||||
MLFloat16(1.f), MLFloat16(8.f), MLFloat16(-4.f),
|
||||
MLFloat16(-1.f), MLFloat16(6.f), MLFloat16(7.f),
|
||||
MLFloat16(-1.f), MLFloat16(4.f), MLFloat16(-5.f),
|
||||
MLFloat16(-9.f), MLFloat16(1.f), MLFloat16(2.f),
|
||||
MLFloat16(-5.f), MLFloat16(3.f), MLFloat16(-2.f),
|
||||
MLFloat16(5.f), MLFloat16(-3.f), MLFloat16(1.f)};
|
||||
MLFloat16(-9.f), MLFloat16(1.f), MLFloat16(-9.f),
|
||||
MLFloat16(1.f), MLFloat16(8.f), MLFloat16(1.f),
|
||||
MLFloat16(2.f), MLFloat16(-4.f), MLFloat16(2.f),
|
||||
MLFloat16(-5.f), MLFloat16(-1.f), MLFloat16(-5.f),
|
||||
MLFloat16(3.f), MLFloat16(6.f), MLFloat16(3.f),
|
||||
MLFloat16(-2.f), MLFloat16(7.f), MLFloat16(-2.f),
|
||||
MLFloat16(5.f), MLFloat16(-1.f), MLFloat16(5.f),
|
||||
MLFloat16(-3.f), MLFloat16(4.f), MLFloat16(-3.f),
|
||||
MLFloat16(1.f), MLFloat16(-5.f), MLFloat16(1.f)};
|
||||
vector<int64_t> X_shape = {1, 3, 3, 3};
|
||||
vector<MLFloat16> W = {MLFloat16(2.f), MLFloat16(-3.f), MLFloat16(0.5f),
|
||||
MLFloat16(0.25f), MLFloat16(-2.f), MLFloat16(-0.75f)};
|
||||
vector<int64_t> W_shape = {2, 3, 1, 1};
|
||||
vector<int64_t> Y_shape = {1, 2, 3, 3};
|
||||
vector<int64_t> Y_shape = {1, 3, 3, 2};
|
||||
auto expected_vals = {
|
||||
MLFloat16(0.f), MLFloat16(0.f), MLFloat16(17.f),
|
||||
MLFloat16(0.f), MLFloat16(0.f), MLFloat16(0.f),
|
||||
MLFloat16(15.5f), MLFloat16(0.f), MLFloat16(17.5f),
|
||||
MLFloat16(2.5f), MLFloat16(0.f), MLFloat16(7.f),
|
||||
MLFloat16(4.5f), MLFloat16(0.f), MLFloat16(0.f),
|
||||
MLFloat16(0.f), MLFloat16(0.f), MLFloat16(9.5f)};
|
||||
MLFloat16(0.f), MLFloat16(2.5f),
|
||||
MLFloat16(0.f), MLFloat16(0.f),
|
||||
MLFloat16(17.f), MLFloat16(7.f),
|
||||
MLFloat16(0.f), MLFloat16(4.5f),
|
||||
MLFloat16(0.f), MLFloat16(0.f),
|
||||
MLFloat16(0.f), MLFloat16(0.f),
|
||||
MLFloat16(15.5f), MLFloat16(0.f),
|
||||
MLFloat16(0.f), MLFloat16(0.f),
|
||||
MLFloat16(17.5f), MLFloat16(9.5f)};
|
||||
|
||||
TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
TEST(ConvFp16Test, Conv2D_HardSigmoid) {
|
||||
ConvOpAndTestAttributes attrs = {
|
||||
"", // auto_pad
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
1, // group
|
||||
vector<int64_t>{2, 2}, // kernel_shape
|
||||
vector<int64_t>{0, 0, 0, 0}, // pads
|
||||
vector<int64_t>{1, 1}, // strides
|
||||
{}, // excluded EPs
|
||||
"HardSigmoid", // activation
|
||||
vector<float>{0.2f, 0.5f} // activation_parameters
|
||||
};
|
||||
|
||||
vector<MLFloat16> X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f),
|
||||
MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f),
|
||||
MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f)};
|
||||
vector<int64_t> X_shape = {1, 3, 3, 1};
|
||||
vector<MLFloat16> W = {MLFloat16(0.125f), MLFloat16(0.125f), MLFloat16(0.125f), MLFloat16(0.125f),
|
||||
MLFloat16(-0.125f), MLFloat16(-0.125f), MLFloat16(-0.125f), MLFloat16(-0.125f)};
|
||||
vector<int64_t> W_shape = {2, 1, 2, 2};
|
||||
vector<int64_t> Y_shape = {1, 2, 2, 2};
|
||||
auto expected_vals = {
|
||||
MLFloat16(0.8f), MLFloat16(0.2f),
|
||||
MLFloat16(0.9f), MLFloat16(0.1f),
|
||||
MLFloat16(1.0f), MLFloat16(0.0f),
|
||||
MLFloat16(1.0f), MLFloat16(0.0f)};
|
||||
TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
|
||||
TEST(ConvFp16Test, Conv2D_Bias_Z_Relu) {
|
||||
ConvOpAndTestAttributes attrs = {
|
||||
"", // auto_pad
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
1, // group
|
||||
vector<int64_t>{2, 2}, // kernel_shape
|
||||
vector<int64_t>{0, 0, 0, 0}, // pads
|
||||
vector<int64_t>{1, 1}, // strides
|
||||
{}, // excluded EPs
|
||||
"Relu" // activation
|
||||
};
|
||||
|
||||
vector<MLFloat16> X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f),
|
||||
MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f),
|
||||
MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f)};
|
||||
vector<int64_t> X_shape = {1, 3, 3, 1};
|
||||
vector<MLFloat16> W = {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f),
|
||||
MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f)};
|
||||
vector<int64_t> W_shape = {2, 1, 2, 2};
|
||||
vector<int64_t> Y_shape = {1, 2, 2, 2};
|
||||
vector<MLFloat16> B = {MLFloat16(1.0f), MLFloat16(-1.0f)};
|
||||
vector<int64_t> B_shape = {2};
|
||||
vector<MLFloat16> Z = {MLFloat16(-1.0f), MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f),
|
||||
MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(1.0f)};
|
||||
vector<int64_t> Z_shape = {1, 2, 2, 2};
|
||||
auto expected_vals = {MLFloat16(12.0f), MLFloat16(11.0f), MLFloat16(17.0f), MLFloat16(15.0f), MLFloat16(25.0f), MLFloat16(23.0f), MLFloat16(29.0f), MLFloat16(28.0f)};
|
||||
TestConvFp16Op(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
#endif // CONTRIB_OPS
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue