Fix CUDA10.2 Build Break for BFloat16 Change (#10331)

* fix build break on cuda 10.2

* fix linux build
This commit is contained in:
Vincent Wang 2022-01-20 18:17:28 +08:00 committed by GitHub
parent 4aa7cee0d8
commit 001cc53968
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -112,32 +112,25 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle,
}
}
inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const BFloat16* alpha,
const BFloat16* A, int lda,
const BFloat16* B, int ldb,
const BFloat16* beta,
BFloat16* C, int ldc,
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m,
int n, int k, const BFloat16* alpha, const BFloat16* A, int lda,
const BFloat16* B, int ldb, const BFloat16* beta, BFloat16* C, int ldc,
const cudaDeviceProp& /*prop*/) {
float h_a = alpha->ToFloat();
float h_b = beta->ToFloat();
// accumulating in FP32
return cublasGemmEx(handle,
transa,
transb,
m, n, k,
&h_a,
A, CUDA_R_16BF, lda,
B, CUDA_R_16BF, ldb,
&h_b,
C, CUDA_R_16BF, ldc,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
return cublasGemmEx(handle, transa, transb, m, n, k, &h_a, A, CUDA_R_16BF, lda, B, CUDA_R_16BF, ldb, &h_b, C,
CUDA_R_16BF, ldc, CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT);
}
#else
inline cublasStatus_t cublasGemmHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int,
const BFloat16*, const BFloat16*, int, const BFloat16*, int, const BFloat16*,
BFloat16*, int, const cudaDeviceProp&) {
return CUBLAS_STATUS_NOT_SUPPORTED;
}
#endif
// batched gemm
inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle,
@ -236,34 +229,27 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle,
}
}
inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const BFloat16* alpha,
const BFloat16* Aarray[], int lda,
const BFloat16* Barray[], int ldb,
const BFloat16* beta,
BFloat16* Carray[], int ldc,
int batch_count,
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k, const BFloat16* alpha, const BFloat16* Aarray[],
int lda, const BFloat16* Barray[], int ldb, const BFloat16* beta,
BFloat16* Carray[], int ldc, int batch_count,
const cudaDeviceProp& /*prop*/) {
float h_a = alpha->ToFloat();
float h_b = beta->ToFloat();
// accumulating in FP32
return cublasGemmBatchedEx(handle,
transa,
transb,
m, n, k,
&h_a,
(const void**)Aarray, CUDA_R_16BF, lda,
(const void**)Barray, CUDA_R_16BF, ldb,
&h_b,
(void**)Carray, CUDA_R_16BF, ldc,
batch_count,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT);
return cublasGemmBatchedEx(handle, transa, transb, m, n, k, &h_a, (const void**)Aarray, CUDA_R_16BF, lda,
(const void**)Barray, CUDA_R_16BF, ldb, &h_b, (void**)Carray, CUDA_R_16BF, ldc,
batch_count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT);
}
#else
inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int,
const BFloat16*, const BFloat16*[], int, const BFloat16*[], int,
const BFloat16*, BFloat16*[], int, int, const cudaDeviceProp&) {
return CUBLAS_STATUS_NOT_SUPPORTED;
}
#endif
// strided batched gemm
inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle,
@ -425,36 +411,29 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle,
}
}
inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const BFloat16* alpha,
const BFloat16* A, int lda,
long long int strideA,
const BFloat16* B, int ldb,
long long int strideB,
const BFloat16* beta,
BFloat16* C, int ldc,
long long int strideC,
int batch_count,
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const BFloat16* alpha, const BFloat16* A, int lda,
long long int strideA, const BFloat16* B, int ldb,
long long int strideB, const BFloat16* beta, BFloat16* C, int ldc,
long long int strideC, int batch_count,
const cudaDeviceProp& /*prop*/) {
float h_a = alpha->ToFloat();
float h_b = beta->ToFloat();
// accumulating in FP32
return cublasGemmStridedBatchedEx(handle,
transa,
transb,
m, n, k,
&h_a,
A, CUDA_R_16BF, lda, strideA,
B, CUDA_R_16BF, ldb, strideB,
&h_b,
C, CUDA_R_16BF, ldc, strideC,
batch_count,
CUDA_R_32F,
return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, &h_a, A, CUDA_R_16BF, lda, strideA, B, CUDA_R_16BF,
ldb, strideB, &h_b, C, CUDA_R_16BF, ldc, strideC, batch_count, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT);
}
#else
inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int,
int, const BFloat16*, const BFloat16*, int, long long int,
const BFloat16*, int, long long int, const BFloat16*, BFloat16*,
int, long long int, int, const cudaDeviceProp&) {
return CUBLAS_STATUS_NOT_SUPPORTED;
}
#endif
// transpose using geam
inline cublasStatus_t cublasTransposeHelper(cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, float* C, int ldc) {