mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Fix CUDA10.2 Build Break for BFloat16 Change (#10331)
* fix build break on cuda 10.2 * fix linux build
This commit is contained in:
parent
4aa7cee0d8
commit
001cc53968
1 changed files with 45 additions and 66 deletions
|
|
@ -112,32 +112,25 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle,
|
|||
}
|
||||
}
|
||||
|
||||
inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle,
|
||||
cublasOperation_t transa,
|
||||
cublasOperation_t transb,
|
||||
int m, int n, int k,
|
||||
const BFloat16* alpha,
|
||||
const BFloat16* A, int lda,
|
||||
const BFloat16* B, int ldb,
|
||||
const BFloat16* beta,
|
||||
BFloat16* C, int ldc,
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m,
|
||||
int n, int k, const BFloat16* alpha, const BFloat16* A, int lda,
|
||||
const BFloat16* B, int ldb, const BFloat16* beta, BFloat16* C, int ldc,
|
||||
const cudaDeviceProp& /*prop*/) {
|
||||
float h_a = alpha->ToFloat();
|
||||
float h_b = beta->ToFloat();
|
||||
|
||||
// accumulating in FP32
|
||||
return cublasGemmEx(handle,
|
||||
transa,
|
||||
transb,
|
||||
m, n, k,
|
||||
&h_a,
|
||||
A, CUDA_R_16BF, lda,
|
||||
B, CUDA_R_16BF, ldb,
|
||||
&h_b,
|
||||
C, CUDA_R_16BF, ldc,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
CUBLAS_GEMM_DEFAULT);
|
||||
return cublasGemmEx(handle, transa, transb, m, n, k, &h_a, A, CUDA_R_16BF, lda, B, CUDA_R_16BF, ldb, &h_b, C,
|
||||
CUDA_R_16BF, ldc, CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT);
|
||||
}
|
||||
#else
|
||||
inline cublasStatus_t cublasGemmHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int,
|
||||
const BFloat16*, const BFloat16*, int, const BFloat16*, int, const BFloat16*,
|
||||
BFloat16*, int, const cudaDeviceProp&) {
|
||||
return CUBLAS_STATUS_NOT_SUPPORTED;
|
||||
}
|
||||
#endif
|
||||
|
||||
// batched gemm
|
||||
inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle,
|
||||
|
|
@ -236,34 +229,27 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle,
|
|||
}
|
||||
}
|
||||
|
||||
inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle,
|
||||
cublasOperation_t transa,
|
||||
cublasOperation_t transb,
|
||||
int m, int n, int k,
|
||||
const BFloat16* alpha,
|
||||
const BFloat16* Aarray[], int lda,
|
||||
const BFloat16* Barray[], int ldb,
|
||||
const BFloat16* beta,
|
||||
BFloat16* Carray[], int ldc,
|
||||
int batch_count,
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
|
||||
int m, int n, int k, const BFloat16* alpha, const BFloat16* Aarray[],
|
||||
int lda, const BFloat16* Barray[], int ldb, const BFloat16* beta,
|
||||
BFloat16* Carray[], int ldc, int batch_count,
|
||||
const cudaDeviceProp& /*prop*/) {
|
||||
float h_a = alpha->ToFloat();
|
||||
float h_b = beta->ToFloat();
|
||||
|
||||
// accumulating in FP32
|
||||
return cublasGemmBatchedEx(handle,
|
||||
transa,
|
||||
transb,
|
||||
m, n, k,
|
||||
&h_a,
|
||||
(const void**)Aarray, CUDA_R_16BF, lda,
|
||||
(const void**)Barray, CUDA_R_16BF, ldb,
|
||||
&h_b,
|
||||
(void**)Carray, CUDA_R_16BF, ldc,
|
||||
batch_count,
|
||||
CUDA_R_32F,
|
||||
CUBLAS_GEMM_DEFAULT);
|
||||
return cublasGemmBatchedEx(handle, transa, transb, m, n, k, &h_a, (const void**)Aarray, CUDA_R_16BF, lda,
|
||||
(const void**)Barray, CUDA_R_16BF, ldb, &h_b, (void**)Carray, CUDA_R_16BF, ldc,
|
||||
batch_count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT);
|
||||
}
|
||||
#else
|
||||
inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int,
|
||||
const BFloat16*, const BFloat16*[], int, const BFloat16*[], int,
|
||||
const BFloat16*, BFloat16*[], int, int, const cudaDeviceProp&) {
|
||||
return CUBLAS_STATUS_NOT_SUPPORTED;
|
||||
}
|
||||
#endif
|
||||
|
||||
// strided batched gemm
|
||||
inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle,
|
||||
|
|
@ -425,36 +411,29 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle,
|
|||
}
|
||||
}
|
||||
|
||||
inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle,
|
||||
cublasOperation_t transa,
|
||||
cublasOperation_t transb,
|
||||
int m, int n, int k,
|
||||
const BFloat16* alpha,
|
||||
const BFloat16* A, int lda,
|
||||
long long int strideA,
|
||||
const BFloat16* B, int ldb,
|
||||
long long int strideB,
|
||||
const BFloat16* beta,
|
||||
BFloat16* C, int ldc,
|
||||
long long int strideC,
|
||||
int batch_count,
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n, int k,
|
||||
const BFloat16* alpha, const BFloat16* A, int lda,
|
||||
long long int strideA, const BFloat16* B, int ldb,
|
||||
long long int strideB, const BFloat16* beta, BFloat16* C, int ldc,
|
||||
long long int strideC, int batch_count,
|
||||
const cudaDeviceProp& /*prop*/) {
|
||||
float h_a = alpha->ToFloat();
|
||||
float h_b = beta->ToFloat();
|
||||
// accumulating in FP32
|
||||
return cublasGemmStridedBatchedEx(handle,
|
||||
transa,
|
||||
transb,
|
||||
m, n, k,
|
||||
&h_a,
|
||||
A, CUDA_R_16BF, lda, strideA,
|
||||
B, CUDA_R_16BF, ldb, strideB,
|
||||
&h_b,
|
||||
C, CUDA_R_16BF, ldc, strideC,
|
||||
batch_count,
|
||||
CUDA_R_32F,
|
||||
return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, &h_a, A, CUDA_R_16BF, lda, strideA, B, CUDA_R_16BF,
|
||||
ldb, strideB, &h_b, C, CUDA_R_16BF, ldc, strideC, batch_count, CUDA_R_32F,
|
||||
CUBLAS_GEMM_DEFAULT);
|
||||
}
|
||||
#else
|
||||
inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int,
|
||||
int, const BFloat16*, const BFloat16*, int, long long int,
|
||||
const BFloat16*, int, long long int, const BFloat16*, BFloat16*,
|
||||
int, long long int, int, const cudaDeviceProp&) {
|
||||
return CUBLAS_STATUS_NOT_SUPPORTED;
|
||||
}
|
||||
#endif
|
||||
|
||||
// transpose using geam
|
||||
inline cublasStatus_t cublasTransposeHelper(cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, float* C, int ldc) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue