Revert "Enable ROCm to use tunable GEMM" (#13160)

Reverts microsoft/onnxruntime#12853 due to CI pipeline problem.
This commit is contained in:
cloudhan 2022-09-30 14:01:16 +08:00 committed by GitHub
parent c8781b77f6
commit c93cb8f949
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 16 additions and 213 deletions

View file

@ -1418,9 +1418,6 @@ if (onnxruntime_USE_ROCM)
#endif()
endif()
include(composable_kernel)
target_link_libraries(onnxruntime_providers_rocm PRIVATE onnxruntime_composable_kernel_includes device_gemm_instance)
if(UNIX)
set_property(TARGET onnxruntime_providers_rocm APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/rocm/version_script.lds -Xlinker --gc-sections")
target_link_libraries(onnxruntime_providers_rocm PRIVATE nsync_cpp)

View file

@ -2,18 +2,13 @@
// Licensed under the MIT License.
#include "core/providers/rocm/math/gemm.h"
#include "core/providers/cpu/math/gemm_helper.h"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/shared_inc/fpgeneric.h"
#include "core/providers/rocm/tunable/gemm.h"
namespace onnxruntime {
namespace rocm {
using tunable::blas::BlasOp;
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Gemm, \
@ -127,21 +122,24 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
}
}
return tunable::blas::column_major::Gemm(
false, Stream(),
HipT alpha = ToHipType<T>::FromFloat(alpha_);
HipT beta = ToHipType<T>::FromFloat(beta_);
// Gemm, note that HIP assumes col-major, so Y(N,M) = alpha * op(W) x op(X) + beta * Y
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
RocblasHandle(),
trans_B_ ? BlasOp::Trans : BlasOp::NonTrans,
trans_A_ ? BlasOp::Trans : BlasOp::NonTrans,
trans_B_ ? rocblas_operation_transpose : rocblas_operation_none,
trans_A_ ? rocblas_operation_transpose : rocblas_operation_none,
N, M, K,
alpha_,
&alpha,
reinterpret_cast<const HipT*>(W->Data<T>()),
(trans_B_ ? K : N),
reinterpret_cast<const HipT*>(X->Data<T>()),
(trans_A_ ? M : K),
// ideally we need to set the output buffer contents to 0 if bias is missing,
// but passing 0 for beta is cheaper and it will ignore any junk in the output buffer
B != nullptr ? beta_ : 0.0f,
out_data, N);
B != nullptr ? &beta : &zero,
out_data, N));
return Status::OK();
}
} // namespace rocm

View file

