MLAS: fix qgemm bus error with Android + ARM32 (#7250)

This commit is contained in:
Tracy Sharpe 2021-04-05 22:46:04 -07:00 committed by GitHub
parent fb40602ea2
commit a9dbb511fb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1553,23 +1553,23 @@ MlasGemmU8X8CopyPackA<MLAS_GEMM_U8X8_KERNEL_NEON>(
k -= 16;
}
uint32x4_t GatherVector = vmovq_n_u32(0);
while (k >= 4) {
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a0), GatherVector, 0);
uint32_t v0 = *reinterpret_cast<const uint32_t*>(a0);
a0 += 4;
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a1), GatherVector, 1);
uint32_t v1 = *reinterpret_cast<const uint32_t*>(a1);
a1 += 4;
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a2), GatherVector, 2);
uint32_t v2 = *reinterpret_cast<const uint32_t*>(a2);
a2 += 4;
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a3), GatherVector, 3);
uint32_t v3 = *reinterpret_cast<const uint32_t*>(a3);
a3 += 4;
uint8x16_t PackedVector = vreinterpretq_u8_u32(GatherVector);
vst1q_u8(D, PackedVector);
*reinterpret_cast<uint32_t*>(&D[0]) = v0;
*reinterpret_cast<uint32_t*>(&D[4]) = v1;
*reinterpret_cast<uint32_t*>(&D[8]) = v2;
*reinterpret_cast<uint32_t*>(&D[12]) = v3;
RowSums = vpadalq_u16(RowSums, vpaddlq_u8(PackedVector));
RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vld1q_u8(D)));
D += 16;
k -= 4;
@ -1633,19 +1633,18 @@ MlasGemmU8X8CopyPackA<MLAS_GEMM_U8X8_KERNEL_NEON>(
size_t k = CountK;
uint32x2_t RowSums = vmov_n_u32(0);
uint32x2_t GatherVector = vmov_n_u32(0);
while (k >= 4) {
GatherVector = vld1_lane_u32(reinterpret_cast<const uint32_t*>(a0), GatherVector, 0);
uint32_t v0 = *reinterpret_cast<const uint32_t*>(a0);
a0 += 4;
GatherVector = vld1_lane_u32(reinterpret_cast<const uint32_t*>(a1), GatherVector, 1);
uint32_t v1 = *reinterpret_cast<const uint32_t*>(a1);
a1 += 4;
uint8x8_t PackedVector = vreinterpret_u8_u32(GatherVector);
vst1_u8(D, PackedVector);
*reinterpret_cast<uint32_t*>(&D[0]) = v0;
*reinterpret_cast<uint32_t*>(&D[4]) = v1;
RowSums = vpadal_u16(RowSums, vpaddl_u8(PackedVector));
RowSums = vpadal_u16(RowSums, vpaddl_u8(vld1_u8(D)));
D += 8;
k -= 4;
@ -2030,23 +2029,23 @@ MlasGemmU8X8CopyPackA<MLAS_GEMM_U8X8_KERNEL_UDOT>(
k -= 16;
}
uint32x4_t GatherVector = vmovq_n_u32(0);
while (k >= 4) {
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a0), GatherVector, 0);
uint32_t v0 = *reinterpret_cast<const uint32_t*>(a0);
a0 += 4;
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a1), GatherVector, 1);
uint32_t v1 = *reinterpret_cast<const uint32_t*>(a1);
a1 += 4;
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a2), GatherVector, 2);
uint32_t v2 = *reinterpret_cast<const uint32_t*>(a2);
a2 += 4;
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a3), GatherVector, 3);
uint32_t v3 = *reinterpret_cast<const uint32_t*>(a3);
a3 += 4;
uint8x16_t PackedVector = vreinterpretq_u8_u32(GatherVector);
vst1q_u8(D, PackedVector);
*reinterpret_cast<uint32_t*>(&D[0]) = v0;
*reinterpret_cast<uint32_t*>(&D[4]) = v1;
*reinterpret_cast<uint32_t*>(&D[8]) = v2;
*reinterpret_cast<uint32_t*>(&D[12]) = v3;
RowSums = vpadalq_u16(RowSums, vpaddlq_u8(PackedVector));
RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vld1q_u8(D)));
D += 16;
k -= 4;
@ -2117,19 +2116,18 @@ MlasGemmU8X8CopyPackA<MLAS_GEMM_U8X8_KERNEL_UDOT>(
size_t k = CountK;
uint32x2_t RowSums = vmov_n_u32(0);
uint32x2_t GatherVector = vmov_n_u32(0);
while (k >= 4) {
GatherVector = vld1_lane_u32(reinterpret_cast<const uint32_t*>(a0), GatherVector, 0);
uint32_t v0 = *reinterpret_cast<const uint32_t*>(a0);
a0 += 4;
GatherVector = vld1_lane_u32(reinterpret_cast<const uint32_t*>(a1), GatherVector, 1);
uint32_t v1 = *reinterpret_cast<const uint32_t*>(a1);
a1 += 4;
uint8x8_t PackedVector = vreinterpret_u8_u32(GatherVector);
vst1_u8(D, PackedVector);
*reinterpret_cast<uint32_t*>(&D[0]) = v0;
*reinterpret_cast<uint32_t*>(&D[4]) = v1;
RowSums = vpadal_u16(RowSums, vpaddl_u8(PackedVector));
RowSums = vpadal_u16(RowSums, vpaddl_u8(vld1_u8(D)));
D += 8;
k -= 4;