mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Add ck's streamk and splitk gemm impl (#17280)
This commit is contained in:
parent
5e747071be
commit
6ea3908db4
4 changed files with 97 additions and 3 deletions
|
|
@ -1696,6 +1696,8 @@ if (onnxruntime_USE_ROCM)
|
|||
device_gemm_instance
|
||||
device_gemm_add_fastgelu_instance
|
||||
device_gemm_fastgelu_instance
|
||||
device_gemm_splitk_instance
|
||||
device_gemm_streamk_instance
|
||||
device_batched_gemm_instance
|
||||
device_softmax_instance
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@
|
|||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
|
|
@ -50,9 +52,8 @@ auto GetCKGemmTypeStringAndOps() {
|
|||
auto ck_gemm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmParams<T>* params) -> Status {
|
||||
auto one = ToHipType<T>::FromFloat(1.0f);
|
||||
auto zero = ToHipType<T>::FromFloat(0.0f);
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
params->alpha != one || params->beta != zero,
|
||||
impl->GetTypeString(), " only supports alpha == 1 and beta == 0", params->Signature());
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->alpha != one || params->beta != zero,
|
||||
impl->GetTypeString(), " only supports alpha == 1 and beta == 0");
|
||||
|
||||
auto nop = Nop{};
|
||||
auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c,
|
||||
|
|
@ -69,6 +70,80 @@ auto GetCKGemmTypeStringAndOps() {
|
|||
return ret;
|
||||
}
|
||||
|
||||
template <typename T, typename ALayout, typename BLayout>
|
||||
auto GetCKStreamKGemmTypeStringAndOps() {
|
||||
using CKDataType = typename CKDataTypeAdaptor<T>::type;
|
||||
using DeviceGemm = ck::tensor_operation::device::DeviceGemmStreamK<
|
||||
ALayout, BLayout, Row,
|
||||
CKDataType, CKDataType, CKDataType,
|
||||
Nop, Nop, Nop>;
|
||||
using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceGemm>;
|
||||
|
||||
std::vector<std::pair<std::string, Op<GemmParams<T>>>> ret;
|
||||
for (auto&& impl : InstanceFactory::GetInstances()) {
|
||||
auto type_string = impl->GetTypeString();
|
||||
auto invoker = impl->MakeInvokerPointer();
|
||||
auto ck_gemm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmParams<T>* params) -> Status {
|
||||
auto one = ToHipType<T>::FromFloat(1.0f);
|
||||
auto zero = ToHipType<T>::FromFloat(0.0f);
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->alpha != one || params->beta != zero,
|
||||
impl->GetTypeString(), " only supports alpha == 1 and beta == 0");
|
||||
|
||||
auto nop = Nop{};
|
||||
auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c,
|
||||
params->m, params->n, params->k,
|
||||
params->lda, params->ldb, params->ldc,
|
||||
nop, nop, nop);
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
|
||||
impl->GetTypeString(), " does not support ", params->Signature());
|
||||
invoker->Run(arg.get(), StreamConfig{params->StreamHandle()});
|
||||
return Status::OK();
|
||||
};
|
||||
ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op)));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T, typename ALayout, typename BLayout>
|
||||
auto GetCKSplitKGemmTypeStringAndOps() {
|
||||
using CKDataType = typename CKDataTypeAdaptor<T>::type;
|
||||
using DeviceGemm = ck::tensor_operation::device::DeviceGemmSplitK<
|
||||
ALayout, BLayout, Row,
|
||||
CKDataType, CKDataType, CKDataType,
|
||||
Nop, Nop, Nop>;
|
||||
using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceGemm>;
|
||||
|
||||
std::vector<std::pair<std::string, Op<GemmParams<T>>>> ret;
|
||||
for (auto num_split : {4, 16, 64}) {
|
||||
auto instances = InstanceFactory::GetInstances();
|
||||
for (auto&& impl : instances) {
|
||||
auto type_string = impl->GetTypeString() + "_SplitK" + std::to_string(num_split);
|
||||
auto invoker = impl->MakeInvokerPointer();
|
||||
auto ck_gemm_op = [num_split, impl = std::move(impl), invoker = std::move(invoker)](const GemmParams<T>* params) -> Status {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
params->k < 128 * num_split, "k=", params->k, " is too small, it makes no sense to use this split-k gemm.");
|
||||
|
||||
auto one = ToHipType<T>::FromFloat(1.0f);
|
||||
auto zero = ToHipType<T>::FromFloat(0.0f);
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->alpha != one || params->beta != zero,
|
||||
impl->GetTypeString(), " only supports alpha == 1 and beta == 0");
|
||||
|
||||
auto nop = Nop{};
|
||||
auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c,
|
||||
params->m, params->n, params->k,
|
||||
params->lda, params->ldb, params->ldc,
|
||||
nop, nop, nop, num_split);
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
|
||||
impl->GetTypeString(), " does not support ", params->Signature());
|
||||
invoker->Run(arg.get(), StreamConfig{params->StreamHandle()});
|
||||
return Status::OK();
|
||||
};
|
||||
ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op)));
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T, typename ALayout, typename BLayout>
|
||||
auto GetCKStridedBatchedGemmTypeStringAndOps() {
|
||||
using CKDataType = typename CKDataTypeAdaptor<T>::type;
|
||||
|
|
|
|||
|
|
@ -58,6 +58,15 @@ class GemmTunableOp : public TunableOp<GemmParams<T>> {
|
|||
ORT_UNUSED_PARAMETER(_);
|
||||
this->RegisterOp(std::move(op));
|
||||
}
|
||||
|
||||
for (auto&& [_, op] : GetCKStreamKGemmTypeStringAndOps<T, ALayout, BLayout>()) {
|
||||
ORT_UNUSED_PARAMETER(_);
|
||||
this->RegisterOp(std::move(op));
|
||||
}
|
||||
for (auto&& [_, op] : GetCKSplitKGemmTypeStringAndOps<T, ALayout, BLayout>()) {
|
||||
ORT_UNUSED_PARAMETER(_);
|
||||
this->RegisterOp(std::move(op));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -60,6 +60,14 @@ class CKGemm : public IKernelExplorer {
|
|||
type_strings_.emplace_back(std::move(type_string));
|
||||
ops_.emplace_back(std::move(op));
|
||||
}
|
||||
for (auto&& [type_string, op] : GetCKStreamKGemmTypeStringAndOps<T, ALayout, BLayout>()) {
|
||||
type_strings_.emplace_back(std::move(type_string));
|
||||
ops_.emplace_back(std::move(op));
|
||||
}
|
||||
for (auto&& [type_string, op] : GetCKSplitKGemmTypeStringAndOps<T, ALayout, BLayout>()) {
|
||||
type_strings_.emplace_back(std::move(type_string));
|
||||
ops_.emplace_back(std::move(op));
|
||||
}
|
||||
ORT_ENFORCE(!ops_.empty());
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue