diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 3a100f5b70..95c3114c82 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -4,7 +4,6 @@ #include "core/providers/cpu/cpu_execution_provider.h" #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" -#include "core/mlas/inc/mlas.h" #ifndef DISABLE_CONTRIB_OPS #include "contrib_ops/cpu/cpu_contrib_kernels.h" @@ -284,8 +283,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, QLinearMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, MatMulInteger); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ConvInteger); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, QLinearConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, QLinearConv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, QLinearConv); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, Slice); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 11, Dropout); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression); @@ -989,12 +987,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, -#if defined(MLAS_TARGET_AMD64_IX86) - BuildKernelCreateInfo, -#endif + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo -class QLinearConv; +#if defined(MLAS_TARGET_AMD64_IX86) -template <> -class QLinearConv : public OpKernel { +class QLinearConv : public OpKernel { public: - explicit QLinearConv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {} + explicit QLinearConv(const OpKernelInfo& info) : OpKernel(info), + conv_attrs_(info), + is_W_signed_(false), + is_W_packed_(false) { + } + + Status Compute(OpKernelContext* context) const override; + Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override; + + private: + static void ReorderFilter(const uint8_t* input, + uint8_t* output, + size_t output_channels, + size_t input_channels, + size_t kernel_size) { + for (size_t k = 0; k < kernel_size; k++) { + for (size_t ic = 0; ic < input_channels; ic++) { + for (size_t oc = 0; oc < output_channels; oc++) { + size_t index = (oc * input_channels * kernel_size) + (ic * kernel_size) + k; + *output++ = input[index]; + } + } + } + } + + ConvAttributes conv_attrs_; + TensorShape W_shape_; +#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 + BufferUniquePtr packed_W_buffer_; + size_t packed_W_size_; +#endif + BufferUniquePtr reordered_W_buffer_; + bool is_W_signed_; + bool is_W_packed_; +}; + +ONNX_CPU_OPERATOR_KERNEL( + QLinearConv, + 10, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + QLinearConv); + +Status QLinearConv::PrePack(const Tensor& tensor, int input_idx, bool& is_packed) { + is_packed = false; + + // Support packing the weight matrix. + if (input_idx != 3) { + return Status::OK(); + } + + const auto& shape = tensor.Shape().GetDims(); + size_t rank = shape.size(); + if (rank <= 2) { + return Status::OK(); + } + + if (shape[0] % conv_attrs_.group != 0) { + return Status::OK(); + } + + // Note: The tensor has already been allocated with this tensor shape, so all + // shape indices are guaranteed to fit inside size_t. + const size_t output_channels = static_cast(shape[0]); + const size_t group_input_channels = static_cast(shape[1]); + const size_t kernel_size = + static_cast(std::accumulate(shape.data() + 2, shape.data() + rank, 1LL, std::multiplies())); + + const size_t group_count = static_cast(conv_attrs_.group); + const size_t group_output_channels = output_channels / group_count; + const size_t kernel_dim = group_input_channels * kernel_size; + + const auto* Wdata = static_cast(tensor.DataRaw()); + W_shape_ = shape; + is_W_signed_ = tensor.IsDataType(); + + auto alloc = Info().GetAllocator(0, OrtMemTypeDefault); + +#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 + // Don't pack the filter buffer if the MlasConvDepthwise path is used. + if (group_input_channels != 1 && group_output_channels != 1) { + packed_W_size_ = MlasGemmPackBSize(group_output_channels, kernel_dim, true); + + if (packed_W_size_ != 0) { + auto* packed_W = static_cast(alloc->Alloc(SafeInt(group_count) * packed_W_size_)); + packed_W_buffer_ = BufferUniquePtr(packed_W, BufferDeleter(alloc)); + + // Allocate a temporary buffer to hold the reordered oihw->ohwi filter for + // a single group. + // + // Note: The size of this buffer is less than or equal to the size of the original + // weight tensor, so the allocation size is guaranteed to fit inside size_t. + auto* group_reordered_W = static_cast(alloc->Alloc(group_output_channels * group_input_channels * kernel_size)); + BufferUniquePtr group_reordered_W_buffer(group_reordered_W, BufferDeleter(alloc)); + + const size_t W_offset = group_output_channels * kernel_dim; + + for (int64_t group_id = 0; group_id < conv_attrs_.group; ++group_id) { + ReorderFilter(Wdata, group_reordered_W, group_output_channels, group_input_channels, kernel_size); + MlasGemmPackB(group_output_channels, kernel_dim, group_reordered_W, group_output_channels, is_W_signed_, packed_W); + packed_W += packed_W_size_; + Wdata += W_offset; + } + + is_W_packed_ = true; + is_packed = true; + return Status::OK(); + } + } +#endif + + auto* reordered_W = static_cast(alloc->Alloc(SafeInt(sizeof(uint8_t)) * output_channels * group_input_channels * kernel_size)); + reordered_W_buffer_ = BufferUniquePtr(reordered_W, BufferDeleter(alloc)); + + ReorderFilter(Wdata, reordered_W, output_channels, group_input_channels, kernel_size); + + is_W_packed_ = true; + is_packed = true; + return Status::OK(); +} + +Status QLinearConv::Compute(OpKernelContext* context) const { + const Tensor* X = context->Input(0); + const Tensor* W = is_W_packed_ ? nullptr : context->Input(3); + const auto& W_shape = is_W_packed_ ? W_shape_ : W->Shape(); + const bool is_W_signed = (W != nullptr) ? W->IsDataType() : is_W_signed_; + + const int64_t N = X->Shape()[0]; + const int64_t M = W_shape[0]; + + // validate offsets + const Tensor* X_zero_point = context->Input(2); + const Tensor* W_zero_point = context->Input(5); + const Tensor* Y_zero_point = context->Input(7); + ORT_ENFORCE(IsScalarOr1ElementVector(X_zero_point), + "QLinearConv : input zero point must be a scalar or 1D tensor of size 1"); + ORT_ENFORCE(IsScalarOr1ElementVector(Y_zero_point), + "QLinearConv : result zero point must be a scalar or 1D tensor of size 1"); + + auto X_zero_point_value = *(X_zero_point->template Data()); + auto Y_zero_point_value = *(Y_zero_point->template Data()); + + uint8_t W_zero_point_value; + const auto& W_zero_point_shape = W_zero_point->Shape(); + if (W_zero_point_shape.NumDimensions() == 0 || + (W_zero_point_shape.NumDimensions() == 1 && (W_zero_point_shape[0] == 1 || W_zero_point_shape[0] == M))) { + const int64_t W_zero_point_size = W_zero_point_shape.Size(); + const auto* W_zero_point_data = static_cast(W_zero_point->DataRaw()); + if (is_W_signed) { + W_zero_point_value = 0; + for (int64_t i = 0; i < W_zero_point_size; i++) { + if (W_zero_point_data[i] != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "QLinearConv : filter zero point must be zero"); + } + } + } else { + W_zero_point_value = W_zero_point_data[0]; + for (int64_t i = 1; i < W_zero_point_size; i++) { + if (W_zero_point_data[i] != W_zero_point_value) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "QLinearConv : filter zero point must be constant"); + } + } + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "QLinearConv : filter zero point shape invalid"); + } + + // validate scale + const Tensor* X_scale = context->Input(1); + const Tensor* W_scale = context->Input(4); + const Tensor* Y_scale = context->Input(6); + ORT_ENFORCE(IsScalarOr1ElementVector(X_scale), + "QLinearConv : input scale must be a scalar or 1D tensor of size 1"); + ORT_ENFORCE(IsScalarOr1ElementVector(Y_scale), + "QLinearConv : result scale must be a scalar or 1D tensor of size 1"); + + auto X_scale_value = *(X_scale->template Data()); + auto Y_scale_value = *(Y_scale->template Data()); + + std::vector output_scales; + const auto& W_scale_shape = W_scale->Shape(); + if (W_scale_shape.NumDimensions() == 0 || + (W_scale_shape.NumDimensions() == 1 && (W_scale_shape[0] == 1 || W_scale_shape[0] == M))) { + const int64_t W_scale_size = W_scale_shape.Size(); + const auto* W_scale_data = W_scale->template Data(); + output_scales.resize(static_cast(W_scale_size)); + for (int64_t i = 0; i < W_scale_size; i++) { + output_scales[i] = (X_scale_value * W_scale_data[i] / Y_scale_value); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "QLinearConv : filter scale shape invalid"); + } + + const Tensor* B = context->Input(8); + + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W_shape)); + + std::vector kernel_shape; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W_shape, kernel_shape)); + + const size_t kernel_rank = kernel_shape.size(); + + std::vector pads(conv_attrs_.pads); + if (pads.empty()) { + pads.resize(kernel_rank * 2, 0); + } + std::vector dilations(conv_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(kernel_rank, 1); + } + std::vector strides(conv_attrs_.strides); + if (strides.empty()) { + strides.resize(kernel_rank, 1); + } + + std::vector Y_dims({N, M}); + TensorShape input_shape = X->Shape().Slice(2); + ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims)); + Tensor* Y = context->Output(0, TensorShape(Y_dims)); + TensorShape output_shape = Y->Shape().Slice(2); + + // Bail out early if one of the dimensions is zero. + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + + const int64_t input_image_size = input_shape.Size(); + const int64_t output_image_size = output_shape.Size(); + const int64_t kernel_size = TensorShape(kernel_shape).Size(); + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); + + // Handle the case of a dynamic weight filter. + BufferUniquePtr reordered_W_buffer; + uint8_t* reordered_W = nullptr; + bool use_reordered_W = true; +#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 + if (packed_W_buffer_) { + use_reordered_W = false; + } +#endif + if (use_reordered_W) { + if (reordered_W_buffer_) { + reordered_W = static_cast(reordered_W_buffer_.get()); + } else { + // Weight tensor was not constant or prepacking is disabled. + reordered_W = static_cast(alloc->Alloc(SafeInt(sizeof(uint8_t)) * W_shape.Size())); + reordered_W_buffer = BufferUniquePtr(reordered_W, BufferDeleter(alloc)); + ReorderFilter(static_cast(W->DataRaw()), + reordered_W, + static_cast(M), + static_cast(W_shape[1]), + static_cast(kernel_size)); + } + } + + int64_t group_count = conv_attrs_.group; + int64_t group_input_channels = W_shape[1]; + int64_t group_output_channels = M / group_count; + + // Test for depthwise convolution. + const bool is_depthwise_conv = (use_reordered_W && group_input_channels == 1 && group_output_channels == 1); + if (is_depthwise_conv) { + // Update the input and output channels to the number of groups in order to + // reuse as much of the below standard convolution path. + group_input_channels = group_count; + group_output_channels = group_count; + group_count = 1; + } + + const int64_t X_offset = group_input_channels * input_image_size; + const int64_t Y_offset = group_output_channels * output_image_size; + const int64_t kernel_dim = group_input_channels * kernel_size; + const int64_t col_buffer_size = kernel_dim * output_image_size; + + // Use an intermediate int32_t buffer for the GEMM computation before + // requantizing to the output type. + auto gemm_output_data = alloc->Alloc(SafeInt(sizeof(int32_t)) * Y_offset); + BufferUniquePtr gemm_output_buffer(gemm_output_data, BufferDeleter(alloc)); + auto* gemm_output = static_cast(gemm_output_buffer.get()); + + const auto* Xdata = X->template Data(); + const auto* Bdata = B != nullptr ? B->template Data() : nullptr; + auto* Ydata = Y->template MutableData(); + + auto* transpose_input = static_cast(alloc->Alloc(SafeInt(sizeof(uint8_t)) * X_offset)); + BufferUniquePtr transpose_input_buffer(transpose_input, BufferDeleter(alloc)); + + auto* transpose_output = static_cast(alloc->Alloc(SafeInt(sizeof(uint8_t)) * Y_offset)); + BufferUniquePtr transpose_output_buffer(transpose_output, BufferDeleter(alloc)); + + BufferUniquePtr col_buffer; + + // Pointwise convolutions can use the original input tensor in place, + // otherwise a temporary buffer is required for the im2col transform. + if (kernel_size != 1 || !conv_attrs_.HasStridesOneAndNoPadding()) { + auto* col_data = alloc->Alloc(SafeInt(sizeof(uint8_t)) * col_buffer_size); + col_buffer = BufferUniquePtr(col_data, BufferDeleter(alloc)); + } + auto* col_buffer_data = static_cast(col_buffer.get()); + + // Replicate the logic from MlasGemmU8X8Schedule to control the number of + // worker threads used for the convolution. + constexpr int32_t maximum_thread_count = 16; + constexpr double thread_complexity = static_cast(64 * 1024); + + const double complexity = static_cast(output_image_size) * + static_cast(group_output_channels) * + static_cast(kernel_dim); + + int32_t thread_count = maximum_thread_count; + if (complexity < thread_complexity * maximum_thread_count) { + thread_count = static_cast(complexity / thread_complexity) + 1; + } + if (thread_count > output_image_size) { + // Ensure that every thread produces at least one output. + thread_count = static_cast(output_image_size); + } + + concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); + thread_count = std::min(thread_count, concurrency::ThreadPool::DegreeOfParallelism(thread_pool)); + + for (int64_t image_id = 0; image_id < N; ++image_id) { + for (int64_t group_id = 0; group_id < group_count; ++group_id) { + // Transpose the input from channels first (NCHW) to channels last (NHWC). + MlasTranspose(Xdata, + transpose_input, + static_cast(group_input_channels), + static_cast(input_image_size)); + + if (col_buffer_data != nullptr) { + if (kernel_rank > 2) { + math::Im2colNd()( + transpose_input, + input_shape.GetDims().data(), + output_shape.GetDims().data(), + kernel_dim, + kernel_shape.data(), + strides.data(), + dilations.data(), + pads.data(), + static_cast(kernel_rank), + col_buffer_data, + false, + X_zero_point_value); + } + } + + auto conv_worker = [&](ptrdiff_t batch) { + auto work = concurrency::ThreadPool::PartitionWork(batch, thread_count, static_cast(output_image_size)); + int64_t output_start = static_cast(work.start); + int64_t output_count = static_cast(work.end - work.start); + + // Prepare the im2col transformation or use the input buffer directly for + // pointwise convolutions. + uint8_t* worker_gemm_input; + if (col_buffer_data != nullptr) { + worker_gemm_input = col_buffer_data + output_start * kernel_dim; + if (kernel_rank == 2) { + math::Im2col()( + transpose_input, + group_input_channels, + input_shape[0], + input_shape[1], + kernel_shape[0], + kernel_shape[1], + dilations[0], + dilations[1], + pads[0], + pads[1], + strides[0], + strides[1], + output_shape[1], + output_start, + output_count, + worker_gemm_input, + X_zero_point_value); + } else if (kernel_rank == 1) { + math::Im2col()( + transpose_input, + group_input_channels, + 1, + input_shape[0], + 1, + kernel_shape[0], + 1, + dilations[0], + 0, + pads[0], + 1, + strides[0], + output_shape[0], + output_start, + output_count, + worker_gemm_input, + X_zero_point_value); + } + } else { + worker_gemm_input = transpose_input + output_start * kernel_dim; + } + + auto* worker_gemm_output = gemm_output + output_start * group_output_channels; + auto* worker_transpose_output = transpose_output + output_start * group_output_channels; + + if (is_depthwise_conv) { + if (is_W_signed) { + MlasConvDepthwise(worker_gemm_input, + X_zero_point_value, + reinterpret_cast(reordered_W), + static_cast(W_zero_point_value), + worker_gemm_output, + static_cast(group_output_channels), + static_cast(output_count), + static_cast(kernel_size)); + } else { + MlasConvDepthwise(worker_gemm_input, + X_zero_point_value, + reordered_W, + W_zero_point_value, + worker_gemm_output, + static_cast(group_output_channels), + static_cast(output_count), + static_cast(kernel_size)); + } + } else { +#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 + if (packed_W_buffer_) { + MlasGemm(static_cast(output_count), + static_cast(group_output_channels), + static_cast(kernel_dim), + worker_gemm_input, + static_cast(kernel_dim), + X_zero_point_value, + static_cast(packed_W_buffer_.get()) + group_id * packed_W_size_, + W_zero_point_value, + is_W_signed, + worker_gemm_output, + static_cast(group_output_channels), + nullptr); + } else +#endif + { + MlasGemm(static_cast(output_count), + static_cast(group_output_channels), + static_cast(kernel_dim), + worker_gemm_input, + static_cast(kernel_dim), + X_zero_point_value, + reordered_W + group_id * group_output_channels, + static_cast(M), + W_zero_point_value, + is_W_signed, + worker_gemm_output, + static_cast(group_output_channels), + nullptr); + } + } + + if (output_scales.size() == 1) { + MlasRequantizeOutputColumn(worker_gemm_output, + worker_transpose_output, + Bdata != nullptr ? Bdata + group_id * group_output_channels : nullptr, + static_cast(output_count), + static_cast(group_output_channels), + output_scales[0], + Y_zero_point_value); + } else { + MlasRequantizeOutputColumn(worker_gemm_output, + worker_transpose_output, + Bdata != nullptr ? Bdata + group_id * group_output_channels : nullptr, + static_cast(output_count), + static_cast(group_output_channels), + output_scales.data() + group_id * group_output_channels, + Y_zero_point_value); + } + }; + + concurrency::ThreadPool::TrySimpleParallelFor(thread_pool, thread_count, conv_worker); + + // Transpose the output from channels last (NHWC) to channels first (NCHW). + MlasTranspose(transpose_output, + Ydata, + static_cast(output_image_size), + static_cast(group_output_channels)); + + Xdata += X_offset; + Ydata += Y_offset; + } + } + + return Status::OK(); +} + +#else + +class QLinearConv : public OpKernel { + public: + explicit QLinearConv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {} Status Compute(OpKernelContext* context) const override; @@ -27,18 +526,17 @@ class QLinearConv : public OpKernel { ConvAttributes conv_attrs_; }; -ONNX_CPU_OPERATOR_TYPED_KERNEL( +ONNX_CPU_OPERATOR_KERNEL( QLinearConv, 10, - uint8_t, KernelDefBuilder() .TypeConstraint("T1", DataTypeImpl::GetTensorType()) .TypeConstraint("T2", DataTypeImpl::GetTensorType()) .TypeConstraint("T3", DataTypeImpl::GetTensorType()) .TypeConstraint("T4", DataTypeImpl::GetTensorType()), - QLinearConv); + QLinearConv); -Status QLinearConv::Compute(OpKernelContext* context) const { +Status QLinearConv::Compute(OpKernelContext* context) const { const Tensor* X = context->Input(0); const Tensor* W = context->Input(3); @@ -237,484 +735,6 @@ Status QLinearConv::Compute(OpKernelContext* context) const { return Status::OK(); } -#if defined(MLAS_TARGET_AMD64_IX86) - -template <> -class QLinearConv : public OpKernel { - public: - explicit QLinearConv(const OpKernelInfo& info) : OpKernel(info), - conv_attrs_(info), - is_W_signed_(false), - is_W_packed_(false) { - } - - Status Compute(OpKernelContext* context) const override; - Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override; - - private: - static void ReorderFilter(const uint8_t* input, - uint8_t* output, - size_t output_channels, - size_t input_channels, - size_t kernel_size); - - ConvAttributes conv_attrs_; - TensorShape W_shape_; -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 - BufferUniquePtr packed_W_buffer_; - size_t packed_W_size_; -#endif - BufferUniquePtr reordered_W_buffer_; - bool is_W_signed_; - bool is_W_packed_; -}; - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - QLinearConv, - 10, - int8_t, - KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()) - .TypeConstraint("T3", DataTypeImpl::GetTensorType()) - .TypeConstraint("T4", DataTypeImpl::GetTensorType()), - QLinearConv); - -void QLinearConv::ReorderFilter(const uint8_t* input, - uint8_t* output, - size_t output_channels, - size_t input_channels, - size_t kernel_size) { - for (size_t k = 0; k < kernel_size; k++) { - for (size_t ic = 0; ic < input_channels; ic++) { - for (size_t oc = 0; oc < output_channels; oc++) { - size_t index = (oc * input_channels * kernel_size) + (ic * kernel_size) + k; - *output++ = input[index]; - } - } - } -} - -Status QLinearConv::PrePack(const Tensor& tensor, int input_idx, bool& is_packed) { - is_packed = false; - - // Support packing the weight matrix. - if (input_idx != 3) { - return Status::OK(); - } - - const auto& shape = tensor.Shape(); - size_t rank = shape.NumDimensions(); - if (rank != 4) { - return Status::OK(); - } - - if (shape[0] % conv_attrs_.group != 0) { - return Status::OK(); - } - - // Note: The tensor has already been allocated with this tensor shape, so all - // shape indices are guaranteed to fit inside size_t. - const size_t output_channels = static_cast(shape[0]); - const size_t group_input_channels = static_cast(shape[1]); - const size_t kernel_size = static_cast(shape[2] * shape[3]); - - const size_t group_count = static_cast(conv_attrs_.group); - const size_t group_output_channels = output_channels / group_count; - const size_t kernel_dim = group_input_channels * kernel_size; - - const auto* Wdata = static_cast(tensor.DataRaw()); - W_shape_ = shape; - is_W_signed_ = tensor.IsDataType(); - - auto alloc = Info().GetAllocator(0, OrtMemTypeDefault); - -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 - // Don't pack the filter buffer if the MlasConvDepthwise path is used. - if (group_input_channels != 1 && group_output_channels != 1) { - packed_W_size_ = MlasGemmPackBSize(group_output_channels, kernel_dim, true); - - if (packed_W_size_ != 0) { - auto* packed_W = static_cast(alloc->Alloc(SafeInt(group_count) * packed_W_size_)); - packed_W_buffer_ = BufferUniquePtr(packed_W, BufferDeleter(alloc)); - - // Allocate a temporary buffer to hold the reordered oihw->ohwi filter for - // a single group. - // - // Note: The size of this buffer is less than or equal to the size of the original - // weight tensor, so the allocation size is guaranteed to fit inside size_t. - auto* group_reordered_W = static_cast(alloc->Alloc(group_output_channels * group_input_channels * kernel_size)); - BufferUniquePtr group_reordered_W_buffer(group_reordered_W, BufferDeleter(alloc)); - - const size_t W_offset = group_output_channels * kernel_dim; - - for (int64_t group_id = 0; group_id < conv_attrs_.group; ++group_id) { - ReorderFilter(Wdata, group_reordered_W, group_output_channels, group_input_channels, kernel_size); - MlasGemmPackB(group_output_channels, kernel_dim, group_reordered_W, group_output_channels, is_W_signed_, packed_W); - packed_W += packed_W_size_; - Wdata += W_offset; - } - - is_W_packed_ = true; - is_packed = true; - return Status::OK(); - } - } -#endif - - auto* reordered_W = static_cast(alloc->Alloc(SafeInt(sizeof(uint8_t)) * shape.Size())); - reordered_W_buffer_ = BufferUniquePtr(reordered_W, BufferDeleter(alloc)); - - ReorderFilter(Wdata, reordered_W, output_channels, group_input_channels, kernel_size); - - is_W_packed_ = true; - is_packed = true; - return Status::OK(); -} - -Status QLinearConv::Compute(OpKernelContext* context) const { - const Tensor* X = context->Input(0); - const Tensor* W = is_W_packed_ ? nullptr : context->Input(3); - const auto& W_shape = is_W_packed_ ? W_shape_ : W->Shape(); - - const int64_t N = X->Shape()[0]; - const int64_t M = W_shape[0]; - - // validate offsets - const Tensor* X_zero_point = context->Input(2); - const Tensor* W_zero_point = context->Input(5); - const Tensor* Y_zero_point = context->Input(7); - ORT_ENFORCE(IsScalarOr1ElementVector(X_zero_point), - "QLinearConv : input zero point must be a scalar or 1D tensor of size 1"); - ORT_ENFORCE(IsScalarOr1ElementVector(Y_zero_point), - "QLinearConv : result zero point must be a scalar or 1D tensor of size 1"); - - auto X_zero_point_value = *(X_zero_point->template Data()); - auto Y_zero_point_value = *(Y_zero_point->template Data()); - - const auto& W_zero_point_shape = W_zero_point->Shape(); - if (W_zero_point_shape.NumDimensions() == 0 || - (W_zero_point_shape.NumDimensions() == 1 && (W_zero_point_shape[0] == 1 || W_zero_point_shape[0] == M))) { - const int64_t W_zero_point_size = W_zero_point_shape.Size(); - const auto* W_zero_point_data = W_zero_point->template Data(); - for (int64_t i = 0; i < W_zero_point_size; i++) { - if (W_zero_point_data[i] != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "QLinearConv : filter zero point must be zero"); - } - } - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "QLinearConv : filter zero point shape invalid"); - } - - // validate scale - const Tensor* X_scale = context->Input(1); - const Tensor* W_scale = context->Input(4); - const Tensor* Y_scale = context->Input(6); - ORT_ENFORCE(IsScalarOr1ElementVector(X_scale), - "QLinearConv : input scale must be a scalar or 1D tensor of size 1"); - ORT_ENFORCE(IsScalarOr1ElementVector(Y_scale), - "QLinearConv : result scale must be a scalar or 1D tensor of size 1"); - - auto X_scale_value = *(X_scale->template Data()); - auto Y_scale_value = *(Y_scale->template Data()); - - std::vector output_scales; - const auto& W_scale_shape = W_scale->Shape(); - if (W_scale_shape.NumDimensions() == 0 || - (W_scale_shape.NumDimensions() == 1 && (W_scale_shape[0] == 1 || W_scale_shape[0] == M))) { - const int64_t W_scale_size = W_scale_shape.Size(); - const auto* W_scale_data = W_scale->template Data(); - output_scales.resize(static_cast(W_scale_size)); - for (int64_t i = 0; i < W_scale_size; i++) { - output_scales[i] = (X_scale_value * W_scale_data[i] / Y_scale_value); - } - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "QLinearConv : filter scale shape invalid"); - } - - const Tensor* B = context->Input(8); - - ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W_shape)); - - std::vector kernel_shape; - ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W_shape, kernel_shape)); - - const size_t kernel_rank = kernel_shape.size(); - - std::vector pads(conv_attrs_.pads); - if (pads.empty()) { - pads.resize(kernel_rank * 2, 0); - } - std::vector dilations(conv_attrs_.dilations); - if (dilations.empty()) { - dilations.resize(kernel_rank, 1); - } - std::vector strides(conv_attrs_.strides); - if (strides.empty()) { - strides.resize(kernel_rank, 1); - } - - std::vector Y_dims({N, M}); - TensorShape input_shape = X->Shape().Slice(2); - ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims)); - Tensor* Y = context->Output(0, TensorShape(Y_dims)); - TensorShape output_shape = Y->Shape().Slice(2); - - // Bail out early if one of the dimensions is zero. - if (Y->Shape().Size() == 0) { - return Status::OK(); - } - - const int64_t input_image_size = input_shape.Size(); - const int64_t output_image_size = output_shape.Size(); - const int64_t kernel_size = TensorShape(kernel_shape).Size(); - - AllocatorPtr alloc; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); - - // Handle the case of a dynamic weight filter. - BufferUniquePtr reordered_W_buffer; - uint8_t* reordered_W = nullptr; - bool use_reordered_W = true; - bool is_W_signed = is_W_signed_; -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 - if (packed_W_buffer_) { - use_reordered_W = false; - } -#endif - if (use_reordered_W) { - if (reordered_W_buffer_) { - reordered_W = static_cast(reordered_W_buffer_.get()); - } else { - // Weight tensor was not constant or prepacking is disabled. - reordered_W = static_cast(alloc->Alloc(SafeInt(sizeof(uint8_t)) * W_shape.Size())); - reordered_W_buffer = BufferUniquePtr(reordered_W, BufferDeleter(alloc)); - ReorderFilter(static_cast(W->DataRaw()), - reordered_W, - static_cast(M), - static_cast(W_shape[1]), - static_cast(kernel_size)); - is_W_signed = W->IsDataType(); - } - } - - int64_t group_count = conv_attrs_.group; - int64_t group_input_channels = W_shape[1]; - int64_t group_output_channels = M / group_count; - - // Test for depthwise convolution. - const bool is_depthwise_conv = (use_reordered_W && group_input_channels == 1 && group_output_channels == 1); - if (is_depthwise_conv) { - // Update the input and output channels to the number of groups in order to - // reuse as much of the below standard convolution path. - group_input_channels = group_count; - group_output_channels = group_count; - group_count = 1; - } - - const int64_t X_offset = group_input_channels * input_image_size; - const int64_t Y_offset = group_output_channels * output_image_size; - const int64_t kernel_dim = group_input_channels * kernel_size; - const int64_t col_buffer_size = kernel_dim * output_image_size; - - // Use an intermediate int32_t buffer for the GEMM computation before - // requantizing to the output type. - auto gemm_output_data = alloc->Alloc(SafeInt(sizeof(int32_t)) * Y_offset); - BufferUniquePtr gemm_output_buffer(gemm_output_data, BufferDeleter(alloc)); - auto* gemm_output = static_cast(gemm_output_buffer.get()); - - const auto* Xdata = X->template Data(); - const auto* Bdata = B != nullptr ? B->template Data() : nullptr; - auto* Ydata = Y->template MutableData(); - - auto* transpose_input = static_cast(alloc->Alloc(SafeInt(sizeof(uint8_t)) * X_offset)); - BufferUniquePtr transpose_input_buffer(transpose_input, BufferDeleter(alloc)); - - auto* transpose_output = static_cast(alloc->Alloc(SafeInt(sizeof(uint8_t)) * Y_offset)); - BufferUniquePtr transpose_output_buffer(transpose_output, BufferDeleter(alloc)); - - BufferUniquePtr col_buffer; - - // Pointwise convolutions can use the original input tensor in place, - // otherwise a temporary buffer is required for the im2col transform. - if (kernel_size != 1 || !conv_attrs_.HasStridesOneAndNoPadding()) { - auto* col_data = alloc->Alloc(SafeInt(sizeof(uint8_t)) * col_buffer_size); - col_buffer = BufferUniquePtr(col_data, BufferDeleter(alloc)); - } - auto* col_buffer_data = static_cast(col_buffer.get()); - - // Replicate the logic from MlasGemmU8X8Schedule to control the number of - // worker threads used for the convolution. - constexpr int32_t maximum_thread_count = 16; - constexpr double thread_complexity = static_cast(64 * 1024); - - const double complexity = static_cast(output_image_size) * - static_cast(group_output_channels) * - static_cast(kernel_dim); - - int32_t thread_count = maximum_thread_count; - if (complexity < thread_complexity * maximum_thread_count) { - thread_count = static_cast(complexity / thread_complexity) + 1; - } - if (thread_count > output_image_size) { - // Ensure that every thread produces at least one output. - thread_count = static_cast(output_image_size); - } - - concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); - thread_count = std::min(thread_count, concurrency::ThreadPool::DegreeOfParallelism(thread_pool)); - - for (int64_t image_id = 0; image_id < N; ++image_id) { - for (int64_t group_id = 0; group_id < group_count; ++group_id) { - // Transpose the input from channels first (NCHW) to channels last (NHWC). - MlasTranspose(Xdata, - transpose_input, - static_cast(group_input_channels), - static_cast(input_image_size)); - - if (kernel_rank != 2 && col_buffer_data != nullptr) { - // Try big Im2ColNd in this case, parallel it later if needed - math::Im2colNd()( - transpose_input, - input_shape.GetDims().data(), - output_shape.GetDims().data(), - kernel_dim, - kernel_shape.data(), - strides.data(), - dilations.data(), - pads.data(), - static_cast(kernel_rank), - col_buffer_data, - false, - X_zero_point_value); - } - - auto conv_worker = [&](ptrdiff_t batch) { - auto work = concurrency::ThreadPool::PartitionWork(batch, thread_count, static_cast(output_image_size)); - int64_t output_start = static_cast(work.start); - int64_t output_count = static_cast(work.end - work.start); - - // Prepare the im2col transformation or use the input buffer directly for - // pointwise convolutions. - uint8_t* worker_gemm_input; - if (col_buffer_data != nullptr) { - worker_gemm_input = col_buffer_data + output_start * kernel_dim; - if (kernel_rank == 2) { - math::Im2col()( - transpose_input, - group_input_channels, - input_shape[0], - input_shape[1], - kernel_shape[0], - kernel_shape[1], - dilations[0], - dilations[1], - pads[0], - pads[1], - strides[0], - strides[1], - output_shape[1], - output_start, - output_count, - worker_gemm_input, - X_zero_point_value); - } - } else { - worker_gemm_input = transpose_input + output_start * kernel_dim; - } - - auto* worker_gemm_output = gemm_output + output_start * group_output_channels; - auto* worker_transpose_output = transpose_output + output_start * group_output_channels; - - if (is_depthwise_conv) { - if (is_W_signed) { - MlasConvDepthwise(worker_gemm_input, - X_zero_point_value, - reinterpret_cast(reordered_W), - static_cast(0), - worker_gemm_output, - static_cast(group_output_channels), - static_cast(output_count), - static_cast(kernel_size)); - } else { - MlasConvDepthwise(worker_gemm_input, - X_zero_point_value, - reordered_W, - static_cast(0), - worker_gemm_output, - static_cast(group_output_channels), - static_cast(output_count), - static_cast(kernel_size)); - } - } else { -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 - if (packed_W_buffer_) { - MlasGemm(static_cast(output_count), - static_cast(group_output_channels), - static_cast(kernel_dim), - worker_gemm_input, - static_cast(kernel_dim), - X_zero_point_value, - static_cast(packed_W_buffer_.get()) + group_id * packed_W_size_, - 0, - is_W_signed, - worker_gemm_output, - static_cast(group_output_channels), - nullptr); - } else -#endif - { - MlasGemm(static_cast(output_count), - static_cast(group_output_channels), - static_cast(kernel_dim), - worker_gemm_input, - static_cast(kernel_dim), - X_zero_point_value, - reordered_W + group_id * group_output_channels, - static_cast(M), - 0, - is_W_signed, - worker_gemm_output, - static_cast(group_output_channels), - nullptr); - } - } - - if (output_scales.size() == 1) { - MlasRequantizeOutputColumn(worker_gemm_output, - worker_transpose_output, - Bdata != nullptr ? Bdata + group_id * group_output_channels : nullptr, - static_cast(output_count), - static_cast(group_output_channels), - output_scales[0], - Y_zero_point_value); - } else { - MlasRequantizeOutputColumn(worker_gemm_output, - worker_transpose_output, - Bdata != nullptr ? Bdata + group_id * group_output_channels : nullptr, - static_cast(output_count), - static_cast(group_output_channels), - output_scales.data() + group_id * group_output_channels, - Y_zero_point_value); - } - }; - - concurrency::ThreadPool::TrySimpleParallelFor(thread_pool, thread_count, conv_worker); - - // Transpose the output from channels last (NHWC) to channels first (NCHW). - MlasTranspose(transpose_output, - Ydata, - static_cast(output_image_size), - static_cast(group_output_channels)); - - Xdata += X_offset; - Ydata += Y_offset; - } - } - - return Status::OK(); -} - #endif } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc b/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc index 9f6cc00b2e..ecf038bf22 100644 --- a/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc @@ -401,6 +401,7 @@ class QLinearConvOpTester { const int64_t kernel_size = std::accumulate( kernel_shape, kernel_shape + kernel_rank, 1LL, std::multiplies()); const int32_t X_zero_point = X_.zero_point_; + const int32_t W_zero_point = W_.zero_point_; const T1* Xdata = X_.data_.data(); T1* Ydata = Y_data.data(); @@ -434,7 +435,7 @@ class QLinearConvOpTester { input_offset *= input_shape[axis]; input_offset += input_dim; } - int32_t w_value = static_cast(*weight_data++); + int32_t w_value = static_cast(*weight_data++) - W_zero_point; if (!is_padding) { int32_t x_value = static_cast(input_image[input_offset]) - X_zero_point; sum += x_value * w_value; @@ -507,7 +508,11 @@ class QLinearConvOpTester { } void GenerateRandomWeights(const std::vector& shape, float scale, T2 zero_point) { - GenerateRandom(W_, shape, scale, zero_point, -63, 63); + if (std::is_signed::value) { + GenerateRandom(W_, shape, scale, zero_point, -63, 63); + } else { + GenerateRandom(W_, shape, scale, zero_point, 0, 255); + } } void SetWeightScales(const std::vector& scales) { @@ -600,6 +605,15 @@ TEST(QLinearConvTest, Conv2D_U8S8_Pointwise) { test.Run(); } +TEST(QLinearConvTest, Conv2D_U8U8_Pointwise) { + QLinearConvOpTester test; + test.GenerateRandomInput({3, 24, 19, 19}, .05f, 4); + test.GenerateRandomWeights({32, 24, 1, 1}, .105f, 126); + test.GenerateRandomBias(); + test.SetOutputScaleAndZeroPoint(.75f, 114); + test.Run(); +} + TEST(QLinearConvTest, Conv3D_U8S8_Pointwise) { QLinearConvOpTester test; test.GenerateRandomInput({2, 2, 15, 11, 6}, .05f, 4); @@ -708,7 +722,7 @@ TEST(QLinearConvTest, Conv2D_U8S8_Groups_PerChannel) { test.Run(); } -TEST(QLinearConvTest, Conv2D_U8S8_Depthwise5x5) { +TEST(QLinearConvTest, Conv2D_U8S8_Depthwise) { QLinearConvOpTester test; test.GenerateRandomInput({1, 24, 25, 25}, .03f, 12); test.GenerateRandomWeights({24, 1, 5, 5}, .10f, 0); @@ -719,12 +733,22 @@ TEST(QLinearConvTest, Conv2D_U8S8_Depthwise5x5) { test.Run(); } -TEST(QLinearConvTest, Conv2D_U8S8_Depthwise1x1) { +TEST(QLinearConvTest, Conv2D_U8U8_Depthwise) { + QLinearConvOpTester test; + test.GenerateRandomInput({1, 30, 25, 25}, .03f, 12); + test.GenerateRandomWeights({30, 1, 3, 3}, .10f, 167); + test.GenerateRandomBias(); + test.SetPads({2, 0, 2, 0}); + test.SetGroups(30); + test.SetOutputScaleAndZeroPoint(.76f, 88); + test.Run(); +} + +TEST(QLinearConvTest, Conv2D_U8S8_DepthwisePointwise) { // Tests the combination of using the depthwise convolution path along with the // pointed convolution optimization that avoids im2col. QLinearConvOpTester test; test.GenerateRandomInput({1, 27, 18, 18}, .03f, 12); - test.GenerateRandomInput({1, 27, 4, 4}, .03f, 12); test.GenerateRandomWeights({27, 1, 1, 1}, .05f, 0); test.GenerateRandomBias(); test.SetGroups(27); @@ -732,6 +756,16 @@ TEST(QLinearConvTest, Conv2D_U8S8_Depthwise1x1) { test.Run(); } +TEST(QLinearConvTest, Conv3D_U8S8_Depthwise) { + QLinearConvOpTester test; + test.GenerateRandomInput({1, 16, 15, 11, 13}, .02f, 135); + test.GenerateRandomWeights({16, 1, 3, 3, 3}, .09f, 0); + test.GenerateRandomBias(); + test.SetGroups(16); + test.SetOutputScaleAndZeroPoint(.85f, 112); + test.Run(); +} + #endif } // namespace