mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
fix stream sync issue (#6954)
This commit is contained in:
parent
bdaea1d9ae
commit
89916fdb05
1 changed files with 7 additions and 3 deletions
|
|
@ -20,6 +20,8 @@ Status GemmInt8(int m, int n, int k,
|
|||
ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null");
|
||||
ORT_ENFORCE(cuda_kernel != nullptr, "kernel is null");
|
||||
|
||||
cudaStream_t stream = cuda_kernel->Stream();
|
||||
|
||||
// pad A and B to make their leading dimension be multiples of 32
|
||||
// because cublasGemmEx requires:
|
||||
// 1. leading dimension is multiples of 4
|
||||
|
|
@ -31,7 +33,7 @@ Status GemmInt8(int m, int n, int k,
|
|||
if ((mask & lda_aligned) != 0) {
|
||||
lda_aligned = roundoff(lda, 32);
|
||||
a_padded = cuda_kernel->GetScratchBuffer<int8_t>(m * lda_aligned);
|
||||
cudaMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, cudaMemcpyDeviceToDevice, 0);
|
||||
cudaMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, cudaMemcpyDeviceToDevice, stream);
|
||||
}
|
||||
|
||||
int ldb_aligned = ldb;
|
||||
|
|
@ -39,11 +41,13 @@ Status GemmInt8(int m, int n, int k,
|
|||
if ((mask & ldb_aligned) != 0) {
|
||||
ldb_aligned = roundoff(ldb, 32);
|
||||
b_padded = cuda_kernel->GetScratchBuffer<int8_t>(k * ldb_aligned);
|
||||
cudaMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, cudaMemcpyDeviceToDevice, 0);
|
||||
cudaMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, cudaMemcpyDeviceToDevice, stream);
|
||||
}
|
||||
|
||||
cublasHandle_t cublas = cuda_kernel->CublasHandle();
|
||||
cublasSetStream(cublas, stream);
|
||||
CUBLAS_RETURN_IF_ERROR(cublasGemmEx(
|
||||
cuda_kernel->CublasHandle(),
|
||||
cublas,
|
||||
CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
n, m, k,
|
||||
&alpha,
|
||||
|
|
|
|||
Loading…
Reference in a new issue