From 1942e40e05ade03a3aeecfa2435f03ccf897415e Mon Sep 17 00:00:00 2001 From: Jing Fang <126209182+fajin-corp@users.noreply.github.com> Date: Thu, 26 Sep 2024 20:55:40 +0000 Subject: [PATCH] [ARM64] MatMulNBits: use neon instrinsics to convert between fp16 and fp32 (#22195) ### Description For fp16 Atype, the fallback operation is convert the data to fp32 and calculate. Added neon intrinsics version to speed up the conversion. Store address alignment and loop unrolling have insignificant impact on latency so they are omitted. |Benchmark | Time | CPU | |--------------|---------------------------------------------|--------------------| |M_ConvertF16ToF32/baseline/real_time | 1076961 ns | 1083398 ns | |M_ConvertF16ToF32/aligned:0/real_time | 46785 ns | 46516 ns | |M_ConvertF16ToF32/aligned:1/real_time | 46631 ns | 46391 ns | |M_ConvertF16ToF32_unroll2/aligned:0/real_time | 44074 ns | 44392 ns | |M_ConvertF16ToF32_unroll2/aligned:1/real_time | 44726 ns | 45226 ns | |M_ConvertF32ToF16/baseline/real_time | 520109 ns | 527329 ns | |M_ConvertF32ToF16/aligned:0/real_time | 73610 ns | 74015 ns | |M_ConvertF32ToF16/aligned:1/real_time | 71557 ns | 71525 ns | |M_ConvertF32ToF16_unroll2/aligned:0/real_time | 64227 ns | 63374 ns | |M_ConvertF32ToF16_unroll2/aligned:1/real_time | 67428 ns | 67989 ns | ### Motivation and Context speed up fallback implementation of Fp16 MatMulNBits --- cmake/onnxruntime_mlas.cmake | 3 + .../cpu/quantization/matmul_nbits.cc | 208 ++++++++++++------ .../core/mlas/lib/fp16_neon_common.cpp | 164 ++++++++++++++ onnxruntime/core/mlas/lib/mlasi.h | 5 +- onnxruntime/core/mlas/lib/platform.cpp | 7 +- .../mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp | 1 + .../mlas/lib/sqnbitgemm_kernel_neon_int8.cpp | 1 + .../mlas/bench/bench_fp16_neon_common.cpp | 54 +++++ .../unittest/test_sqnbitgemm_neon_fp16.cpp | 82 +++++++ 9 files changed, 459 insertions(+), 66 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/fp16_neon_common.cpp create mode 100644 onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_sqnbitgemm_neon_fp16.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index e35c83ba45..0ba4694c32 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -88,6 +88,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + ${MLAS_SRC_DIR}/fp16_neon_common.cpp ) set(mlas_platform_preprocess_srcs @@ -382,6 +383,7 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/fp16_neon_common.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") @@ -391,6 +393,7 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index f8f07b6e28..67af00beab 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -142,6 +142,8 @@ class MatMulNBits final : public OpKernel { const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_{}; size_t packed_b_size_{0}; + IAllocatorUniquePtr scales_fp32_{}; + IAllocatorUniquePtr bias_fp32_{}; bool has_zp_input_{false}; #if defined(ORT_NEURAL_SPEED) @@ -175,30 +177,9 @@ class MatMulNBits final : public OpKernel { const MatMulComputeHelper& helper) const { ORT_THROW("ComputeBPacked is not supported for T1 type."); } - - void PackScale(const Tensor& tensor) { - ORT_THROW("PackScale is not supported for T1 type."); - } }; -#ifdef MLAS_TARGET_AMD64_IX86 -template <> -void MatMulNBits::PackScale(const Tensor& tensor) { - auto sptr = tensor.Data(); - std::vector scales_v(static_cast(tensor.Shape().Size())); - MlasConvertHalfToFloatBuffer(sptr, &scales_v[0], scales_v.size()); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), &scales_v[0], - has_zp_input_, nullptr, nullptr); -} - -template <> -void MatMulNBits::PackScale(const Tensor& tensor) { - auto sptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, - has_zp_input_, nullptr, nullptr); -} -#endif - +#if defined(ORT_NEURAL_SPEED) template Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, @@ -207,7 +188,6 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All if (has_g_idx_ || has_unquantized_zero_point_) { return Status::OK(); } -#if defined(ORT_NEURAL_SPEED) if (!all_constant_) { return Status::OK(); @@ -259,8 +239,21 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All is_packed = true; } + return Status::OK(); +} + #else // defined(ORT_NEURAL_SPEED) + +template +Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); + is_packed = false; + if (has_g_idx_ || has_unquantized_zero_point_) { + return Status::OK(); + } + if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { return Status::OK(); } @@ -276,20 +269,77 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All } else if (compute_type_ == CompInt8) { #ifdef MLAS_TARGET_AMD64_IX86 if (input_idx == InputIndex::scales && packed_b_ != nullptr) { - PackScale(tensor); + auto sptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, + has_zp_input_, nullptr, nullptr); is_packed = false; } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { auto zptr = tensor.Data(); MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); is_packed = false; } -#endif +#endif // MLAS_TARGET_AMD64_IX86 } -#endif // defined(ORT_NEURAL_SPEED) return Status::OK(); } +template <> +Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + ORT_UNUSED_PARAMETER(prepacked_weights); + + if (input_idx == InputIndex::scales || input_idx == InputIndex::bias) { + auto sptr = tensor.Data(); + auto tensor_size = static_cast(tensor.Shape().Size()); + auto ptr = IAllocator::MakeUniquePtr(alloc, tensor_size, true); + MlasConvertHalfToFloatBuffer(sptr, ptr.get(), tensor_size); + if (input_idx == InputIndex::scales) { + scales_fp32_ = std::move(ptr); + } else { + bias_fp32_ = std::move(ptr); + } + } + + is_packed = false; + if (has_g_idx_ || has_unquantized_zero_point_) { + return Status::OK(); + } + + if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { + return Status::OK(); + } + if (input_idx == InputIndex::B) { + packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type_); + if (packed_b_size_ == 0) { + return Status::OK(); + } + auto qptr = tensor.DataRaw(); + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), + nullptr, has_zp_input_, nullptr, nullptr); + is_packed = true; + } else if (compute_type_ == CompInt8) { +#ifdef MLAS_TARGET_AMD64_IX86 + if (input_idx == InputIndex::scales && packed_b_ != nullptr) { + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), + scales_fp32_.get(), has_zp_input_, nullptr, nullptr); + is_packed = false; + } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { + auto zptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), + nullptr, has_zp_input_, zptr, nullptr); + is_packed = false; + } +#endif // MLAS_TARGET_AMD64_IX86 + } + + return Status::OK(); +} + +#endif // !defined(ORT_NEURAL_SPEED) + template Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) { @@ -348,7 +398,8 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( M, N, K, batch_count, nbits_, block_size_, compute_type_); if (workspace_size > 0) { - workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); + // Use reserve since no caching is needed + workspace = IAllocator::MakeUniquePtr(allocator, workspace_size, true); } InlinedVector data(batch_count); @@ -397,22 +448,36 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( M, N, K, batch_count, nbits_, block_size_, compute_type_); if (workspace_size > 0) { - workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); + // Use reserve since no caching is needed + workspace = IAllocator::MakeUniquePtr(allocator, workspace_size, true); } - auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(a->Shape().Size())); - MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), static_cast(a->Shape().Size())); + auto a_size = static_cast(a->Shape().Size()); + auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, a_size, true); + MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), a_size); - auto tmp_scales_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(scales->Shape().Size())); - MlasConvertHalfToFloatBuffer(scales_data, tmp_scales_data_ptr.get(), static_cast(scales->Shape().Size())); - - std::vector bias_data_v; - if (bias_data != nullptr) { - bias_data_v.resize(static_cast(bias->Shape().Size())); - MlasConvertHalfToFloatBuffer(bias_data, &bias_data_v[0], bias_data_v.size()); + float* scales_ptr = nullptr; + if (!scales_fp32_) { + auto scales_temp = IAllocator::MakeUniquePtr(allocator, static_cast(scales->Shape().Size()), true); + MlasConvertHalfToFloatBuffer(scales_data, scales_temp.get(), static_cast(scales->Shape().Size())); + scales_ptr = scales_temp.get(); + } else { + scales_ptr = scales_fp32_.get(); } - std::vector C_v(static_cast(y->Shape().Size())); + float* bias_ptr = nullptr; + if (bias_data) { + if (!bias_fp32_) { + auto bias_temp = IAllocator::MakeUniquePtr(allocator, static_cast(bias->Shape().Size()), true); + MlasConvertHalfToFloatBuffer(bias_data, bias_temp.get(), static_cast(bias->Shape().Size())); + bias_ptr = bias_temp.get(); + } else { + bias_ptr = bias_fp32_.get(); + } + } + + size_t c_size = static_cast(y->Shape().Size()); + std::vector c_v(c_size); InlinedVector data(batch_count); for (size_t i = 0; i < batch_count; ++i) { @@ -424,15 +489,15 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, } #endif data[i].PackedQuantBData = static_cast(packed_b_.get()); - data[i].QuantBScale = tmp_scales_data_ptr.get(); + data[i].QuantBScale = scales_ptr; data[i].QuantBZeroPoint = zero_points_data; - data[i].Bias = bias_data != nullptr ? &bias_data_v[0] : nullptr; - data[i].C = &C_v[0] + helper.OutputOffsets()[i]; + data[i].Bias = bias ? bias_ptr : nullptr; + data[i].C = c_v.data() + helper.OutputOffsets()[i]; data[i].ldc = N; } MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type_, data.data(), workspace.get(), thread_pool); - MlasConvertFloatToHalfBuffer(&C_v[0], y_data, C_v.size()); + MlasConvertFloatToHalfBuffer(c_v.data(), y_data, c_size); return Status::OK(); } @@ -461,7 +526,8 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, const size_t lda = helper.Lda(false); const size_t ldb = helper.Ldb(true); - auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + // TODO(fajin): move B dequant to prepack + auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_, true); if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { // dequantize b, only 4b quantization is supported for now @@ -561,12 +627,6 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); auto* y_data = y->MutableData(); - const float* scales_data_; - std::vector scales_data_v; - scales_data_v.resize(static_cast(scales->Shape().Size())); - MlasConvertHalfToFloatBuffer(scales_data, &scales_data_v[0], scales_data_v.size()); - scales_data_ = &scales_data_v[0]; - const size_t batch_count = helper.OutputOffsets().size(); const size_t M = static_cast(helper.M()); const size_t N = static_cast(helper.N()); @@ -574,14 +634,25 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, const size_t lda = helper.Lda(false); const size_t ldb = helper.Ldb(true); - auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + float* scales_ptr = nullptr; + if (!scales_fp32_) { + auto scales_size = static_cast(scales->Shape().Size()); + auto temp_scales = IAllocator::MakeUniquePtr(allocator, scales_size, true); + MlasConvertHalfToFloatBuffer(scales_data, temp_scales.get(), scales_size); + scales_ptr = temp_scales.get(); + } else { + scales_ptr = scales_fp32_.get(); + } + + // TODO(fajin): move B dequant to prepack + auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_, true); if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { // dequantize b, only 4b quantization is supported for now MlasDequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data_, // quantization scales + scales_ptr, // quantization scales static_cast(zero_points_data), // quantization zero points static_cast(block_size_), // quantization block size column_wise_quant_, // columnwise quantization or row-wise @@ -595,7 +666,7 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, DequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data_, // quantization scales + scales_ptr, // quantization scales static_cast(zero_points_data), // quantization zero points reorder_idx_data, static_cast(block_size_), // quantization block size @@ -607,7 +678,7 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, DequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data_, // quantization scales + scales_ptr, // quantization scales static_cast(zero_points_data), // quantization zero points reorder_idx_data, static_cast(block_size_), // quantization block size @@ -623,9 +694,14 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, #endif std::vector data(batch_count); - auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(a->Shape().Size())); - MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), static_cast(a->Shape().Size())); - auto tmp_c_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(y->Shape().Size())); + + auto a_size = static_cast(a->Shape().Size()); + auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, a_size, true); + MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), a_size); + + auto c_size = static_cast(y->Shape().Size()); + auto tmp_c_ptr = IAllocator::MakeUniquePtr(allocator, c_size, true); + for (size_t i = 0; i < batch_count; i++) { data[i].BIsPacked = false; data[i].A = tmp_a_data_ptr.get() + helper.LeftOffsets()[i]; @@ -640,24 +716,28 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, // if there is a bias input, copy bias values into C and set beta to 1.0f if (bias) { - auto tmp_bias_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(bias->Shape().Size())); - MlasConvertHalfToFloatBuffer(bias->Data(), - tmp_bias_data_ptr.get(), - static_cast(bias->Shape().Size())); + float* bias_ptr = nullptr; + const size_t bias_size = static_cast(bias->Shape().Size()); + if (!bias_fp32_) { + auto bias_temp = IAllocator::MakeUniquePtr(allocator, bias_size, true); + MlasConvertHalfToFloatBuffer(bias->Data(), bias_temp.get(), bias_size); + bias_ptr = bias_temp.get(); + } else { + bias_ptr = bias_fp32_.get(); + } for (size_t i = 0; i < batch_count; ++i) { float* C_row = data[i].C; const size_t ldc = data[i].ldc; for (size_t m = 0; m < M; ++m) { - std::copy(tmp_bias_data_ptr.get(), tmp_bias_data_ptr.get() + bias->Shape().Size(), C_row); + std::copy(bias_ptr, bias_ptr + bias_size, C_row); C_row += ldc; } data[i].beta = 1.0f; } } - MlasGemmBatch(CblasNoTrans, CblasTrans, - M, N, K, data.data(), batch_count, thread_pool); - MlasConvertFloatToHalfBuffer(tmp_c_ptr.get(), y_data, static_cast(y->Shape().Size())); + MlasGemmBatch(CblasNoTrans, CblasTrans, M, N, K, data.data(), batch_count, thread_pool); + MlasConvertFloatToHalfBuffer(tmp_c_ptr.get(), y_data, c_size); return Status::OK(); } diff --git a/onnxruntime/core/mlas/lib/fp16_neon_common.cpp b/onnxruntime/core/mlas/lib/fp16_neon_common.cpp new file mode 100644 index 0000000000..29734c2277 --- /dev/null +++ b/onnxruntime/core/mlas/lib/fp16_neon_common.cpp @@ -0,0 +1,164 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + fp16_neon_common.cpp + +Abstract: + + This module implements the common kernels for ARM NEON specific to float16. + +--*/ + +#include "mlasi.h" + +#include "arm_neon.h" + +// This file is enabled in cmake only if ARM64 is defined and not on Apple platforms +// The cmake condition is equivalent to MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64. +// Therefore omit the MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64 macro in this file. + +MLAS_FORCEINLINE +size_t +StoreFp32Lane(float* dest, float32x4_t src, size_t count) +{ + if (count == 3) { + vst1q_lane_f32(dest + 0, src, 0); + vst1q_lane_f32(dest + 1, src, 1); + vst1q_lane_f32(dest + 2, src, 2); + return 3; + } else if (count == 2) { + vst1q_lane_f32(dest + 0, src, 0); + vst1q_lane_f32(dest + 1, src, 1); + return 2; + } else if (count == 1) { + vst1q_lane_f32(dest + 0, src, 0); + return 1; + } + + return 0; +} + +void +MlasCastF16ToF32KernelNeon(const unsigned short* src, float* dest, size_t count) +{ + // 4 float16 alignment + auto* src_aligned = reinterpret_cast((reinterpret_cast(src) + 7) & ~7); + auto pre_count = std::min(static_cast(src_aligned - src), count); + size_t i = 0; + + // Handle leading unaligned src + if (pre_count > 0) { + float16x4_t fp16v4; + std::memcpy(&fp16v4, src, pre_count * sizeof(unsigned short)); + float32x4_t fp32v4 = vcvt_f32_f16(fp16v4); + + i = StoreFp32Lane(dest, fp32v4, pre_count); + } + + // aligned src + for (; i + 7 < count; i += 8) + { + float16x4_t fp16v4_0 = vreinterpret_f16_u16(vld1_u16(src + i)); + float32x4_t fp32v4_0 = vcvt_f32_f16(fp16v4_0); + vst1q_f32(dest + i, fp32v4_0); + + float16x4_t fp16v4_1 = vreinterpret_f16_u16(vld1_u16(src + i + 4)); + float32x4_t fp32v4_1 = vcvt_f32_f16(fp16v4_1); + vst1q_f32(dest + i + 4, fp32v4_1); + } + + if (i + 3 < count) + { + float16x4_t fp16v4_0 = vreinterpret_f16_u16(vld1_u16(src + i)); + float32x4_t fp32v4_0 = vcvt_f32_f16(fp16v4_0); + vst1q_f32(dest + i, fp32v4_0); + i += 4; + } + + // Handle trailing unaligned src + auto post_count = count - i; + if (post_count > 0) + { + float16x4_t fp16v4; + std::memcpy(&fp16v4, src + i, post_count * sizeof(unsigned short)); + float32x4_t fp32v4 = vcvt_f32_f16(fp16v4); + + StoreFp32Lane(dest + i, fp32v4, post_count); + } +} + +MLAS_FORCEINLINE +size_t +StoreU16Lane(unsigned short* dest, uint16x4_t src, size_t count) +{ + if (count == 3) { + vst1_lane_u16(dest + 0, src, 0); + vst1_lane_u16(dest + 1, src, 1); + vst1_lane_u16(dest + 2, src, 2); + return 3; + } else if (count == 2) { + vst1_lane_u16(dest + 0, src, 0); + vst1_lane_u16(dest + 1, src, 1); + return 2; + } else if (count == 1) { + vst1_lane_u16(dest + 0, src, 0); + return 1; + } + + return 0; +} + +void +MlasCastF32ToF16KernelNeon(const float* src, unsigned short* dest, size_t count) +{ + // 4 float32 alignment + auto* src_aligned = reinterpret_cast((reinterpret_cast(src) + 15) & ~15); + auto pre_count = std::min(static_cast(src_aligned - src), count); + size_t i = 0; + + // Handle leading unaligned src + if (pre_count > 0) + { + float32x4_t fp32v4; + std::memcpy(&fp32v4, src, pre_count * sizeof(float)); + uint16x4_t u16v4 = vreinterpret_u16_f16(vcvt_f16_f32(fp32v4)); + + i = StoreU16Lane(dest, u16v4, pre_count); + } + + // aligned src + for (; i + 7 < count; i += 8) + { + float32x4_t fp32v4_0 = vld1q_f32(src + i); + float16x4_t fp16v4_0 = vcvt_f16_f32(fp32v4_0); + vst1_u16(dest + i, vreinterpret_u16_f16(fp16v4_0)); + + float32x4_t fp32v4_1 = vld1q_f32(src + i + 4); + float16x4_t fp16v4_1 = vcvt_f16_f32(fp32v4_1); + vst1_u16(dest + i + 4, vreinterpret_u16_f16(fp16v4_1)); + } + + if (i + 3 < count) + { + float32x4_t fp32v4_0 = vld1q_f32(src + i); + float16x4_t fp16v4_0 = vcvt_f16_f32(fp32v4_0); + vst1_u16(dest + i, vreinterpret_u16_f16(fp16v4_0)); + i += 4; + } + + // Handle trailing unaligned src + auto post_count = count - i; + if (post_count > 0) + { + float32x4_t fp32v4; + std::memcpy(&fp32v4, src + i, post_count * sizeof(float)); + uint16x4_t u16v4 = vreinterpret_u16_f16(vcvt_f16_f32(fp32v4)); + + StoreU16Lane(dest + i, u16v4, post_count); + } +} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 96ba8c6c92..13ea8d96c2 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -893,6 +893,10 @@ extern "C" { MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelAvx2; #endif +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelNeon; + MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelNeon; +#endif } // @@ -2603,4 +2607,3 @@ MlasPackInt4Elements(uint8_t* Output, UnpackedType ValueLow, UnpackedType ValueH static_assert(std::is_same_v || std::is_same_v); *Output = static_cast(((ValueHigh & 0xF) << 4) | (ValueLow & 0xF)); } - diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 102d605227..23d29fd02f 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -20,7 +20,7 @@ Abstract: #include #include -#if defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_POWER) #if defined(__linux__) #include #elif defined(_AIX) @@ -576,6 +576,11 @@ Return Value: } #endif +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelNeon; + this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelNeon; +#endif + #endif // MLAS_TARGET_ARM64 #if defined(MLAS_TARGET_POWER) this->GemmFloatKernel = MlasSgemmKernel; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp index ca64ebe3b1..12ddc42506 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp @@ -12,6 +12,7 @@ Abstract: This module implements the float/quantized n-bit integer matrix multiplication kernels for ARM NEON specific to + input type T1 as float32 and MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompFp32. --*/ diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp index ec5cdbc752..0d62ea37b7 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -12,6 +12,7 @@ Abstract: This module implements the float/quantized n-bit integer matrix multiplication kernels for ARM NEON specific to + input type T1 as float32 and MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompInt8. --*/ diff --git a/onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp b/onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp new file mode 100644 index 0000000000..1dccbe44aa --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp @@ -0,0 +1,54 @@ +#include "bench_util.h" +#include "core/mlas/lib/mlasi.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +void BM_ConvertF16ToF32(benchmark::State& state) { + bool aligned = static_cast(state.range(0)); + const size_t count = 1 << 18; + auto src = RandomVectorUniform(count, 0, 60000); + auto dst = std::vector(count + 16); + auto aligned_dst = (reinterpret_cast(dst.data()) + 15) & (~15); + float* dst_start = aligned ? reinterpret_cast(aligned_dst) + : reinterpret_cast(aligned_dst + 1); + + // Warm up + MlasCastF16ToF32KernelNeon(src.data(), dst_start, count); + + for (auto _ : state) { + MlasCastF16ToF32KernelNeon(src.data(), dst_start, count); + } +} + +void BM_ConvertF32ToF16(benchmark::State& state) { + bool aligned = static_cast(state.range(0)); + const size_t count = 1 << 18; + auto src = RandomVectorUniform(count, -30000.0f, 30000.0f); + auto dst = std::vector(count + 16); + auto aligned_dst = (reinterpret_cast(dst.data()) + 15) & (~15); + unsigned short* dst_start = aligned ? reinterpret_cast(aligned_dst) + : reinterpret_cast(aligned_dst + 1); + + // Warm up + MlasCastF32ToF16KernelNeon(src.data(), dst_start, count); + + for (auto _ : state) { + MlasCastF32ToF16KernelNeon(src.data(), dst_start, count); + } +} + +BENCHMARK(BM_ConvertF16ToF32) + ->UseRealTime() + ->Apply([](benchmark::internal::Benchmark* b) { + b->ArgNames({"aligned"}); + b->ArgsProduct({{0, 1}}); + }); + +BENCHMARK(BM_ConvertF32ToF16) + ->UseRealTime() + ->Apply([](benchmark::internal::Benchmark* b) { + b->ArgNames({"aligned"}); + b->ArgsProduct({{0, 1}}); + }); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm_neon_fp16.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm_neon_fp16.cpp new file mode 100644 index 0000000000..243752bbea --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm_neon_fp16.cpp @@ -0,0 +1,82 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_sqnbitgemm_neon_fp16.cpp + +Abstract: + + Tests for MLAS n-bit int block quantized GEMM on ARM CPU with input A type T1 fp16. + +--*/ + +#include + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +class MlasNeonFp16CastTest : public MlasTestBase { + private: + void TestFp16ToFp32(size_t count) { + std::vector src(count); + std::vector dest(count); + + for (size_t i = 0; i < count; i++) { + src[i] = static_cast(i); + } + + MlasCastF16ToF32KernelNeon(src.data(), dest.data(), count); + + for (size_t i = 0; i < count; i++) { + if ((src[i] & 0x1c00) == 0x1c00) continue; // skip inf and nan + ASSERT_EQ(dest[i], MLAS_FP16::FromBits(src[i]).ToFloat()); + } + } + + void TestFp32ToFp16(size_t count) { + std::vector src(count); + std::vector dest(count); + + for (size_t i = 0; i < count; i++) { + src[i] = static_cast(i) + 0.125f; + } + + MlasCastF32ToF16KernelNeon(src.data(), dest.data(), count); + + for (size_t i = 0; i < count; i++) { + ASSERT_EQ(dest[i], MLAS_FP16(src[i]).val); + } + } + + public: + static const char* GetTestSuiteName() { + return "NeonFp16Cast"; + } + + void ExecuteShort(void) override { + TestFp16ToFp32(1 << 16); + TestFp16ToFp32(1); + TestFp16ToFp32(4); + TestFp16ToFp32(7); + TestFp32ToFp16(1 << 16); + TestFp32ToFp16(3); + TestFp32ToFp16(4); + TestFp32ToFp16(6); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)