@ -9,7 +9,6 @@
#include "core/providers/rocm/rocm_allocator.h"
#include "core/providers/rocm/rocm_kernel.h"
#include "core/providers/rocm/tunable/gemm.h"
namespace onnxruntime {
namespace rocm {
@ -84,16 +83,13 @@ Status MatMulImpl(const RocmKernel* op, MatMulComputeHelper& helper,
int64_t stride_A, stride_B, stride_C, batch_count;
if (helper.OutputOffsets().size() == 1) {
using tunable::blas::BlasOp;
BlasOp transA = transa ? BlasOp::Trans : BlasOp::NonTrans;
BlasOp transB = transb ? BlasOp::Trans : BlasOp::NonTrans;
return tunable::blas::column_major::Gemm(
false, op->Stream(),
op->RocblasHandle(), transB, transA, static_cast<int64_t>(helper.N()),
static_cast<int64_t>(helper.M()), static_cast<int64_t>(helper.K()), t_alpha,
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
op->RocblasHandle(), transB, transA, static_cast<int>(helper.N()),
static_cast<int>(helper.M()), static_cast<int>(helper.K()), &alpha,
reinterpret_cast<const HipT*>(right_x_data), ldb,
reinterpret_cast<const HipT*>(left_x_data), lda, t_zero,
reinterpret_cast<HipT*>(output_y_data), ldc);
reinterpret_cast<const HipT*>(left_x_data), lda, &zero,
reinterpret_cast<HipT*>(output_y_data), ldc));
return Status::OK();
} else if (CanUseStridedBatchedGemm(left_shape, right_shape,
transa, transb, trans_batch_a, trans_batch_b,
stride_A, stride_B, stride_C, batch_count)) {

View file

@ -1,119 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define _GEMM_H_KEEP_SIGNATURE_DEFINES
#include "core/providers/rocm/tunable/gemm.h"
#include <type_traits>
#include <utility>
#include "core/providers/rocm/shared_inc/fpgeneric.h"
#include "core/providers/rocm/tunable/gemm_rocblas.h"
#include "core/providers/rocm/tunable/gemm_tunable.cuh"
namespace onnxruntime {
namespace rocm {
namespace tunable {
namespace blas {
namespace row_major {
template <typename T, typename ScalarT>
inline GEMM(T, ScalarT) {
GemmParams<T> params;
params.stream = stream;
params.handle = handle;
params.opa = opa;
params.opb = opb;
params.m = m;
params.n = n;
params.k = k;
if constexpr (!std::is_same_v<T, ScalarT> && std::is_same_v<ScalarT, float>) {
params.alpha = ToHipType<T>::FromFloat(std::forward<T>(alpha));
} else {
params.alpha = alpha;
}
params.a = a;
params.lda = lda;
params.b = b;
params.ldb = ldb;
if constexpr (!std::is_same_v<T, ScalarT> && std::is_same_v<ScalarT, float>) {
params.beta = ToHipType<T>::FromFloat(std::forward<T>(beta));
} else {
params.beta = beta;
}
params.c = c;
params.ldc = ldc;
if (tunable) {
if (opa == BlasOp::N && opb == BlasOp::N) {
static internal::GemmTunableOp<T, internal::Row, internal::Row> gemm{};
gemm.EnableTuning();
return gemm(&params);
} else if (opa == BlasOp::T && opb == BlasOp::N) {
static internal::GemmTunableOp<T, internal::Col, internal::Row> gemm{};
gemm.EnableTuning();
return gemm(&params);
} else if (opa == BlasOp::N && opb == BlasOp::T) {
static internal::GemmTunableOp<T, internal::Row, internal::Col> gemm{};
gemm.EnableTuning();
return gemm(&params);
} else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ {
static internal::GemmTunableOp<T, internal::Col, internal::Col> gemm{};
gemm.EnableTuning();
return gemm(&params);
}
}
return internal::RocBlasGemmOp(&params);
}
#define CALL_GEMM(T, ScalarT) \
Gemm<T, ScalarT>(tunable, stream, handle, \
opa, opb, \
m, n, k, \
alpha, a, lda, b, ldb, \
beta, c, ldc)
// clang-format off
GEMM(double, double ) { return CALL_GEMM(double, double ); }
GEMM(float, float ) { return CALL_GEMM(float, float ); }
GEMM(half, half ) { return CALL_GEMM(half, half ); }
GEMM(BFloat16, BFloat16) { return CALL_GEMM(BFloat16, BFloat16); }
GEMM(double, float ) { return CALL_GEMM(double, float ); }
GEMM(half, float ) { return CALL_GEMM(half, float ); }
GEMM(BFloat16, float ) { return CALL_GEMM(BFloat16, float ); }
// clang-format on
#undef CALL_GEMM
} // namespace row_major
namespace column_major {
#define CALL_GEMM_WITH_AB_SWAPPED(T, ScalarT) \
row_major::Gemm<T, ScalarT>(tunable, stream, handle, \
opb, opa, \
n, m, k, \
alpha, b, ldb, a, lda, \
beta, c, ldc)
// clang-format off
GEMM(double, double ) { return CALL_GEMM_WITH_AB_SWAPPED(double, double ); }
GEMM(float, float ) { return CALL_GEMM_WITH_AB_SWAPPED(float, float ); }
GEMM(half, half ) { return CALL_GEMM_WITH_AB_SWAPPED(half, half ); }
GEMM(BFloat16, BFloat16) { return CALL_GEMM_WITH_AB_SWAPPED(BFloat16, BFloat16); }
GEMM(double, float ) { return CALL_GEMM_WITH_AB_SWAPPED(double, float ); }
GEMM(half, float ) { return CALL_GEMM_WITH_AB_SWAPPED(half, float ); }
GEMM(BFloat16, float ) { return CALL_GEMM_WITH_AB_SWAPPED(BFloat16, float ); }
// clang-format on
#undef CALL_GEMM_WITH_AB_SWAPPED
} // namespace column_major
} // namespace blas
} // namespace tunable
} // namespace rocm
} // namespace onnxruntime

View file

@ -1,58 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/status.h"
#include "core/framework/float16.h"
#include "core/providers/rocm/tunable/gemm_common.h"
namespace onnxruntime {
namespace rocm {
namespace tunable {
namespace blas {
#define GEMM(T, ScalarT) \
common::Status Gemm( \
bool tunable, hipStream_t stream, rocblas_handle handle, \
BlasOp opa, BlasOp opb, \
std::int64_t m, std::int64_t n, std::int64_t k, \
ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \
ScalarT beta, T* c, std::int64_t ldc)
namespace row_major {
GEMM(double, double);
GEMM(float, float);
GEMM(half, half);
GEMM(BFloat16, BFloat16);
GEMM(double, float);
GEMM(half, float);
GEMM(BFloat16, float);
} // namespace row_major
// TODO(anyone): the caller should not need to swap the params a and b manually, but all the current callsites are
// doing so. It is cumbersome and unintuitive. At the moment, this namespace only ease the porting from old direct
// rocblas_gemm* calls to tunable gemm calls. After all porting of all callsites, if there is no column_major usecase
// left, then we shall remove this namespace, finally.
namespace column_major {
GEMM(double, double);
GEMM(float, float);
GEMM(half, half);
GEMM(BFloat16, BFloat16);
GEMM(double, float);
GEMM(half, float);
GEMM(BFloat16, float);
} // namespace column_major
} // namespace blas
} // namespace tunable
} // namespace rocm
} // namespace onnxruntime
#ifndef _GEMM_H_KEEP_SIGNATURE_DEFINES
#undef GEMM
#endif

View file

@ -31,11 +31,6 @@ struct DataTypeAdaptor<half> {
using type = ck::half_t;
};
template <>
struct DataTypeAdaptor<BFloat16> {
using type = ck::bhalf16_t;
};
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
@ -55,12 +50,6 @@ auto GetCKGemmTypeStringAndOps() {
auto type_string = impl->GetTypeString();
auto invoker = impl->MakeInvokerPointer();
auto ck_gemm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmParams<T>* params) -> Status {
auto one = ToHipType<T>::FromFloat(1.0f);
auto zero = ToHipType<T>::FromFloat(0.0f);
TUNABLE_OP_RETURN_UNSUPPOTED_ARGUMENT_IF(
params->alpha != one || params->beta != zero,
impl->GetTypeString(), " only supports alpha == 1 and beta == 0", params->Signature());
auto nop = Nop{};
auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c,
params->m, params->n, params->k,