fix issue when build with hipblasLt on rocm6.1 (#22553)

### Description
<!-- Describe your changes. -->

hipblasLt library is released with rocm6.x, and current onnxruntime's
code need some modifications to match new hipblasLt API.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
kailums 2024-10-28 13:57:08 +08:00 committed by GitHub
parent 7ad78733e6
commit dd28f09ce2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 16 additions and 17 deletions

View file

@ -170,9 +170,4 @@ template Status RocmCall<ncclResult_t, false>(ncclResult_t retCode, const char*
template void RocmCall<ncclResult_t, true>(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line);
#endif
#ifdef USE_HIPBLASLT
template Status RocmCall<hipblasStatus_t, false>(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line);
template void RocmCall<hipblasStatus_t, true>(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line);
#endif
} // namespace onnxruntime

View file

@ -37,26 +37,26 @@ enum ActivationType {
};
template <typename T>
constexpr hipblasltDatatype_t HipBlasDataTypeFor();
constexpr hipDataType HipBlasDataTypeFor();
template <>
constexpr hipblasltDatatype_t HipBlasDataTypeFor<float>() {
return HIPBLASLT_R_32F;
constexpr hipDataType HipBlasDataTypeFor<float>() {
return HIP_R_32F;
}
template <>
constexpr hipblasltDatatype_t HipBlasDataTypeFor<half>() {
return HIPBLASLT_R_16F;
constexpr hipDataType HipBlasDataTypeFor<half>() {
return HIP_R_16F;
}
template <>
constexpr hipblasltDatatype_t HipBlasDataTypeFor<BFloat16>() {
return HIPBLASLT_R_16B;
constexpr hipDataType HipBlasDataTypeFor<BFloat16>() {
return HIP_R_16BF;
}
template <>
constexpr hipblasltDatatype_t HipBlasDataTypeFor<double>() {
return HIPBLASLT_R_64F;
constexpr hipDataType HipBlasDataTypeFor<double>() {
return HIP_R_64F;
}
template <BlasOp Op>
@ -108,7 +108,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp
hipblasOperation_t trans_a = MapBlasOpToHipBlasLt<OpB>();
hipblasOperation_t trans_b = MapBlasOpToHipBlasLt<OpA>();
hipblasltDatatype_t in_out_datatype = HipBlasDataTypeFor<T>();
hipDataType in_out_datatype = HipBlasDataTypeFor<T>();
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
HIPBLASLT_CALL_THROW(hipblaslt_ext::getAllAlgos(handle,
@ -119,7 +119,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLASLT_COMPUTE_F32,
HIPBLAS_COMPUTE_32F,
heuristic_result));
HIPBLASLT_CALL_THROW(hipblasLtDestroy(handle));
@ -161,7 +161,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, row_a, col_a, lda));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, row_b, col_b, ldb));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, row_c, col_c, ldc));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLASLT_R_32F));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F));
int batch = GetBatchCountFromParams<T>(params);
if (batch > 1) {

View file

@ -21,6 +21,10 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path):
s = s.replace("kCudaStreamCopyIn", "kHipStreamCopyIn")
s = s.replace("kCudaStreamCopyOut", "kHipStreamCopyOut")
s = s.replace("kTotalCudaStreams", "kTotalHipStreams")
# in rocm 6.0, hipify-perl, the -roc option also maps __half -> rocblas_half which we don't want
s = s.replace("rocblas_half", "__half")
# these should be "hip" but it's easier to just use rocm to avoid complicated file renaming
s = s.replace("CudaGraph", "RocmGraph")
s = s.replace("CUDAGraph", "ROCMGraph")