diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index e35c83ba45..0ba4694c32 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -88,6 +88,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + ${MLAS_SRC_DIR}/fp16_neon_common.cpp ) set(mlas_platform_preprocess_srcs @@ -382,6 +383,7 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/fp16_neon_common.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") @@ -391,6 +393,7 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index f8f07b6e28..67af00beab 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -142,6 +142,8 @@ class MatMulNBits final : public OpKernel { const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_{}; size_t packed_b_size_{0}; + IAllocatorUniquePtr scales_fp32_{}; + IAllocatorUniquePtr bias_fp32_{}; bool has_zp_input_{false}; #if defined(ORT_NEURAL_SPEED) @@ -175,30 +177,9 @@ class MatMulNBits final : public OpKernel { const MatMulComputeHelper& helper) const { ORT_THROW("ComputeBPacked is not supported for T1 type."); } - - void PackScale(const Tensor& tensor) { - ORT_THROW("PackScale is not supported for T1 type."); - } }; -#ifdef MLAS_TARGET_AMD64_IX86 -template <> -void MatMulNBits::PackScale(const Tensor& tensor) { - auto sptr = tensor.Data(); - std::vector scales_v(static_cast(tensor.Shape().Size())); - MlasConvertHalfToFloatBuffer(sptr, &scales_v[0], scales_v.size()); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), &scales_v[0], - has_zp_input_, nullptr, nullptr); -} - -template <> -void MatMulNBits::PackScale(const Tensor& tensor) { - auto sptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, - has_zp_input_, nullptr, nullptr); -} -#endif - +#if defined(ORT_NEURAL_SPEED) template Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, @@ -207,7 +188,6 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All if (has_g_idx_ || has_unquantized_zero_point_) { return Status::OK(); } -#if defined(ORT_NEURAL_SPEED) if (!all_constant_) { return Status::OK(); @@ -259,8 +239,21 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All is_packed = true; } + return Status::OK(); +} + #else // defined(ORT_NEURAL_SPEED) + +template +Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); + is_packed = false; + if (has_g_idx_ || has_unquantized_zero_point_) { + return Status::OK(); + } + if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { return Status::OK(); } @@ -276,20 +269,77 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All } else if (compute_type_ == CompInt8) { #ifdef MLAS_TARGET_AMD64_IX86 if (input_idx == InputIndex::scales && packed_b_ != nullptr) { - PackScale(tensor); + auto sptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, + has_zp_input_, nullptr, nullptr); is_packed = false; } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { auto zptr = tensor.Data(); MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); is_packed = false; } -#endif +#endif // MLAS_TARGET_AMD64_IX86 } -#endif // defined(ORT_NEURAL_SPEED) return Status::OK(); } +template <> +Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + ORT_UNUSED_PARAMETER(prepacked_weights); + + if (input_idx == InputIndex::scales || input_idx == InputIndex::bias) { + auto sptr = tensor.Data(); + auto tensor_size = static_cast(tensor.Shape().Size()); + auto ptr = IAllocator::MakeUniquePtr(alloc, tensor_size, true); + MlasConvertHalfToFloatBuffer(sptr, ptr.get(), tensor_size); + if (input_idx == InputIndex::scales) { + scales_fp32_ = std::move(ptr); + } else { + bias_fp32_ = std::move(ptr); + } + } + + is_packed = false; + if (has_g_idx_ || has_unquantized_zero_point_) { + return Status::OK(); + } + + if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { + return Status::OK(); + } + if (input_idx == InputIndex::B) { + packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type_); + if (packed_b_size_ == 0) { + return Status::OK(); + } + auto qptr = tensor.DataRaw(); + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), + nullptr, has_zp_input_, nullptr, nullptr); + is_packed = true; + } else if (compute_type_ == CompInt8) { +#ifdef MLAS_TARGET_AMD64_IX86 + if (input_idx == InputIndex::scales && packed_b_ != nullptr) { + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), + scales_fp32_.get(), has_zp_input_, nullptr, nullptr); + is_packed = false; + } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { + auto zptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), + nullptr, has_zp_input_, zptr, nullptr); + is_packed = false; + } +#endif // MLAS_TARGET_AMD64_IX86 + } + + return Status::OK(); +} + +#endif // !defined(ORT_NEURAL_SPEED) + template Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) { @@ -348,7 +398,8 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( M, N, K, batch_count, nbits_, block_size_, compute_type_); if (workspace_size > 0) { - workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); + // Use reserve since no caching is needed + workspace = IAllocator::MakeUniquePtr(allocator, workspace_size, true); } InlinedVector data(batch_count); @@ -397,22 +448,36 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( M, N, K, batch_count, nbits_, block_size_, compute_type_); if (workspace_size > 0) { - workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); + // Use reserve since no caching is needed + workspace = IAllocator::MakeUniquePtr(allocator, workspace_size, true); } - auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(a->Shape().Size())); - MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), static_cast(a->Shape().Size())); + auto a_size = static_cast(a->Shape().Size()); + auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, a_size, true); + MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), a_size); - auto tmp_scales_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(scales->Shape().Size())); - MlasConvertHalfToFloatBuffer(scales_data, tmp_scales_data_ptr.get(), static_cast(scales->Shape().Size())); - - std::vector bias_data_v; - if (bias_data != nullptr) { - bias_data_v.resize(static_cast(bias->Shape().Size())); - MlasConvertHalfToFloatBuffer(bias_data, &bias_data_v[0], bias_data_v.size()); + float* scales_ptr = nullptr; + if (!scales_fp32_) { + auto scales_temp = IAllocator::MakeUniquePtr(allocator, static_cast(scales->Shape().Size()), true); + MlasConvertHalfToFloatBuffer(scales_data, scales_temp.get(), static_cast(scales->Shape().Size())); + scales_ptr = scales_temp.get(); + } else { + scales_ptr = scales_fp32_.get(); } - std::vector C_v(static_cast(y->Shape().Size())); + float* bias_ptr = nullptr; + if (bias_data) { + if (!bias_fp32_) { + auto bias_temp = IAllocator::MakeUniquePtr(allocator, static_cast(bias->Shape().Size()), true); + MlasConvertHalfToFloatBuffer(bias_data, bias_temp.get(), static_cast(bias->Shape().Size())); + bias_ptr = bias_temp.get(); + } else { + bias_ptr = bias_fp32_.get(); + } + } + + size_t c_size = static_cast(y->Shape().Size()); + std::vector c_v(c_size); InlinedVector data(batch_count); for (size_t i = 0; i < batch_count; ++i) { @@ -424,15 +489,15 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, } #endif data[i].PackedQuantBData = static_cast(packed_b_.get()); - data[i].QuantBScale = tmp_scales_data_ptr.get(); + data[i].QuantBScale = scales_ptr; data[i].QuantBZeroPoint = zero_points_data; - data[i].Bias = bias_data != nullptr ? &bias_data_v[0] : nullptr; - data[i].C = &C_v[0] + helper.OutputOffsets()[i]; + data[i].Bias = bias ? bias_ptr : nullptr; + data[i].C = c_v.data() + helper.OutputOffsets()[i]; data[i].ldc = N; } MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type_, data.data(), workspace.get(), thread_pool); - MlasConvertFloatToHalfBuffer(&C_v[0], y_data, C_v.size()); + MlasConvertFloatToHalfBuffer(c_v.data(), y_data, c_size); return Status::OK(); } @@ -461,7 +526,8 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, const size_t lda = helper.Lda(false); const size_t ldb = helper.Ldb(true); - auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + // TODO(fajin): move B dequant to prepack + auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_, true); if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { // dequantize b, only 4b quantization is supported for now @@ -561,12 +627,6 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); auto* y_data = y->MutableData(); - const float* scales_data_; - std::vector scales_data_v; - scales_data_v.resize(static_cast(scales->Shape().Size())); - MlasConvertHalfToFloatBuffer(scales_data, &scales_data_v[0], scales_data_v.size()); - scales_data_ = &scales_data_v[0]; - const size_t batch_count = helper.OutputOffsets().size(); const size_t M = static_cast(helper.M()); const size_t N = static_cast(helper.N()); @@ -574,14 +634,25 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, const size_t lda = helper.Lda(false); const size_t ldb = helper.Ldb(true); - auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + float* scales_ptr = nullptr; + if (!scales_fp32_) { + auto scales_size = static_cast(scales->Shape().Size()); + auto temp_scales = IAllocator::MakeUniquePtr(allocator, scales_size, true); + MlasConvertHalfToFloatBuffer(scales_data, temp_scales.get(), scales_size); + scales_ptr = temp_scales.get(); + } else { + scales_ptr = scales_fp32_.get(); + } + + // TODO(fajin): move B dequant to prepack + auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_, true); if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { // dequantize b, only 4b quantization is supported for now MlasDequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data_, // quantization scales + scales_ptr, // quantization scales static_cast(zero_points_data), // quantization zero points static_cast(block_size_), // quantization block size column_wise_quant_, // columnwise quantization or row-wise @@ -595,7 +666,7 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, DequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data_, // quantization scales + scales_ptr, // quantization scales static_cast(zero_points_data), // quantization zero points reorder_idx_data, static_cast(block_size_), // quantization block size @@ -607,7 +678,7 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, DequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data_, // quantization scales + scales_ptr, // quantization scales static_cast(zero_points_data), // quantization zero points reorder_idx_data, static_cast(block_size_), // quantization block size @@ -623,9 +694,14 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, #endif std::vector data(batch_count); - auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(a->Shape().Size())); - MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), static_cast(a->Shape().Size())); - auto tmp_c_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(y->Shape().Size())); + + auto a_size = static_cast(a->Shape().Size()); + auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, a_size, true); + MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), a_size); + + auto c_size = static_cast(y->Shape().Size()); + auto tmp_c_ptr = IAllocator::MakeUniquePtr(allocator, c_size, true); + for (size_t i = 0; i < batch_count; i++) { data[i].BIsPacked = false; data[i].A = tmp_a_data_ptr.get() + helper.LeftOffsets()[i]; @@ -640,24 +716,28 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, // if there is a bias input, copy bias values into C and set beta to 1.0f if (bias) { - auto tmp_bias_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(bias->Shape().Size())); - MlasConvertHalfToFloatBuffer(bias->Data(), - tmp_bias_data_ptr.get(), - static_cast(bias->Shape().Size())); + float* bias_ptr = nullptr; + const size_t bias_size = static_cast(bias->Shape().Size()); + if (!bias_fp32_) { + auto bias_temp = IAllocator::MakeUniquePtr(allocator, bias_size, true); + MlasConvertHalfToFloatBuffer(bias->Data(), bias_temp.get(), bias_size); + bias_ptr = bias_temp.get(); + } else { + bias_ptr = bias_fp32_.get(); + } for (size_t i = 0; i < batch_count; ++i) { float* C_row = data[i].C; const size_t ldc = data[i].ldc; for (size_t m = 0; m < M; ++m) { - std::copy(tmp_bias_data_ptr.get(), tmp_bias_data_ptr.get() + bias->Shape().Size(), C_row); + std::copy(bias_ptr, bias_ptr + bias_size, C_row); C_row += ldc; } data[i].beta = 1.0f; } } - MlasGemmBatch(CblasNoTrans, CblasTrans, - M, N, K, data.data(), batch_count, thread_pool); - MlasConvertFloatToHalfBuffer(tmp_c_ptr.get(), y_data, static_cast(y->Shape().Size())); + MlasGemmBatch(CblasNoTrans, CblasTrans, M, N, K, data.data(), batch_count, thread_pool); + MlasConvertFloatToHalfBuffer(tmp_c_ptr.get(), y_data, c_size); return Status::OK(); } diff --git a/onnxruntime/core/mlas/lib/fp16_neon_common.cpp b/onnxruntime/core/mlas/lib/fp16_neon_common.cpp new file mode 100644 index 0000000000..29734c2277 --- /dev/null +++ b/onnxruntime/core/mlas/lib/fp16_neon_common.cpp @@ -0,0 +1,164 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + fp16_neon_common.cpp + +Abstract: + + This module implements the common kernels for ARM NEON specific to float16. + +--*/ + +#include "mlasi.h" + +#include "arm_neon.h" + +// This file is enabled in cmake only if ARM64 is defined and not on Apple platforms +// The cmake condition is equivalent to MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64. +// Therefore omit the MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64 macro in this file. + +MLAS_FORCEINLINE +size_t +StoreFp32Lane(float* dest, float32x4_t src, size_t count) +{ + if (count == 3) { + vst1q_lane_f32(dest + 0, src, 0); + vst1q_lane_f32(dest + 1, src, 1); + vst1q_lane_f32(dest + 2, src, 2); + return 3; + } else if (count == 2) { + vst1q_lane_f32(dest + 0, src, 0); + vst1q_lane_f32(dest + 1, src, 1); + return 2; + } else if (count == 1) { + vst1q_lane_f32(dest + 0, src, 0); + return 1; + } + + return 0; +} + +void +MlasCastF16ToF32KernelNeon(const unsigned short* src, float* dest, size_t count) +{ + // 4 float16 alignment + auto* src_aligned = reinterpret_cast((reinterpret_cast(src) + 7) & ~7); + auto pre_count = std::min(static_cast(src_aligned - src), count); + size_t i = 0; + + // Handle leading unaligned src + if (pre_count > 0) { + float16x4_t fp16v4; + std::memcpy(&fp16v4, src, pre_count * sizeof(unsigned short)); + float32x4_t fp32v4 = vcvt_f32_f16(fp16v4); + + i = StoreFp32Lane(dest, fp32v4, pre_count); + } + + // aligned src + for (; i + 7 < count; i += 8) + { + float16x4_t fp16v4_0 = vreinterpret_f16_u16(vld1_u16(src + i)); + float32x4_t fp32v4_0 = vcvt_f32_f16(fp16v4_0); + vst1q_f32(dest + i, fp32v4_0); + + float16x4_t fp16v4_1 = vreinterpret_f16_u16(vld1_u16(src + i + 4)); + float32x4_t fp32v4_1 = vcvt_f32_f16(fp16v4_1); + vst1q_f32(dest + i + 4, fp32v4_1); + } + + if (i + 3 < count) + { + float16x4_t fp16v4_0 = vreinterpret_f16_u16(vld1_u16(src + i)); + float32x4_t fp32v4_0 = vcvt_f32_f16(fp16v4_0); + vst1q_f32(dest + i, fp32v4_0); + i += 4; + } + + // Handle trailing unaligned src + auto post_count = count - i; + if (post_count > 0) + { + float16x4_t fp16v4; + std::memcpy(&fp16v4, src + i, post_count * sizeof(unsigned short)); + float32x4_t fp32v4 = vcvt_f32_f16(fp16v4); + + StoreFp32Lane(dest + i, fp32v4, post_count); + } +} + +MLAS_FORCEINLINE +size_t +StoreU16Lane(unsigned short* dest, uint16x4_t src, size_t count) +{ + if (count == 3) { + vst1_lane_u16(dest + 0, src, 0); + vst1_lane_u16(dest + 1, src, 1); + vst1_lane_u16(dest + 2, src, 2); + return 3; + } else if (count == 2) { + vst1_lane_u16(dest + 0, src, 0); + vst1_lane_u16(dest + 1, src, 1); + return 2; + } else if (count == 1) { + vst1_lane_u16(dest + 0, src, 0); + return 1; + } + + return 0; +} + +void +MlasCastF32ToF16KernelNeon(const float* src, unsigned short* dest, size_t count) +{ + // 4 float32 alignment + auto* src_aligned = reinterpret_cast((reinterpret_cast(src) + 15) & ~15); + auto pre_count = std::min(static_cast(src_aligned - src), count); + size_t i = 0; + + // Handle leading unaligned src + if (pre_count > 0) + { + float32x4_t fp32v4; + std::memcpy(&fp32v4, src, pre_count * sizeof(float)); + uint16x4_t u16v4 = vreinterpret_u16_f16(vcvt_f16_f32(fp32v4)); + + i = StoreU16Lane(dest, u16v4, pre_count); + } + + // aligned src + for (; i + 7 < count; i += 8) + { + float32x4_t fp32v4_0 = vld1q_f32(src + i); + float16x4_t fp16v4_0 = vcvt_f16_f32(fp32v4_0); + vst1_u16(dest + i, vreinterpret_u16_f16(fp16v4_0)); + + float32x4_t fp32v4_1 = vld1q_f32(src + i + 4); + float16x4_t fp16v4_1 = vcvt_f16_f32(fp32v4_1); + vst1_u16(dest + i + 4, vreinterpret_u16_f16(fp16v4_1)); + } + + if (i + 3 < count) + { + float32x4_t fp32v4_0 = vld1q_f32(src + i); + float16x4_t fp16v4_0 = vcvt_f16_f32(fp32v4_0); + vst1_u16(dest + i, vreinterpret_u16_f16(fp16v4_0)); + i += 4; + } + + // Handle trailing unaligned src + auto post_count = count - i; + if (post_count > 0) + { + float32x4_t fp32v4; + std::memcpy(&fp32v4, src + i, post_count * sizeof(float)); + uint16x4_t u16v4 = vreinterpret_u16_f16(vcvt_f16_f32(fp32v4)); + + StoreU16Lane(dest + i, u16v4, post_count); + } +} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 96ba8c6c92..13ea8d96c2 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -893,6 +893,10 @@ extern "C" { MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelAvx2; #endif +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelNeon; + MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelNeon; +#endif } // @@ -2603,4 +2607,3 @@ MlasPackInt4Elements(uint8_t* Output, UnpackedType ValueLow, UnpackedType ValueH static_assert(std::is_same_v || std::is_same_v); *Output = static_cast(((ValueHigh & 0xF) << 4) | (ValueLow & 0xF)); } - diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 102d605227..23d29fd02f 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -20,7 +20,7 @@ Abstract: #include #include -#if defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_POWER) #if defined(__linux__) #include #elif defined(_AIX) @@ -576,6 +576,11 @@ Return Value: } #endif +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelNeon; + this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelNeon; +#endif + #endif // MLAS_TARGET_ARM64 #if defined(MLAS_TARGET_POWER) this->GemmFloatKernel = MlasSgemmKernel; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp index ca64ebe3b1..12ddc42506 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp @@ -12,6 +12,7 @@ Abstract: This module implements the float/quantized n-bit integer matrix multiplication kernels for ARM NEON specific to + input type T1 as float32 and MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompFp32. --*/ diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp index ec5cdbc752..0d62ea37b7 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -12,6 +12,7 @@ Abstract: This module implements the float/quantized n-bit integer matrix multiplication kernels for ARM NEON specific to + input type T1 as float32 and MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompInt8. --*/ diff --git a/onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp b/onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp new file mode 100644 index 0000000000..1dccbe44aa --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp @@ -0,0 +1,54 @@ +#include "bench_util.h" +#include "core/mlas/lib/mlasi.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +void BM_ConvertF16ToF32(benchmark::State& state) { + bool aligned = static_cast(state.range(0)); + const size_t count = 1 << 18; + auto src = RandomVectorUniform(count, 0, 60000); + auto dst = std::vector(count + 16); + auto aligned_dst = (reinterpret_cast(dst.data()) + 15) & (~15); + float* dst_start = aligned ? reinterpret_cast(aligned_dst) + : reinterpret_cast(aligned_dst + 1); + + // Warm up + MlasCastF16ToF32KernelNeon(src.data(), dst_start, count); + + for (auto _ : state) { + MlasCastF16ToF32KernelNeon(src.data(), dst_start, count); + } +} + +void BM_ConvertF32ToF16(benchmark::State& state) { + bool aligned = static_cast(state.range(0)); + const size_t count = 1 << 18; + auto src = RandomVectorUniform(count, -30000.0f, 30000.0f); + auto dst = std::vector(count + 16); + auto aligned_dst = (reinterpret_cast(dst.data()) + 15) & (~15); + unsigned short* dst_start = aligned ? reinterpret_cast(aligned_dst) + : reinterpret_cast(aligned_dst + 1); + + // Warm up + MlasCastF32ToF16KernelNeon(src.data(), dst_start, count); + + for (auto _ : state) { + MlasCastF32ToF16KernelNeon(src.data(), dst_start, count); + } +} + +BENCHMARK(BM_ConvertF16ToF32) + ->UseRealTime() + ->Apply([](benchmark::internal::Benchmark* b) { + b->ArgNames({"aligned"}); + b->ArgsProduct({{0, 1}}); + }); + +BENCHMARK(BM_ConvertF32ToF16) + ->UseRealTime() + ->Apply([](benchmark::internal::Benchmark* b) { + b->ArgNames({"aligned"}); + b->ArgsProduct({{0, 1}}); + }); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm_neon_fp16.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm_neon_fp16.cpp new file mode 100644 index 0000000000..243752bbea --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm_neon_fp16.cpp @@ -0,0 +1,82 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_sqnbitgemm_neon_fp16.cpp + +Abstract: + + Tests for MLAS n-bit int block quantized GEMM on ARM CPU with input A type T1 fp16. + +--*/ + +#include + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +class MlasNeonFp16CastTest : public MlasTestBase { + private: + void TestFp16ToFp32(size_t count) { + std::vector src(count); + std::vector dest(count); + + for (size_t i = 0; i < count; i++) { + src[i] = static_cast(i); + } + + MlasCastF16ToF32KernelNeon(src.data(), dest.data(), count); + + for (size_t i = 0; i < count; i++) { + if ((src[i] & 0x1c00) == 0x1c00) continue; // skip inf and nan + ASSERT_EQ(dest[i], MLAS_FP16::FromBits(src[i]).ToFloat()); + } + } + + void TestFp32ToFp16(size_t count) { + std::vector src(count); + std::vector dest(count); + + for (size_t i = 0; i < count; i++) { + src[i] = static_cast(i) + 0.125f; + } + + MlasCastF32ToF16KernelNeon(src.data(), dest.data(), count); + + for (size_t i = 0; i < count; i++) { + ASSERT_EQ(dest[i], MLAS_FP16(src[i]).val); + } + } + + public: + static const char* GetTestSuiteName() { + return "NeonFp16Cast"; + } + + void ExecuteShort(void) override { + TestFp16ToFp32(1 << 16); + TestFp16ToFp32(1); + TestFp16ToFp32(4); + TestFp16ToFp32(7); + TestFp32ToFp16(1 << 16); + TestFp32ToFp16(3); + TestFp32ToFp16(4); + TestFp32ToFp16(6); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)