merge from internal master

This commit is contained in:
Tracy Sharpe 2018-11-28 12:06:33 -08:00
parent bd50598d17
commit 39fc17281c

View file

@ -293,15 +293,15 @@ Return Value:
#if defined(MLAS_NEON_INTRINSICS)
vst4q_f32(D, vld4q_f32(b));
#else
MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(b);
MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(b + 4);
MLAS_FLOAT32X4 t2 = MlasLoadFloat32x4(b + 8);
MLAS_FLOAT32X4 t3 = MlasLoadFloat32x4(b + 12);
MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]);
MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&b[4]);
MLAS_FLOAT32X4 t2 = MlasLoadFloat32x4(&b[8]);
MLAS_FLOAT32X4 t3 = MlasLoadFloat32x4(&b[12]);
MlasStoreAlignedFloat32x4(D, t0);
MlasStoreAlignedFloat32x4(D + 4, t1);
MlasStoreAlignedFloat32x4(D + 8, t2);
MlasStoreAlignedFloat32x4(D + 12, t3);
MlasStoreAlignedFloat32x4(&D[0], t0);
MlasStoreAlignedFloat32x4(&D[4], t1);
MlasStoreAlignedFloat32x4(&D[8], t2);
MlasStoreAlignedFloat32x4(&D[12], t3);
#endif
D += 16;
@ -335,7 +335,7 @@ Return Value:
const float* b = B;
#if defined(MLAS_NEON_INTRINSICS)
vst4q_f32(D, ZeroFloat32x4x4);
vst4q_f32(d, ZeroFloat32x4x4);
#else
MlasStoreAlignedFloat32x4(d, ZeroFloat32x4);
MlasStoreAlignedFloat32x4(d + 4, ZeroFloat32x4);
@ -376,7 +376,7 @@ Return Value:
}
if ((CountX & 1) != 0) {
*d = *b;
d[0] = b[0];
}
D += 16;
@ -387,11 +387,10 @@ Return Value:
}
}
#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS)
template<unsigned N>
inline
void
MlasSgemmTransposePackB16x4(
MlasSgemmTransposePackBNx4(
float* D,
const float* B,
size_t ldb
@ -403,7 +402,7 @@ Routine Description:
This routine transposes elements from the source matrix to the destination
packed buffer.
4 columns of 16 rows from the source matrix are transposed to 16 columns of 4
4 columns of N rows from the source matrix are transposed to N columns of 4
rows in the destination packed buffer.
Arguments:
@ -420,7 +419,7 @@ Return Value:
--*/
{
for (unsigned n = 0; n < 4; n++) {
for (unsigned n = 0; n < N / 4; n++) {
MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&B[ldb * 0]);
MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&B[ldb * 1]);
@ -436,8 +435,20 @@ Return Value:
t1 = o0.val[1];
t2 = o1.val[0];
t3 = o1.val[1];
#elif defined(MLAS_SSE2_INTRINSICS)
// N.B. The MSVC version of _MM_TRANSPOSE4_PS uses shufps which is
// slightly larger than the below sequence, so manually expand the
// matrix transpose.
__m128 z0 = _mm_unpacklo_ps(t0, t1);
__m128 z1 = _mm_unpackhi_ps(t0, t1);
__m128 z2 = _mm_unpacklo_ps(t2, t3);
__m128 z3 = _mm_unpackhi_ps(t2, t3);
t0 = _mm_movelh_ps(z0, z2);
t1 = _mm_movehl_ps(z2, z0);
t2 = _mm_movelh_ps(z1, z3);
t3 = _mm_movehl_ps(z3, z1);
#else
_MM_TRANSPOSE4_PS(t0, t1, t2, t3);
#error Unsupported architecture.
#endif
MlasStoreAlignedFloat32x4(&D[0], t0);
@ -450,8 +461,6 @@ Return Value:
}
}
#endif
void
MlasSgemmTransposePackB(
float* D,
@ -508,19 +517,19 @@ Return Value:
SgemmTransposePackB16x4Routine(&D[0], &b[0], ldb);
b += 4;
D += 16 * 4;
b += 4;
x -= 4;
}
#elif defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS)
#else
while (x >= 4) {
MlasSgemmTransposePackB16x4(&D[0], &b[0], ldb);
MlasSgemmTransposePackBNx4<16>(&D[0], &b[0], ldb);
b += 4;
D += 16 * 4;
b += 4;
x -= 4;
}
@ -571,26 +580,128 @@ Return Value:
CountY -= 16;
}
//
// Special case the handling of the less than 16 remaining rows.
//
if (CountY > 0) {
//
// Zero pad the remaining entries of the packed buffer.
//
MLAS_FLOAT32X4 ZeroFloat32x4 = MlasZeroFloat32x4();
memset(D, 0, CountX * 16 * sizeof(float));
size_t x = CountX;
//
// Transpose elements from matrix B into the packed buffer for the
// remaining rows.
// Transpose 4 columns at a time.
//
if (CountY >= 8) {
while (x >= 4) {
float* d = D;
const float* b = B;
size_t x = CountX;
do {
if ((CountY & 8) != 0) {
MlasSgemmTransposePackBNx4<8>(&d[0], &b[0], ldb);
d += 8;
b += ldb * 8;
} else {
MlasStoreAlignedFloat32x4(&d[8], ZeroFloat32x4);
MlasStoreAlignedFloat32x4(&d[12], ZeroFloat32x4);
MlasStoreAlignedFloat32x4(&d[24], ZeroFloat32x4);
MlasStoreAlignedFloat32x4(&d[28], ZeroFloat32x4);
MlasStoreAlignedFloat32x4(&d[40], ZeroFloat32x4);
MlasStoreAlignedFloat32x4(&d[44], ZeroFloat32x4);
MlasStoreAlignedFloat32x4(&d[56], ZeroFloat32x4);
MlasStoreAlignedFloat32x4(&d[60], ZeroFloat32x4);
}
if ((CountY & 4) != 0) {
MlasSgemmTransposePackBNx4<4>(&d[0], &b[0], ldb);
d += 4;
b += ldb * 4;
} else {
MlasStoreAlignedFloat32x4(&d[4], ZeroFloat32x4);
MlasStoreAlignedFloat32x4(&d[20], ZeroFloat32x4);
MlasStoreAlignedFloat32x4(&d[36], ZeroFloat32x4);
MlasStoreAlignedFloat32x4(&d[52], ZeroFloat32x4);
}
MlasStoreAlignedFloat32x4(&d[0], ZeroFloat32x4);
MlasStoreAlignedFloat32x4(&d[16], ZeroFloat32x4);
MlasStoreAlignedFloat32x4(&d[32], ZeroFloat32x4);
MlasStoreAlignedFloat32x4(&d[48], ZeroFloat32x4);
if ((CountY & 2) != 0) {
#if defined(MLAS_NEON_INTRINSICS)
MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]);
MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&b[ldb]);
MlasStoreLaneFloat32x4<0>(&d[0], t0);
MlasStoreLaneFloat32x4<0>(&d[1], t1);
MlasStoreLaneFloat32x4<1>(&d[16], t0);
MlasStoreLaneFloat32x4<1>(&d[17], t1);
MlasStoreLaneFloat32x4<2>(&d[32], t0);
MlasStoreLaneFloat32x4<2>(&d[33], t1);
MlasStoreLaneFloat32x4<3>(&d[48], t0);
MlasStoreLaneFloat32x4<3>(&d[49], t1);
#elif defined(MLAS_SSE2_INTRINSICS)
MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]);
MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&b[ldb]);
__m128 v0 = _mm_unpacklo_ps(t0, t1);
__m128 v1 = _mm_unpackhi_ps(t0, t1);
_mm_storel_pi((__m64*)&d[0], v0);
_mm_storeh_pi((__m64*)&d[16], v0);
_mm_storel_pi((__m64*)&d[32], v1);
_mm_storeh_pi((__m64*)&d[48], v1);
#else
#error Unsupported architecture.
#endif
d += 2;
b += ldb * 2;
}
if ((CountY & 1) != 0) {
#if defined(MLAS_NEON_INTRINSICS)
MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]);
MlasStoreLaneFloat32x4<0>(&d[0], t0);
MlasStoreLaneFloat32x4<1>(&d[16], t0);
MlasStoreLaneFloat32x4<2>(&d[32], t0);
MlasStoreLaneFloat32x4<3>(&d[48], t0);
#else
d[0] = b[0];
d[16] = b[1];
d[32] = b[2];
d[48] = b[3];
#endif
}
D += 16 * 4;
B += 4;
x -= 4;
}
//
// Transpose the remaining columns.
//
while (x > 0) {
float* d = D;
const float* b = B;
if ((CountY & 8) != 0) {
float t0 = b[0];
float t1 = b[ldb];
@ -610,24 +721,16 @@ Return Value:
d[6] = t6;
d[7] = t7;
d += 16;
b += 1;
x--;
d += 8;
b += ldb * 8;
} while (x > 0);
} else {
D += 8;
B += ldb * 8;
CountY -= 8;
}
MlasStoreAlignedFloat32x4(&d[8], ZeroFloat32x4);
MlasStoreAlignedFloat32x4(&d[12], ZeroFloat32x4);
}
if (CountY >= 4) {
float* d = D;
const float* b = B;
size_t x = CountX;
do {
if ((CountY & 4) != 0) {
float t0 = b[0];
float t1 = b[ldb];
@ -639,24 +742,17 @@ Return Value:
d[2] = t2;
d[3] = t3;
d += 16;
b += 1;
x--;
d += 4;
b += ldb * 4;
} while (x > 0);
} else {
D += 4;
B += ldb * 4;
CountY -= 4;
}
MlasStoreAlignedFloat32x4(&d[4], ZeroFloat32x4);
}
if (CountY >= 2) {
MlasStoreAlignedFloat32x4(d, ZeroFloat32x4);
float* d = D;
const float* b = B;
size_t x = CountX;
do {
if ((CountY & 2) != 0) {
float t0 = b[0];
float t1 = b[ldb];
@ -664,32 +760,17 @@ Return Value:
d[0] = t0;
d[1] = t1;
d += 16;
b += 1;
x--;
} while (x > 0);
D += 2;
B += ldb * 2;
CountY -= 2;
}
if (CountY >= 1) {
float* d = D;
const float* b = B;
size_t x = CountX;
do {
d += 2;
b += ldb * 2;
}
if ((CountY & 1) != 0) {
d[0] = b[0];
}
d += 16;
b += 1;
x--;
} while (x > 0);
D += 16;
B += 1;
x--;
}
}
}