diff --git a/onnxruntime/core/providers/rocm/rocm_call.cc b/onnxruntime/core/providers/rocm/rocm_call.cc index 2d0bb006a8..a73ef9b34b 100644 --- a/onnxruntime/core/providers/rocm/rocm_call.cc +++ b/onnxruntime/core/providers/rocm/rocm_call.cc @@ -170,9 +170,4 @@ template Status RocmCall(ncclResult_t retCode, const char* template void RocmCall(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line); #endif -#ifdef USE_HIPBLASLT -template Status RocmCall(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line); -template void RocmCall(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line); -#endif - } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h index 6554ed977c..486ce5bfb7 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h +++ b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h @@ -37,26 +37,26 @@ enum ActivationType { }; template -constexpr hipblasltDatatype_t HipBlasDataTypeFor(); +constexpr hipDataType HipBlasDataTypeFor(); template <> -constexpr hipblasltDatatype_t HipBlasDataTypeFor() { - return HIPBLASLT_R_32F; +constexpr hipDataType HipBlasDataTypeFor() { + return HIP_R_32F; } template <> -constexpr hipblasltDatatype_t HipBlasDataTypeFor() { - return HIPBLASLT_R_16F; +constexpr hipDataType HipBlasDataTypeFor() { + return HIP_R_16F; } template <> -constexpr hipblasltDatatype_t HipBlasDataTypeFor() { - return HIPBLASLT_R_16B; +constexpr hipDataType HipBlasDataTypeFor() { + return HIP_R_16BF; } template <> -constexpr hipblasltDatatype_t HipBlasDataTypeFor() { - return HIPBLASLT_R_64F; +constexpr hipDataType HipBlasDataTypeFor() { + return HIP_R_64F; } template @@ -108,7 +108,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp hipblasOperation_t trans_a = MapBlasOpToHipBlasLt(); hipblasOperation_t trans_b = MapBlasOpToHipBlasLt(); - hipblasltDatatype_t in_out_datatype = HipBlasDataTypeFor(); + hipDataType in_out_datatype = HipBlasDataTypeFor(); std::vector heuristic_result; HIPBLASLT_CALL_THROW(hipblaslt_ext::getAllAlgos(handle, @@ -119,7 +119,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp in_out_datatype, in_out_datatype, in_out_datatype, - HIPBLASLT_COMPUTE_F32, + HIPBLAS_COMPUTE_32F, heuristic_result)); HIPBLASLT_CALL_THROW(hipblasLtDestroy(handle)); @@ -161,7 +161,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, row_a, col_a, lda)); HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, row_b, col_b, ldb)); HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, row_c, col_c, ldc)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLASLT_R_32F)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); int batch = GetBatchCountFromParams(params); if (batch > 1) { diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index ff246503e8..6a8154681e 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -21,6 +21,10 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): s = s.replace("kCudaStreamCopyIn", "kHipStreamCopyIn") s = s.replace("kCudaStreamCopyOut", "kHipStreamCopyOut") s = s.replace("kTotalCudaStreams", "kTotalHipStreams") + + # in rocm 6.0, hipify-perl, the -roc option also maps __half -> rocblas_half which we don't want + s = s.replace("rocblas_half", "__half") + # these should be "hip" but it's easier to just use rocm to avoid complicated file renaming s = s.replace("CudaGraph", "RocmGraph") s = s.replace("CUDAGraph", "ROCMGraph")