fix stream sync issue (#6954)

This commit is contained in:
Tianlei Wu 2021-03-09 20:57:18 -08:00 committed by GitHub
parent bdaea1d9ae
commit 89916fdb05
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -20,6 +20,8 @@ Status GemmInt8(int m, int n, int k,
ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null");
ORT_ENFORCE(cuda_kernel != nullptr, "kernel is null");
cudaStream_t stream = cuda_kernel->Stream();
// pad A and B to make their leading dimension be multiples of 32
// because cublasGemmEx requires:
// 1. leading dimension is multiples of 4
@ -31,7 +33,7 @@ Status GemmInt8(int m, int n, int k,
if ((mask & lda_aligned) != 0) {
lda_aligned = roundoff(lda, 32);
a_padded = cuda_kernel->GetScratchBuffer<int8_t>(m * lda_aligned);
cudaMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, cudaMemcpyDeviceToDevice, 0);
cudaMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, cudaMemcpyDeviceToDevice, stream);
}
int ldb_aligned = ldb;
@ -39,11 +41,13 @@ Status GemmInt8(int m, int n, int k,
if ((mask & ldb_aligned) != 0) {
ldb_aligned = roundoff(ldb, 32);
b_padded = cuda_kernel->GetScratchBuffer<int8_t>(k * ldb_aligned);
cudaMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, cudaMemcpyDeviceToDevice, 0);
cudaMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, cudaMemcpyDeviceToDevice, stream);
}
cublasHandle_t cublas = cuda_kernel->CublasHandle();
cublasSetStream(cublas, stream);
CUBLAS_RETURN_IF_ERROR(cublasGemmEx(
cuda_kernel->CublasHandle(),
cublas,
CUBLAS_OP_N, CUBLAS_OP_N,
n, m, k,
&alpha,