[ROCm] TunableOp: add hipBLASLt tuning logic (#16338)

### Description
- Add hipBLASLt tuning logic in place of default hipBLASLt
implementation;
- add kernel explorer for hipBLASLt.

related operators: Gemm, StridedBatchedGemm, and GemmFastGelu.

Temporarily mark algos that require extra workspace as unsupported.
Will add workspace support in later PR, which will change Gemm Params
def and affect multiple files.
This commit is contained in:
mindest 2023-07-14 08:20:58 +08:00 committed by GitHub
parent a3fc04ba74
commit 810512c658
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 490 additions and 183 deletions

View file

@ -70,7 +70,10 @@ class GemmFastGeluTunableOp : public TunableOp<GemmFastGeluParams<T>> {
#endif
#ifdef USE_HIPBLASLT
this->RegisterOp(HipBlasLtGemmFastGeluOp<T>);
for (auto&& [_, op] : GetHipBlasLtGemmFastGeluTypeStringAndOps<T, ALayout, BLayout>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
#endif
}
};

View file

@ -5,6 +5,10 @@
#ifdef USE_HIPBLASLT
#include <hipblaslt/hipblaslt.h>
#include <hipblaslt/hipblaslt-ext.hpp>
#include "core/providers/rocm/tunable/gemm_ck.cuh"
#include "core/providers/rocm/rocm_execution_provider.h"
#include "core/providers/rocm/rocm_stream_handle.h"
#endif
#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h"
@ -22,13 +26,6 @@ using onnxruntime::contrib::rocm::blas::GemmFastGeluParams;
#ifdef USE_HIPBLASLT
// For large K and small M/N, K dim will be split to multiple workgroups and buffers,
// which will require additional workspace. Here we set the max workspace size to 32MB.
constexpr const size_t kHipBlasLtMaxWorkSpaceSizeInBytes = 32 * 1024 * 1024;
// We only keep one heuristic result here. Note that for tuned input sizes, the first result
// will be the most performant one; but in untuned cases, this is not guaranteed.
constexpr const int kHeuristicResultCount = 1;
enum ActivationType {
NONE = 0,
RELU = 1,
@ -36,184 +33,234 @@ enum ActivationType {
};
template <typename T>
constexpr hipblasDatatype_t HipBlasDataTypeFor(const T*);
constexpr hipblasDatatype_t HipBlasDataTypeFor();
template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor(const float*) {
constexpr hipblasDatatype_t HipBlasDataTypeFor<float>() {
return HIPBLAS_R_32F;
}
template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor(const half*) {
constexpr hipblasDatatype_t HipBlasDataTypeFor<half>() {
return HIPBLAS_R_16F;
}
template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor(const BFloat16*) {
constexpr hipblasDatatype_t HipBlasDataTypeFor<BFloat16>() {
return HIPBLAS_R_16B;
}
template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor(const double*) {
constexpr hipblasDatatype_t HipBlasDataTypeFor<double>() {
return HIPBLAS_R_64F;
}
template <typename Layout>
constexpr hipblasOperation_t MapCKLayoutToHipBlasLt() {
if constexpr (std::is_same_v<Layout, Row>) {
return HIPBLAS_OP_N;
}
return HIPBLAS_OP_T;
}
template <typename T, typename ParamsT>
Status HipBlasLtMatMul(const ParamsT* params, int64_t batch, ActivationType activation_type = ActivationType::NONE,
bool enable_bias = false, const T* d_bias = nullptr,
bool enable_scaleD = false, const T* d_scaleD = nullptr) {
int GetBatchCountFromParams(const ParamsT* params) {
ORT_UNUSED_PARAMETER(params);
return 1;
}
template <typename T>
int GetBatchCountFromParams(const StridedBatchedGemmParams<T>* params) {
return params->batch;
}
template <typename T, typename ParamsT>
const T* GetBiasFromParams(const ParamsT* params) {
ORT_UNUSED_PARAMETER(params);
return nullptr;
}
template <typename T>
const T* GetBiasFromParams(const GemmFastGeluParams<T>* params) {
return params->bias;
}
template <typename T, typename ParamsT>
std::string TypeStringFor() {
if constexpr (std::is_same_v<ParamsT, GemmParams<T>>) {
return "Gemm";
} else if constexpr (std::is_same_v<ParamsT, StridedBatchedGemmParams<T>>) {
return "StridedBatchedGemm";
} else if constexpr (std::is_same_v<ParamsT, GemmFastGeluParams<T>>) {
return "GemmFastGelu";
}
return "UnknownType";
}
template <typename T, typename ALayout, typename BLayout, typename ParamsT>
auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationType::NONE) {
hipblasLtHandle_t handle;
HIPBLASLT_RETURN_IF_ERROR(hipblasLtCreate(&handle));
HIPBLASLT_CALL_THROW(hipblasLtCreate(&handle));
// Note: properties of original matrices A and B are swapped.
int64_t lda = (params->opb == BlasOp::N) ? params->n : params->k;
int64_t ldb = (params->opa == BlasOp::N) ? params->k : params->m;
int64_t ldc = params->n;
int64_t stride_a = (params->opb == BlasOp::N) ? lda * params->k : lda * params->n;
int64_t stride_b = (params->opa == BlasOp::N) ? ldb * params->m : ldb * params->k;
int64_t stride_c = ldc * params->m;
float alpha = static_cast<float>(params->alpha);
float beta = static_cast<float>(params->beta);
int row_a, col_a, row_b, col_b, row_c, col_c;
row_a = lda;
col_a = (params->opb == BlasOp::N) ? params->k : params->n;
row_b = ldb;
col_b = (params->opa == BlasOp::N) ? params->m : params->k;
row_c = ldc;
col_c = params->m;
hipblasOperation_t trans_a = MapCKLayoutToHipBlasLt<BLayout>();
hipblasOperation_t trans_b = MapCKLayoutToHipBlasLt<ALayout>();
hipblasDatatype_t in_out_datatype = HipBlasDataTypeFor<T>();
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
hipblasDatatype_t in_out_datatype = HipBlasDataTypeFor(params->a);
hipblasLtMatrixLayout_t mat_a, mat_b, mat_c;
hipblasLtMatmulDesc_t matmul;
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, row_a, col_a, lda));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, row_b, col_b, ldb));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, row_c, col_c, ldc));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F));
HIPBLASLT_CALL_THROW(hipblaslt_ext::getAllAlgos(handle,
hipblaslt_ext::GemmType::HIPBLASLT_GEMM,
trans_a,
trans_b,
in_out_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLASLT_COMPUTE_F32,
heuristic_result));
HIPBLASLT_CALL_THROW(hipblasLtDestroy(handle));
if (batch > 1) {
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
int returned_algo_count = heuristic_result.size();
std::vector<std::pair<std::string, Op<ParamsT>>> ret;
for (int i = 0; i < returned_algo_count; i++) {
hipblasLtMatmulAlgo_t algo = heuristic_result[i].algo;
auto hipblaslt_gemm_op = [=](const ParamsT* params) -> Status {
hipblasLtHandle_t op_handle;
HIPBLASLT_RETURN_IF_ERROR(hipblasLtCreate(&op_handle));
// Note: properties of original matrices A and B are swapped.
int64_t lda = (params->opb == BlasOp::N) ? params->n : params->k;
int64_t ldb = (params->opa == BlasOp::N) ? params->k : params->m;
int64_t ldc = params->n;
int64_t stride_a = (params->opb == BlasOp::N) ? lda * params->k : lda * params->n;
int64_t stride_b = (params->opa == BlasOp::N) ? ldb * params->m : ldb * params->k;
int64_t stride_c = ldc * params->m;
float alpha = static_cast<float>(params->alpha);
float beta = static_cast<float>(params->beta);
int row_a, col_a, row_b, col_b, row_c, col_c;
row_a = lda;
col_a = (params->opb == BlasOp::N) ? params->k : params->n;
row_b = ldb;
col_b = (params->opa == BlasOp::N) ? params->m : params->k;
row_c = ldc;
col_c = params->m;
hipblasLtMatrixLayout_t mat_a, mat_b, mat_c;
hipblasLtMatmulDesc_t matmul;
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, row_a, col_a, lda));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, row_b, col_b, ldb));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, row_c, col_c, ldc));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F));
int batch = GetBatchCountFromParams<T>(params);
if (batch > 1) {
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
}
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(int32_t)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(int32_t)));
// Deduce enable_bias from params
auto d_bias = GetBiasFromParams<T>(params);
bool enable_bias = d_bias != nullptr;
hipblasLtEpilogue_t epilogue;
switch (activation_type) {
case ActivationType::NONE:
epilogue = enable_bias ? HIPBLASLT_EPILOGUE_BIAS : HIPBLASLT_EPILOGUE_DEFAULT;
break;
case ActivationType::RELU:
epilogue = enable_bias ? HIPBLASLT_EPILOGUE_RELU_BIAS : HIPBLASLT_EPILOGUE_RELU;
break;
case ActivationType::GELU:
epilogue = enable_bias ? HIPBLASLT_EPILOGUE_GELU_BIAS : HIPBLASLT_EPILOGUE_GELU;
break;
default:
throw std::runtime_error("Unsupported activation type for HipBlasLtMatMul");
}
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
if (enable_bias) {
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &d_bias, sizeof(void*)));
}
size_t workspace_size = 0;
hipblasLtMatmulAlgo_t algo_i = algo;
auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle,
matmul,
&alpha,
mat_a,
mat_b,
&beta,
mat_c,
mat_c,
algo_i,
workspace_size);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status != HIPBLAS_STATUS_SUCCESS, "hipBLASLt find_all: algo not supported, index ", std::to_string(i));
// TODO: support workspace in next PR
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
workspace_size > 0, "hipBLASLt find_all: extra workspace not supported for now.");
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmul(op_handle,
matmul,
&alpha,
params->b,
mat_a,
params->a,
mat_b,
&beta,
params->c,
mat_c,
params->c,
mat_c,
&algo_i,
nullptr,
0,
params->stream));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescDestroy(matmul));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_a));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_b));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_c));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtDestroy(op_handle));
return Status::OK();
};
std::string type_string = onnxruntime::MakeString(TypeStringFor<T, ParamsT>(), "HipBlasLt_", i);
ret.emplace_back(type_string, std::move(hipblaslt_gemm_op));
}
hipblasOperation_t trans_a = (params->opb == BlasOp::N) ? HIPBLAS_OP_N : HIPBLAS_OP_T;
hipblasOperation_t trans_b = (params->opa == BlasOp::N) ? HIPBLAS_OP_N : HIPBLAS_OP_T;
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(int32_t)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(int32_t)));
hipblasLtEpilogue_t epilogue;
switch (activation_type) {
case ActivationType::NONE:
epilogue = enable_bias ? HIPBLASLT_EPILOGUE_BIAS : HIPBLASLT_EPILOGUE_DEFAULT;
break;
case ActivationType::RELU:
epilogue = enable_bias ? HIPBLASLT_EPILOGUE_RELU_BIAS : HIPBLASLT_EPILOGUE_RELU;
break;
case ActivationType::GELU:
epilogue = enable_bias ? HIPBLASLT_EPILOGUE_GELU_BIAS : HIPBLASLT_EPILOGUE_GELU;
break;
default:
throw std::runtime_error("Unsupported activation type for HipBlasLtMatMul");
}
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
if (enable_bias) {
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &d_bias, sizeof(void*)));
}
if (enable_scaleD) {
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, &d_scaleD, sizeof(void*)));
}
hipblasLtMatmulPreference_t pref;
void* workspace;
size_t max_workspace_size = kHipBlasLtMaxWorkSpaceSizeInBytes;
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulPreferenceCreate(&pref));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulPreferenceSetAttribute(
pref, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &max_workspace_size, sizeof(max_workspace_size)));
hipblasLtMatmulHeuristicResult_t heuristic_result[kHeuristicResultCount] = {};
int ret_algo_count = 0;
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle,
matmul,
mat_a,
mat_b,
mat_c,
mat_c,
pref,
kHeuristicResultCount,
heuristic_result,
&ret_algo_count));
assert(ret_algo_count > 0);
size_t workspace_size = heuristic_result[0].workspaceSize;
if (workspace_size > 0) {
HIP_RETURN_IF_ERROR(hipMallocAsync(&workspace, workspace_size, params->stream));
}
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmul(handle,
matmul,
&alpha,
params->b,
mat_a,
params->a,
mat_b,
&beta,
params->c,
mat_c,
params->c,
mat_c,
&heuristic_result[0].algo,
workspace,
workspace_size,
params->stream));
if (workspace_size > 0) {
HIP_RETURN_IF_ERROR(hipFreeAsync(workspace, params->stream));
}
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulPreferenceDestroy(pref));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescDestroy(matmul));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_a));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_b));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_c));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtDestroy(handle));
return Status::OK();
return ret;
}
template <typename T>
Status HipBlasLtGemmOp(const GemmParams<T>* params) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((std::is_same_v<T, double>), "hipBLASLt does not support double inputs");
return HipBlasLtMatMul<T, GemmParams<T>>(params, /*batch=*/1);
template <typename T, typename ALayout, typename BLayout>
auto GetHipBlasLtGemmTypeStringAndOps() {
return GetHipBlasLtTypeStringAndOps<T, ALayout, BLayout, GemmParams<T>>();
}
template <typename T>
Status HipBlasLtStridedBatchedGemmOp(const StridedBatchedGemmParams<T>* params) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((std::is_same_v<T, double>), "hipBLASLt does not support double inputs");
return HipBlasLtMatMul<T, StridedBatchedGemmParams<T>>(params, params->batch);
};
template <typename T, typename ALayout, typename BLayout>
auto GetHipBlasLtStridedBatchedGemmTypeStringAndOps() {
return GetHipBlasLtTypeStringAndOps<T, ALayout, BLayout, StridedBatchedGemmParams<T>>();
}
template <typename T>
Status HipBlasLtGemmFastGeluOp(const GemmFastGeluParams<T>* params) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((std::is_same_v<T, double>), "hipBLASLt does not support double inputs");
bool enable_bias = nullptr != params->bias;
return HipBlasLtMatMul<T, GemmFastGeluParams<T>>(params, /*batch=*/1, ActivationType::GELU,
enable_bias, params->bias);
};
template <typename T, typename ALayout, typename BLayout>
auto GetHipBlasLtGemmFastGeluTypeStringAndOps() {
return GetHipBlasLtTypeStringAndOps<T, ALayout, BLayout, GemmFastGeluParams<T>>(ActivationType::GELU);
}
#endif // USE_HIPBLASLT

