diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 73d21973ae..e79e1ed985 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -21,6 +21,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/sgemm.cpp ${MLAS_SRC_DIR}/halfgemm.cpp ${MLAS_SRC_DIR}/qgemm.cpp + ${MLAS_SRC_DIR}/dwconv.cpp ${MLAS_SRC_DIR}/qdwconv.cpp ${MLAS_SRC_DIR}/convolve.cpp ${MLAS_SRC_DIR}/convsym.cpp @@ -326,6 +327,7 @@ else() ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") if(ONNXRUNTIME_MLAS_MULTI_ARCH) onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs}) diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index a04ef0d71b..de4e699b52 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -16,6 +16,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv); +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv); +#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling); @@ -198,6 +201,9 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + BuildKernelCreateInfo, +#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index a8dcf35499..f3853a211f 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1603,3 +1603,41 @@ MlasHalfGemmConvertPackB( size_t ldb, void* PackedB ); + +/** + * @brief Indirect Depthwise convolution for fp16 + * @param Input Supplies the indirect buffer for NHWC input + * @param Filter Supplies the address for filter tensor + * @param Output Supplies the address for the result tensor + * @param Channels # of input channels + * @param OutputCount # of output pixels + * @param KernelSize # kernel size + * @return +*/ +void +MLASCALL +MlasConvDepthwise( + const MLAS_FP16* const* Input, + const MLAS_FP16* Filter, + MLAS_FP16* Output, + size_t Channels, + size_t OutputCount, + size_t KernelSize, + MLAS_HALF_GEMM_POSTPROCESSOR* PostProc + ); + + +inline +void +MlasTranspose( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t M, + size_t N + ) +{ + MlasTranspose( + reinterpret_cast(Input), + reinterpret_cast(Output), + M, N); +} diff --git a/onnxruntime/core/mlas/lib/dwconv.cpp b/onnxruntime/core/mlas/lib/dwconv.cpp new file mode 100644 index 0000000000..15511d2d8c --- /dev/null +++ b/onnxruntime/core/mlas/lib/dwconv.cpp @@ -0,0 +1,154 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + dwconv.cpp + +Abstract: + + This module implements the half precision floating point depthwise convolution routines. + +--*/ + + +#include "fp16_common.h" + +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + +MLAS_FORCEINLINE +void +MlasConvDepthwiseKernel( + const _mlas_fp16_* const* Input, + const _mlas_fp16_* Filter, + _mlas_fp16_* Output, + size_t Channels, + size_t OutputCount, + size_t KernelSize, + MLAS_HALF_GEMM_POSTPROCESSOR* PostProc + ) +{ + while (OutputCount > 0) { + size_t ChannelOffset = 0; + size_t c = Channels; + + while (c >= 8) { + MLAS_FLOAT16X8 Accumulator = MlasZeroFloat16x8(); + size_t ChannelKernelOffset = ChannelOffset; + + for (size_t k = 0; k < KernelSize; k++) { + MLAS_FLOAT16X8 InputVector = MlasLoadFloat16x8(&Input[k][ChannelOffset]); + MLAS_FLOAT16X8 FilterVector = MlasLoadFloat16x8(&Filter[ChannelKernelOffset]); + + Accumulator = MlasMultiplyAddFloat16x8(InputVector, FilterVector, Accumulator); + ChannelKernelOffset += Channels; + } + MlasStoreFloat16x8(Output, Accumulator); + Output += 8; + + ChannelOffset += 8; + c -= 8; + } + + if (c >= 4) { + MLAS_FLOAT16X4 Accumulator = MlasZeroFloat16x4(); + size_t ChannelKernelOffset = ChannelOffset; + + for (size_t k = 0; k < KernelSize; k++) { + MLAS_FLOAT16X4 InputVector = MlasLoadFloat16x4(&Input[k][ChannelOffset]); + MLAS_FLOAT16X4 FilterVector = MlasLoadFloat16x4(&Filter[ChannelKernelOffset]); + + Accumulator = MlasMultiplyAddFloat16x4(InputVector, FilterVector, Accumulator); + ChannelKernelOffset += Channels; + } + MlasStoreFloat16x4(Output, Accumulator); + Output += 4; + + ChannelOffset += 4; + c -= 4; + } + + if (c > 0) { + MLAS_FLOAT16X4 Accumulator = MlasZeroFloat16x4(); + size_t ChannelKernelOffset = ChannelOffset; + + for (size_t k = 0; k < KernelSize; k++) { + MLAS_FLOAT16X4 InputValue = MlasLoadFloat16x4(&Input[k][ChannelOffset]); + MLAS_FLOAT16X4 FilterValue = MlasLoadFloat16x4(&Filter[ChannelKernelOffset]); + + Accumulator = MlasMultiplyAddFloat16x4(InputValue, FilterValue, Accumulator); + ChannelKernelOffset += Channels; + } + MlasStorePartialFloat16x4(Output, Accumulator, c); + Output += c; + } + if (PostProc) { + PostProc->Process(reinterpret_cast(Output - Channels), 0, 0, 1, Channels, + Channels); + } + Input += KernelSize; + OutputCount -= 1; + } +} + +#else + +MLAS_FORCEINLINE +void +MlasConvDepthwiseKernel( + const _mlas_fp16_* const* Input, + const _mlas_fp16_* Filter, + _mlas_fp16_* Output, + size_t Channels, + size_t OutputCount, + size_t KernelSize, + MLAS_HALF_GEMM_POSTPROCESSOR* PostProc + ) +{ + while (OutputCount > 0) { + for (size_t ChannelOffset = 0; ChannelOffset < Channels; ChannelOffset++) { + float Accumulator = 0.0f; + size_t ChannelKernelOffset = ChannelOffset; + + for (size_t k = 0; k < KernelSize; k++) { + Accumulator += MLAS_Half2Float(Input[k][ChannelOffset]) * MLAS_Half2Float(Filter[ChannelKernelOffset]); + ChannelKernelOffset += Channels; + } + *Output++ = MLAS_Float2Half(Accumulator); + } + if (PostProc) { + PostProc->Process(reinterpret_cast(Output - Channels), 0, 0, 1, Channels, + Channels); + } + Input += KernelSize; + OutputCount -= 1; + } +} + +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED + + +void +MLASCALL +MlasConvDepthwise( + const MLAS_FP16* const* Input, + const MLAS_FP16* Filter, + MLAS_FP16* Output, + size_t Channels, + size_t OutputCount, + size_t KernelSize, + MLAS_HALF_GEMM_POSTPROCESSOR* PostProc + ) +{ + MlasConvDepthwiseKernel( + reinterpret_cast(Input), + reinterpret_cast(Filter), + reinterpret_cast<_mlas_fp16_*>(Output), + Channels, + OutputCount, + KernelSize, + PostProc); +} diff --git a/onnxruntime/core/mlas/lib/fp16_common.h b/onnxruntime/core/mlas/lib/fp16_common.h index e952a5667c..8b00190757 100644 --- a/onnxruntime/core/mlas/lib/fp16_common.h +++ b/onnxruntime/core/mlas/lib/fp16_common.h @@ -52,6 +52,10 @@ MLAS_FORCEINLINE MLAS_FLOAT16X8 MlasZeroFloat16x8(void) { return vreinterpretq_f16_f32(vdupq_n_f32(0.0f)); } +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasZeroFloat16x4(void) { return vreinterpret_f16_f32(vdup_n_f32(0.0f)); } + MLAS_FORCEINLINE MLAS_FLOAT16X8 MlasLoadFloat16x8(const _mlas_fp16_* Buffer) { return vreinterpretq_f16_u16(vld1q_u16(Buffer)); } diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index c866226cf7..6322f6e460 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -441,6 +441,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Av class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, MaxUnpool); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 17, LpPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Conv); +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, MLFloat16, Conv); +#endif class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ConvTranspose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, If); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SequenceLength); @@ -1474,6 +1477,9 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + BuildKernelCreateInfo, +#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc new file mode 100644 index 0000000000..78820bdf63 --- /dev/null +++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc @@ -0,0 +1,602 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This file contains implementation of a fp16 convolution operator. +// + +#include "core/mlas/inc/mlas.h" + +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + +#include "core/common/safeint.h" +#include "core/framework/float16.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/nn/conv_attributes.h" + +#include "contrib_ops/cpu/fused_activation.h" + +namespace onnxruntime { + +using ConvPadVector = ConvAttributes::ConvPadVector; + +/** + * @brief Convolution Operator for FP16 tensors + * + * With two optional fused operations: + * + * 1. Add + * It takes an extra (optional) input Sum, a tensor same shape as the output. + * Sum is added to the output tensor. + * + * 2. Activation + * It takes an operator attribute 'activation', which supplies the activation info. + * + * Add is performed BEFORE activation. + * + * The implementation runs faster with NHWC. By default, it converts NCHW to NHWC + * before processing, and convert the result back. It can take NHWC tensors directly. + * Use operator attribute 'channels_last' to specify that the data layout is NHWC. + * +*/ +class FusedConvFp16 final : public OpKernel { + public: + FusedConvFp16(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) { + ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK()); + channels_last_ = (info.GetAttrOrDefault("channels_last", static_cast(0)) != 0); + } + + Status Compute(OpKernelContext* context) const override; + + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; + + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + int input_idx, + /*out*/ bool& used_shared_buffers) override; + + private: + + /** + * @brief Reorder filter data to facilitate compute. + * + * Based on Conv operator spec, filters are organized as (M x C/group x kH x kW), + * where C is the number of input channels, and kH and kW are the height and width + * of the kernel, and M is the number of feature maps. We need to change it into + * (kH x kW x C/group) x M, forming a matrix of M columns, where each kernel is a + * single column in channel last format. + * + * @param input + * @param output + * @param output_channels number of feature maps + * @param input_channels + * @param kernel_size kH x kW + */ + static void ReorderFilter(const MLFloat16* input, + MLFloat16* output, + size_t output_channels, + size_t input_channels, + size_t kernel_size) { + for (size_t k = 0; k < kernel_size; k++) { + for (size_t ic = 0; ic < input_channels; ic++) { + for (size_t oc = 0; oc < output_channels; oc++) { + size_t index = (oc * input_channels * kernel_size) + (ic * kernel_size) + k; + *output++ = input[index]; + } + } + } + } + + MLAS_ACTIVATION activation_; + ConvAttributes conv_attrs_; + bool channels_last_{false}; + TensorShape W_shape_; + BufferUniquePtr packed_W_buffer_; + size_t packed_W_size_{0}; + bool is_W_packed_{false}; + BufferUniquePtr reordered_W_buffer_; +}; + + +Status FusedConvFp16::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + if (input_idx != 1) { + // Only pack filter tensor (aka weights) + return Status::OK(); + } + + const auto& shape = tensor.Shape().GetDims(); + size_t rank = shape.size(); + if (rank <= 2) { + return Status::OK(); + } + + const int64_t M = shape[0]; + const int64_t C = shape[1]; + + // Verify that the total number of output channels is a multiple of the group count. + if (M % conv_attrs_.group != 0) { + return Status::OK(); + } + + // Note: The tensor has already been allocated with this tensor shape, so all + // shape indices are guaranteed to fit inside size_t. + const size_t output_channels = static_cast(M); + const size_t group_input_channels = static_cast(C); + const size_t kernel_size = + static_cast(std::accumulate(shape.data() + 2, shape.data() + rank, 1LL, std::multiplies())); + + const auto* Wdata = static_cast(tensor.DataRaw()); + W_shape_ = shape; + + const size_t group_count = static_cast(conv_attrs_.group); + const size_t group_output_channels = output_channels / group_count; + const size_t kernel_dim = group_input_channels * kernel_size; + + bool share_prepacked_weights = (prepacked_weights != nullptr); + + // Don't pack the filter buffer if the MlasConvDepthwise path is used. + if (!(group_input_channels == 1 && group_output_channels == 1)) { + packed_W_size_ = MlasHalfGemmPackBSize(group_output_channels, kernel_dim, false); + if (packed_W_size_ != 0) { + size_t packed_W_data_size = SafeInt(group_count) * packed_W_size_; + auto* packed_W = static_cast(alloc->Alloc(packed_W_data_size)); + + // Initialize memory to 0 as there could be some padding associated with pre-packed + // buffer memory and we don not want it uninitialized and generate different hashes + // if and when we try to cache this pre-packed buffer for sharing between sessions. + memset((void*)packed_W, 0, packed_W_data_size); + + packed_W_buffer_ = BufferUniquePtr(packed_W, BufferDeleter(alloc)); + + // Allocate a temporary buffer to hold the reordered oihw->hwio filter for + // a single group. + // + // Note: The size of this buffer is less than or equal to the size of the original + // weight tensor, so the allocation size is guaranteed to fit inside size_t. + auto* group_reordered_W = static_cast( + alloc->Alloc(group_output_channels * kernel_dim * sizeof(MLFloat16))); + BufferUniquePtr group_reordered_W_buffer(group_reordered_W, BufferDeleter(alloc)); + + const size_t W_offset = group_output_channels * kernel_dim; + + for (int64_t group_id = 0; group_id < conv_attrs_.group; ++group_id) { + ReorderFilter(Wdata, group_reordered_W, group_output_channels, group_input_channels, kernel_size); + MlasHalfGemmPackB(group_output_channels, kernel_dim, group_reordered_W, group_output_channels, packed_W); + packed_W += packed_W_size_; + Wdata += W_offset; + } + + if (share_prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_W_buffer_)); + prepacked_weights->buffer_sizes_.push_back(packed_W_data_size); + } + + is_W_packed_ = true; + is_packed = true; + return Status::OK(); + } + } + + if (share_prepacked_weights) { + prepacked_weights->buffers_.push_back(nullptr); // packed_W_buffer_ is nullptr + prepacked_weights->buffer_sizes_.push_back(0); + } + + size_t reordered_w_data_size = SafeInt(sizeof(MLFloat16)) * output_channels * kernel_dim; + auto* reordered_W = static_cast(alloc->Alloc(reordered_w_data_size)); + + // Initialize memory to 0 as there could be some padding associated with pre-packed + // buffer memory and we don not want it uninitialized and generate different hashes + // if and when we try to cache this pre-packed buffer for sharing between sessions. + memset((void*)reordered_W, 0, reordered_w_data_size); + + reordered_W_buffer_ = BufferUniquePtr(reordered_W, BufferDeleter(alloc)); + + ReorderFilter(Wdata, reordered_W, output_channels, group_input_channels, kernel_size); + + if (share_prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(reordered_W_buffer_)); + prepacked_weights->buffer_sizes_.push_back(reordered_w_data_size); + } + + is_W_packed_ = true; + is_packed = true; + return Status::OK(); +} + +Status FusedConvFp16::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + int input_idx, + /*out*/ bool& used_shared_buffers) { + if (input_idx != 1) { + // only the filter tensor is packed + return Status::OK(); + } + + used_shared_buffers = true; + + if (prepacked_buffers.size() == 1) { // This means that only packed_W_ exists + packed_W_buffer_ = std::move(prepacked_buffers[0]); + } else if (prepacked_buffers.size() == 2) { // This means that only reordered_W_ exists + // Enforce that the first "placeholder" buffer is nullptr + ORT_ENFORCE(prepacked_buffers[0].get() == nullptr); + reordered_W_buffer_ = std::move(prepacked_buffers[1]); + } + + return Status::OK(); +} + + +Status FusedConvFp16::Compute(OpKernelContext* context) const { + size_t num_inputs = OpKernel::Node().InputDefs().size(); + const Tensor* X = context->Input(0); + const Tensor* W = is_W_packed_? nullptr : context->Input(1); + const auto& W_shape = W ? W->Shape() : W_shape_; + const Tensor* B = num_inputs >= 3 ? context->Input(2) : nullptr; + + // TODO!! + // This tensor should be added to the result before activation is applied + // We need to augment the post processor to accept an addition operation. + // const Tensor* Sum = num_inputs >= 4 ? context->Input(3) : nullptr; + + const int64_t N = X->Shape()[0]; + const int64_t M = W_shape[0]; + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W_shape, channels_last_)); + + TensorShapeVector kernel_shape; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W_shape, kernel_shape)); + const size_t kernel_rank = kernel_shape.size(); + + ConvPadVector pads(conv_attrs_.pads); + if (pads.empty()) { + pads.resize(kernel_rank * 2, 0); + } + TensorShapeVector dilations(conv_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(kernel_rank, 1); + } + TensorShapeVector strides(conv_attrs_.strides); + if (strides.empty()) { + strides.resize(kernel_rank, 1); + } + + const int64_t C = X->Shape()[channels_last_ ? 1 + kernel_rank : 1]; + const size_t spatial_dim_start = channels_last_ ? 1 : 2; + const size_t spatial_dim_end = spatial_dim_start + kernel_rank; + + TensorShapeVector Y_dims({N}); + if (!channels_last_) { + Y_dims.push_back(M); + } + TensorShape input_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end); + ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims)); + if (channels_last_) { + Y_dims.push_back(M); + } + Tensor* Y = context->Output(0, TensorShape(Y_dims)); + TensorShape output_shape = Y->Shape().Slice(spatial_dim_start, spatial_dim_end); + + // Bail out early if one of the dimensions is zero. + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + + const int64_t input_image_size = input_shape.Size(); + const int64_t output_image_size = output_shape.Size(); + const int64_t kernel_size = TensorShape(kernel_shape).Size(); + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); + + // Handle the case of a dynamic weight filter. + BufferUniquePtr reordered_W_buffer; + MLFloat16* reordered_W = nullptr; + if (!packed_W_buffer_) { + if (reordered_W_buffer_) { + // Weight was constant and reordered. + reordered_W = static_cast(reordered_W_buffer_.get()); + } else { + // Weight tensor was not constant or prepacking is disabled. + reordered_W = static_cast(alloc->Alloc(SafeInt(sizeof(MLFloat16)) * W_shape.Size())); + reordered_W_buffer = BufferUniquePtr(reordered_W, BufferDeleter(alloc)); + ReorderFilter( + static_cast(W->DataRaw()), + reordered_W, + static_cast(M), + static_cast(W_shape[1]), + static_cast(kernel_size)); + } + } + + int64_t group_count = conv_attrs_.group; + int64_t group_input_channels = W_shape[1]; + int64_t group_output_channels = M / group_count; + + // Test for depthwise convolution. + const bool is_depthwise_conv = (group_input_channels == 1 && group_output_channels == 1); + if (is_depthwise_conv) { + // Update the input and output channels to the number of groups in order to + // reuse as much of the below standard convolution path. + group_input_channels = group_count; + group_output_channels = group_count; + group_count = 1; + } + + const int64_t X_offset = C * input_image_size; + const int64_t Y_offset = M * output_image_size; + const int64_t kernel_dim = group_input_channels * kernel_size; + const int64_t col_buffer_size = kernel_dim * output_image_size; + + const auto* Xdata = X->Data(); + const auto* Bdata = B != nullptr ? B->Data() : nullptr; + auto* Ydata = Y->MutableData(); + + BufferUniquePtr transpose_input_buffer; + BufferUniquePtr transpose_output_buffer; + + // Allocate temporary buffers for transposing to channels last format. + if (!channels_last_) { + auto* transpose_input = alloc->Alloc(SafeInt(sizeof(MLFloat16)) * X_offset + MLAS_SYMM_QGEMM_BUF_OVERRUN); + transpose_input_buffer = BufferUniquePtr(transpose_input, BufferDeleter(alloc)); + auto* transpose_output = alloc->Alloc(SafeInt(sizeof(MLFloat16)) * Y_offset); + transpose_output_buffer = BufferUniquePtr(transpose_output, BufferDeleter(alloc)); + } + + BufferUniquePtr col_buffer; + BufferUniquePtr indirection_buffer; + size_t ind_buf_length = 0; + std::vector padding_data; + + bool use_indirection_buffer = false; + if (is_depthwise_conv) { + use_indirection_buffer = true; + } else if (kernel_size != 1 || !conv_attrs_.HasStridesOneAndNoPadding()) { +// if (is_symmetric_conv_) { +// use_indirection_buffer = true; +// } else { + // Pointwise convolutions can use the original input tensor in place, + // otherwise a temporary buffer is required for the im2col transform. + int64_t group_col_buffer_size = (kernel_rank > 2) ? group_count * col_buffer_size : col_buffer_size; + group_col_buffer_size += MLAS_SYMM_QGEMM_BUF_OVERRUN; + auto* col_data = alloc->Alloc(SafeInt(sizeof(MLFloat16)) * group_col_buffer_size); + col_buffer = BufferUniquePtr(col_data, BufferDeleter(alloc)); + memset(col_data, 0, SafeInt(sizeof(MLFloat16)) * group_col_buffer_size); +// } + } + +// bool parallel_batch = is_symmetric_conv_ && channels_last_; + + if (use_indirection_buffer) { + // Allocate indirection buffer pointers and prepare a padding vector for + // the im2col transform. + ind_buf_length = SafeInt(sizeof(const MLFloat16*)) * kernel_size * output_image_size; +// if (parallel_batch) +// ind_buf_length *= SafeInt(N); // ind buffer per each image in the batch + auto* indirection_data = alloc->Alloc(ind_buf_length); + indirection_buffer = BufferUniquePtr(indirection_data, BufferDeleter(alloc)); + padding_data.resize(static_cast(C), MLFloat16()); + } + + + concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); + + /************************************* + * Thread partition idea: we are essentially partition a GEMM A[M,K] x B[K,N]. + * Here B contains the conv filters, which are usually not big, so we assume + * it can be in cache entirely. Then we simply partition A horizontally into + * thin slices along M dimension. This would ensure that the slice of A fits + * into the cache and reduce the chance of kernel waiting for memory. + * + * The thickness of A slice should be multiple of kernel stride M. Since + * we have to choose from many different kernels, the logic of finding + * the stride M is hacky. + */ + + // The following convoluted branches must match the kernel selection logic + // in conv_worker. + + const int32_t stride_m = 6; + const int64_t task_count = (output_image_size + stride_m - 1) / stride_m; + + for (int64_t image_id = 0; image_id < N; ++image_id) { + const auto* input_data = Xdata; + auto* output_data = Ydata; + + if (!channels_last_) { + // Transpose the input from channels first (CHW) to channels last (HWC). + MlasTranspose( + Xdata, + static_cast(transpose_input_buffer.get()), + static_cast(C), + static_cast(input_image_size)); + input_data = static_cast(transpose_input_buffer.get()); + output_data = static_cast(transpose_output_buffer.get()); + } + + // Threaded implementation of ND convolution is not yet supported, so + // prepare all im2col transformations here. + if (col_buffer && kernel_rank > 2) { + for (int64_t group_id = 0; group_id < group_count; ++group_id) { + math::Im2col()( + input_data + group_id * group_input_channels, + group_input_channels, + C, + input_shape.GetDims().data(), + output_shape.GetDims().data(), + kernel_shape.data(), + strides.data(), + dilations.data(), + pads.data(), + static_cast(kernel_rank), + static_cast(col_buffer.get()) + group_id * col_buffer_size, + MLFloat16()); + } + } + + auto conv_worker = [&](ptrdiff_t batch) { + int64_t output_start = (int64_t)batch * (int64_t)stride_m; + int64_t output_count = std::min((int64_t)stride_m, output_image_size - output_start); + + MLFloat16 const** worker_indirection_buffer = nullptr; + if (indirection_buffer) { + worker_indirection_buffer = static_cast(indirection_buffer.get()) + output_start * kernel_size; + math::Im2col()( + input_data, + C, + input_shape.GetDims().data(), + output_shape.GetDims().data(), + kernel_shape.data(), + strides.data(), + dilations.data(), + pads.data(), + static_cast(kernel_rank), + output_start, + output_count, + worker_indirection_buffer, + padding_data.data()); + } + + auto* worker_output = output_data + output_start * M; + + if (is_depthwise_conv) { + // TODO!! add Sum tensor to activation + MLAS_HALF_GEMM_ACTIVATION_PROCESSOR act(activation_); + MlasConvDepthwise( + worker_indirection_buffer, + reordered_W, + worker_output, + static_cast(M), + static_cast(output_count), + static_cast(kernel_size), + &act); + } else { + for (int64_t group_id = 0; group_id < group_count; ++group_id) { + // Prepare the im2col transformation or use the input buffer directly for + // pointwise convolutions. + const auto* group_input_data = input_data + group_id * group_input_channels; + const MLFloat16* AData; + size_t lda; + if (col_buffer) { + auto* worker_col_buffer = static_cast(col_buffer.get()) + output_start * kernel_dim; + if (kernel_rank == 2) { + math::Im2col()( + group_input_data, + group_input_channels, + C, + input_shape[0], + input_shape[1], + kernel_shape[0], + kernel_shape[1], + dilations[0], + dilations[1], + pads[0], + pads[1], + strides[0], + strides[1], + output_shape[1], + output_start, + output_count, + worker_col_buffer, + MLFloat16()); + } else if (kernel_rank == 1) { + math::Im2col()( + group_input_data, + group_input_channels, + C, + 1, + input_shape[0], + 1, + kernel_shape[0], + 1, + dilations[0], + 0, + pads[0], + 1, + strides[0], + output_shape[0], + output_start, + output_count, + worker_col_buffer, + MLFloat16()); + } else { + // Use the im2col buffer prepared outside the thread, indexed by group. + worker_col_buffer += group_id * col_buffer_size; + } + AData = reinterpret_cast(worker_col_buffer); + lda = static_cast(kernel_dim); + } else { + AData = reinterpret_cast(group_input_data + output_start * C); + lda = static_cast(C); + } + + // TODO!! add Sum tensor to activation + MLAS_HALF_GEMM_ACTIVATION_PROCESSOR act(activation_); + MLAS_HALF_GEMM_DATA_PARAMS gemm_params; + gemm_params.A = AData; + gemm_params.lda = lda; + if (packed_W_buffer_) { + gemm_params.B = static_cast(packed_W_buffer_.get()) + group_id * packed_W_size_, + gemm_params.ldb = 0; + } else { + gemm_params.B = reordered_W + group_id * group_output_channels, + gemm_params.ldb = static_cast(M); + } + gemm_params.C = worker_output + group_id * group_output_channels; + gemm_params.ldc = static_cast(M); + gemm_params.Bias = Bdata; + gemm_params.OutputProcessor = &act; // process fused activation and add + + MlasHalfGemmBatch( + static_cast(output_count), + static_cast(group_output_channels), + static_cast(kernel_dim), + 1, &gemm_params, thread_pool); + } + } + + }; + + concurrency::ThreadPool::TrySimpleParallelFor(thread_pool, onnxruntime::narrow(task_count), conv_worker); + + if (!channels_last_) { + // Transpose the output from channels last (NHWC) to channels first (NCHW). + MlasTranspose( + output_data, + Ydata, + static_cast(output_image_size), + static_cast(M)); + } + + Xdata += X_offset; + Ydata += Y_offset; + } + + return Status::OK(); +} + + +ONNX_CPU_OPERATOR_TYPED_KERNEL( + Conv, + 11, + MLFloat16, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + FusedConvFp16); + + +#ifndef DISABLE_CONTRIB_OPS +namespace contrib { + ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + FusedConv, + 1, + MLFloat16, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + FusedConvFp16); +} // namespace contrib +#endif + +} // namespace onnxruntime + +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index a505b8e018..47e31bb828 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -17,6 +17,7 @@ #include "core/util/math_cpuonly.h" #include "core/util/math.h" +#include "core/framework/float16.h" #include #include "core/common/narrow.h" @@ -654,6 +655,7 @@ void Im2col::operator()( template struct Im2col; template struct Im2col; +template struct Im2col; template <> void Col2im(const float* data_col, int64_t channels, int64_t height, diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc new file mode 100644 index 0000000000..46312292ad --- /dev/null +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -0,0 +1,1110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/mlas/inc/mlas.h" + +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/providers/run_options_config_keys.h" +#include "default_providers.h" + +using namespace std; +namespace onnxruntime { +namespace test { + +namespace { + +struct ConvOpAndTestAttributes { + string auto_pad; + vector dilations; + int64_t group; + vector kernel_shape; + vector pads; + vector strides; + std::unordered_set excluded_providers; + string activation = ""; + vector activation_parameters = {}; +}; + +void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, + const vector>& inputs, + const vector>& input_shapes, + const std::initializer_list& expected_output, + const vector& expected_output_shape, + bool weight_is_initializer = false, + OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, + const std::string& err_str = "", + int opset = 11) { + std::unique_ptr tester; + if (!attributes.activation.empty()) { + tester = std::make_unique("FusedConv", 1, onnxruntime::kMSDomain); + tester->AddAttribute("activation", attributes.activation); + + if (!attributes.activation_parameters.empty()) { + tester->AddAttribute("activation_params", attributes.activation_parameters); + } + } else { + tester = std::make_unique("Conv", opset); + } + + tester->AddAttribute("group", attributes.group); + tester->AddAttribute("kernel_shape", attributes.kernel_shape); + + if (!attributes.dilations.empty()) { + tester->AddAttribute("dilations", attributes.dilations); + } + + // Only one of pads / auto_pad can be present + if (!attributes.pads.empty()) { + tester->AddAttribute("pads", attributes.pads); + } else { + tester->AddAttribute("auto_pad", attributes.auto_pad); + } + + if (!attributes.strides.empty()) { + tester->AddAttribute("strides", attributes.strides); + } + + + ORT_ENFORCE(inputs.size() <= 3, "Our name array is only setup to handle 3 inputs"); + const char* szNames[] = {"X", "W", "B"}; + tester->AddInput(szNames[0], input_shapes[0], inputs[0]); + tester->AddInput(szNames[1], input_shapes[1], inputs[1], weight_is_initializer); + if (inputs.size() == 3) + tester->AddInput(szNames[2], input_shapes[2], inputs[2]); + + tester->AddOutput("Y", expected_output_shape, expected_output, /*no sort*/ false, 0.002f, 0.0f); + + std::unordered_set excluded_providers(attributes.excluded_providers); + // Disable TensorRT because weight as input is not supported + excluded_providers.insert(kTensorrtExecutionProvider); + // QNN have issue with dynamic weight, auto pad with SAME_UPPER, SAME_LOWER + if (!weight_is_initializer || attributes.auto_pad == "SAME_UPPER" || attributes.auto_pad == "SAME_LOWER") { + excluded_providers.insert(kQnnExecutionProvider); + } + + tester->Run(expect_result, err_str, excluded_providers); +} + +} // namespace + + +TEST(ConvFp16Test, Conv1D_Invalid_Input_Shape) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1}, // dilations + 1, // group + vector{2}, // kernel_shape + vector{0, 0}, // pads + vector{1}, // strides + {} // excluded EPs + }; + + vector X = vector(1, MLFloat16(1.0f)); + vector X_shape = {1, 1, 1}; + vector dummy_shape = {1, 1, 2}; + auto dummy_vals = {MLFloat16(0.0f), MLFloat16(0.0f)}; + TestConvFp16Op(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, + OpTester::ExpectResult::kExpectFailure, + "Node:node1 Output:Y [ShapeInferenceError] Can't merge shape info. " + "Both source and target dimension have values but they differ. Source=0 Target=2 Dimension=2", + -1); // use latest opset for shape inferencing errors +} + +TEST(ConvFp16Test, Conv2D_Invalid_Input_Shape) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{3, 3}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = vector(1 * 3 * 1 * 111, MLFloat16(1.0f)); + vector X_shape = {1, 3, 1, 111}; + vector dummy_shape = {2, 2, 1, 2}; + auto dummy_vals = {MLFloat16(-0.0f), MLFloat16(0.0f), MLFloat16(-0.0f), MLFloat16(-0.0f), + MLFloat16(-0.0f), MLFloat16(0.0f), MLFloat16(-0.0f), MLFloat16(-0.0f)}; + TestConvFp16Op(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, + OpTester::ExpectResult::kExpectFailure, + "Node:node1 Output:Y [ShapeInferenceError] Can't merge shape info. " + "Both source and target dimension have values but they differ. Source=1 Target=2 Dimension=0", + -1); // use latest opset for shape inferencing errors +} + + +TEST(ConvFp16Test, Conv1D_1) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1}, // dilations + 1, // group + vector{1}, // kernel_shape + vector{0, 0}, // pads + vector{1}, // strides + {} // excluded EPs + }; + + vector X = {MLFloat16(-0.215576172f), MLFloat16(0.469238281f), MLFloat16(0.442626953f), + MLFloat16(-0.451660156f), MLFloat16(-0.0521545410f), MLFloat16(0.290771484f), MLFloat16(0.250976562f)}; + vector X_shape = {1, 1, 7}; + vector W = {MLFloat16(0.244750977f)}; + vector W_shape = {1, 1, 1}; + vector Y_shape = {1, 1, 7}; + auto expected_vals = {MLFloat16(-0.0527624786f), MLFloat16(0.114846528f), MLFloat16(0.108333379f), + MLFloat16(-0.110544264f), MLFloat16(-0.0127648748f), MLFloat16(0.0711666048f), MLFloat16(0.0614267588f)}; + + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + + +TEST(ConvFp16Test, Conv1D_1_DefaultStridesAndDilations) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{}, // dilations + 1, // group + vector{1}, // kernel_shape + vector{0, 0}, // pads + vector{}, // strides + {} // excluded EPs + }; + + vector X = {MLFloat16(-0.215576172f), MLFloat16(0.469238281f), MLFloat16(0.442626953f), + MLFloat16(-0.451660156f), MLFloat16(-0.0521545410f), MLFloat16(0.290771484f), + MLFloat16(0.250976562f)}; + vector X_shape = {1, 1, 7}; + vector W = {MLFloat16(0.244750977f)}; + vector W_shape = {1, 1, 1}; + vector Y_shape = {1, 1, 7}; + auto expected_vals = {MLFloat16(-0.0527624786f), MLFloat16(0.114846528f), MLFloat16(0.108333379f), + MLFloat16(-0.110544264f), MLFloat16(-0.0127648748f), MLFloat16(0.0711666048f), + MLFloat16(0.0614267588f)}; + + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + + // CoreML EP requires weight to be an initializer + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + + +TEST(ConvFp16Test, Conv1D_2) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{2}, // dilations + 1, // group + vector{2}, // kernel_shape + vector{2, 2}, // pads + vector{2}, // strides + {} // excluded EPs + }; + + vector X = {MLFloat16(0.112f), MLFloat16(-0.0038f), MLFloat16(0.382f), MLFloat16(0.332f), + MLFloat16(0.0279f), MLFloat16(-0.0836f), MLFloat16(-0.41f), MLFloat16(-0.095f), + MLFloat16(-0.113f), MLFloat16(-0.0254f), MLFloat16(0.369f), MLFloat16(0.352f), + MLFloat16(-0.349f), MLFloat16(-0.22f), MLFloat16(0.231f), MLFloat16(-0.457f), + MLFloat16(-0.176f), MLFloat16(-0.0603f), MLFloat16(-0.399f), MLFloat16(-0.193f), + MLFloat16(-0.104f), MLFloat16(-0.145f), MLFloat16(-0.319f), MLFloat16(-0.153f)}; + vector X_shape = {3, 1, 8}; + vector W = {MLFloat16(0.132f), MLFloat16(0.0975f), MLFloat16(0.346f), MLFloat16(0.474f)}; + vector W_shape = {2, 1, 2}; + vector Y_shape = {3, 2, 5}; + auto expected_vals = { + MLFloat16(0.0109176636f), MLFloat16(0.0520324707f), MLFloat16(0.0531311035f), MLFloat16(-0.0362854004f), + MLFloat16(-0.0540771484f), MLFloat16(0.0531005859f), MLFloat16(0.219848633f), MLFloat16(0.145385742f), + MLFloat16(-0.184692383f), MLFloat16(-0.141845703f), MLFloat16(-0.0110092163f), MLFloat16(0.0210418701f), + MLFloat16(0.0146484375f), MLFloat16(-0.0235595703f), MLFloat16(0.0304718018f), MLFloat16(-0.0535583496f), + MLFloat16(0.135864258f), MLFloat16(-0.0379028320f), MLFloat16(-0.0112762451f), MLFloat16(0.0798950195f), + MLFloat16(-0.0171508789f), MLFloat16(-0.0621032715f), MLFloat16(-0.0628051758f), MLFloat16(-0.0448303223f), + MLFloat16(-0.0421142578f), MLFloat16(-0.0834350586f), MLFloat16(-0.250000000f), MLFloat16(-0.187377930f), + MLFloat16(-0.187255859f), MLFloat16(-0.110412598f)}; + + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + + +// Conv1 +TEST(ConvFp16Test, Conv1D_Bias) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{2}, // dilations + 1, // group + vector{1}, // kernel_shape + vector{1, 1}, // pads + vector{3}, // strides + {} // excluded EPs + }; + + vector X = {MLFloat16(0.458251953f), MLFloat16(0.387695312f), MLFloat16(-0.0541381836f), + MLFloat16(-0.301513672f), MLFloat16(0.192993164f), MLFloat16(-0.475830078f), + MLFloat16(0.467041016f), MLFloat16(0.407958984f), MLFloat16(0.240112305f), + MLFloat16(0.416503906f), MLFloat16(-0.0383300781f), MLFloat16(0.229736328f), + MLFloat16(0.356445312f), MLFloat16(0.128173828f), MLFloat16(0.100952148f), + MLFloat16(0.256835938f), MLFloat16(0.416992188f), MLFloat16(0.341064453f), + MLFloat16(-0.429931641f), MLFloat16(0.354492188f), MLFloat16(0.403320312f), + MLFloat16(0.101745605f), MLFloat16(0.457031250f), MLFloat16(0.0857543945f), + MLFloat16(0.380859375f), MLFloat16(0.163818359f), MLFloat16(0.123229980f), + MLFloat16(-0.199340820f), MLFloat16(0.260253906f), MLFloat16(-0.184082031f), + MLFloat16(0.311035156f), MLFloat16(0.155517578f), MLFloat16(-0.146240234f), + MLFloat16(-0.177978516f), MLFloat16(-0.0139007568f), MLFloat16(-0.0926513672f)}; + vector X_shape = {2, 2, 9}; + vector W = {MLFloat16(-0.172119141f), MLFloat16(0.323730469f)}; + vector W_shape = {1, 2, 1}; + vector B = {MLFloat16(0.378906250f)}; + vector B_shape = {1}; + vector Y_shape = {2, 1, 4}; + auto expected_vals = {MLFloat16(0.378906250f), MLFloat16(0.462597132f), MLFloat16(0.493487000f), + MLFloat16(0.447991282f), MLFloat16(0.378906250f), MLFloat16(0.249894142f), + MLFloat16(0.316803873f), MLFloat16(0.327701926f)}; + + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + + +TEST(ConvFp16Test, Conv2D_1) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{3, 3}, // kernel_shape + vector{1, 1, 1, 2}, // pads + vector{3, 1}, // strides + {} // excluded EPs + }; + + vector X = {MLFloat16(-0.0910644531f), MLFloat16(-0.325195312f)}; + vector X_shape = {2, 1, 1, 1}; + vector W = {MLFloat16(0.431152344f), MLFloat16(-0.125610352f), MLFloat16(0.448974609f), + MLFloat16(-0.310058594f), MLFloat16(0.135253906f), MLFloat16(-0.0679321289f), + MLFloat16(0.226684570f), MLFloat16(-0.173950195f), MLFloat16(-0.312988281f), + MLFloat16(-0.315429688f), MLFloat16(0.065612793f), MLFloat16(0.265625f), + MLFloat16(0.413574219f), MLFloat16(0.312255859f), MLFloat16(-0.375976562f), + MLFloat16(-0.00571060181f), MLFloat16(0.349121094f), MLFloat16(0.450927734f)}; + vector W_shape = {2, 1, 3, 3}; + vector Y_shape = {2, 2, 1, 2}; + auto expected_vals = {MLFloat16(-0.012316823f), MLFloat16(0.0282353163f), + MLFloat16(-0.0284354091f), MLFloat16(-0.0376619101f), + MLFloat16(-0.0439839363f), MLFloat16(0.100829601f), + MLFloat16(-0.101544142f), MLFloat16(-0.134492397f)}; + + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + + // NNAPI/CoreML EP requires weight to be an initializer + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvFp16Test, Conv2D_2) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + MLFloat16(0.452392578f), MLFloat16(0.155029297f), MLFloat16(0.111999512f), + MLFloat16(-0.394287109f), MLFloat16(0.262695312f), MLFloat16(0.134155273f), + MLFloat16(-0.271728516f), MLFloat16(-0.430175781f), MLFloat16(-0.268310547f), + MLFloat16(0.389404297f), MLFloat16(-0.136352539f), MLFloat16(-0.00959014893f), + MLFloat16(-0.487792969f), MLFloat16(-0.252685547f), MLFloat16(-0.281250000f), + MLFloat16(0.404296875f), MLFloat16(0.0779418945f), MLFloat16(0.326904297f), + MLFloat16(0.131103516f), MLFloat16(-0.441650391f), MLFloat16(0.124450684f), + MLFloat16(0.367431641f), MLFloat16(0.169921875f), MLFloat16(0.200927734f), + MLFloat16(0.233398438f), MLFloat16(0.386230469f), MLFloat16(0.111145020f), + MLFloat16(0.387695312f), MLFloat16(0.208129883f), MLFloat16(-0.343017578f), + MLFloat16(-0.0292510986f), MLFloat16(-0.204833984f), MLFloat16(-0.192382812f), + MLFloat16(-0.111022949f), MLFloat16(-0.328369141f), MLFloat16(-0.0180053711f), + MLFloat16(0.361816406f), MLFloat16(-0.409423828f), MLFloat16(-0.182495117f), + MLFloat16(-0.334960938f), MLFloat16(-0.340820312f), MLFloat16(0.00649642944f), + MLFloat16(0.453857422f), MLFloat16(0.0800781250f), MLFloat16(-0.147827148f), + MLFloat16(0.0344543457f), MLFloat16(-0.333251953f), MLFloat16(0.0604858398f), + MLFloat16(0.426269531f)}; + vector X_shape = {1, 1, 7, 7}; + vector W = {MLFloat16(-0.440673828f)}; + vector W_shape = {1, 1, 1, 1}; + vector Y_shape = {1, 1, 7, 7}; + auto expected_vals = { + MLFloat16(-0.199340820f), MLFloat16(-0.0682983398f), MLFloat16(-0.0493469238f), + MLFloat16(0.173706055f), MLFloat16(-0.115783691f), MLFloat16(-0.0591125488f), + MLFloat16(0.119750977f), MLFloat16(0.189575195f), MLFloat16(0.118225098f), + MLFloat16(-0.171630859f), MLFloat16(0.0600891113f), MLFloat16(0.00422668457f), + MLFloat16(0.214965820f), MLFloat16(0.111328125f), MLFloat16(0.123962402f), + MLFloat16(-0.178222656f), MLFloat16(-0.0343322754f), MLFloat16(-0.144042969f), + MLFloat16(-0.0577697754f), MLFloat16(0.194580078f), MLFloat16(-0.0548400879f), + MLFloat16(-0.161865234f), MLFloat16(-0.0748901367f), MLFloat16(-0.0885620117f), + MLFloat16(-0.102844238f), MLFloat16(-0.170166016f), MLFloat16(-0.0489807129f), + MLFloat16(-0.170898438f), MLFloat16(-0.0917358398f), MLFloat16(0.151123047f), + MLFloat16(0.0128936768f), MLFloat16(0.0902709961f), MLFloat16(0.0847778320f), + MLFloat16(0.0489196777f), MLFloat16(0.144653320f), MLFloat16(0.00793457031f), + MLFloat16(-0.159423828f), MLFloat16(0.180419922f), MLFloat16(0.0804443359f), + MLFloat16(0.147583008f), MLFloat16(0.150146484f), MLFloat16(-0.00286293030f), + MLFloat16(-0.199951172f), MLFloat16(-0.0352783203f), MLFloat16(0.0651245117f), + MLFloat16(-0.0151824951f), MLFloat16(0.146850586f), MLFloat16(-0.0266571045f), + MLFloat16(-0.187866211f)}; + + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + + // NNAPI/CoreML EP requires weight to be an initializer + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + + +TEST(ConvFp16Test, Conv2D_Bias_1) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{2, 2}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f)}; + vector X_shape = {1, 1, 3, 3}; + vector W = {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f)}; + vector W_shape = {2, 1, 2, 2}; + vector Y_shape = {1, 2, 2, 2}; + vector B = {MLFloat16(1.0f), MLFloat16(-1.0f)}; + vector B_shape = {2}; + auto expected_vals = {MLFloat16(13.0f), MLFloat16(17.0f), MLFloat16(25.0f), MLFloat16(29.0f), MLFloat16(11.0f), MLFloat16(15.0f), MLFloat16(23.0f), MLFloat16(27.0f)}; + + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + +// Conv48 +TEST(ConvFp16Test, Conv2D_Bias_2) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{4, 4}, // kernel_shape + vector{1, 2, 3, 1}, // pads + vector{2, 3}, // strides + {} // excluded EPs + }; + + vector X = { + MLFloat16(-0.625f), MLFloat16(0.4375f), MLFloat16(0.0625f), + MLFloat16(-0.3125f), MLFloat16(-0.6875f), MLFloat16(0.375f), + MLFloat16(0.0625f), MLFloat16(-0.375f), MLFloat16(0.6875f), + MLFloat16(0.3125f), MLFloat16(-0.0625f), MLFloat16(-0.4375f), + MLFloat16(0.625f), MLFloat16(0.25f), MLFloat16(-0.125f), + MLFloat16(-0.5f), MLFloat16(0.5625f), MLFloat16(0.1875f), + MLFloat16(-0.1875f), MLFloat16(-0.5625f), MLFloat16(0.5f), + MLFloat16(0.125f), MLFloat16(-0.25f), MLFloat16(-0.625f), + MLFloat16(0.4375f), MLFloat16(0.0625f), MLFloat16(-0.3125f), + MLFloat16(-0.6875f), MLFloat16(0.375f), MLFloat16(0.25f), + MLFloat16(-0.375f), MLFloat16(0.6875f), MLFloat16(0.3125f), + MLFloat16(-0.0625f), MLFloat16(-0.4375f), MLFloat16(0.625f), + MLFloat16(0.25f), MLFloat16(-0.125f), MLFloat16(-0.5f), + MLFloat16(0.5625f), MLFloat16(0.1875f), MLFloat16(-0.1875f), + MLFloat16(-0.5625f), MLFloat16(0.5f), MLFloat16(0.125f), + MLFloat16(-0.25f), MLFloat16(-0.625f), MLFloat16(0.4375f), + MLFloat16(0.0625f), MLFloat16(-0.3125f), MLFloat16(-0.6875f), + MLFloat16(0.375f), MLFloat16(0.125f), MLFloat16(-0.375f), + MLFloat16(0.6875f), MLFloat16(0.3125f), MLFloat16(-0.0625f), + MLFloat16(-0.4375f), MLFloat16(0.625f), MLFloat16(0.25f), + MLFloat16(-0.125f), MLFloat16(-0.5f), MLFloat16(0.5625f), + MLFloat16(0.1875f), MLFloat16(-0.1875f), MLFloat16(-0.5625f), + MLFloat16(0.5f), MLFloat16(0.125f), MLFloat16(-0.25f), + MLFloat16(-0.625f), MLFloat16(0.4375f), MLFloat16(0.0625f)}; + vector X_shape = {1, 2, 6, 6}; + vector W = { + MLFloat16(-0.3125f), MLFloat16(-0.6875f), MLFloat16(0.375f), MLFloat16(0.025f), + MLFloat16(-0.375f), MLFloat16(0.6875f), MLFloat16(0.3125f), MLFloat16(-0.0625f), + MLFloat16(-0.4375f), MLFloat16(0.625f), MLFloat16(0.25f), MLFloat16(-0.125f), + MLFloat16(-0.5f), MLFloat16(0.5625f), MLFloat16(0.1875f), MLFloat16(-0.1875f), + MLFloat16(-0.5625f), MLFloat16(0.5f), MLFloat16(0.125f), MLFloat16(-0.25f), + MLFloat16(-0.625f), MLFloat16(0.4375f), MLFloat16(0.0625f), MLFloat16(-0.3125f), + MLFloat16(-0.6875f), MLFloat16(0.375f), MLFloat16(-0.125f), MLFloat16(-0.375f), + MLFloat16(0.6875f), MLFloat16(0.3125f), MLFloat16(-0.0625f), MLFloat16(-0.4375f)}; + vector W_shape = {1, 2, 4, 4}; + vector B = {MLFloat16(-0.8125f)}; + vector B_shape = {1}; + vector Y_shape = {1, 1, 4, 2}; + auto expected_vals = { + MLFloat16(-0.83203125f), MLFloat16(-1.40625f), MLFloat16(-0.595312476f), MLFloat16(-1.93906248f), + MLFloat16(-0.896875024f), MLFloat16(-1.53750002f), MLFloat16(-0.904687524f), MLFloat16(-1.65937495f)}; + + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + + +TEST(ConvFp16Test, Conv2D_AutoPad1) { + ConvOpAndTestAttributes attrs = { + "SAME_UPPER", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{3, 3}, // kernel_shape + {}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = vector(25, MLFloat16(1.0f)); + vector X_shape = {1, 1, 5, 5}; + vector W = {MLFloat16(0.0f), MLFloat16(1.0f), MLFloat16(2.0f), + MLFloat16(3.0f), MLFloat16(4.0f), MLFloat16(5.0f), + MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f)}; + + vector W_shape = {1, 1, 3, 3}; + vector Y_shape = {1, 1, 5, 5}; + auto expected_vals = {MLFloat16(24.0f), MLFloat16(33.0f), MLFloat16(33.0f), MLFloat16(33.0f), MLFloat16(20.0f), + MLFloat16(27.0f), MLFloat16(36.0f), MLFloat16(36.0f), MLFloat16(36.0f), MLFloat16(21.0f), + MLFloat16(27.0f), MLFloat16(36.0f), MLFloat16(36.0f), MLFloat16(36.0f), MLFloat16(21.0f), + MLFloat16(27.0f), MLFloat16(36.0f), MLFloat16(36.0f), MLFloat16(36.0f), MLFloat16(21.0f), + MLFloat16(12.0f), MLFloat16(15.0f), MLFloat16(15.0f), MLFloat16(15.0f), MLFloat16(8.0f)}; + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + + // NNAPI/CoreML EP requires weight to be an initializer + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvFp16Test, Conv2D_AutoPad2) { + ConvOpAndTestAttributes attrs = { + "SAME_LOWER", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{3, 3}, // kernel_shape + {}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = {MLFloat16(1.0f), MLFloat16(0.0f), MLFloat16(1.0f), MLFloat16(0.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(0.0f), MLFloat16(1.0f), MLFloat16(0.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(0.0f), MLFloat16(1.0f), MLFloat16(0.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(0.0f), MLFloat16(1.0f), MLFloat16(0.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(0.0f), MLFloat16(1.0f), MLFloat16(0.0f), MLFloat16(1.0f)}; + vector X_shape = {1, 1, 5, 5}; + vector W = {MLFloat16(0.0f), MLFloat16(1.0f), MLFloat16(2.0f), + MLFloat16(3.0f), MLFloat16(4.0f), MLFloat16(5.0f), + MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f)}; + + vector W_shape = {1, 1, 3, 3}; + vector Y_shape = {1, 1, 5, 5}; + auto expected_vals = {MLFloat16(11.0f), MLFloat16(22.0f), MLFloat16(11.0f), MLFloat16(22.0f), MLFloat16(11.0f), + MLFloat16(12.0f), MLFloat16(24.0f), MLFloat16(12.0f), MLFloat16(24.0f), MLFloat16(12.0f), + MLFloat16(12.0f), MLFloat16(24.0f), MLFloat16(12.0f), MLFloat16(24.0f), MLFloat16(12.0f), + MLFloat16(12.0f), MLFloat16(24.0f), MLFloat16(12.0f), MLFloat16(24.0f), MLFloat16(12.0f), + MLFloat16(5.0f), MLFloat16(10.0f), MLFloat16(5.0f), MLFloat16(10.0f), MLFloat16(5.0f)}; + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + +#ifndef DISABLE_CONTRIB_OPS +TEST(ConvFp16Test, Conv2D_HardSigmoid) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{2, 2}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {}, // excluded EPs + "HardSigmoid", // activation + vector{0.2f, 0.5f} // activation_parameters + }; + + vector X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), + MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f), + MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f)}; + vector X_shape = {1, 1, 3, 3}; + vector W = {MLFloat16(0.125f), MLFloat16(0.125f), MLFloat16(0.125f), MLFloat16(0.125f), + MLFloat16(-0.125f), MLFloat16(-0.125f), MLFloat16(-0.125f), MLFloat16(-0.125f)}; + vector W_shape = {2, 1, 2, 2}; + vector Y_shape = {1, 2, 2, 2}; + auto expected_vals = {MLFloat16(0.8f), MLFloat16(0.9f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(0.2f), MLFloat16(0.1f), MLFloat16(0.0f), MLFloat16(0.0f)}; + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); +} + +TEST(ConvFp16Test, Conv2D_Relu) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{2, 2}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {}, // excluded EPs + "Relu" // activation + }; + + vector X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), + MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f), + MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f)}; + vector X_shape = {1, 1, 3, 3}; + vector W = {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(-1.0f), MLFloat16(-1.0f), MLFloat16(-1.0f), MLFloat16(-1.0f)}; + vector W_shape = {2, 1, 2, 2}; + vector Y_shape = {1, 2, 2, 2}; + auto expected_vals = {MLFloat16(12.0f), MLFloat16(16.0f), MLFloat16(24.0f), MLFloat16(28.0f), + MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f)}; + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); +} + +TEST(ConvFp16Test, Conv2D_Bias_Relu) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{2, 2}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {}, // excluded EPs + "Relu" // activation + }; + + vector X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), + MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f), + MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f)}; + vector X_shape = {1, 1, 3, 3}; + vector W = {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f)}; + vector W_shape = {2, 1, 2, 2}; + vector Y_shape = {1, 2, 2, 2}; + vector B = {MLFloat16(1.0f), MLFloat16(-1.0f)}; + vector B_shape = {2}; + auto expected_vals = {MLFloat16(13.0f), MLFloat16(17.0f), MLFloat16(25.0f), MLFloat16(29.0f), + MLFloat16(11.0f), MLFloat16(15.0f), MLFloat16(23.0f), MLFloat16(27.0f)}; + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); +} +#endif // CONTRIB_OPS + + +TEST(ConvFp16Test, Conv3D_1) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1, 1}, // dilations + 1, // group + vector{1, 1, 1}, // kernel_shape + vector{0, 0, 0, 0, 0, 0}, // pads + vector{1, 1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + MLFloat16(-0.433349609f), MLFloat16(-0.483886719f), MLFloat16(-0.309570312f), + MLFloat16(0.160766602f), MLFloat16(-0.466796875f), MLFloat16(0.465820312f), + MLFloat16(-0.370605469f), MLFloat16(0.406005859f), MLFloat16(-0.0354919434f), + MLFloat16(-0.312500000f), MLFloat16(0.426757812f), MLFloat16(0.398437500f), + MLFloat16(-0.390625000f), MLFloat16(0.259033203f), MLFloat16(-0.206420898f), + MLFloat16(0.138183594f), MLFloat16(-0.201538086f), MLFloat16(0.100280762f), + MLFloat16(-0.241333008f), MLFloat16(0.123107910f), MLFloat16(0.0327453613f), + MLFloat16(0.296142578f), MLFloat16(-0.231201172f), MLFloat16(0.334472656f), + MLFloat16(0.0256805420f), MLFloat16(0.245849609f), MLFloat16(0.117248535f)}; + vector X_shape = {1, 1, 3, 3, 3}; + vector W = {MLFloat16(-0.442138672f)}; + vector W_shape = {1, 1, 1, 1, 1}; + vector Y_shape = {1, 1, 3, 3, 3}; + auto expected_vals = { + MLFloat16(0.191600621f), MLFloat16(0.213945031f), MLFloat16(0.136873007f), + MLFloat16(-0.0710811317f), MLFloat16(0.206388950f), MLFloat16(-0.205957174f), + MLFloat16(0.163859010f), MLFloat16(-0.179510891f), MLFloat16(0.0156923607f), + MLFloat16(0.138168335f), MLFloat16(-0.188686132f), MLFloat16(-0.176164627f), + MLFloat16(0.172710419f), MLFloat16(-0.114528596f), MLFloat16(0.0912666619f), + MLFloat16(-0.0610963106f), MLFloat16(0.0891077816f), MLFloat16(-0.0443380028f), + MLFloat16(0.106702656f), MLFloat16(-0.0544307679f), MLFloat16(-0.0144779906f), + MLFloat16(-0.130936086f), MLFloat16(0.102222979f), MLFloat16(-0.147883296f), + MLFloat16(-0.0113543607f), MLFloat16(-0.108699620f), MLFloat16(-0.0518401116f)}; + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + + +TEST(ConvFp16Test, Conv3D_2) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1, 1}, // dilations + 1, // group + vector{1, 1, 1}, // kernel_shape + vector{2, 2, 2, 2, 2, 2}, // pads + vector{2, 2, 2}, // strides + {} // excluded EPs + }; + + vector X = { + MLFloat16(0.0107727051f), MLFloat16(-0.437988281f), MLFloat16(0.455322266f), MLFloat16(-0.286621094f), + MLFloat16(0.456787109f), MLFloat16(-0.0320434570f), MLFloat16(0.422851562f), MLFloat16(-0.187255859f), + MLFloat16(-0.458496094f), MLFloat16(0.0420532227f), MLFloat16(-0.133300781f), MLFloat16(-0.253662109f), + MLFloat16(-0.238403320f), MLFloat16(0.122131348f), MLFloat16(-0.177856445f), MLFloat16(0.189208984f), + MLFloat16(0.379638672f), MLFloat16(-0.0339965820f), MLFloat16(0.127319336f), MLFloat16(-0.0402832031f), + MLFloat16(0.464355469f), MLFloat16(-0.226928711f), MLFloat16(0.173950195f), MLFloat16(-0.301513672f), + MLFloat16(-0.404296875f), MLFloat16(-0.332031250f), MLFloat16(0.0465393066f), MLFloat16(-0.494873047f), + MLFloat16(0.0755004883f), MLFloat16(0.117309570f), MLFloat16(0.470458984f), MLFloat16(0.482421875f), + MLFloat16(-0.377441406f), MLFloat16(-0.0564880371f), MLFloat16(-0.107910156f), MLFloat16(0.0434875488f), + MLFloat16(0.244750977f), MLFloat16(-0.409912109f), MLFloat16(0.0616149902f), MLFloat16(0.229736328f), + MLFloat16(0.278808594f), MLFloat16(0.0814819336f), MLFloat16(0.245361328f), MLFloat16(0.0825195312f), + MLFloat16(-0.147216797f), MLFloat16(-0.430175781f), MLFloat16(0.0271759033f), MLFloat16(0.360595703f), + MLFloat16(0.249511719f), MLFloat16(-0.225097656f), MLFloat16(-0.362792969f), MLFloat16(-0.476806641f), + MLFloat16(0.112731934f), MLFloat16(0.497802734f), MLFloat16(0.268554688f), MLFloat16(0.0255279541f), + MLFloat16(-0.303710938f), MLFloat16(0.411376953f), MLFloat16(0.361572266f), MLFloat16(0.00883483887f), + MLFloat16(-0.0795898438f), MLFloat16(0.360107422f), MLFloat16(0.173217773f), MLFloat16(-0.0120086670f)}; + vector X_shape = {1, 1, 4, 4, 4}; + vector W = {MLFloat16(0.328125f)}; + vector W_shape = {1, 1, 1, 1, 1}; + vector Y_shape = {1, 1, 4, 4, 4}; + auto expected_vals = {MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), + MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), + MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(0.00353479385f), MLFloat16(0.149402618f), MLFloat16(), + MLFloat16(), MLFloat16(-0.150444031f), MLFloat16(-0.0437393188f), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), + MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(-0.123847961f), MLFloat16(-0.03540802f), MLFloat16(), + MLFloat16(), MLFloat16(0.0914840698f), MLFloat16(0.0805091858f), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), + MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), + MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16(), MLFloat16()}; + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + + +TEST(ConvFp16Test, Conv3D_Bias) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{2, 2, 2}, // dilations + 1, // group + vector{2, 2, 2}, // kernel_shape + vector{2, 2, 2, 2, 2, 2}, // pads + vector{2, 2, 2}, // strides + {} // excluded EPs + }; + + vector X = { + MLFloat16(0.468017578f), MLFloat16(-0.461425781f), MLFloat16(0.335205078f), MLFloat16(-0.401123047f), + MLFloat16(0.417236328f), MLFloat16(-0.0481262207f), MLFloat16(0.204101562f), MLFloat16(0.0318908691f), + MLFloat16(-0.0477905273f), MLFloat16(-0.0795288086f), MLFloat16(0.498779297f), MLFloat16(0.350585938f), + MLFloat16(0.480712891f), MLFloat16(0.269775391f), MLFloat16(-0.246337891f), MLFloat16(0.190429688f), + MLFloat16(-0.118286133f), MLFloat16(-0.257568359f), MLFloat16(-0.339355469f), MLFloat16(-0.258056641f), + MLFloat16(-0.0828247070f), MLFloat16(0.351318359f), MLFloat16(-0.291259766f), MLFloat16(-0.433593750f), + MLFloat16(-0.134277344f), MLFloat16(0.440429688f), MLFloat16(0.0530700684f), MLFloat16(-0.350097656f), + MLFloat16(-0.284667969f), MLFloat16(-0.442138672f), MLFloat16(-0.0741577148f), MLFloat16(-0.109191895f), + MLFloat16(0.284423828f), MLFloat16(0.349853516f), MLFloat16(-0.193115234f), MLFloat16(0.326171875f), + MLFloat16(0.488037109f), MLFloat16(0.0557556152f), MLFloat16(-0.464599609f), MLFloat16(-0.0252380371f), + MLFloat16(-0.187866211f), MLFloat16(-0.147216797f), MLFloat16(0.207641602f), MLFloat16(0.471679688f), + MLFloat16(-0.0556640625f), MLFloat16(-0.498779297f), MLFloat16(0.227416992f), MLFloat16(0.458984375f), + MLFloat16(-0.472412109f), MLFloat16(-0.435791016f), MLFloat16(0.284179688f), MLFloat16(-0.270263672f), + MLFloat16(0.342285156f), MLFloat16(0.335693359f), MLFloat16(-0.194824219f), MLFloat16(-0.276855469f), + MLFloat16(-0.423828125f), MLFloat16(-0.438476562f), MLFloat16(0.437255859f), MLFloat16(0.306396484f), + MLFloat16(0.457031250f), MLFloat16(0.0529174805f), MLFloat16(-0.0236206055f), MLFloat16(-0.186035156f), + MLFloat16(0.0866699219f), MLFloat16(0.325439453f), MLFloat16(0.184570312f), MLFloat16(-0.198486328f), + MLFloat16(-0.275390625f), MLFloat16(0.320068359f), MLFloat16(-0.348388672f), MLFloat16(0.0999755859f), + MLFloat16(-0.113769531f), MLFloat16(0.212280273f), MLFloat16(-0.0231475830f), MLFloat16(0.167114258f), + MLFloat16(0.223144531f), MLFloat16(0.0361022949f), MLFloat16(-0.158691406f), MLFloat16(0.0599975586f), + MLFloat16(-0.0395202637f), MLFloat16(-0.484130859f), MLFloat16(0.329101562f), MLFloat16(-0.231201172f), + MLFloat16(0.394531250f), MLFloat16(-0.355468750f), MLFloat16(-0.170288086f), MLFloat16(-0.0550842285f), + MLFloat16(0.158569336f), MLFloat16(-0.418457031f), MLFloat16(-0.247436523f), MLFloat16(0.0360412598f), + MLFloat16(-0.283691406f), MLFloat16(0.460205078f), MLFloat16(0.291015625f), MLFloat16(-0.199340820f), + MLFloat16(0.380859375f), MLFloat16(-0.138427734f), MLFloat16(-0.238403320f), MLFloat16(-0.190673828f), + MLFloat16(-0.110595703f), MLFloat16(-0.0871582031f), MLFloat16(0.244506836f), MLFloat16(-0.147216797f), + MLFloat16(0.143676758f), MLFloat16(0.395507812f), MLFloat16(-0.125366211f), MLFloat16(0.115905762f), + MLFloat16(0.459716797f), MLFloat16(-0.300048828f), MLFloat16(-0.465820312f), MLFloat16(-0.339599609f), + MLFloat16(-0.267089844f), MLFloat16(0.361083984f), MLFloat16(-0.114257812f), MLFloat16(-0.0838012695f), + MLFloat16(-0.318115234f), MLFloat16(0.145141602f), MLFloat16(0.315673828f), MLFloat16(0.331787109f), + MLFloat16(-0.255859375f), MLFloat16(0.118896484f), MLFloat16(0.128295898f), MLFloat16(-0.331054688f), + MLFloat16(0.254882812f), MLFloat16(-0.467529297f), MLFloat16(-0.119812012f), MLFloat16(0.183471680f)}; + vector X_shape = {2, 1, 4, 4, 4}; + vector W = { + MLFloat16(0.388183594f), MLFloat16(-0.163696289f), + MLFloat16(-0.428710938f), MLFloat16(0.427734375f), + MLFloat16(0.215209961f), MLFloat16(0.00791168213f), + MLFloat16(0.338867188f), MLFloat16(0.218383789f), + MLFloat16(0.341064453f), MLFloat16(-0.170410156f), + MLFloat16(-0.0135726929f), MLFloat16(-0.267822266f), + MLFloat16(-0.348632812f), MLFloat16(-0.267333984f), + MLFloat16(-0.366943359f), MLFloat16(0.373046875f)}; + vector W_shape = {2, 1, 2, 2, 2}; + vector B = {MLFloat16(0.430908203f), MLFloat16(-0.456298828f)}; + vector B_shape = {2}; + vector Y_shape = {2, 2, 3, 3, 3}; + + auto expected_vals = { + MLFloat16(0.533115625f), MLFloat16(0.662707329f), MLFloat16(0.544498205f), + MLFloat16(0.424174339f), MLFloat16(0.627012968f), MLFloat16(0.672067642f), + MLFloat16(0.430530101f), MLFloat16(0.424569398f), MLFloat16(0.538250446f), + MLFloat16(0.693208933f), MLFloat16(0.427851349f), MLFloat16(0.221761703f), + MLFloat16(0.295077145f), MLFloat16(0.832913339f), MLFloat16(0.375999779f), + MLFloat16(0.437245011f), MLFloat16(0.291920483f), MLFloat16(0.669212699f), + MLFloat16(0.552566051f), MLFloat16(0.226370573f), MLFloat16(0.513698816f), + MLFloat16(0.303992242f), MLFloat16(0.742284894f), MLFloat16(0.266925812f), + MLFloat16(0.461661220f), MLFloat16(0.323991477f), MLFloat16(0.511511266f), + MLFloat16(-0.281706333f), MLFloat16(-0.502987564f), MLFloat16(-0.579300106f), + MLFloat16(-0.599243939f), MLFloat16(-0.505472362f), MLFloat16(-0.756186068f), + MLFloat16(-0.443522811f), MLFloat16(-0.572978139f), MLFloat16(-0.630189657f), + MLFloat16(-0.475540936f), MLFloat16(-0.728834927f), MLFloat16(-0.389986098f), + MLFloat16(-0.669373453f), MLFloat16(-0.387869477f), MLFloat16(-0.357608467f), + MLFloat16(-0.397931814f), MLFloat16(-0.547608852f), MLFloat16(-0.358573616f), + MLFloat16(-0.532473862f), MLFloat16(-0.408438683f), MLFloat16(-0.453677744f), + MLFloat16(-0.454452783f), MLFloat16(-0.379444361f), MLFloat16(-0.524981856f), + MLFloat16(-0.424284518f), MLFloat16(-0.555757523f), MLFloat16(-0.385479659f), + MLFloat16(0.449835509f), MLFloat16(0.500584960f), MLFloat16(0.493453026f), + MLFloat16(0.406748474f), MLFloat16(0.407412887f), MLFloat16(0.462785602f), + MLFloat16(0.430008084f), MLFloat16(0.406240731f), MLFloat16(0.425926626f), + MLFloat16(0.551153421f), MLFloat16(0.549696267f), MLFloat16(0.270993829f), + MLFloat16(0.402447432f), MLFloat16(0.574599743f), MLFloat16(0.418689728f), + MLFloat16(0.450668573f), MLFloat16(0.420462728f), MLFloat16(0.394942641f), + MLFloat16(0.593814850f), MLFloat16(0.165656328f), MLFloat16(0.533114314f), + MLFloat16(0.430018425f), MLFloat16(0.502558053f), MLFloat16(0.392109811f), + MLFloat16(0.407388866f), MLFloat16(0.507203162f), MLFloat16(0.382243097f), + MLFloat16(-0.423966885f), MLFloat16(-0.419248402f), MLFloat16(-0.524025679f), + MLFloat16(-0.521910012f), MLFloat16(-0.502744913f), MLFloat16(-0.512152255f), + MLFloat16(-0.425884366f), MLFloat16(-0.410446912f), MLFloat16(-0.448228836f), + MLFloat16(-0.337432563f), MLFloat16(-0.735596657f), MLFloat16(-0.371323436f), + MLFloat16(-0.488816738f), MLFloat16(-0.618983328f), MLFloat16(-0.263916761f), + MLFloat16(-0.475321025f), MLFloat16(-0.507732749f), MLFloat16(-0.420486867f), + MLFloat16(-0.558301449f), MLFloat16(-0.397618413f), MLFloat16(-0.453063041f), + MLFloat16(-0.559680939f), MLFloat16(-0.254149109f), MLFloat16(-0.535908163f), + MLFloat16(-0.480782807f), MLFloat16(-0.385932118f), MLFloat16(-0.499056786f)}; + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + + +TEST(ConvFp16Test, Conv2D_group) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 2, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + MLFloat16(0.0f), MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), + MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(7.0f), + MLFloat16(8.0f), MLFloat16(9.0f), MLFloat16(10.0f), MLFloat16(11.0f), + MLFloat16(12.0f), MLFloat16(13.0f), MLFloat16(14.0f), MLFloat16(15.0f), + MLFloat16(16.0f), MLFloat16(17.0f)}; + vector X_shape = {1, 2, 3, 3}; + vector W = {MLFloat16(1.0f), MLFloat16(2.0f)}; + vector W_shape = {2, 1, 1, 1}; + vector Y_shape = {1, 2, 3, 3}; + auto expected_vals = { + MLFloat16(0.0f), MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), + MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(7.0f), + MLFloat16(8.0f), MLFloat16(18.0f), MLFloat16(20.0f), MLFloat16(22.0f), + MLFloat16(24.0f), MLFloat16(26.0f), MLFloat16(28.0f), MLFloat16(30.0f), + MLFloat16(32.0f), MLFloat16(34.0f)}; + + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvFp16Test, ConvDimWithZero) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X; + vector X_shape = {0, 2, 4, 4}; // N of 0 should be handled + vector W = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(1.0f), MLFloat16(2.0f)}; + vector W_shape = {2, 2, 1, 1}; + vector out_shape = {0, 2, 4, 4}; + + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, {}, out_shape); +} + +TEST(ConvFp16Test, Conv1D_asymmetric_padding) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1}, // dilations + 1, // group + vector{3}, // kernel_shape + vector{1, 0}, // pads + vector{1}, // strides + {} // excluded EPs + }; + + vector X = {MLFloat16(1.f), MLFloat16(2.f), MLFloat16(3.f)}; + vector X_shape = {1, 1, 3}; + vector W = {MLFloat16(1.f), MLFloat16(1.f), MLFloat16(1.f)}; + vector W_shape = {1, 1, 3}; + vector B = {MLFloat16()}; + vector B_shape = {1}; + vector Y_shape = {1, 1, 2}; + auto expected_vals = {MLFloat16(3.f), MLFloat16(6.f)}; + + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvFp16Test, Conv_AutoPad_with_non_default_strides) { + ConvOpAndTestAttributes attrs = { + "SAME_LOWER", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{3, 3}, // kernel_shape + vector{}, // pads + vector{2, 2}, // strides + {} // excluded EPs + }; + + vector X = { + MLFloat16(0.0f), MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), MLFloat16(4.0f), + MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f), + MLFloat16(10.0f), MLFloat16(11.0f), MLFloat16(12.0f), MLFloat16(13.0f), MLFloat16(14.0f), + MLFloat16(15.0f), MLFloat16(16.0f), MLFloat16(17.0f), MLFloat16(18.0f), MLFloat16(19.0f), + MLFloat16(20.0f), MLFloat16(21.0f), MLFloat16(22.0f), MLFloat16(23.0f), MLFloat16(24.0f)}; + vector X_shape = {1, 1, 5, 5}; + + vector W = {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f)}; + vector W_shape = {1, 1, 3, 3}; + + auto expected_vals = {MLFloat16(12.0f), MLFloat16(27.0f), MLFloat16(24.0f), + MLFloat16(63.0f), MLFloat16(108.0f), MLFloat16(81.0f), + MLFloat16(72.0f), MLFloat16(117.0f), MLFloat16(84.0f)}; + vector Y_shape = {1, 1, 3, 3}; + + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + + +TEST(ConvFp16Test, Pointwise_2D) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + vector X = { + MLFloat16(-9.f), MLFloat16(1.f), MLFloat16(2.f), + MLFloat16(-5.f), MLFloat16(3.f), MLFloat16(-2.f), + MLFloat16(5.f), MLFloat16(-3.f), MLFloat16(1.f), + MLFloat16(1.f), MLFloat16(8.f), MLFloat16(-4.f), + MLFloat16(-1.f), MLFloat16(6.f), MLFloat16(7.f), + MLFloat16(-1.f), MLFloat16(4.f), MLFloat16(-5.f), + MLFloat16(-9.f), MLFloat16(1.f), MLFloat16(2.f), + MLFloat16(-5.f), MLFloat16(3.f), MLFloat16(-2.f), + MLFloat16(5.f), MLFloat16(-3.f), MLFloat16(1.f)}; + vector X_shape = {1, 3, 3, 3}; + vector W = {MLFloat16(2.f), MLFloat16(-3.f), MLFloat16(0.5f), + MLFloat16(0.25f), MLFloat16(-2.f), MLFloat16(-0.75f)}; + vector W_shape = {2, 3, 1, 1}; + vector Y_shape = {1, 2, 3, 3}; + auto expected_vals = { + MLFloat16(-25.5f), MLFloat16(-21.5f), MLFloat16(17.f), + MLFloat16(-9.5f), MLFloat16(-10.5f), MLFloat16(-26.f), + MLFloat16(15.5f), MLFloat16(-19.5f), MLFloat16(17.5f), + MLFloat16(2.5f), MLFloat16(-16.5f), MLFloat16(7.f), + MLFloat16(4.5f), MLFloat16(-13.5f), MLFloat16(-13.f), + MLFloat16(-0.5f), MLFloat16(-6.5f), MLFloat16(9.5f)}; + + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + + +TEST(ConvFp16Test, Pointwise_3D) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1, 1}, // dilations + 1, // group + vector{1, 1, 1}, // kernel_shape + vector{0, 0, 0, 0, 0, 0}, // pads + vector{1, 1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + MLFloat16(2 / 16.f), MLFloat16(3 / 16.f), MLFloat16(4 / 16.f), + MLFloat16(5 / 16.f), MLFloat16(6 / 16.f), MLFloat16(7 / 16.f), + MLFloat16(8 / 16.f), MLFloat16(9 / 16.f), MLFloat16(10 / 16.f), + MLFloat16(11 / 16.f), MLFloat16(12 / 16.f), MLFloat16(13 / 16.f), + MLFloat16(14 / 16.f), MLFloat16(15 / 16.f), MLFloat16(16 / 16.f), + MLFloat16(17 / 16.f), MLFloat16(18 / 16.f), MLFloat16(19 / 16.f), + MLFloat16(20 / 16.f), MLFloat16(21 / 16.f), MLFloat16(22 / 16.f), + MLFloat16(23 / 16.f), MLFloat16(24 / 16.f), MLFloat16(25 / 16.f), + MLFloat16(26 / 16.f), MLFloat16(27 / 16.f), MLFloat16(28 / 16.f)}; + vector X_shape = {1, 1, 3, 3, 3}; + + vector W = {MLFloat16(0.5f)}; + vector W_shape = {1, 1, 1, 1, 1}; + + auto expected_vals = { + MLFloat16(0.0625f), MLFloat16(0.09375f), MLFloat16(0.125f), + MLFloat16(0.15625f), MLFloat16(0.1875f), MLFloat16(0.21875f), + MLFloat16(0.25f), MLFloat16(0.28125f), MLFloat16(0.3125f), + MLFloat16(0.34375f), MLFloat16(0.375f), MLFloat16(0.40625f), + MLFloat16(0.4375f), MLFloat16(0.46875f), MLFloat16(0.5f), + MLFloat16(0.53125f), MLFloat16(0.5625f), MLFloat16(0.59375f), + MLFloat16(0.625f), MLFloat16(0.65625f), MLFloat16(0.6875f), + MLFloat16(0.71875f), MLFloat16(0.75f), MLFloat16(0.78125f), + MLFloat16(0.8125f), MLFloat16(0.84375f), MLFloat16(0.875f)}; + vector Y_shape = {1, 1, 3, 3, 3}; + + // Test with weight as initializer + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + +#ifndef DISABLE_CONTRIB_OPS + +TEST(ConvFp16Test, Pointwise_Relu) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {}, // excluded EPs + "Relu" // activation + }; + + vector X = { + MLFloat16(-9.f), MLFloat16(1.f), MLFloat16(2.f), + MLFloat16(-5.f), MLFloat16(3.f), MLFloat16(-2.f), + MLFloat16(5.f), MLFloat16(-3.f), MLFloat16(1.f), + MLFloat16(1.f), MLFloat16(8.f), MLFloat16(-4.f), + MLFloat16(-1.f), MLFloat16(6.f), MLFloat16(7.f), + MLFloat16(-1.f), MLFloat16(4.f), MLFloat16(-5.f), + MLFloat16(-9.f), MLFloat16(1.f), MLFloat16(2.f), + MLFloat16(-5.f), MLFloat16(3.f), MLFloat16(-2.f), + MLFloat16(5.f), MLFloat16(-3.f), MLFloat16(1.f)}; + vector X_shape = {1, 3, 3, 3}; + vector W = {MLFloat16(2.f), MLFloat16(-3.f), MLFloat16(0.5f), + MLFloat16(0.25f), MLFloat16(-2.f), MLFloat16(-0.75f)}; + vector W_shape = {2, 3, 1, 1}; + vector Y_shape = {1, 2, 3, 3}; + auto expected_vals = { + MLFloat16(0.f), MLFloat16(0.f), MLFloat16(17.f), + MLFloat16(0.f), MLFloat16(0.f), MLFloat16(0.f), + MLFloat16(15.5f), MLFloat16(0.f), MLFloat16(17.5f), + MLFloat16(2.5f), MLFloat16(0.f), MLFloat16(7.f), + MLFloat16(4.5f), MLFloat16(0.f), MLFloat16(0.f), + MLFloat16(0.f), MLFloat16(0.f), MLFloat16(9.5f)}; + + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); +} + +#endif // CONTRIB_OPS + + +#ifndef ENABLE_TRAINING +// Prepacking is disabled in full training build so no need to test the feature in a training build. + +const onnxruntime::RunOptions run_options = []() { + onnxruntime::RunOptions options{}; + ORT_THROW_IF_ERROR(options.config_options.AddConfigEntry(kOpTesterRunOptionsConfigTestTunableOp, "true")); + return options; +}(); + +const constexpr auto run_with_tunable_op = &run_options; + +TEST(ConvFp16Test, SharedPrepackedWeights) { + OpTester test("Conv", 11); + + vector X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), + MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f), + MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f)}; + vector X_shape = {1, 1, 3, 3}; + vector W = {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f)}; + vector W_shape = {2, 1, 2, 2}; + vector Y_shape = {1, 2, 2, 2}; + vector B = {MLFloat16(1.0f), MLFloat16(-1.0f)}; + vector B_shape = {2}; + auto expected_vals = { + MLFloat16(13.0f), MLFloat16(17.0f), MLFloat16(25.0f), MLFloat16(29.0f), + MLFloat16(11.0f), MLFloat16(15.0f), MLFloat16(23.0f), MLFloat16(27.0f)}; + + test.AddInput("X", X_shape, X); + test.AddInput("W", W_shape, W, true); + test.AddInput("B", B_shape, B, true); + test.AddOutput("Y", Y_shape, expected_vals, /*no sort*/ false, 0.002f, 0.0f); + + OrtValue w; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape(W_shape), + W.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), w); + + SessionOptions so; + // Set up B as a shared initializer to be shared between sessions + ASSERT_EQ(so.AddInitializer("W", &w), Status::OK()); + + // We want all sessions running using this OpTester to be able to share pre-packed weights if applicable + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + // Pre-packing is limited just to the CPU EP for now and we will only test the CPU EP + // and we want to ensure that it is available in this build + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t number_of_pre_packed_weights_counter_session_1 = 0; + size_t number_of_shared_pre_packed_weights_counter = 0; + + // Session 1 + { + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_1, &number_of_shared_pre_packed_weights_counter); + // Assert that no pre-packed weights have been shared thus far + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(0)); + } + + auto number_of_elements_in_shared_prepacked_buffers_container = + test.GetNumPrePackedWeightsShared(); + // Assert that the number of elements in the shared container + // is the same as the number of weights that have been pre-packed + ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_elements_in_shared_prepacked_buffers_container); + + // On some platforms/architectures MLAS may choose to not do any pre-packing and the number of elements + // that have been pre-packed will be zero in which case we do not continue with the testing + // of "sharing" of pre-packed weights as there are no pre-packed weights to be shared at all. + if (number_of_pre_packed_weights_counter_session_1 == 0) + return; + + // Session 2 + { + size_t number_of_pre_packed_weights_counter_session_2 = 0; + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_2, &number_of_shared_pre_packed_weights_counter); + + // Assert that the same number of weights were pre-packed in both sessions + ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_pre_packed_weights_counter_session_2); + + // Assert that the number of pre-packed weights that were shared equals + // the number of pre-packed weights in the second session + ASSERT_EQ(number_of_pre_packed_weights_counter_session_2, + static_cast(number_of_shared_pre_packed_weights_counter)); + } +} + +#endif + + +} // namespace test +} // namespace onnxruntime + +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED \ No newline at end of file