From a24b41d92e2aaa51087f52dbf57d2d7f6e7dfd82 Mon Sep 17 00:00:00 2001 From: cloudhan Date: Fri, 23 Sep 2022 11:10:19 +0800 Subject: [PATCH] Move all TunableOp related falicilities to EP level directory (#12857) Some Ops in EP directory instead of contrib_ops directory will require TunableOp. We will also need to add EP level session tuning options for it. So move those code all at once. Also remove duplicated utility functions. --- cmake/onnxruntime_kernel_explorer.cmake | 3 +- .../rocm/bert/fast_gelu_impl_kernel.h | 5 +- .../rocm/bert/fast_gelu_tunable_op.h | 9 +-- .../rocm/bert/skip_layer_norm_op.h | 8 ++- onnxruntime/contrib_ops/rocm/bert/util.cc | 44 ------------ onnxruntime/contrib_ops/rocm/bert/util.h | 46 ------------ .../core/providers/rocm/tunable/gemm_ck.cuh | 72 +++++++++++++++++++ .../core/providers/rocm/tunable/gemm_common.h | 60 ++++++++++++++++ .../providers/rocm/tunable/gemm_rocblas.h | 59 +++++++++++++++ .../providers/rocm/tunable/gemm_tunable.cuh | 35 +++++++++ .../providers/rocm/tunable/tunable.h} | 9 ++- .../core/providers/rocm/tunable/util.cc | 41 +++++++++++ .../core/providers/rocm/tunable/util.h | 30 ++++++++ .../tools/kernel_explorer/device_array.h | 2 +- .../kernel_explorer_interface.h | 5 +- .../tools/kernel_explorer/kernels/gemm.cc | 8 ++- .../tools/kernel_explorer/kernels/gemm.h | 45 ------------ .../tools/kernel_explorer/kernels/gemm_ck.cc | 11 ++- .../tools/kernel_explorer/kernels/gemm_ck.h | 57 --------------- .../kernel_explorer/kernels/gemm_rocblas.cc | 8 ++- .../kernel_explorer/kernels/gemm_rocblas.h | 44 ------------ .../kernel_explorer/kernels/gemm_tunable.cc | 23 ++---- .../kernel_explorer/kernels/vector_add.cc | 6 +- .../kernels/vector_add_kernel.h | 12 ++-- 24 files changed, 361 insertions(+), 281 deletions(-) delete mode 100644 onnxruntime/contrib_ops/rocm/bert/util.cc delete mode 100644 onnxruntime/contrib_ops/rocm/bert/util.h create mode 100644 onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh create mode 100644 onnxruntime/core/providers/rocm/tunable/gemm_common.h create mode 100644 onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h create mode 100644 onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh rename onnxruntime/{contrib_ops/rocm/bert/tunable_op.h => core/providers/rocm/tunable/tunable.h} (95%) create mode 100644 onnxruntime/core/providers/rocm/tunable/util.cc create mode 100644 onnxruntime/core/providers/rocm/tunable/util.h diff --git a/cmake/onnxruntime_kernel_explorer.cmake b/cmake/onnxruntime_kernel_explorer.cmake index ad91139297..918c7ae834 100644 --- a/cmake/onnxruntime_kernel_explorer.cmake +++ b/cmake/onnxruntime_kernel_explorer.cmake @@ -23,8 +23,7 @@ file(GLOB kernel_explorer_kernel_srcs CONFIGURE_DEPENDS "${KERNEL_EXPLORER_ROOT} onnxruntime_add_shared_library_module(kernel_explorer ${kernel_explorer_srcs} - ${kernel_explorer_kernel_srcs} - ${BERT_DIR}/util.cc) + ${kernel_explorer_kernel_srcs}) set_target_properties(kernel_explorer PROPERTIES PREFIX "_") target_include_directories(kernel_explorer PUBLIC $ diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl_kernel.h b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl_kernel.h index 79d5eb0122..6c2b981f60 100644 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl_kernel.h +++ b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl_kernel.h @@ -3,7 +3,8 @@ #pragma once -#include "contrib_ops/rocm/bert/util.h" +#include "core/providers/rocm/tunable/util.h" +#include "core/providers/rocm/cu_inc/common.cuh" namespace onnxruntime { namespace contrib { @@ -34,7 +35,7 @@ __global__ void FastGeluKernel(int input_length, int bias_length, const T* input template __global__ void FastGeluKernelVec(int input_length, int bias_length, const T* input, const T* bias, T* output) { - using VecT = AlignedVector; + using VecT = onnxruntime::rocm::aligned_vector; const T a = T(0.5f); const T b = T(0.7978845608028654f); const T c = T(0.035677408136300125f); diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_tunable_op.h b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_tunable_op.h index ff019e9c4d..0c151778be 100644 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_tunable_op.h +++ b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_tunable_op.h @@ -7,7 +7,8 @@ #include #include #include -#include "contrib_ops/rocm/bert/tunable_op.h" +#include "core/providers/rocm/tunable/tunable.h" +#include "core/providers/rocm/cu_inc/common.cuh" #include "contrib_ops/rocm/bert/fast_gelu_impl_kernel.h" namespace onnxruntime { @@ -15,7 +16,7 @@ namespace contrib { namespace rocm { template -struct FastGeluParams : OpParams { +struct FastGeluParams : onnxruntime::rocm::tunable::OpParams { FastGeluParams(hipStream_t stream, const T* input, const T* bias, T* output, int input_length, int bias_length) : OpParams(stream), input(input), bias(bias), output(output), input_length(input_length), bias_length(bias_length) {} @@ -39,7 +40,7 @@ Status FastGeluOp(const FastGeluParams* params) { (params->bias_length == 0 && params->input_length % VecSize == 0))); hipLaunchKernelGGL((FastGeluKernelVec), - dim3(CeilingDivision(params->input_length, ThreadsPerBlock * VecSize)), + dim3(onnxruntime::rocm::CeilDiv(params->input_length, ThreadsPerBlock * VecSize)), dim3(ThreadsPerBlock), 0, params->stream, params->input_length, params->bias_length, params->input, params->bias, params->output); @@ -56,7 +57,7 @@ Status FastGeluOp(const FastGeluParams* params) { this->ops_.emplace_back(FastGeluOp); template -class FastGeluTunableOp : public TunableOp> { +class FastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp> { public: FastGeluTunableOp() { ADD_OP(64); diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_op.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_op.h index 36b99243e9..77fede8431 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_op.h +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_op.h @@ -9,14 +9,15 @@ #include #include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h" -#include "contrib_ops/rocm/bert/tunable_op.h" +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/tunable/tunable.h" namespace onnxruntime { namespace contrib { namespace rocm { template -struct SkipLayerNormParams : OpParams { +struct SkipLayerNormParams : onnxruntime::rocm::tunable::OpParams { SkipLayerNormParams(hipStream_t stream, T* output, const T* input, const T* skip, const T* gamma, const T* beta, const T* bias, float epsilon, const int ld, @@ -42,9 +43,10 @@ struct SkipLayerNormParams : OpParams { template Status SkipLayerNormSmallOp(const SkipLayerNormParams* params) { + using onnxruntime::rocm::CeilDiv; TUNABLE_OP_RETURN_UNSUPPOTED_ARGUMENT_IF( !((params->ld <= 1024 && params->ld % VecSize == 0 && params->ld == ThreadsPerBlock * VecSize))); - SkipLayerNormKernelSmall<<element_count, params->ld)), + SkipLayerNormKernelSmall<<element_count, params->ld)), dim3(ThreadsPerBlock), 0, params->stream>>>( params->ld, params->input, params->skip, diff --git a/onnxruntime/contrib_ops/rocm/bert/util.cc b/onnxruntime/contrib_ops/rocm/bert/util.cc deleted file mode 100644 index 04874d1204..0000000000 --- a/onnxruntime/contrib_ops/rocm/bert/util.cc +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/util.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -int CeilingDivision(int n, int m) { - int r = (n - 1) / m + 1; - return r; -} - -Timer::Timer(hipStream_t stream): stream_(stream) { - HIP_CHECK(hipEventCreate(&start_)); - HIP_CHECK(hipEventCreate(&end_)); -} - -void Timer::Start() { - HIP_CHECK(hipDeviceSynchronize()); - HIP_CHECK(hipEventRecord(start_, stream_)); -} - -void Timer::End() { - HIP_CHECK(hipEventRecord(end_, stream_)); - HIP_CHECK(hipEventSynchronize(end_)); -} - -float Timer::Duration() { - float time; - // time is in ms with a resolution of 1 us - HIP_CHECK(hipEventElapsedTime(&time, start_, end_)); - return time; -} - -Timer::~Timer() { - HIP_CHECK(hipEventDestroy(start_)); - HIP_CHECK(hipEventDestroy(end_)); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/util.h b/onnxruntime/contrib_ops/rocm/bert/util.h deleted file mode 100644 index 16fe3d538f..0000000000 --- a/onnxruntime/contrib_ops/rocm/bert/util.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#define HIP_CHECK(expr) \ - do { \ - auto status = expr; \ - if (status != hipSuccess) { \ - std::cerr << hipGetErrorName(status); \ - std::abort(); \ - } \ - } while (0) - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -int CeilingDivision(int n, int m); - -template -struct alignas(sizeof(T) * VecSize) AlignedVector { - T val[VecSize]; -}; - -class Timer { - public: - explicit Timer(hipStream_t stream); - void Start(); - void End(); - float Duration(); - ~Timer(); - - private: - hipStream_t stream_; - hipEvent_t start_; - hipEvent_t end_; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh b/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh new file mode 100644 index 0000000000..f3d3bad0ae --- /dev/null +++ b/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "core/providers/rocm/tunable/gemm_common.h" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { +namespace internal { + +template +struct DataTypeAdaptor { + using type = T; +}; + +template <> +struct DataTypeAdaptor { + using type = ck::half_t; +}; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using Nop = ck::tensor_operation::element_wise::PassThrough; + +template +auto GetCKGemmTypeStringAndOps() { + using CKDataType = typename DataTypeAdaptor::type; + using DeviceGemm = ck::tensor_operation::device::DeviceGemm< + ALayout, BLayout, Row, + CKDataType, CKDataType, CKDataType, + Nop, Nop, Nop>; + using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; + + std::vector>>> ret; + for (auto&& impl : InstanceFactory::GetInstances()) { + auto type_string = impl->GetTypeString(); + auto invoker = impl->MakeInvokerPointer(); + auto ck_gemm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmParams* params) -> Status { + auto nop = Nop{}; + auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, + params->m, params->n, params->k, + params->lda, params->ldb, params->ldc, + nop, nop, nop); + TUNABLE_OP_RETURN_UNSUPPOTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support ", params->Signature()); + invoker->Run(arg.get(), StreamConfig{params->stream}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op))); + } + return ret; +} + +} // namespace internal +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_common.h b/onnxruntime/core/providers/rocm/tunable/gemm_common.h new file mode 100644 index 0000000000..49b61627a5 --- /dev/null +++ b/onnxruntime/core/providers/rocm/tunable/gemm_common.h @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/tunable/tunable.h" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { + +enum class BlasOp { + N = 0, + T = 1, + NonTrans = 0, + Trans = 1, +}; + +inline std::string BlasOpToString(BlasOp op) { + switch (op) { + case BlasOp::N: + return "N"; + case BlasOp::T: + return "T"; + } +} + +// We don't assume the implementation is row-majored or column-majored. But for testing convenience, we assume all +// our wrappers have row-majored convention, since it is the native layout to numpy and pytorch. +template +struct GemmParams : tunable::OpParams { + std::string Signature() const override { + return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k); + } + + rocblas_handle handle; + BlasOp opa; + BlasOp opb; + int64_t m; + int64_t n; + int64_t k; + T alpha; + const T* a; + int64_t lda; + const T* b; + int64_t ldb; + T beta; + T* c; + int64_t ldc; +}; + +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h b/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h new file mode 100644 index 0000000000..3c0b07da0c --- /dev/null +++ b/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/rocm/shared_inc/fpgeneric.h" +#include "core/providers/rocm/tunable/gemm_common.h" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { +namespace internal { + +// RAII style guard to set stream and restore original stream for rocblas_handle +class RocblasHandleStreamGuard { + public: + RocblasHandleStreamGuard(rocblas_handle handle, hipStream_t stream) : handle_{handle} { + ROCBLAS_CALL_THROW(rocblas_get_stream(handle_, &original_stream_)); + ROCBLAS_CALL_THROW(rocblas_set_stream(handle_, stream)); + } + + ~RocblasHandleStreamGuard() { + ROCBLAS_CALL_THROW(rocblas_set_stream(handle_, original_stream_)); + } + + ORT_DISALLOW_COPY_AND_ASSIGNMENT(RocblasHandleStreamGuard); + + private: + rocblas_handle handle_; + hipStream_t original_stream_; +}; + +template +Status RocBlasGemmOp(const GemmParams* params) { + RocblasHandleStreamGuard guard(params->handle, params->stream); + // NOTE: rocblas assumes the storage is column-majored, swapping A and B makes it have the same interface + // as those with row-majored convention. That is, if you treat the storage as row-majored but view the matrices as + // transposed, then by using the property Transpose(A*B) = Tranpose(B)*Transpose(A), the correctness is obvious. + auto status = rocblasGemmHelper( + params->handle, + params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, + params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, + params->n, params->m, params->k, + &(params->alpha), + params->b, params->ldb, + params->a, params->lda, + &(params->beta), + params->c, params->ldc); + ORT_RETURN_IF(status != rocblas_status_success, rocblas_status_to_string(status)); + return Status::OK(); +} + +} // namespace internal +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh b/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh new file mode 100644 index 0000000000..dc185c793d --- /dev/null +++ b/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/providers/rocm/tunable/gemm_ck.cuh" +#include "core/providers/rocm/tunable/gemm_rocblas.h" +#include "core/providers/rocm/tunable/gemm_common.h" +#include "core/providers/rocm/tunable/tunable.h" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { +namespace internal { + +template +class GemmTunableOp : public tunable::TunableOp> { + public: + GemmTunableOp() { + this->ops_.emplace_back(RocBlasGemmOp); + for (auto&& [_, op] : GetCKGemmTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->ops_.emplace_back(std::move(op)); + } + } +}; + +} // namespace internal +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/tunable_op.h b/onnxruntime/core/providers/rocm/tunable/tunable.h similarity index 95% rename from onnxruntime/contrib_ops/rocm/bert/tunable_op.h rename to onnxruntime/core/providers/rocm/tunable/tunable.h index 39bec11300..e033620ddb 100644 --- a/onnxruntime/contrib_ops/rocm/bert/tunable_op.h +++ b/onnxruntime/core/providers/rocm/tunable/tunable.h @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -16,11 +17,12 @@ #include #include "core/common/common.h" -#include "contrib_ops/rocm/bert/util.h" +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/tunable/util.h" namespace onnxruntime { -namespace contrib { namespace rocm { +namespace tunable { struct OpParams { OpParams() : stream{} {} @@ -150,6 +152,7 @@ class TunableOp { } } ORT_ENFORCE(id >= 0, "Cannot found viable op"); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); return id; } @@ -166,6 +169,6 @@ class TunableOp { bool tuning_{false}; }; +} // namespace tunable } // namespace rocm -} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/util.cc b/onnxruntime/core/providers/rocm/tunable/util.cc new file mode 100644 index 0000000000..3802488d60 --- /dev/null +++ b/onnxruntime/core/providers/rocm/tunable/util.cc @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/rocm/tunable/util.h" + +#include "core/providers/rocm/shared_inc/rocm_call.h" + +namespace onnxruntime { +namespace rocm { +namespace tunable { + +Timer::Timer(hipStream_t stream) : stream_(stream) { + HIP_CALL_THROW(hipEventCreate(&start_)); + HIP_CALL_THROW(hipEventCreate(&end_)); +} + +void Timer::Start() { + HIP_CALL_THROW(hipDeviceSynchronize()); + HIP_CALL_THROW(hipEventRecord(start_, stream_)); +} + +void Timer::End() { + HIP_CALL_THROW(hipEventRecord(end_, stream_)); + HIP_CALL_THROW(hipEventSynchronize(end_)); +} + +float Timer::Duration() { + float time; + // time is in ms with a resolution of 1 us + HIP_CALL_THROW(hipEventElapsedTime(&time, start_, end_)); + return time; +} + +Timer::~Timer() { + HIP_CALL_THROW(hipEventDestroy(start_)); + HIP_CALL_THROW(hipEventDestroy(end_)); +} + +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/util.h b/onnxruntime/core/providers/rocm/tunable/util.h new file mode 100644 index 0000000000..85c2976f74 --- /dev/null +++ b/onnxruntime/core/providers/rocm/tunable/util.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { +namespace rocm { +namespace tunable { + +class Timer { + public: + explicit Timer(hipStream_t stream); + void Start(); + void End(); + float Duration(); + ~Timer(); + + private: + hipStream_t stream_; + hipEvent_t start_; + hipEvent_t end_; +}; + +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/device_array.h b/onnxruntime/python/tools/kernel_explorer/device_array.h index c27788acd8..366ad40e5e 100644 --- a/onnxruntime/python/tools/kernel_explorer/device_array.h +++ b/onnxruntime/python/tools/kernel_explorer/device_array.h @@ -6,7 +6,7 @@ #include #include #include "core/providers/rocm/rocm_common.h" -#include "contrib_ops/rocm/bert/util.h" +#include "core/providers/rocm/tunable/util.h" namespace py = pybind11; diff --git a/onnxruntime/python/tools/kernel_explorer/kernel_explorer_interface.h b/onnxruntime/python/tools/kernel_explorer/kernel_explorer_interface.h index 32285ab5b8..96dd3f135d 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernel_explorer_interface.h +++ b/onnxruntime/python/tools/kernel_explorer/kernel_explorer_interface.h @@ -4,9 +4,10 @@ #pragma once #include -#include "contrib_ops/rocm/bert/util.h" +#include "core/providers/rocm/tunable/tunable.h" +#include "core/providers/rocm/tunable/util.h" -using onnxruntime::contrib::rocm::Timer; +using onnxruntime::rocm::tunable::Timer; /// Wrapping around Op and TunableOp class IKernelExplorer { diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm.cc b/onnxruntime/python/tools/kernel_explorer/kernels/gemm.cc index 9abae017de..8bc24f7a5c 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm.cc +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm.cc @@ -2,12 +2,16 @@ // Licensed under the MIT License. #include "python/tools/kernel_explorer/kernels/gemm.h" + +#include +#include + +#include "core/providers/rocm/tunable/gemm_common.h" #include "python/tools/kernel_explorer/kernels/gemm_ck.h" #include "python/tools/kernel_explorer/kernels/gemm_rocblas.h" #include "python/tools/kernel_explorer/kernels/gemm_tunable.h" -#include -#include +using BlasOp = onnxruntime::rocm::tunable::blas::BlasOp; namespace py = pybind11; diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm.h b/onnxruntime/python/tools/kernel_explorer/kernels/gemm.h index eaf81e85f1..3c6d0a5630 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm.h +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm.h @@ -3,57 +3,12 @@ #pragma once -#include -#include - #include -#include "contrib_ops/rocm/bert/tunable_op.h" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - namespace py = pybind11; namespace onnxruntime { -enum class BlasOp { - N, - T, -}; - -inline std::string BlasOpToString(BlasOp op) { - switch (op) { - case BlasOp::N: - return "N"; - case BlasOp::T: - return "T"; - } -} - -// We don't assume the implementation is row-majored or column-majored. But for testing convenience, we assume all -// our wrappers have row-majored convention, since it is the native layout to numpy and pytorch. -template -struct GemmParams : contrib::rocm::OpParams { - std::string Signature() const override { - return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k); - } - - rocblas_handle handle; - BlasOp opa; - BlasOp opb; - int64_t m; - int64_t n; - int64_t k; - T alpha; - T* a; - int64_t lda; - T* b; - int64_t ldb; - T beta; - T* c; - int64_t ldc; -}; - void InitGemm(py::module mod); } // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_ck.cc b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_ck.cc index 188aa979bd..9623444998 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_ck.cc +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_ck.cc @@ -11,6 +11,15 @@ #include #include +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/tunable/gemm_common.h" +#include "core/providers/rocm/tunable/gemm_ck.cuh" +#include "python/tools/kernel_explorer/device_array.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" + +using namespace onnxruntime::rocm::tunable::blas; +using namespace onnxruntime::rocm::tunable::blas::internal; + namespace py = pybind11; namespace onnxruntime { @@ -76,7 +85,7 @@ class CKGemm : public IKernelExplorer { private: using ParamsT = GemmParams; - using OpT = contrib::rocm::Op; + using OpT = rocm::tunable::Op; ParamsT params_; std::vector ops_; std::vector type_strings_; diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_ck.h b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_ck.h index 7058df89b0..ba0dfdcb64 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_ck.h +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_ck.h @@ -5,67 +5,10 @@ #include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "python/tools/kernel_explorer/kernels/gemm.h" - namespace py = pybind11; namespace onnxruntime { -template -struct DataTypeAdaptor { - using type = T; -}; - -template <> -struct DataTypeAdaptor { - using type = ck::half_t; -}; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using Nop = ck::tensor_operation::element_wise::PassThrough; - -// to be moved to onnxruntime once we have a monolithicly tunable gemm wrapper and it is enabled for onnxruntime -template -auto GetCKGemmTypeStringAndOps() { - using CKDataType = typename DataTypeAdaptor::type; - using DeviceGemm = ck::tensor_operation::device::DeviceGemm< - ALayout, BLayout, Row, - CKDataType, CKDataType, CKDataType, - Nop, Nop, Nop>; - using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; - - std::vector>>> ret; - for (auto&& impl : InstanceFactory::GetInstances()) { - auto type_string = impl->GetTypeString(); - auto invoker = impl->MakeInvokerPointer(); - auto ck_gemm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmParams* params) -> Status { - auto nop = Nop{}; - auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, - params->m, params->n, params->k, - params->lda, params->ldb, params->ldc, - nop, nop, nop); - TUNABLE_OP_RETURN_UNSUPPOTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); - invoker->Run(arg.get(), StreamConfig{params->stream}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op))); - } - return ret; -} - void InitComposableKernelGemm(py::module mod); } // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_rocblas.cc b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_rocblas.cc index f2da8513d7..4d354d4d5c 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_rocblas.cc +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_rocblas.cc @@ -10,7 +10,13 @@ #include #include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/tunable/gemm_common.h" +#include "core/providers/rocm/tunable/gemm_rocblas.h" #include "python/tools/kernel_explorer/device_array.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" + +using namespace onnxruntime::rocm::tunable::blas; +using namespace onnxruntime::rocm::tunable::blas::internal; namespace py = pybind11; @@ -65,7 +71,7 @@ class RocBlasGemm : public IKernelExplorer { rocblas_handle rocblas_handle_; using ParamsT = GemmParams; - using OpT = contrib::rocm::Op; + using OpT = rocm::tunable::Op; ParamsT params_{}; OpT op_{RocBlasGemmOp}; diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_rocblas.h b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_rocblas.h index 4c4518d4d2..2d49783821 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_rocblas.h +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_rocblas.h @@ -5,54 +5,10 @@ #include -#include "core/common/common.h" -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "python/tools/kernel_explorer/kernels/gemm.h" - namespace py = pybind11; namespace onnxruntime { -// RAII style guard to set stream and restore original stream for rocblas_handle -class RocblasHandleStreamGuard { - public: - RocblasHandleStreamGuard(rocblas_handle handle, hipStream_t stream) : handle_{handle} { - ROCBLAS_CALL_THROW(rocblas_get_stream(handle_, &original_stream_)); - ROCBLAS_CALL_THROW(rocblas_set_stream(handle_, stream)); - } - - ~RocblasHandleStreamGuard() { - ROCBLAS_CALL_THROW(rocblas_set_stream(handle_, original_stream_)); - } - - ORT_DISALLOW_COPY_AND_ASSIGNMENT(RocblasHandleStreamGuard); - - private: - rocblas_handle handle_; - hipStream_t original_stream_; -}; - -// to be moved to onnxruntime once we have a monolithicly tunable gemm wrapper and it is enabled for onnxruntime -template -Status RocBlasGemmOp(const GemmParams* params) { - RocblasHandleStreamGuard guard(params->handle, params->stream); - // NOTE: rocblas assumes the storage is column-majored, swapping A and B makes it have the same interface - // as those with row-majored convention. That is, if you treat the storage as row-majored but view the matrices as - // transposed, then by using the property Transpose(A*B) = Tranpose(B)*Transpose(A), the correctness is obvious. - auto status = rocblasGemmHelper( - params->handle, - params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->n, params->m, params->k, - &(params->alpha), - params->b, params->ldb, - params->a, params->lda, - &(params->beta), - params->c, params->ldc); - ORT_RETURN_IF(status != rocblas_status_success, rocblas_status_to_string(status)); - return Status::OK(); -} - void InitRocBlasGemm(py::module mod); } // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_tunable.cc b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_tunable.cc index e93759a856..6a76e45e6d 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_tunable.cc +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_tunable.cc @@ -10,25 +10,16 @@ #include #include "core/providers/rocm/rocm_common.h" -#include "contrib_ops/rocm/bert/tunable_op.h" -#include "python/tools/kernel_explorer/kernels/gemm.h" -#include "python/tools/kernel_explorer/kernels/gemm_ck.h" -#include "python/tools/kernel_explorer/kernels/gemm_rocblas.h" +#include "core/providers/rocm/tunable/gemm_common.h" +#include "core/providers/rocm/tunable/gemm_tunable.cuh" +#include "python/tools/kernel_explorer/device_array.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" + +using namespace onnxruntime::rocm::tunable::blas; +using namespace onnxruntime::rocm::tunable::blas::internal; namespace onnxruntime { -template -class GemmTunableOp : public contrib::rocm::TunableOp> { - public: - GemmTunableOp() { - this->ops_.emplace_back(RocBlasGemmOp); - for (auto&& [_, op] : GetCKGemmTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->ops_.emplace_back(std::move(op)); - } - } -}; - template class GemmTunable : public IKernelExplorer { public: diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.cc b/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.cc index 0986a74bda..daf481b01c 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.cc +++ b/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.cc @@ -9,7 +9,7 @@ #include -#include "contrib_ops/rocm/bert/tunable_op.h" +#include "core/providers/rocm/tunable/tunable.h" #include "python/tools/kernel_explorer/kernel_explorer_interface.h" #include "python/tools/kernel_explorer/kernels/vector_add_kernel.h" @@ -24,7 +24,7 @@ namespace onnxruntime { // Extend the OpParams so that all specializations have the same parameter passing interface template -struct VectorAddParams : contrib::rocm::OpParams { +struct VectorAddParams : rocm::tunable::OpParams { std::string Signature() const override { return std::to_string(n); } T* x; @@ -54,7 +54,7 @@ Status VectorAddOp(const VectorAddParams* params) { // A Tunable VectorAddOp is a collection of non-tunable VectorAddOps implementations that have variable performance // characteristics. Those implementations may be put into a C++ container for tuner to select. template -class VectorAddTunableOp : public contrib::rocm::TunableOp> { +class VectorAddTunableOp : public rocm::tunable::TunableOp> { public: VectorAddTunableOp() { ADD_OP(64); diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_kernel.h b/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_kernel.h index 5d0b44a4d5..874c833ff3 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_kernel.h +++ b/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_kernel.h @@ -4,12 +4,14 @@ #pragma once #include + +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/tunable/util.h" #include "python/tools/kernel_explorer/device_array.h" #include "python/tools/kernel_explorer/kernel_explorer_interface.h" -#include "contrib_ops/rocm/bert/util.h" -using onnxruntime::contrib::rocm::CeilingDivision; -using onnxruntime::contrib::rocm::AlignedVector; +using onnxruntime::rocm::CeilDiv; +using onnxruntime::rocm::aligned_vector; namespace onnxruntime { @@ -18,7 +20,7 @@ __global__ void VectorAddKernel(const T* __restrict__ x, const T* __restrict__ y, T* __restrict__ z, int n) { int i = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x; - using LoadT = AlignedVector; + using LoadT = aligned_vector; if (VecSize * i + VecSize - 1 < n) { T x_vec[VecSize]; @@ -50,7 +52,7 @@ __global__ void VectorAddKernel(const T* __restrict__ x, template Status LaunchVectorAdd(hipStream_t stream, const T* x, const T* y, T* z, int n) { hipLaunchKernelGGL((VectorAddKernel), - dim3(CeilingDivision(n, ThreadsPerBlock*VecSize)), + dim3(CeilDiv(n, ThreadsPerBlock*VecSize)), dim3(ThreadsPerBlock), 0, stream, x, y, z, n);