From 89916fdb05a7b6cee83f49b3f72e09821eea5fc2 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 9 Mar 2021 20:57:18 -0800 Subject: [PATCH] fix stream sync issue (#6954) --- onnxruntime/core/providers/cuda/integer_gemm.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/cuda/integer_gemm.cc b/onnxruntime/core/providers/cuda/integer_gemm.cc index 3bc3b3962e..09058edca7 100644 --- a/onnxruntime/core/providers/cuda/integer_gemm.cc +++ b/onnxruntime/core/providers/cuda/integer_gemm.cc @@ -20,6 +20,8 @@ Status GemmInt8(int m, int n, int k, ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null"); ORT_ENFORCE(cuda_kernel != nullptr, "kernel is null"); + cudaStream_t stream = cuda_kernel->Stream(); + // pad A and B to make their leading dimension be multiples of 32 // because cublasGemmEx requires: // 1. leading dimension is multiples of 4 @@ -31,7 +33,7 @@ Status GemmInt8(int m, int n, int k, if ((mask & lda_aligned) != 0) { lda_aligned = roundoff(lda, 32); a_padded = cuda_kernel->GetScratchBuffer(m * lda_aligned); - cudaMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, cudaMemcpyDeviceToDevice, 0); + cudaMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, cudaMemcpyDeviceToDevice, stream); } int ldb_aligned = ldb; @@ -39,11 +41,13 @@ Status GemmInt8(int m, int n, int k, if ((mask & ldb_aligned) != 0) { ldb_aligned = roundoff(ldb, 32); b_padded = cuda_kernel->GetScratchBuffer(k * ldb_aligned); - cudaMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, cudaMemcpyDeviceToDevice, 0); + cudaMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, cudaMemcpyDeviceToDevice, stream); } + cublasHandle_t cublas = cuda_kernel->CublasHandle(); + cublasSetStream(cublas, stream); CUBLAS_RETURN_IF_ERROR(cublasGemmEx( - cuda_kernel->CublasHandle(), + cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha,