mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Revert "Enable ROCm to use tunable GEMM" (#13160)
Reverts microsoft/onnxruntime#12853 due to CI pipeline problem.
This commit is contained in:
parent
c8781b77f6
commit
c93cb8f949
6 changed files with 16 additions and 213 deletions
|
|
@ -1418,9 +1418,6 @@ if (onnxruntime_USE_ROCM)
|
|||
#endif()
|
||||
endif()
|
||||
|
||||
include(composable_kernel)
|
||||
target_link_libraries(onnxruntime_providers_rocm PRIVATE onnxruntime_composable_kernel_includes device_gemm_instance)
|
||||
|
||||
if(UNIX)
|
||||
set_property(TARGET onnxruntime_providers_rocm APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/rocm/version_script.lds -Xlinker --gc-sections")
|
||||
target_link_libraries(onnxruntime_providers_rocm PRIVATE nsync_cpp)
|
||||
|
|
|
|||
|
|
@ -2,18 +2,13 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/rocm/math/gemm.h"
|
||||
|
||||
#include "core/providers/cpu/math/gemm_helper.h"
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "core/providers/rocm/shared_inc/fpgeneric.h"
|
||||
#include "core/providers/rocm/tunable/gemm.h"
|
||||
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
using tunable::blas::BlasOp;
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
Gemm, \
|
||||
|
|
@ -127,21 +122,24 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
}
|
||||
}
|
||||
|
||||
return tunable::blas::column_major::Gemm(
|
||||
false, Stream(),
|
||||
HipT alpha = ToHipType<T>::FromFloat(alpha_);
|
||||
HipT beta = ToHipType<T>::FromFloat(beta_);
|
||||
// Gemm, note that HIP assumes col-major, so Y(N,M) = alpha * op(W) x op(X) + beta * Y
|
||||
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
|
||||
RocblasHandle(),
|
||||
trans_B_ ? BlasOp::Trans : BlasOp::NonTrans,
|
||||
trans_A_ ? BlasOp::Trans : BlasOp::NonTrans,
|
||||
trans_B_ ? rocblas_operation_transpose : rocblas_operation_none,
|
||||
trans_A_ ? rocblas_operation_transpose : rocblas_operation_none,
|
||||
N, M, K,
|
||||
alpha_,
|
||||
&alpha,
|
||||
reinterpret_cast<const HipT*>(W->Data<T>()),
|
||||
(trans_B_ ? K : N),
|
||||
reinterpret_cast<const HipT*>(X->Data<T>()),
|
||||
(trans_A_ ? M : K),
|
||||
// ideally we need to set the output buffer contents to 0 if bias is missing,
|
||||
// but passing 0 for beta is cheaper and it will ignore any junk in the output buffer
|
||||
B != nullptr ? beta_ : 0.0f,
|
||||
out_data, N);
|
||||
B != nullptr ? &beta : &zero,
|
||||
out_data, N));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@
|
|||
|
||||
#include "core/providers/rocm/rocm_allocator.h"
|
||||
#include "core/providers/rocm/rocm_kernel.h"
|
||||
#include "core/providers/rocm/tunable/gemm.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
|
@ -84,16 +83,13 @@ Status MatMulImpl(const RocmKernel* op, MatMulComputeHelper& helper,
|
|||
int64_t stride_A, stride_B, stride_C, batch_count;
|
||||
|
||||
if (helper.OutputOffsets().size() == 1) {
|
||||
using tunable::blas::BlasOp;
|
||||
BlasOp transA = transa ? BlasOp::Trans : BlasOp::NonTrans;
|
||||
BlasOp transB = transb ? BlasOp::Trans : BlasOp::NonTrans;
|
||||
return tunable::blas::column_major::Gemm(
|
||||
false, op->Stream(),
|
||||
op->RocblasHandle(), transB, transA, static_cast<int64_t>(helper.N()),
|
||||
static_cast<int64_t>(helper.M()), static_cast<int64_t>(helper.K()), t_alpha,
|
||||
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
|
||||
op->RocblasHandle(), transB, transA, static_cast<int>(helper.N()),
|
||||
static_cast<int>(helper.M()), static_cast<int>(helper.K()), &alpha,
|
||||
reinterpret_cast<const HipT*>(right_x_data), ldb,
|
||||
reinterpret_cast<const HipT*>(left_x_data), lda, t_zero,
|
||||
reinterpret_cast<HipT*>(output_y_data), ldc);
|
||||
reinterpret_cast<const HipT*>(left_x_data), lda, &zero,
|
||||
reinterpret_cast<HipT*>(output_y_data), ldc));
|
||||
return Status::OK();
|
||||
} else if (CanUseStridedBatchedGemm(left_shape, right_shape,
|
||||
transa, transb, trans_batch_a, trans_batch_b,
|
||||
stride_A, stride_B, stride_C, batch_count)) {
|
||||
|
|
|
|||
|
|
@ -1,119 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#define _GEMM_H_KEEP_SIGNATURE_DEFINES
|
||||
#include "core/providers/rocm/tunable/gemm.h"
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "core/providers/rocm/shared_inc/fpgeneric.h"
|
||||
#include "core/providers/rocm/tunable/gemm_rocblas.h"
|
||||
#include "core/providers/rocm/tunable/gemm_tunable.cuh"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
namespace tunable {
|
||||
namespace blas {
|
||||
|
||||
namespace row_major {
|
||||
|
||||
template <typename T, typename ScalarT>
|
||||
inline GEMM(T, ScalarT) {
|
||||
GemmParams<T> params;
|
||||
params.stream = stream;
|
||||
params.handle = handle;
|
||||
|
||||
params.opa = opa;
|
||||
params.opb = opb;
|
||||
params.m = m;
|
||||
params.n = n;
|
||||
params.k = k;
|
||||
if constexpr (!std::is_same_v<T, ScalarT> && std::is_same_v<ScalarT, float>) {
|
||||
params.alpha = ToHipType<T>::FromFloat(std::forward<T>(alpha));
|
||||
} else {
|
||||
params.alpha = alpha;
|
||||
}
|
||||
params.a = a;
|
||||
params.lda = lda;
|
||||
params.b = b;
|
||||
params.ldb = ldb;
|
||||
if constexpr (!std::is_same_v<T, ScalarT> && std::is_same_v<ScalarT, float>) {
|
||||
params.beta = ToHipType<T>::FromFloat(std::forward<T>(beta));
|
||||
} else {
|
||||
params.beta = beta;
|
||||
}
|
||||
params.c = c;
|
||||
params.ldc = ldc;
|
||||
|
||||
if (tunable) {
|
||||
if (opa == BlasOp::N && opb == BlasOp::N) {
|
||||
static internal::GemmTunableOp<T, internal::Row, internal::Row> gemm{};
|
||||
gemm.EnableTuning();
|
||||
return gemm(¶ms);
|
||||
} else if (opa == BlasOp::T && opb == BlasOp::N) {
|
||||
static internal::GemmTunableOp<T, internal::Col, internal::Row> gemm{};
|
||||
gemm.EnableTuning();
|
||||
return gemm(¶ms);
|
||||
} else if (opa == BlasOp::N && opb == BlasOp::T) {
|
||||
static internal::GemmTunableOp<T, internal::Row, internal::Col> gemm{};
|
||||
gemm.EnableTuning();
|
||||
return gemm(¶ms);
|
||||
} else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ {
|
||||
static internal::GemmTunableOp<T, internal::Col, internal::Col> gemm{};
|
||||
gemm.EnableTuning();
|
||||
return gemm(¶ms);
|
||||
}
|
||||
}
|
||||
|
||||
return internal::RocBlasGemmOp(¶ms);
|
||||
}
|
||||
|
||||
#define CALL_GEMM(T, ScalarT) \
|
||||
Gemm<T, ScalarT>(tunable, stream, handle, \
|
||||
opa, opb, \
|
||||
m, n, k, \
|
||||
alpha, a, lda, b, ldb, \
|
||||
beta, c, ldc)
|
||||
|
||||
// clang-format off
|
||||
GEMM(double, double ) { return CALL_GEMM(double, double ); }
|
||||
GEMM(float, float ) { return CALL_GEMM(float, float ); }
|
||||
GEMM(half, half ) { return CALL_GEMM(half, half ); }
|
||||
GEMM(BFloat16, BFloat16) { return CALL_GEMM(BFloat16, BFloat16); }
|
||||
GEMM(double, float ) { return CALL_GEMM(double, float ); }
|
||||
GEMM(half, float ) { return CALL_GEMM(half, float ); }
|
||||
GEMM(BFloat16, float ) { return CALL_GEMM(BFloat16, float ); }
|
||||
// clang-format on
|
||||
|
||||
#undef CALL_GEMM
|
||||
|
||||
} // namespace row_major
|
||||
|
||||
namespace column_major {
|
||||
|
||||
#define CALL_GEMM_WITH_AB_SWAPPED(T, ScalarT) \
|
||||
row_major::Gemm<T, ScalarT>(tunable, stream, handle, \
|
||||
opb, opa, \
|
||||
n, m, k, \
|
||||
alpha, b, ldb, a, lda, \
|
||||
beta, c, ldc)
|
||||
|
||||
// clang-format off
|
||||
GEMM(double, double ) { return CALL_GEMM_WITH_AB_SWAPPED(double, double ); }
|
||||
GEMM(float, float ) { return CALL_GEMM_WITH_AB_SWAPPED(float, float ); }
|
||||
GEMM(half, half ) { return CALL_GEMM_WITH_AB_SWAPPED(half, half ); }
|
||||
GEMM(BFloat16, BFloat16) { return CALL_GEMM_WITH_AB_SWAPPED(BFloat16, BFloat16); }
|
||||
GEMM(double, float ) { return CALL_GEMM_WITH_AB_SWAPPED(double, float ); }
|
||||
GEMM(half, float ) { return CALL_GEMM_WITH_AB_SWAPPED(half, float ); }
|
||||
GEMM(BFloat16, float ) { return CALL_GEMM_WITH_AB_SWAPPED(BFloat16, float ); }
|
||||
// clang-format on
|
||||
|
||||
#undef CALL_GEMM_WITH_AB_SWAPPED
|
||||
|
||||
} // namespace column_major
|
||||
|
||||
} // namespace blas
|
||||
} // namespace tunable
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/status.h"
|
||||
#include "core/framework/float16.h"
|
||||
#include "core/providers/rocm/tunable/gemm_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
namespace tunable {
|
||||
namespace blas {
|
||||
|
||||
#define GEMM(T, ScalarT) \
|
||||
common::Status Gemm( \
|
||||
bool tunable, hipStream_t stream, rocblas_handle handle, \
|
||||
BlasOp opa, BlasOp opb, \
|
||||
std::int64_t m, std::int64_t n, std::int64_t k, \
|
||||
ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \
|
||||
ScalarT beta, T* c, std::int64_t ldc)
|
||||
|
||||
namespace row_major {
|
||||
|
||||
GEMM(double, double);
|
||||
GEMM(float, float);
|
||||
GEMM(half, half);
|
||||
GEMM(BFloat16, BFloat16);
|
||||
GEMM(double, float);
|
||||
GEMM(half, float);
|
||||
GEMM(BFloat16, float);
|
||||
|
||||
} // namespace row_major
|
||||
|
||||
// TODO(anyone): the caller should not need to swap the params a and b manually, but all the current callsites are
|
||||
// doing so. It is cumbersome and unintuitive. At the moment, this namespace only ease the porting from old direct
|
||||
// rocblas_gemm* calls to tunable gemm calls. After all porting of all callsites, if there is no column_major usecase
|
||||
// left, then we shall remove this namespace, finally.
|
||||
namespace column_major {
|
||||
|
||||
GEMM(double, double);
|
||||
GEMM(float, float);
|
||||
GEMM(half, half);
|
||||
GEMM(BFloat16, BFloat16);
|
||||
GEMM(double, float);
|
||||
GEMM(half, float);
|
||||
GEMM(BFloat16, float);
|
||||
|
||||
} // namespace column_major
|
||||
|
||||
} // namespace blas
|
||||
} // namespace tunable
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
||||
#ifndef _GEMM_H_KEEP_SIGNATURE_DEFINES
|
||||
#undef GEMM
|
||||
#endif
|
||||
|
|
@ -31,11 +31,6 @@ struct DataTypeAdaptor<half> {
|
|||
using type = ck::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeAdaptor<BFloat16> {
|
||||
using type = ck::bhalf16_t;
|
||||
};
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
|
|
@ -55,12 +50,6 @@ auto GetCKGemmTypeStringAndOps() {
|
|||
auto type_string = impl->GetTypeString();
|
||||
auto invoker = impl->MakeInvokerPointer();
|
||||
auto ck_gemm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmParams<T>* params) -> Status {
|
||||
auto one = ToHipType<T>::FromFloat(1.0f);
|
||||
auto zero = ToHipType<T>::FromFloat(0.0f);
|
||||
TUNABLE_OP_RETURN_UNSUPPOTED_ARGUMENT_IF(
|
||||
params->alpha != one || params->beta != zero,
|
||||
impl->GetTypeString(), " only supports alpha == 1 and beta == 0", params->Signature());
|
||||
|
||||
auto nop = Nop{};
|
||||
auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c,
|
||||
params->m, params->n, params->k,
|
||||
|
|
|
|||
Loading…
Reference in a new issue