mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
fix issue when build with hipblasLt on rocm6.1 (#22553)
### Description <!-- Describe your changes. --> hipblasLt library is released with rocm6.x, and current onnxruntime's code need some modifications to match new hipblasLt API. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
7ad78733e6
commit
dd28f09ce2
3 changed files with 16 additions and 17 deletions
|
|
@ -170,9 +170,4 @@ template Status RocmCall<ncclResult_t, false>(ncclResult_t retCode, const char*
|
|||
template void RocmCall<ncclResult_t, true>(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line);
|
||||
#endif
|
||||
|
||||
#ifdef USE_HIPBLASLT
|
||||
template Status RocmCall<hipblasStatus_t, false>(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line);
|
||||
template void RocmCall<hipblasStatus_t, true>(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line);
|
||||
#endif
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -37,26 +37,26 @@ enum ActivationType {
|
|||
};
|
||||
|
||||
template <typename T>
|
||||
constexpr hipblasltDatatype_t HipBlasDataTypeFor();
|
||||
constexpr hipDataType HipBlasDataTypeFor();
|
||||
|
||||
template <>
|
||||
constexpr hipblasltDatatype_t HipBlasDataTypeFor<float>() {
|
||||
return HIPBLASLT_R_32F;
|
||||
constexpr hipDataType HipBlasDataTypeFor<float>() {
|
||||
return HIP_R_32F;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasltDatatype_t HipBlasDataTypeFor<half>() {
|
||||
return HIPBLASLT_R_16F;
|
||||
constexpr hipDataType HipBlasDataTypeFor<half>() {
|
||||
return HIP_R_16F;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasltDatatype_t HipBlasDataTypeFor<BFloat16>() {
|
||||
return HIPBLASLT_R_16B;
|
||||
constexpr hipDataType HipBlasDataTypeFor<BFloat16>() {
|
||||
return HIP_R_16BF;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasltDatatype_t HipBlasDataTypeFor<double>() {
|
||||
return HIPBLASLT_R_64F;
|
||||
constexpr hipDataType HipBlasDataTypeFor<double>() {
|
||||
return HIP_R_64F;
|
||||
}
|
||||
|
||||
template <BlasOp Op>
|
||||
|
|
@ -108,7 +108,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp
|
|||
|
||||
hipblasOperation_t trans_a = MapBlasOpToHipBlasLt<OpB>();
|
||||
hipblasOperation_t trans_b = MapBlasOpToHipBlasLt<OpA>();
|
||||
hipblasltDatatype_t in_out_datatype = HipBlasDataTypeFor<T>();
|
||||
hipDataType in_out_datatype = HipBlasDataTypeFor<T>();
|
||||
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
|
||||
|
||||
HIPBLASLT_CALL_THROW(hipblaslt_ext::getAllAlgos(handle,
|
||||
|
|
@ -119,7 +119,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp
|
|||
in_out_datatype,
|
||||
in_out_datatype,
|
||||
in_out_datatype,
|
||||
HIPBLASLT_COMPUTE_F32,
|
||||
HIPBLAS_COMPUTE_32F,
|
||||
heuristic_result));
|
||||
HIPBLASLT_CALL_THROW(hipblasLtDestroy(handle));
|
||||
|
||||
|
|
@ -161,7 +161,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp
|
|||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, row_a, col_a, lda));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, row_b, col_b, ldb));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, row_c, col_c, ldc));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLASLT_R_32F));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F));
|
||||
|
||||
int batch = GetBatchCountFromParams<T>(params);
|
||||
if (batch > 1) {
|
||||
|
|
|
|||
|
|
@ -21,6 +21,10 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path):
|
|||
s = s.replace("kCudaStreamCopyIn", "kHipStreamCopyIn")
|
||||
s = s.replace("kCudaStreamCopyOut", "kHipStreamCopyOut")
|
||||
s = s.replace("kTotalCudaStreams", "kTotalHipStreams")
|
||||
|
||||
# in rocm 6.0, hipify-perl, the -roc option also maps __half -> rocblas_half which we don't want
|
||||
s = s.replace("rocblas_half", "__half")
|
||||
|
||||
# these should be "hip" but it's easier to just use rocm to avoid complicated file renaming
|
||||
s = s.replace("CudaGraph", "RocmGraph")
|
||||
s = s.replace("CUDAGraph", "ROCMGraph")
|
||||
|
|
|
|||
Loading…
Reference in a new issue