Move all TunableOp related falicilities to EP level directory (#12857)

Some Ops in EP directory instead of contrib_ops directory will
require TunableOp. We will also need to add EP level session tuning
options for it. So move those code all at once.

Also remove duplicated utility functions.
This commit is contained in:
cloudhan 2022-09-23 11:10:19 +08:00 committed by GitHub
parent 8fb3f05cd6
commit a24b41d92e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 361 additions and 281 deletions

View file

@ -23,8 +23,7 @@ file(GLOB kernel_explorer_kernel_srcs CONFIGURE_DEPENDS "${KERNEL_EXPLORER_ROOT}
onnxruntime_add_shared_library_module(kernel_explorer
${kernel_explorer_srcs}
${kernel_explorer_kernel_srcs}
${BERT_DIR}/util.cc)
${kernel_explorer_kernel_srcs})
set_target_properties(kernel_explorer PROPERTIES PREFIX "_")
target_include_directories(kernel_explorer PUBLIC
$<TARGET_PROPERTY:onnxruntime_pybind11_state,INCLUDE_DIRECTORIES>

View file

@ -3,7 +3,8 @@
#pragma once
#include "contrib_ops/rocm/bert/util.h"
#include "core/providers/rocm/tunable/util.h"
#include "core/providers/rocm/cu_inc/common.cuh"
namespace onnxruntime {
namespace contrib {
@ -34,7 +35,7 @@ __global__ void FastGeluKernel(int input_length, int bias_length, const T* input
template <typename T, unsigned TPB, int ILP>
__global__ void FastGeluKernelVec(int input_length, int bias_length, const T* input, const T* bias,
T* output) {
using VecT = AlignedVector<T, ILP>;
using VecT = onnxruntime::rocm::aligned_vector<T, ILP>;
const T a = T(0.5f);
const T b = T(0.7978845608028654f);
const T c = T(0.035677408136300125f);

View file

@ -7,7 +7,8 @@
#include <memory>
#include <string>
#include <vector>
#include "contrib_ops/rocm/bert/tunable_op.h"
#include "core/providers/rocm/tunable/tunable.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "contrib_ops/rocm/bert/fast_gelu_impl_kernel.h"
namespace onnxruntime {
@ -15,7 +16,7 @@ namespace contrib {
namespace rocm {
template <typename T>
struct FastGeluParams : OpParams {
struct FastGeluParams : onnxruntime::rocm::tunable::OpParams {
FastGeluParams(hipStream_t stream, const T* input, const T* bias, T* output, int input_length, int bias_length) :
OpParams(stream), input(input), bias(bias), output(output), input_length(input_length), bias_length(bias_length) {}
@ -39,7 +40,7 @@ Status FastGeluOp(const FastGeluParams<T>* params) {
(params->bias_length == 0 && params->input_length % VecSize == 0)));
hipLaunchKernelGGL((FastGeluKernelVec<T, ThreadsPerBlock, VecSize>),
dim3(CeilingDivision(params->input_length, ThreadsPerBlock * VecSize)),
dim3(onnxruntime::rocm::CeilDiv(params->input_length, ThreadsPerBlock * VecSize)),
dim3(ThreadsPerBlock),
0, params->stream,
params->input_length, params->bias_length, params->input, params->bias, params->output);
@ -56,7 +57,7 @@ Status FastGeluOp(const FastGeluParams<T>* params) {
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 16>);
template <typename T>
class FastGeluTunableOp : public TunableOp<FastGeluParams<T>> {
class FastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp<FastGeluParams<T>> {
public:
FastGeluTunableOp() {
ADD_OP(64);

View file

@ -9,14 +9,15 @@
#include <vector>
#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h"
#include "contrib_ops/rocm/bert/tunable_op.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/tunable/tunable.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
struct SkipLayerNormParams : OpParams {
struct SkipLayerNormParams : onnxruntime::rocm::tunable::OpParams {
SkipLayerNormParams(hipStream_t stream, T* output, const T* input,
const T* skip, const T* gamma, const T* beta,
const T* bias, float epsilon, const int ld,
@ -42,9 +43,10 @@ struct SkipLayerNormParams : OpParams {
template <typename T, int ThreadsPerBlock, int VecSize>
Status SkipLayerNormSmallOp(const SkipLayerNormParams<T>* params) {
using onnxruntime::rocm::CeilDiv;
TUNABLE_OP_RETURN_UNSUPPOTED_ARGUMENT_IF(
!((params->ld <= 1024 && params->ld % VecSize == 0 && params->ld == ThreadsPerBlock * VecSize)));
SkipLayerNormKernelSmall<T, ThreadsPerBlock, VecSize><<<dim3(CeilingDivision(params->element_count, params->ld)),
SkipLayerNormKernelSmall<T, ThreadsPerBlock, VecSize><<<dim3(CeilDiv(params->element_count, params->ld)),
dim3(ThreadsPerBlock),
0, params->stream>>>(
params->ld, params->input, params->skip,

View file

@ -1,44 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/rocm/bert/util.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
int CeilingDivision(int n, int m) {
int r = (n - 1) / m + 1;
return r;
}
Timer::Timer(hipStream_t stream): stream_(stream) {
HIP_CHECK(hipEventCreate(&start_));
HIP_CHECK(hipEventCreate(&end_));
}
void Timer::Start() {
HIP_CHECK(hipDeviceSynchronize());
HIP_CHECK(hipEventRecord(start_, stream_));
}
void Timer::End() {
HIP_CHECK(hipEventRecord(end_, stream_));
HIP_CHECK(hipEventSynchronize(end_));
}
float Timer::Duration() {
float time;
// time is in ms with a resolution of 1 us
HIP_CHECK(hipEventElapsedTime(&time, start_, end_));
return time;
}
Timer::~Timer() {
HIP_CHECK(hipEventDestroy(start_));
HIP_CHECK(hipEventDestroy(end_));
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -1,46 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <hip/hip_runtime.h>
#include <cstdlib>
#include <iostream>
#define HIP_CHECK(expr) \
do { \
auto status = expr; \
if (status != hipSuccess) { \
std::cerr << hipGetErrorName(status); \
std::abort(); \
} \
} while (0)
namespace onnxruntime {
namespace contrib {
namespace rocm {
int CeilingDivision(int n, int m);
template<typename T, int VecSize>
struct alignas(sizeof(T) * VecSize) AlignedVector {
T val[VecSize];
};
class Timer {
public:
explicit Timer(hipStream_t stream);
void Start();
void End();
float Duration();
~Timer();
private:
hipStream_t stream_;
hipEvent_t start_;
hipEvent_t end_;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,72 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include <utility>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "core/providers/rocm/tunable/gemm_common.h"
namespace onnxruntime {
namespace rocm {
namespace tunable {
namespace blas {
namespace internal {
template <typename T>
struct DataTypeAdaptor {
using type = T;
};
template <>
struct DataTypeAdaptor<half> {
using type = ck::half_t;
};
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Nop = ck::tensor_operation::element_wise::PassThrough;
template <typename T, typename ALayout, typename BLayout>
auto GetCKGemmTypeStringAndOps() {
using CKDataType = typename DataTypeAdaptor<T>::type;
using DeviceGemm = ck::tensor_operation::device::DeviceGemm<
ALayout, BLayout, Row,
CKDataType, CKDataType, CKDataType,
Nop, Nop, Nop>;
using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceGemm>;
std::vector<std::pair<std::string, Op<GemmParams<T>>>> ret;
for (auto&& impl : InstanceFactory::GetInstances()) {
auto type_string = impl->GetTypeString();
auto invoker = impl->MakeInvokerPointer();
auto ck_gemm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmParams<T>* params) -> Status {
auto nop = Nop{};
auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c,
params->m, params->n, params->k,
params->lda, params->ldb, params->ldc,
nop, nop, nop);
TUNABLE_OP_RETURN_UNSUPPOTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support ", params->Signature());
invoker->Run(arg.get(), StreamConfig{params->stream});
return Status::OK();
};
ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op)));
}
return ret;
}
} // namespace internal
} // namespace blas
} // namespace tunable
} // namespace rocm
} // namespace onnxruntime

View file

@ -0,0 +1,60 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include <vector>
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/tunable/tunable.h"
namespace onnxruntime {
namespace rocm {
namespace tunable {
namespace blas {
enum class BlasOp {
N = 0,
T = 1,
NonTrans = 0,
Trans = 1,
};
inline std::string BlasOpToString(BlasOp op) {
switch (op) {
case BlasOp::N:
return "N";
case BlasOp::T:
return "T";
}
}
// We don't assume the implementation is row-majored or column-majored. But for testing convenience, we assume all
// our wrappers have row-majored convention, since it is the native layout to numpy and pytorch.
template <typename T>
struct GemmParams : tunable::OpParams {
std::string Signature() const override {
return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k);
}
rocblas_handle handle;
BlasOp opa;
BlasOp opb;
int64_t m;
int64_t n;
int64_t k;
T alpha;
const T* a;
int64_t lda;
const T* b;
int64_t ldb;
T beta;
T* c;
int64_t ldc;
};
} // namespace blas
} // namespace tunable
} // namespace rocm
} // namespace onnxruntime

View file

@ -0,0 +1,59 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/shared_inc/fpgeneric.h"
#include "core/providers/rocm/tunable/gemm_common.h"
namespace onnxruntime {
namespace rocm {
namespace tunable {
namespace blas {
namespace internal {
// RAII style guard to set stream and restore original stream for rocblas_handle
class RocblasHandleStreamGuard {
public:
RocblasHandleStreamGuard(rocblas_handle handle, hipStream_t stream) : handle_{handle} {
ROCBLAS_CALL_THROW(rocblas_get_stream(handle_, &original_stream_));
ROCBLAS_CALL_THROW(rocblas_set_stream(handle_, stream));
}
~RocblasHandleStreamGuard() {
ROCBLAS_CALL_THROW(rocblas_set_stream(handle_, original_stream_));
}
ORT_DISALLOW_COPY_AND_ASSIGNMENT(RocblasHandleStreamGuard);
private:
rocblas_handle handle_;
hipStream_t original_stream_;
};
template <typename T>
Status RocBlasGemmOp(const GemmParams<T>* params) {
RocblasHandleStreamGuard guard(params->handle, params->stream);
// NOTE: rocblas assumes the storage is column-majored, swapping A and B makes it have the same interface
// as those with row-majored convention. That is, if you treat the storage as row-majored but view the matrices as
// transposed, then by using the property Transpose(A*B) = Tranpose(B)*Transpose(A), the correctness is obvious.
auto status = rocblasGemmHelper(
params->handle,
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->n, params->m, params->k,
&(params->alpha),
params->b, params->ldb,
params->a, params->lda,
&(params->beta),
params->c, params->ldc);
ORT_RETURN_IF(status != rocblas_status_success, rocblas_status_to_string(status));
return Status::OK();
}
} // namespace internal
} // namespace blas
} // namespace tunable
} // namespace rocm
} // namespace onnxruntime

View file

@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <utility>
#include "core/providers/rocm/tunable/gemm_ck.cuh"
#include "core/providers/rocm/tunable/gemm_rocblas.h"
#include "core/providers/rocm/tunable/gemm_common.h"
#include "core/providers/rocm/tunable/tunable.h"
namespace onnxruntime {
namespace rocm {
namespace tunable {
namespace blas {
namespace internal {
template <typename T, typename ALayout, typename BLayout>
class GemmTunableOp : public tunable::TunableOp<GemmParams<T>> {
public:
GemmTunableOp() {
this->ops_.emplace_back(RocBlasGemmOp<T>);
for (auto&& [_, op] : GetCKGemmTypeStringAndOps<T, ALayout, BLayout>()) {
ORT_UNUSED_PARAMETER(_);
this->ops_.emplace_back(std::move(op));
}
}
};
} // namespace internal
} // namespace blas
} // namespace tunable
} // namespace rocm
} // namespace onnxruntime

View file

@ -6,6 +6,7 @@
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <chrono>
#include <functional>
#include <limits>
#include <memory>
@ -16,11 +17,12 @@
#include <vector>
#include "core/common/common.h"
#include "contrib_ops/rocm/bert/util.h"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/tunable/util.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
namespace tunable {
struct OpParams {
OpParams() : stream{} {}
@ -150,6 +152,7 @@ class TunableOp {
}
}
ORT_ENFORCE(id >= 0, "Cannot found viable op");
std::this_thread::sleep_for(std::chrono::milliseconds(50));
return id;
}
@ -166,6 +169,6 @@ class TunableOp {
bool tuning_{false};
};
} // namespace tunable
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/tunable/util.h"
#include "core/providers/rocm/shared_inc/rocm_call.h"
namespace onnxruntime {
namespace rocm {
namespace tunable {
Timer::Timer(hipStream_t stream) : stream_(stream) {
HIP_CALL_THROW(hipEventCreate(&start_));
HIP_CALL_THROW(hipEventCreate(&end_));
}
void Timer::Start() {
HIP_CALL_THROW(hipDeviceSynchronize());
HIP_CALL_THROW(hipEventRecord(start_, stream_));
}
void Timer::End() {
HIP_CALL_THROW(hipEventRecord(end_, stream_));
HIP_CALL_THROW(hipEventSynchronize(end_));
}
float Timer::Duration() {
float time;
// time is in ms with a resolution of 1 us
HIP_CALL_THROW(hipEventElapsedTime(&time, start_, end_));
return time;
}
Timer::~Timer() {
HIP_CALL_THROW(hipEventDestroy(start_));
HIP_CALL_THROW(hipEventDestroy(end_));
}
} // namespace tunable
} // namespace rocm
} // namespace onnxruntime

View file

@ -0,0 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <hip/hip_runtime.h>
#include <cstdlib>
#include <iostream>
namespace onnxruntime {
namespace rocm {
namespace tunable {
class Timer {
public:
explicit Timer(hipStream_t stream);
void Start();
void End();
float Duration();
~Timer();
private:
hipStream_t stream_;
hipEvent_t start_;
hipEvent_t end_;
};
} // namespace tunable
} // namespace rocm
} // namespace onnxruntime

View file

@ -6,7 +6,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/util.h"
#include "core/providers/rocm/tunable/util.h"
namespace py = pybind11;

View file

@ -4,9 +4,10 @@
#pragma once
#include <hip/hip_runtime.h>
#include "contrib_ops/rocm/bert/util.h"
#include "core/providers/rocm/tunable/tunable.h"
#include "core/providers/rocm/tunable/util.h"
using onnxruntime::contrib::rocm::Timer;
using onnxruntime::rocm::tunable::Timer;
/// Wrapping around Op and TunableOp
class IKernelExplorer {

View file

@ -2,12 +2,16 @@
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/gemm.h"
#include <pybind11/pybind11.h>
#include <type_traits>
#include "core/providers/rocm/tunable/gemm_common.h"
#include "python/tools/kernel_explorer/kernels/gemm_ck.h"
#include "python/tools/kernel_explorer/kernels/gemm_rocblas.h"
#include "python/tools/kernel_explorer/kernels/gemm_tunable.h"
#include <type_traits>
#include <pybind11/pybind11.h>
using BlasOp = onnxruntime::rocm::tunable::blas::BlasOp;
namespace py = pybind11;

View file

@ -3,57 +3,12 @@
#pragma once
#include <string>
#include <vector>
#include <pybind11/pybind11.h>
#include "contrib_ops/rocm/bert/tunable_op.h"
#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
namespace py = pybind11;
namespace onnxruntime {
enum class BlasOp {
N,
T,
};
inline std::string BlasOpToString(BlasOp op) {
switch (op) {
case BlasOp::N:
return "N";
case BlasOp::T:
return "T";
}
}
// We don't assume the implementation is row-majored or column-majored. But for testing convenience, we assume all
// our wrappers have row-majored convention, since it is the native layout to numpy and pytorch.
template <typename T>
struct GemmParams : contrib::rocm::OpParams {
std::string Signature() const override {
return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k);
}
rocblas_handle handle;
BlasOp opa;
BlasOp opb;
int64_t m;
int64_t n;
int64_t k;
T alpha;
T* a;
int64_t lda;
T* b;
int64_t ldb;
T beta;
T* c;
int64_t ldc;
};
void InitGemm(py::module mod);
} // namespace onnxruntime

View file

@ -11,6 +11,15 @@
#include <utility>
#include <vector>
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/tunable/gemm_common.h"
#include "core/providers/rocm/tunable/gemm_ck.cuh"
#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
using namespace onnxruntime::rocm::tunable::blas;
using namespace onnxruntime::rocm::tunable::blas::internal;
namespace py = pybind11;
namespace onnxruntime {
@ -76,7 +85,7 @@ class CKGemm : public IKernelExplorer {
private:
using ParamsT = GemmParams<T>;
using OpT = contrib::rocm::Op<ParamsT>;
using OpT = rocm::tunable::Op<ParamsT>;
ParamsT params_;
std::vector<OpT> ops_;
std::vector<std::string> type_strings_;

View file

@ -5,67 +5,10 @@
#include <pybind11/pybind11.h>
#include <string>
#include <utility>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "python/tools/kernel_explorer/kernels/gemm.h"
namespace py = pybind11;
namespace onnxruntime {
template <typename T>
struct DataTypeAdaptor {
using type = T;
};
template <>
struct DataTypeAdaptor<half> {
using type = ck::half_t;
};
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Nop = ck::tensor_operation::element_wise::PassThrough;
// to be moved to onnxruntime once we have a monolithicly tunable gemm wrapper and it is enabled for onnxruntime
template <typename T, typename ALayout, typename BLayout>
auto GetCKGemmTypeStringAndOps() {
using CKDataType = typename DataTypeAdaptor<T>::type;
using DeviceGemm = ck::tensor_operation::device::DeviceGemm<
ALayout, BLayout, Row,
CKDataType, CKDataType, CKDataType,
Nop, Nop, Nop>;
using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceGemm>;
std::vector<std::pair<std::string, contrib::rocm::Op<GemmParams<T>>>> ret;
for (auto&& impl : InstanceFactory::GetInstances()) {
auto type_string = impl->GetTypeString();
auto invoker = impl->MakeInvokerPointer();
auto ck_gemm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmParams<T>* params) -> Status {
auto nop = Nop{};
auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c,
params->m, params->n, params->k,
params->lda, params->ldb, params->ldc,
nop, nop, nop);
TUNABLE_OP_RETURN_UNSUPPOTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support ", params->Signature());
invoker->Run(arg.get(), StreamConfig{params->stream});
return Status::OK();
};
ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op)));
}
return ret;
}
void InitComposableKernelGemm(py::module mod);
} // namespace onnxruntime

View file

@ -10,7 +10,13 @@
#include <vector>
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/tunable/gemm_common.h"
#include "core/providers/rocm/tunable/gemm_rocblas.h"
#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
using namespace onnxruntime::rocm::tunable::blas;
using namespace onnxruntime::rocm::tunable::blas::internal;
namespace py = pybind11;
@ -65,7 +71,7 @@ class RocBlasGemm : public IKernelExplorer {
rocblas_handle rocblas_handle_;
using ParamsT = GemmParams<T>;
using OpT = contrib::rocm::Op<ParamsT>;
using OpT = rocm::tunable::Op<ParamsT>;
ParamsT params_{};
OpT op_{RocBlasGemmOp<T>};

View file

@ -5,54 +5,10 @@
#include <pybind11/pybind11.h>
#include "core/common/common.h"
#include "core/providers/rocm/shared_inc/fpgeneric.h"
#include "python/tools/kernel_explorer/kernels/gemm.h"
namespace py = pybind11;
namespace onnxruntime {
// RAII style guard to set stream and restore original stream for rocblas_handle
class RocblasHandleStreamGuard {
public:
RocblasHandleStreamGuard(rocblas_handle handle, hipStream_t stream) : handle_{handle} {
ROCBLAS_CALL_THROW(rocblas_get_stream(handle_, &original_stream_));
ROCBLAS_CALL_THROW(rocblas_set_stream(handle_, stream));
}
~RocblasHandleStreamGuard() {
ROCBLAS_CALL_THROW(rocblas_set_stream(handle_, original_stream_));
}
ORT_DISALLOW_COPY_AND_ASSIGNMENT(RocblasHandleStreamGuard);
private:
rocblas_handle handle_;
hipStream_t original_stream_;
};
// to be moved to onnxruntime once we have a monolithicly tunable gemm wrapper and it is enabled for onnxruntime
template <typename T>
Status RocBlasGemmOp(const GemmParams<T>* params) {
RocblasHandleStreamGuard guard(params->handle, params->stream);
// NOTE: rocblas assumes the storage is column-majored, swapping A and B makes it have the same interface
// as those with row-majored convention. That is, if you treat the storage as row-majored but view the matrices as
// transposed, then by using the property Transpose(A*B) = Tranpose(B)*Transpose(A), the correctness is obvious.
auto status = rocblasGemmHelper(
params->handle,
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->n, params->m, params->k,
&(params->alpha),
params->b, params->ldb,
params->a, params->lda,
&(params->beta),
params->c, params->ldc);
ORT_RETURN_IF(status != rocblas_status_success, rocblas_status_to_string(status));
return Status::OK();
}
void InitRocBlasGemm(py::module mod);
} // namespace onnxruntime

View file

@ -10,25 +10,16 @@
#include <vector>
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/tunable_op.h"
#include "python/tools/kernel_explorer/kernels/gemm.h"
#include "python/tools/kernel_explorer/kernels/gemm_ck.h"
#include "python/tools/kernel_explorer/kernels/gemm_rocblas.h"
#include "core/providers/rocm/tunable/gemm_common.h"
#include "core/providers/rocm/tunable/gemm_tunable.cuh"
#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
using namespace onnxruntime::rocm::tunable::blas;
using namespace onnxruntime::rocm::tunable::blas::internal;
namespace onnxruntime {
template <typename T, typename ALayout, typename BLayout>
class GemmTunableOp : public contrib::rocm::TunableOp<GemmParams<T>> {
public:
GemmTunableOp() {
this->ops_.emplace_back(RocBlasGemmOp<T>);
for (auto&& [_, op] : GetCKGemmTypeStringAndOps<T, ALayout, BLayout>()) {
ORT_UNUSED_PARAMETER(_);
this->ops_.emplace_back(std::move(op));
}
}
};
template <typename T, typename ALayout, typename BLayout>
class GemmTunable : public IKernelExplorer {
public:

View file

@ -9,7 +9,7 @@
#include <string>
#include "contrib_ops/rocm/bert/tunable_op.h"
#include "core/providers/rocm/tunable/tunable.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
#include "python/tools/kernel_explorer/kernels/vector_add_kernel.h"
@ -24,7 +24,7 @@ namespace onnxruntime {
// Extend the OpParams so that all specializations have the same parameter passing interface
template <typename T>
struct VectorAddParams : contrib::rocm::OpParams {
struct VectorAddParams : rocm::tunable::OpParams {
std::string Signature() const override { return std::to_string(n); }
T* x;
@ -54,7 +54,7 @@ Status VectorAddOp(const VectorAddParams<T>* params) {
// A Tunable VectorAddOp is a collection of non-tunable VectorAddOps implementations that have variable performance
// characteristics. Those implementations may be put into a C++ container for tuner to select.
template <typename T>
class VectorAddTunableOp : public contrib::rocm::TunableOp<VectorAddParams<T>> {
class VectorAddTunableOp : public rocm::tunable::TunableOp<VectorAddParams<T>> {
public:
VectorAddTunableOp() {
ADD_OP(64);

View file

@ -4,12 +4,14 @@
#pragma once
#include <hip/hip_runtime.h>
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/tunable/util.h"
#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
#include "contrib_ops/rocm/bert/util.h"
using onnxruntime::contrib::rocm::CeilingDivision;
using onnxruntime::contrib::rocm::AlignedVector;
using onnxruntime::rocm::CeilDiv;
using onnxruntime::rocm::aligned_vector;
namespace onnxruntime {
@ -18,7 +20,7 @@ __global__ void VectorAddKernel(const T* __restrict__ x,
const T* __restrict__ y,
T* __restrict__ z, int n) {
int i = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x;
using LoadT = AlignedVector<T, VecSize>;
using LoadT = aligned_vector<T, VecSize>;
if (VecSize * i + VecSize - 1 < n) {
T x_vec[VecSize];
@ -50,7 +52,7 @@ __global__ void VectorAddKernel(const T* __restrict__ x,
template <typename T, int ThreadsPerBlock, int VecSize>
Status LaunchVectorAdd(hipStream_t stream, const T* x, const T* y, T* z, int n) {
hipLaunchKernelGGL((VectorAddKernel<T, VecSize>),
dim3(CeilingDivision(n, ThreadsPerBlock*VecSize)),
dim3(CeilDiv(n, ThreadsPerBlock*VecSize)),
dim3(ThreadsPerBlock),
0, stream,
x, y, z, n);