View file

@ -40,7 +40,10 @@ class GemmTunableOp : public TunableOp<GemmParams<T>> {
this->RegisterOp(RocBlasGemmOp<T>);
#ifdef USE_HIPBLASLT
this->RegisterOp(HipBlasLtGemmOp<T>);
for (auto&& [_, op] : GetHipBlasLtGemmTypeStringAndOps<T, ALayout, BLayout>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
#endif
#ifdef USE_ROCBLAS_EXTENSION_API
@ -141,7 +144,10 @@ class StridedBatchedGemmTunableOp : public TunableOp<StridedBatchedGemmParams<T>
this->RegisterOp(RocBlasStridedBatchedGemmOp<T>);
#ifdef USE_HIPBLASLT
this->RegisterOp(HipBlasLtStridedBatchedGemmOp<T>);
for (auto&& [_, op] : GetHipBlasLtStridedBatchedGemmTypeStringAndOps<T, ALayout, BLayout>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
#endif
#ifdef USE_ROCBLAS_EXTENSION_API

View file

@ -166,7 +166,10 @@ def profile_with_args(transa, transb, dtype, m, n, k, sort):
profile_gemmfastgelu_func(
getattr(ke, "GemmFastGeluTunable" + dtype_suffix + transab_suffix), dtype, m, n, k, transa, transb
)
profile_gemmfastgelu_func(getattr(ke, "GemmFastGeluHipBlasLt" + dtype_suffix), dtype, m, n, k, transa, transb)
if ke.is_hipblaslt_available():
profile_gemmfastgelu_func(
getattr(ke, "GemmFastGeluHipBlasLt" + dtype_suffix + transab_suffix), dtype, m, n, k, transa, transb
)
def profile():

View file

@ -179,6 +179,10 @@ def profile_with_args(dtype, transa, transb, m, n, k, sort):
profile_gemm_func(getattr(ke, "RocBlasGemm" + dtype_suffix), dtype, transa, transb, m, n, k)
profile_gemm_func(getattr(ke, "CKGemm" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k)
profile_gemm_func(getattr(ke, "GemmTunable" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k)
if ke.is_hipblaslt_available():
profile_gemm_func(
getattr(ke, "GemmHipBlasLt" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k
)
print()

View file

@ -20,7 +20,10 @@ namespace py = pybind11;
namespace onnxruntime {
#ifdef USE_HIPBLASLT
template <typename T>
using namespace rocm::tunable::blas::internal;
template <typename T, typename ALayout, typename BLayout>
class GemmFastGeluHipBlasLt : public IKernelExplorer {
public:
GemmFastGeluHipBlasLt(BlasOp opa, BlasOp opb,
@ -33,6 +36,8 @@ class GemmFastGeluHipBlasLt : public IKernelExplorer {
DeviceArray& c, int64_t ldc) : params_{} {
params_.tuning_ctx = TuningContext();
params_.stream = Stream();
// rocblas handle is not used for hipBLASLt
params_.handle = nullptr;
params_.opa = opa;
params_.opb = opb;
params_.m = m;
@ -47,44 +52,67 @@ class GemmFastGeluHipBlasLt : public IKernelExplorer {
params_.beta = beta;
params_.c = static_cast<T*>(c.ptr());
params_.ldc = ldc;
for (auto&& [type_string, op] : GetHipBlasLtGemmFastGeluTypeStringAndOps<T, ALayout, BLayout>()) {
type_strings_.emplace_back(std::move(type_string));
ops_.emplace_back(std::move(op));
}
ORT_ENFORCE(!ops_.empty());
}
void Run() override {
ORT_THROW_IF_ERROR((rocm::tunable::blas::internal::HipBlasLtGemmFastGeluOp<T>(&params_)));
ORT_THROW_IF_ERROR(ops_[selected_op_](&params_));
}
std::vector<std::string> ListOps() const {
return {"GemmFastGeluHipBlasLt"};
return type_strings_;
}
bool SelectOp(const std::string& name) {
Status status = rocm::tunable::blas::internal::HipBlasLtGemmFastGeluOp<T>(&params_);
return status.IsOK() && name == "GemmFastGeluHipBlasLt";
for (size_t i = 0; i < ops_.size(); i++) {
if (type_strings_[i] == name) {
selected_op_ = i;
Status status = ops_[i](&params_);
return status.IsOK();
}
}
ORT_THROW("Cannot find implementation ", name);
}
private:
using ParamsT = contrib::rocm::blas::GemmFastGeluParams<T>;
ParamsT params_{};
using OpT = Op<ParamsT>;
ParamsT params_;
std::vector<OpT> ops_;
std::vector<std::string> type_strings_;
size_t selected_op_{};
};
#define REGISTER_OP(type) \
py::class_<GemmFastGeluHipBlasLt<type>>(m, "GemmFastGeluHipBlasLt_" #type) \
.def(py::init<BlasOp, BlasOp, int64_t, int64_t, int64_t, \
double, \
DeviceArray&, int64_t, \
DeviceArray&, int64_t, \
DeviceArray&, \
double, \
DeviceArray&, int64_t>()) \
.def("SetRepeats", &GemmFastGeluHipBlasLt<type>::SetRepeats) \
.def("Run", &GemmFastGeluHipBlasLt<type>::Run) \
.def("Profile", &GemmFastGeluHipBlasLt<type>::Profile) \
.def("ListOps", &GemmFastGeluHipBlasLt<type>::ListOps) \
.def("SelectOp", &GemmFastGeluHipBlasLt<type>::SelectOp);
#define REGISTER_OP(type, alayout, blayout, layout_string) \
py::class_<GemmFastGeluHipBlasLt<type, alayout, blayout>>(m, "GemmFastGeluHipBlasLt_" #type "_" layout_string) \
.def(py::init<BlasOp, BlasOp, int64_t, int64_t, int64_t, \
double, \
DeviceArray&, int64_t, \
DeviceArray&, int64_t, \
DeviceArray&, \
double, \
DeviceArray&, int64_t>()) \
.def("SetRepeats", &GemmFastGeluHipBlasLt<type, alayout, blayout>::SetRepeats) \
.def("Profile", &GemmFastGeluHipBlasLt<type, alayout, blayout>::Profile) \
.def("Run", &GemmFastGeluHipBlasLt<type, alayout, blayout>::Run) \
.def("ListOps", &GemmFastGeluHipBlasLt<type, alayout, blayout>::ListOps) \
.def("SelectOp", &GemmFastGeluHipBlasLt<type, alayout, blayout>::SelectOp);
#define REGISTER_OP_FOR_ALL_TRANSAB(type) \
REGISTER_OP(type, Row, Row, "NN"); \
REGISTER_OP(type, Row, Col, "NT"); \
REGISTER_OP(type, Col, Row, "TN"); \
REGISTER_OP(type, Col, Col, "TT");
KE_REGISTER(m) {
REGISTER_OP(float)
REGISTER_OP(half)
REGISTER_OP_FOR_ALL_TRANSAB(float);
REGISTER_OP_FOR_ALL_TRANSAB(half);
}
#endif // USE_HIPBLASLT

View file

@ -0,0 +1,212 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <pybind11/stl.h>
#include <string>
#include <vector>
#ifdef USE_HIPBLASLT
#include "core/providers/rocm/tunable/gemm_hipblaslt.h"
#endif
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/tunable/gemm_common.h"
#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
using namespace onnxruntime::rocm::tunable::blas;
namespace py = pybind11;
namespace onnxruntime {
#ifdef USE_HIPBLASLT
using namespace rocm::tunable::blas::internal;
template <typename T, typename ALayout, typename BLayout>
class GemmHipBlasLt : public IKernelExplorer {
public:
GemmHipBlasLt(BlasOp opa, BlasOp opb,
int64_t m, int64_t n, int64_t k,
double alpha,
DeviceArray& a, int64_t lda,
DeviceArray& b, int64_t ldb,
double beta,
DeviceArray& c, int64_t ldc)
: params_{} {
params_.tuning_ctx = TuningContext();
params_.stream = Stream();
// rocblas handle is not used for hipBLASLt
params_.handle = nullptr;
params_.opa = opa;
params_.opb = opb;
params_.m = m;
params_.n = n;
params_.k = k;
params_.alpha = alpha;
params_.a = static_cast<T*>(a.ptr());
params_.lda = lda;
params_.b = static_cast<T*>(b.ptr());
params_.ldb = ldb;
params_.beta = beta;
params_.c = static_cast<T*>(c.ptr());
params_.ldc = ldc;
for (auto&& [type_string, op] : GetHipBlasLtGemmTypeStringAndOps<T, ALayout, BLayout>()) {
type_strings_.emplace_back(std::move(type_string));
ops_.emplace_back(std::move(op));
}
ORT_ENFORCE(!ops_.empty());
}
void Run() override {
ORT_THROW_IF_ERROR(ops_[selected_op_](&params_));
}
std::vector<std::string> ListOps() const {
return type_strings_;
}
bool SelectOp(const std::string& name) {
for (size_t i = 0; i < ops_.size(); i++) {
if (type_strings_[i] == name) {
selected_op_ = i;
Status status = ops_[i](&params_);
return status.IsOK();
}
}
ORT_THROW("Cannot find implementation ", name);
}
private:
using ParamsT = GemmParams<T>;
using OpT = Op<ParamsT>;
ParamsT params_;
std::vector<OpT> ops_;
std::vector<std::string> type_strings_;
size_t selected_op_{};
};
template <typename T, typename ALayout, typename BLayout>
class StridedBatchedGemmHipBlasLt : public IKernelExplorer {
public:
StridedBatchedGemmHipBlasLt(
BlasOp opa, BlasOp opb,
int64_t m, int64_t n, int64_t k,
double alpha,
DeviceArray& a, int64_t lda, int64_t stride_a,
DeviceArray& b, int64_t ldb, int64_t stride_b,
double beta,
DeviceArray& c, int64_t ldc, int64_t stride_c,
int64_t batch)
: params_{} {
params_.tuning_ctx = TuningContext();
params_.stream = Stream();
// rocblas handle is not used for hipBLASLt
params_.handle = nullptr;
params_.opa = opa;
params_.opb = opb;
params_.m = m;
params_.n = n;
params_.k = k;
params_.alpha = alpha;
params_.a = static_cast<T*>(a.ptr());
params_.lda = lda;
params_.stride_a = stride_a;
params_.b = static_cast<T*>(b.ptr());
params_.ldb = ldb;
params_.stride_b = stride_b;
params_.beta = beta;
params_.c = static_cast<T*>(c.ptr());
params_.ldc = ldc;
params_.stride_c = stride_c;
params_.batch = batch;
for (auto&& [type_string, op] : GetHipBlasLtStridedBatchedGemmTypeStringAndOps<T, ALayout, BLayout>()) {
type_strings_.emplace_back(std::move(type_string));
ops_.emplace_back(std::move(op));
}
ORT_ENFORCE(!ops_.empty());
}
void Run() override {
ORT_THROW_IF_ERROR(ops_[selected_op_](&params_));
}
std::vector<std::string> ListOps() const {
return type_strings_;
}
bool SelectOp(const std::string& name) {
for (size_t i = 0; i < ops_.size(); i++) {
if (type_strings_[i] == name) {
selected_op_ = i;
Status status = ops_[i](&params_);
return status.IsOK();
}
}
ORT_THROW("Cannot find implementation ", name);
}
private:
using ParamsT = StridedBatchedGemmParams<T>;
using OpT = Op<ParamsT>;
ParamsT params_;
std::vector<OpT> ops_;
std::vector<std::string> type_strings_;
size_t selected_op_{};
};
#define REGISTER_OP_COMMON(type, dtype, alayout, blayout, layout_string) \
py::class_<type<dtype, alayout, blayout>>(m, #type "_" #dtype "_" layout_string) \
.def("SetRepeats", &type<dtype, alayout, blayout>::SetRepeats) \
.def("Profile", &type<dtype, alayout, blayout>::Profile) \
.def("Run", &type<dtype, alayout, blayout>::Run) \
.def("ListOps", &type<dtype, alayout, blayout>::ListOps) \
.def("SelectOp", &type<dtype, alayout, blayout>::SelectOp)
#define REGISTER_GEMM_HIPBLASLT(dtype, alayout, blayout, layout_string) \
REGISTER_OP_COMMON(GemmHipBlasLt, dtype, alayout, blayout, layout_string) \
.def(py::init<BlasOp, BlasOp, int64_t, int64_t, int64_t, \
double, \
DeviceArray&, int64_t, \
DeviceArray&, int64_t, \
double, \
DeviceArray&, int64_t>());
#define REGISTER_GEMM_HIPBLASLT_FOR_ALL_TRANSAB(dtype) \
REGISTER_GEMM_HIPBLASLT(dtype, Row, Row, "NN"); \
REGISTER_GEMM_HIPBLASLT(dtype, Row, Col, "NT"); \
REGISTER_GEMM_HIPBLASLT(dtype, Col, Row, "TN"); \
REGISTER_GEMM_HIPBLASLT(dtype, Col, Col, "TT");
#define REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, alayout, blayout, layout_string) \
REGISTER_OP_COMMON(StridedBatchedGemmHipBlasLt, dtype, alayout, blayout, layout_string) \
.def(py::init<BlasOp, BlasOp, int64_t, int64_t, int64_t, \
double, \
DeviceArray&, int64_t, int64_t, \
DeviceArray&, int64_t, int64_t, \
double, \
DeviceArray&, int64_t, int64_t, \
int64_t>());
#define REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT_FOR_ALL_TRANSAB(dtype) \
REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, Row, Row, "NN"); \
REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, Row, Col, "NT"); \
REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, Col, Row, "TN"); \
REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, Col, Col, "TT");
KE_REGISTER(m) {
REGISTER_GEMM_HIPBLASLT_FOR_ALL_TRANSAB(float);
REGISTER_GEMM_HIPBLASLT_FOR_ALL_TRANSAB(half);
REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT_FOR_ALL_TRANSAB(float);
REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT_FOR_ALL_TRANSAB(half);
}
#endif // USE_HIPBLASLT
} // namespace onnxruntime

View file

@ -215,10 +215,14 @@ def profile_with_args(dtype, transa, transb, m, n, k, batch, sort):
fn_rocblas = getattr(ke, "RocBlasStridedBatchedGemm" + dtype_suffix)
fn_ck = getattr(ke, "CKStridedBatchedGemm" + dtype_suffix + transab_suffix)
fn_tunable = getattr(ke, "StridedBatchedGemmTunable" + dtype_suffix + transab_suffix)
if ke.is_hipblaslt_available():
fn_hipblaslt = getattr(ke, "StridedBatchedGemmHipBlasLt" + dtype_suffix + transab_suffix)
with ke.benchmark(sort):
profile_gemm_func(fn_rocblas, dtype, transa, transb, m, n, k, batch)
profile_gemm_func(fn_ck, dtype, transa, transb, m, n, k, batch)
profile_gemm_func(fn_tunable, dtype, transa, transb, m, n, k, batch)
if ke.is_hipblaslt_available():
profile_gemm_func(fn_hipblaslt, dtype, transa, transb, m, n, k, batch)
print()