mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
MLAS: fix qgemm bus error with Android + ARM32 (#7250)
This commit is contained in:
parent
fb40602ea2
commit
a9dbb511fb
1 changed files with 28 additions and 30 deletions
|
|
@ -1553,23 +1553,23 @@ MlasGemmU8X8CopyPackA<MLAS_GEMM_U8X8_KERNEL_NEON>(
|
|||
k -= 16;
|
||||
}
|
||||
|
||||
uint32x4_t GatherVector = vmovq_n_u32(0);
|
||||
|
||||
while (k >= 4) {
|
||||
|
||||
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a0), GatherVector, 0);
|
||||
uint32_t v0 = *reinterpret_cast<const uint32_t*>(a0);
|
||||
a0 += 4;
|
||||
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a1), GatherVector, 1);
|
||||
uint32_t v1 = *reinterpret_cast<const uint32_t*>(a1);
|
||||
a1 += 4;
|
||||
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a2), GatherVector, 2);
|
||||
uint32_t v2 = *reinterpret_cast<const uint32_t*>(a2);
|
||||
a2 += 4;
|
||||
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a3), GatherVector, 3);
|
||||
uint32_t v3 = *reinterpret_cast<const uint32_t*>(a3);
|
||||
a3 += 4;
|
||||
|
||||
uint8x16_t PackedVector = vreinterpretq_u8_u32(GatherVector);
|
||||
vst1q_u8(D, PackedVector);
|
||||
*reinterpret_cast<uint32_t*>(&D[0]) = v0;
|
||||
*reinterpret_cast<uint32_t*>(&D[4]) = v1;
|
||||
*reinterpret_cast<uint32_t*>(&D[8]) = v2;
|
||||
*reinterpret_cast<uint32_t*>(&D[12]) = v3;
|
||||
|
||||
RowSums = vpadalq_u16(RowSums, vpaddlq_u8(PackedVector));
|
||||
RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vld1q_u8(D)));
|
||||
|
||||
D += 16;
|
||||
k -= 4;
|
||||
|
|
@ -1633,19 +1633,18 @@ MlasGemmU8X8CopyPackA<MLAS_GEMM_U8X8_KERNEL_NEON>(
|
|||
|
||||
size_t k = CountK;
|
||||
uint32x2_t RowSums = vmov_n_u32(0);
|
||||
uint32x2_t GatherVector = vmov_n_u32(0);
|
||||
|
||||
while (k >= 4) {
|
||||
|
||||
GatherVector = vld1_lane_u32(reinterpret_cast<const uint32_t*>(a0), GatherVector, 0);
|
||||
uint32_t v0 = *reinterpret_cast<const uint32_t*>(a0);
|
||||
a0 += 4;
|
||||
GatherVector = vld1_lane_u32(reinterpret_cast<const uint32_t*>(a1), GatherVector, 1);
|
||||
uint32_t v1 = *reinterpret_cast<const uint32_t*>(a1);
|
||||
a1 += 4;
|
||||
|
||||
uint8x8_t PackedVector = vreinterpret_u8_u32(GatherVector);
|
||||
vst1_u8(D, PackedVector);
|
||||
*reinterpret_cast<uint32_t*>(&D[0]) = v0;
|
||||
*reinterpret_cast<uint32_t*>(&D[4]) = v1;
|
||||
|
||||
RowSums = vpadal_u16(RowSums, vpaddl_u8(PackedVector));
|
||||
RowSums = vpadal_u16(RowSums, vpaddl_u8(vld1_u8(D)));
|
||||
|
||||
D += 8;
|
||||
k -= 4;
|
||||
|
|
@ -2030,23 +2029,23 @@ MlasGemmU8X8CopyPackA<MLAS_GEMM_U8X8_KERNEL_UDOT>(
|
|||
k -= 16;
|
||||
}
|
||||
|
||||
uint32x4_t GatherVector = vmovq_n_u32(0);
|
||||
|
||||
while (k >= 4) {
|
||||
|
||||
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a0), GatherVector, 0);
|
||||
uint32_t v0 = *reinterpret_cast<const uint32_t*>(a0);
|
||||
a0 += 4;
|
||||
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a1), GatherVector, 1);
|
||||
uint32_t v1 = *reinterpret_cast<const uint32_t*>(a1);
|
||||
a1 += 4;
|
||||
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a2), GatherVector, 2);
|
||||
uint32_t v2 = *reinterpret_cast<const uint32_t*>(a2);
|
||||
a2 += 4;
|
||||
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a3), GatherVector, 3);
|
||||
uint32_t v3 = *reinterpret_cast<const uint32_t*>(a3);
|
||||
a3 += 4;
|
||||
|
||||
uint8x16_t PackedVector = vreinterpretq_u8_u32(GatherVector);
|
||||
vst1q_u8(D, PackedVector);
|
||||
*reinterpret_cast<uint32_t*>(&D[0]) = v0;
|
||||
*reinterpret_cast<uint32_t*>(&D[4]) = v1;
|
||||
*reinterpret_cast<uint32_t*>(&D[8]) = v2;
|
||||
*reinterpret_cast<uint32_t*>(&D[12]) = v3;
|
||||
|
||||
RowSums = vpadalq_u16(RowSums, vpaddlq_u8(PackedVector));
|
||||
RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vld1q_u8(D)));
|
||||
|
||||
D += 16;
|
||||
k -= 4;
|
||||
|
|
@ -2117,19 +2116,18 @@ MlasGemmU8X8CopyPackA<MLAS_GEMM_U8X8_KERNEL_UDOT>(
|
|||
|
||||
size_t k = CountK;
|
||||
uint32x2_t RowSums = vmov_n_u32(0);
|
||||
uint32x2_t GatherVector = vmov_n_u32(0);
|
||||
|
||||
while (k >= 4) {
|
||||
|
||||
GatherVector = vld1_lane_u32(reinterpret_cast<const uint32_t*>(a0), GatherVector, 0);
|
||||
uint32_t v0 = *reinterpret_cast<const uint32_t*>(a0);
|
||||
a0 += 4;
|
||||
GatherVector = vld1_lane_u32(reinterpret_cast<const uint32_t*>(a1), GatherVector, 1);
|
||||
uint32_t v1 = *reinterpret_cast<const uint32_t*>(a1);
|
||||
a1 += 4;
|
||||
|
||||
uint8x8_t PackedVector = vreinterpret_u8_u32(GatherVector);
|
||||
vst1_u8(D, PackedVector);
|
||||
*reinterpret_cast<uint32_t*>(&D[0]) = v0;
|
||||
*reinterpret_cast<uint32_t*>(&D[4]) = v1;
|
||||
|
||||
RowSums = vpadal_u16(RowSums, vpaddl_u8(PackedVector));
|
||||
RowSums = vpadal_u16(RowSums, vpaddl_u8(vld1_u8(D)));
|
||||
|
||||
D += 8;
|
||||
k -= 4;
|
||||
|
|
|
|||
Loading…
Reference in a new issue