diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 1907512847..94c907aa50 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -1696,6 +1696,8 @@ if (onnxruntime_USE_ROCM) device_gemm_instance device_gemm_add_fastgelu_instance device_gemm_fastgelu_instance + device_gemm_splitk_instance + device_gemm_streamk_instance device_batched_gemm_instance device_softmax_instance ) diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh b/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh index 3e6f1612f2..86d023886c 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh +++ b/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh @@ -13,6 +13,8 @@ #include "ck/ck.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" #include "ck/library/tensor_operation_instance/gpu/gemm.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp" @@ -50,9 +52,8 @@ auto GetCKGemmTypeStringAndOps() { auto ck_gemm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmParams* params) -> Status { auto one = ToHipType::FromFloat(1.0f); auto zero = ToHipType::FromFloat(0.0f); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->alpha != one || params->beta != zero, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0", params->Signature()); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->alpha != one || params->beta != zero, + impl->GetTypeString(), " only supports alpha == 1 and beta == 0"); auto nop = Nop{}; auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, @@ -69,6 +70,80 @@ auto GetCKGemmTypeStringAndOps() { return ret; } +template +auto GetCKStreamKGemmTypeStringAndOps() { + using CKDataType = typename CKDataTypeAdaptor::type; + using DeviceGemm = ck::tensor_operation::device::DeviceGemmStreamK< + ALayout, BLayout, Row, + CKDataType, CKDataType, CKDataType, + Nop, Nop, Nop>; + using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; + + std::vector>>> ret; + for (auto&& impl : InstanceFactory::GetInstances()) { + auto type_string = impl->GetTypeString(); + auto invoker = impl->MakeInvokerPointer(); + auto ck_gemm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmParams* params) -> Status { + auto one = ToHipType::FromFloat(1.0f); + auto zero = ToHipType::FromFloat(0.0f); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->alpha != one || params->beta != zero, + impl->GetTypeString(), " only supports alpha == 1 and beta == 0"); + + auto nop = Nop{}; + auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, + params->m, params->n, params->k, + params->lda, params->ldb, params->ldc, + nop, nop, nop); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support ", params->Signature()); + invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op))); + } + return ret; +} + +template +auto GetCKSplitKGemmTypeStringAndOps() { + using CKDataType = typename CKDataTypeAdaptor::type; + using DeviceGemm = ck::tensor_operation::device::DeviceGemmSplitK< + ALayout, BLayout, Row, + CKDataType, CKDataType, CKDataType, + Nop, Nop, Nop>; + using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; + + std::vector>>> ret; + for (auto num_split : {4, 16, 64}) { + auto instances = InstanceFactory::GetInstances(); + for (auto&& impl : instances) { + auto type_string = impl->GetTypeString() + "_SplitK" + std::to_string(num_split); + auto invoker = impl->MakeInvokerPointer(); + auto ck_gemm_op = [num_split, impl = std::move(impl), invoker = std::move(invoker)](const GemmParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->k < 128 * num_split, "k=", params->k, " is too small, it makes no sense to use this split-k gemm."); + + auto one = ToHipType::FromFloat(1.0f); + auto zero = ToHipType::FromFloat(0.0f); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->alpha != one || params->beta != zero, + impl->GetTypeString(), " only supports alpha == 1 and beta == 0"); + + auto nop = Nop{}; + auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, + params->m, params->n, params->k, + params->lda, params->ldb, params->ldc, + nop, nop, nop, num_split); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support ", params->Signature()); + invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op))); + } + } + return ret; +} + template auto GetCKStridedBatchedGemmTypeStringAndOps() { using CKDataType = typename CKDataTypeAdaptor::type; diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh b/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh index d39fa3e662..dbef772f8c 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh +++ b/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh @@ -58,6 +58,15 @@ class GemmTunableOp : public TunableOp> { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } + + for (auto&& [_, op] : GetCKStreamKGemmTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } + for (auto&& [_, op] : GetCKSplitKGemmTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } #endif } diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.cu index 2420353714..6707892cca 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.cu @@ -60,6 +60,14 @@ class CKGemm : public IKernelExplorer { type_strings_.emplace_back(std::move(type_string)); ops_.emplace_back(std::move(op)); } + for (auto&& [type_string, op] : GetCKStreamKGemmTypeStringAndOps()) { + type_strings_.emplace_back(std::move(type_string)); + ops_.emplace_back(std::move(op)); + } + for (auto&& [type_string, op] : GetCKSplitKGemmTypeStringAndOps()) { + type_strings_.emplace_back(std::move(type_string)); + ops_.emplace_back(std::move(op)); + } ORT_ENFORCE(!ops_.empty()); }