From a9dbb511fb1fcf28ade60604f3c59f4ba4958318 Mon Sep 17 00:00:00 2001 From: Tracy Sharpe <42477615+tracysh@users.noreply.github.com> Date: Mon, 5 Apr 2021 22:46:04 -0700 Subject: [PATCH] MLAS: fix qgemm bus error with Android + ARM32 (#7250) --- onnxruntime/core/mlas/lib/qgemm.cpp | 58 ++++++++++++++--------------- 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index a6c83e5951..b64669564c 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -1553,23 +1553,23 @@ MlasGemmU8X8CopyPackA( k -= 16; } - uint32x4_t GatherVector = vmovq_n_u32(0); - while (k >= 4) { - GatherVector = vld1q_lane_u32(reinterpret_cast(a0), GatherVector, 0); + uint32_t v0 = *reinterpret_cast(a0); a0 += 4; - GatherVector = vld1q_lane_u32(reinterpret_cast(a1), GatherVector, 1); + uint32_t v1 = *reinterpret_cast(a1); a1 += 4; - GatherVector = vld1q_lane_u32(reinterpret_cast(a2), GatherVector, 2); + uint32_t v2 = *reinterpret_cast(a2); a2 += 4; - GatherVector = vld1q_lane_u32(reinterpret_cast(a3), GatherVector, 3); + uint32_t v3 = *reinterpret_cast(a3); a3 += 4; - uint8x16_t PackedVector = vreinterpretq_u8_u32(GatherVector); - vst1q_u8(D, PackedVector); + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[4]) = v1; + *reinterpret_cast(&D[8]) = v2; + *reinterpret_cast(&D[12]) = v3; - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(PackedVector)); + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vld1q_u8(D))); D += 16; k -= 4; @@ -1633,19 +1633,18 @@ MlasGemmU8X8CopyPackA( size_t k = CountK; uint32x2_t RowSums = vmov_n_u32(0); - uint32x2_t GatherVector = vmov_n_u32(0); while (k >= 4) { - GatherVector = vld1_lane_u32(reinterpret_cast(a0), GatherVector, 0); + uint32_t v0 = *reinterpret_cast(a0); a0 += 4; - GatherVector = vld1_lane_u32(reinterpret_cast(a1), GatherVector, 1); + uint32_t v1 = *reinterpret_cast(a1); a1 += 4; - uint8x8_t PackedVector = vreinterpret_u8_u32(GatherVector); - vst1_u8(D, PackedVector); + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[4]) = v1; - RowSums = vpadal_u16(RowSums, vpaddl_u8(PackedVector)); + RowSums = vpadal_u16(RowSums, vpaddl_u8(vld1_u8(D))); D += 8; k -= 4; @@ -2030,23 +2029,23 @@ MlasGemmU8X8CopyPackA( k -= 16; } - uint32x4_t GatherVector = vmovq_n_u32(0); - while (k >= 4) { - GatherVector = vld1q_lane_u32(reinterpret_cast(a0), GatherVector, 0); + uint32_t v0 = *reinterpret_cast(a0); a0 += 4; - GatherVector = vld1q_lane_u32(reinterpret_cast(a1), GatherVector, 1); + uint32_t v1 = *reinterpret_cast(a1); a1 += 4; - GatherVector = vld1q_lane_u32(reinterpret_cast(a2), GatherVector, 2); + uint32_t v2 = *reinterpret_cast(a2); a2 += 4; - GatherVector = vld1q_lane_u32(reinterpret_cast(a3), GatherVector, 3); + uint32_t v3 = *reinterpret_cast(a3); a3 += 4; - uint8x16_t PackedVector = vreinterpretq_u8_u32(GatherVector); - vst1q_u8(D, PackedVector); + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[4]) = v1; + *reinterpret_cast(&D[8]) = v2; + *reinterpret_cast(&D[12]) = v3; - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(PackedVector)); + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vld1q_u8(D))); D += 16; k -= 4; @@ -2117,19 +2116,18 @@ MlasGemmU8X8CopyPackA( size_t k = CountK; uint32x2_t RowSums = vmov_n_u32(0); - uint32x2_t GatherVector = vmov_n_u32(0); while (k >= 4) { - GatherVector = vld1_lane_u32(reinterpret_cast(a0), GatherVector, 0); + uint32_t v0 = *reinterpret_cast(a0); a0 += 4; - GatherVector = vld1_lane_u32(reinterpret_cast(a1), GatherVector, 1); + uint32_t v1 = *reinterpret_cast(a1); a1 += 4; - uint8x8_t PackedVector = vreinterpret_u8_u32(GatherVector); - vst1_u8(D, PackedVector); + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[4]) = v1; - RowSums = vpadal_u16(RowSums, vpaddl_u8(PackedVector)); + RowSums = vpadal_u16(RowSums, vpaddl_u8(vld1_u8(D))); D += 8; k -= 4;