Merge remote-tracking branch 'upstream/master' into DmlDev

This commit is contained in:
ISS Build Account 2020-07-01 22:50:58 +00:00
commit 6ee6fdbbba
4 changed files with 19 additions and 24 deletions

View file

@ -12,6 +12,19 @@
#endif
namespace onnxruntime {
template <typename TA, typename TB, typename TY>
void QGemmWithEigen(const TA* A_data, const TB* B_data, TY* Y_data, int M, int N, int K, TA a_offset, TB b_offset) {
auto A = ConstEigenMatrixMapRowMajor<TA>(A_data, M, K);
auto B = ConstEigenMatrixMapRowMajor<TB>(B_data, K, N);
auto A_row_sum = (A.template cast<TY>().rowwise().sum()) * static_cast<TY>(b_offset);
auto B_col_sum = (B.template cast<TY>().colwise().sum()) * static_cast<TY>(a_offset);
EigenMatrixMapRowMajor<TY>(Y_data, M, N) = A.template cast<TY>() * B.template cast<TY>() + static_cast<TY>(K * a_offset * b_offset) * ConstEigenMatrixMapRowMajor<TY>::Ones(M, N);
EigenMatrixMapRowMajor<TY>(Y_data, M, N).colwise() -= A_row_sum;
EigenMatrixMapRowMajor<TY>(Y_data, M, N).rowwise() -= B_col_sum;
}
template <>
void QGemm<uint8_t, int8_t, int32_t>(
int M,
@ -29,20 +42,12 @@ void QGemm<uint8_t, int8_t, int32_t>(
#ifdef MLAS_SUPPORTS_GEMM_U8X8
MlasGemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, result_data, ldc, thread_pool);
#else
ORT_UNUSED_PARAMETER(M);
ORT_UNUSED_PARAMETER(N);
ORT_UNUSED_PARAMETER(K);
ORT_UNUSED_PARAMETER(lhs_data);
ORT_UNUSED_PARAMETER(lda);
ORT_UNUSED_PARAMETER(lhs_offset);
ORT_UNUSED_PARAMETER(rhs_data);
ORT_UNUSED_PARAMETER(ldb);
ORT_UNUSED_PARAMETER(rhs_offset);
ORT_UNUSED_PARAMETER(result_data);
ORT_UNUSED_PARAMETER(ldc);
ORT_UNUSED_PARAMETER(thread_pool);
ORT_NOT_IMPLEMENTED("MatMulInteger: activation uint8 and weight int8 not supported on ARM");
ORT_ENFORCE(lda == K && ldb == N && ldc == N, "For Eigen only RowMajor*RowMajor=RowMajor format is supported");
QGemmWithEigen<uint8_t, int8_t, int32_t>(lhs_data, rhs_data, result_data, M, N, K, lhs_offset, rhs_offset);
#endif
}
@ -104,8 +109,7 @@ void QGemm(
#endif
}
template
void QGemm<uint8_t, int8_t>(
template void QGemm<uint8_t, int8_t>(
int M,
int N,
int K,
@ -121,8 +125,7 @@ void QGemm<uint8_t, int8_t>(
const float* bias,
concurrency::ThreadPool* thread_pool);
template
void QGemm<uint8_t, uint8_t>(
template void QGemm<uint8_t, uint8_t>(
int M,
int N,
int K,

View file

@ -49,13 +49,11 @@ void TestDynamicQuantizeMatMul(const std::vector<int64_t>& A_dims,
}
TEST(DynamicQuantizeMatMul, Int8_test) {
#ifdef MLAS_SUPPORTS_GEMM_U8X8
std::vector<int64_t> A_dims{4, 128};
std::vector<int64_t> B_dims{128, 128};
std::vector<int64_t> Y_dims{4, 128};
TestDynamicQuantizeMatMul<int8_t>(A_dims, B_dims, "testdata/dynamic_quantize_matmul_int8.onnx");
#endif
}
TEST(DynamicQuantizeMatMul, UInt8_test) {

View file

@ -135,11 +135,9 @@ void RunQAttention(const std::vector<float>& input_data, // input:
execution_providers.push_back(DefaultCudaExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
} else {
#ifdef MLAS_SUPPORTS_GEMM_U8X8
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
#endif
}
}

View file

@ -302,13 +302,9 @@ void RunMatMulIntegerU8S8Test(const int M, const int N, const int K, bool non_ze
}
}
#ifdef MLAS_SUPPORTS_GEMM_U8X8
#define RUN_MATMUL_INTEGER_U8S8(M, N, K) \
RunMatMulIntegerU8S8Test(M, N, K, false /*non_zero_zp*/); \
RunMatMulIntegerU8S8Test(M, N, K, true /*non_zero_zp*/);
#else
#define RUN_MATMUL_INTEGER_U8S8(M, N, K)
#endif // MLAS_SUPPORTS_GEMM_U8X8
TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_Scalar) {
RUN_MATMUL_INTEGER_U8S8(1, 1, 32);