Add ck's streamk and splitk gemm impl (#17280)

This commit is contained in:
cloudhan 2023-09-04 11:49:07 +08:00 committed by GitHub
parent 5e747071be
commit 6ea3908db4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 3 deletions

View file

@ -1696,6 +1696,8 @@ if (onnxruntime_USE_ROCM)
device_gemm_instance
device_gemm_add_fastgelu_instance
device_gemm_fastgelu_instance
device_gemm_splitk_instance
device_gemm_streamk_instance
device_batched_gemm_instance
device_softmax_instance
)

View file

@ -13,6 +13,8 @@
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
@ -50,9 +52,8 @@ auto GetCKGemmTypeStringAndOps() {
auto ck_gemm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmParams<T>* params) -> Status {
auto one = ToHipType<T>::FromFloat(1.0f);
auto zero = ToHipType<T>::FromFloat(0.0f);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->alpha != one || params->beta != zero,
impl->GetTypeString(), " only supports alpha == 1 and beta == 0", params->Signature());
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->alpha != one || params->beta != zero,
impl->GetTypeString(), " only supports alpha == 1 and beta == 0");
auto nop = Nop{};
auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c,
@ -69,6 +70,80 @@ auto GetCKGemmTypeStringAndOps() {
return ret;
}
template <typename T, typename ALayout, typename BLayout>
auto GetCKStreamKGemmTypeStringAndOps() {
using CKDataType = typename CKDataTypeAdaptor<T>::type;
using DeviceGemm = ck::tensor_operation::device::DeviceGemmStreamK<
ALayout, BLayout, Row,
CKDataType, CKDataType, CKDataType,
Nop, Nop, Nop>;
using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceGemm>;
std::vector<std::pair<std::string, Op<GemmParams<T>>>> ret;
for (auto&& impl : InstanceFactory::GetInstances()) {
auto type_string = impl->GetTypeString();
auto invoker = impl->MakeInvokerPointer();
auto ck_gemm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmParams<T>* params) -> Status {
auto one = ToHipType<T>::FromFloat(1.0f);
auto zero = ToHipType<T>::FromFloat(0.0f);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->alpha != one || params->beta != zero,
impl->GetTypeString(), " only supports alpha == 1 and beta == 0");
auto nop = Nop{};
auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c,
params->m, params->n, params->k,
params->lda, params->ldb, params->ldc,
nop, nop, nop);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support ", params->Signature());
invoker->Run(arg.get(), StreamConfig{params->StreamHandle()});
return Status::OK();
};
ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op)));
}
return ret;
}
template <typename T, typename ALayout, typename BLayout>
auto GetCKSplitKGemmTypeStringAndOps() {
using CKDataType = typename CKDataTypeAdaptor<T>::type;
using DeviceGemm = ck::tensor_operation::device::DeviceGemmSplitK<
ALayout, BLayout, Row,
CKDataType, CKDataType, CKDataType,
Nop, Nop, Nop>;
using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceGemm>;
std::vector<std::pair<std::string, Op<GemmParams<T>>>> ret;
for (auto num_split : {4, 16, 64}) {
auto instances = InstanceFactory::GetInstances();
for (auto&& impl : instances) {
auto type_string = impl->GetTypeString() + "_SplitK" + std::to_string(num_split);
auto invoker = impl->MakeInvokerPointer();
auto ck_gemm_op = [num_split, impl = std::move(impl), invoker = std::move(invoker)](const GemmParams<T>* params) -> Status {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->k < 128 * num_split, "k=", params->k, " is too small, it makes no sense to use this split-k gemm.");
auto one = ToHipType<T>::FromFloat(1.0f);
auto zero = ToHipType<T>::FromFloat(0.0f);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->alpha != one || params->beta != zero,
impl->GetTypeString(), " only supports alpha == 1 and beta == 0");
auto nop = Nop{};
auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c,
params->m, params->n, params->k,
params->lda, params->ldb, params->ldc,
nop, nop, nop, num_split);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support ", params->Signature());
invoker->Run(arg.get(), StreamConfig{params->StreamHandle()});
return Status::OK();
};
ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op)));
}
}
return ret;
}
template <typename T, typename ALayout, typename BLayout>
auto GetCKStridedBatchedGemmTypeStringAndOps() {
using CKDataType = typename CKDataTypeAdaptor<T>::type;

View file

@ -58,6 +58,15 @@ class GemmTunableOp : public TunableOp<GemmParams<T>> {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
for (auto&& [_, op] : GetCKStreamKGemmTypeStringAndOps<T, ALayout, BLayout>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
for (auto&& [_, op] : GetCKSplitKGemmTypeStringAndOps<T, ALayout, BLayout>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
#endif
}

View file

@ -60,6 +60,14 @@ class CKGemm : public IKernelExplorer {
type_strings_.emplace_back(std::move(type_string));
ops_.emplace_back(std::move(op));
}
for (auto&& [type_string, op] : GetCKStreamKGemmTypeStringAndOps<T, ALayout, BLayout>()) {
type_strings_.emplace_back(std::move(type_string));
ops_.emplace_back(std::move(op));
}
for (auto&& [type_string, op] : GetCKSplitKGemmTypeStringAndOps<T, ALayout, BLayout>()) {
type_strings_.emplace_back(std::move(type_string));
ops_.emplace_back(std::move(op));
}
ORT_ENFORCE(!ops_.empty());
}