From c93cb8f9493aa81ed64c16da11df8bcae7fc0194 Mon Sep 17 00:00:00 2001 From: cloudhan Date: Fri, 30 Sep 2022 14:01:16 +0800 Subject: [PATCH] Revert "Enable ROCm to use tunable GEMM" (#13160) Reverts microsoft/onnxruntime#12853 due to CI pipeline problem. --- cmake/onnxruntime_providers.cmake | 3 - onnxruntime/core/providers/rocm/math/gemm.cc | 22 ++-- .../core/providers/rocm/math/matmul_impl.cc | 16 +-- .../core/providers/rocm/tunable/gemm.cu | 119 ------------------ .../core/providers/rocm/tunable/gemm.h | 58 --------- .../core/providers/rocm/tunable/gemm_ck.cuh | 11 -- 6 files changed, 16 insertions(+), 213 deletions(-) delete mode 100644 onnxruntime/core/providers/rocm/tunable/gemm.cu delete mode 100644 onnxruntime/core/providers/rocm/tunable/gemm.h diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index d3991c097f..aa77faa327 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -1418,9 +1418,6 @@ if (onnxruntime_USE_ROCM) #endif() endif() - include(composable_kernel) - target_link_libraries(onnxruntime_providers_rocm PRIVATE onnxruntime_composable_kernel_includes device_gemm_instance) - if(UNIX) set_property(TARGET onnxruntime_providers_rocm APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/rocm/version_script.lds -Xlinker --gc-sections") target_link_libraries(onnxruntime_providers_rocm PRIVATE nsync_cpp) diff --git a/onnxruntime/core/providers/rocm/math/gemm.cc b/onnxruntime/core/providers/rocm/math/gemm.cc index 4b9530348c..b0a32b5acc 100644 --- a/onnxruntime/core/providers/rocm/math/gemm.cc +++ b/onnxruntime/core/providers/rocm/math/gemm.cc @@ -2,18 +2,13 @@ // Licensed under the MIT License. #include "core/providers/rocm/math/gemm.h" - #include "core/providers/cpu/math/gemm_helper.h" #include "core/providers/rocm/rocm_common.h" #include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/providers/rocm/tunable/gemm.h" - namespace onnxruntime { namespace rocm { -using tunable::blas::BlasOp; - #define REGISTER_KERNEL_TYPED(T) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ Gemm, \ @@ -127,21 +122,24 @@ Status Gemm::ComputeInternal(OpKernelContext* ctx) const { } } - return tunable::blas::column_major::Gemm( - false, Stream(), + HipT alpha = ToHipType::FromFloat(alpha_); + HipT beta = ToHipType::FromFloat(beta_); + // Gemm, note that HIP assumes col-major, so Y(N,M) = alpha * op(W) x op(X) + beta * Y + ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper( RocblasHandle(), - trans_B_ ? BlasOp::Trans : BlasOp::NonTrans, - trans_A_ ? BlasOp::Trans : BlasOp::NonTrans, + trans_B_ ? rocblas_operation_transpose : rocblas_operation_none, + trans_A_ ? rocblas_operation_transpose : rocblas_operation_none, N, M, K, - alpha_, + &alpha, reinterpret_cast(W->Data()), (trans_B_ ? K : N), reinterpret_cast(X->Data()), (trans_A_ ? M : K), // ideally we need to set the output buffer contents to 0 if bias is missing, // but passing 0 for beta is cheaper and it will ignore any junk in the output buffer - B != nullptr ? beta_ : 0.0f, - out_data, N); + B != nullptr ? &beta : &zero, + out_data, N)); + return Status::OK(); } } // namespace rocm diff --git a/onnxruntime/core/providers/rocm/math/matmul_impl.cc b/onnxruntime/core/providers/rocm/math/matmul_impl.cc index cf9ce8dbc1..931dde56d8 100644 --- a/onnxruntime/core/providers/rocm/math/matmul_impl.cc +++ b/onnxruntime/core/providers/rocm/math/matmul_impl.cc @@ -9,7 +9,6 @@ #include "core/providers/rocm/rocm_allocator.h" #include "core/providers/rocm/rocm_kernel.h" -#include "core/providers/rocm/tunable/gemm.h" namespace onnxruntime { namespace rocm { @@ -84,16 +83,13 @@ Status MatMulImpl(const RocmKernel* op, MatMulComputeHelper& helper, int64_t stride_A, stride_B, stride_C, batch_count; if (helper.OutputOffsets().size() == 1) { - using tunable::blas::BlasOp; - BlasOp transA = transa ? BlasOp::Trans : BlasOp::NonTrans; - BlasOp transB = transb ? BlasOp::Trans : BlasOp::NonTrans; - return tunable::blas::column_major::Gemm( - false, op->Stream(), - op->RocblasHandle(), transB, transA, static_cast(helper.N()), - static_cast(helper.M()), static_cast(helper.K()), t_alpha, + ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper( + op->RocblasHandle(), transB, transA, static_cast(helper.N()), + static_cast(helper.M()), static_cast(helper.K()), &alpha, reinterpret_cast(right_x_data), ldb, - reinterpret_cast(left_x_data), lda, t_zero, - reinterpret_cast(output_y_data), ldc); + reinterpret_cast(left_x_data), lda, &zero, + reinterpret_cast(output_y_data), ldc)); + return Status::OK(); } else if (CanUseStridedBatchedGemm(left_shape, right_shape, transa, transb, trans_batch_a, trans_batch_b, stride_A, stride_B, stride_C, batch_count)) { diff --git a/onnxruntime/core/providers/rocm/tunable/gemm.cu b/onnxruntime/core/providers/rocm/tunable/gemm.cu deleted file mode 100644 index 8d8d7b9302..0000000000 --- a/onnxruntime/core/providers/rocm/tunable/gemm.cu +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#define _GEMM_H_KEEP_SIGNATURE_DEFINES -#include "core/providers/rocm/tunable/gemm.h" - -#include -#include - -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/providers/rocm/tunable/gemm_rocblas.h" -#include "core/providers/rocm/tunable/gemm_tunable.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { - -namespace row_major { - -template -inline GEMM(T, ScalarT) { - GemmParams params; - params.stream = stream; - params.handle = handle; - - params.opa = opa; - params.opb = opb; - params.m = m; - params.n = n; - params.k = k; - if constexpr (!std::is_same_v && std::is_same_v) { - params.alpha = ToHipType::FromFloat(std::forward(alpha)); - } else { - params.alpha = alpha; - } - params.a = a; - params.lda = lda; - params.b = b; - params.ldb = ldb; - if constexpr (!std::is_same_v && std::is_same_v) { - params.beta = ToHipType::FromFloat(std::forward(beta)); - } else { - params.beta = beta; - } - params.c = c; - params.ldc = ldc; - - if (tunable) { - if (opa == BlasOp::N && opb == BlasOp::N) { - static internal::GemmTunableOp gemm{}; - gemm.EnableTuning(); - return gemm(¶ms); - } else if (opa == BlasOp::T && opb == BlasOp::N) { - static internal::GemmTunableOp gemm{}; - gemm.EnableTuning(); - return gemm(¶ms); - } else if (opa == BlasOp::N && opb == BlasOp::T) { - static internal::GemmTunableOp gemm{}; - gemm.EnableTuning(); - return gemm(¶ms); - } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { - static internal::GemmTunableOp gemm{}; - gemm.EnableTuning(); - return gemm(¶ms); - } - } - - return internal::RocBlasGemmOp(¶ms); -} - -#define CALL_GEMM(T, ScalarT) \ - Gemm(tunable, stream, handle, \ - opa, opb, \ - m, n, k, \ - alpha, a, lda, b, ldb, \ - beta, c, ldc) - -// clang-format off -GEMM(double, double ) { return CALL_GEMM(double, double ); } -GEMM(float, float ) { return CALL_GEMM(float, float ); } -GEMM(half, half ) { return CALL_GEMM(half, half ); } -GEMM(BFloat16, BFloat16) { return CALL_GEMM(BFloat16, BFloat16); } -GEMM(double, float ) { return CALL_GEMM(double, float ); } -GEMM(half, float ) { return CALL_GEMM(half, float ); } -GEMM(BFloat16, float ) { return CALL_GEMM(BFloat16, float ); } -// clang-format on - -#undef CALL_GEMM - -} // namespace row_major - -namespace column_major { - -#define CALL_GEMM_WITH_AB_SWAPPED(T, ScalarT) \ - row_major::Gemm(tunable, stream, handle, \ - opb, opa, \ - n, m, k, \ - alpha, b, ldb, a, lda, \ - beta, c, ldc) - -// clang-format off -GEMM(double, double ) { return CALL_GEMM_WITH_AB_SWAPPED(double, double ); } -GEMM(float, float ) { return CALL_GEMM_WITH_AB_SWAPPED(float, float ); } -GEMM(half, half ) { return CALL_GEMM_WITH_AB_SWAPPED(half, half ); } -GEMM(BFloat16, BFloat16) { return CALL_GEMM_WITH_AB_SWAPPED(BFloat16, BFloat16); } -GEMM(double, float ) { return CALL_GEMM_WITH_AB_SWAPPED(double, float ); } -GEMM(half, float ) { return CALL_GEMM_WITH_AB_SWAPPED(half, float ); } -GEMM(BFloat16, float ) { return CALL_GEMM_WITH_AB_SWAPPED(BFloat16, float ); } -// clang-format on - -#undef CALL_GEMM_WITH_AB_SWAPPED - -} // namespace column_major - -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/gemm.h b/onnxruntime/core/providers/rocm/tunable/gemm.h deleted file mode 100644 index ca39ca2d91..0000000000 --- a/onnxruntime/core/providers/rocm/tunable/gemm.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/status.h" -#include "core/framework/float16.h" -#include "core/providers/rocm/tunable/gemm_common.h" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { - -#define GEMM(T, ScalarT) \ - common::Status Gemm( \ - bool tunable, hipStream_t stream, rocblas_handle handle, \ - BlasOp opa, BlasOp opb, \ - std::int64_t m, std::int64_t n, std::int64_t k, \ - ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \ - ScalarT beta, T* c, std::int64_t ldc) - -namespace row_major { - -GEMM(double, double); -GEMM(float, float); -GEMM(half, half); -GEMM(BFloat16, BFloat16); -GEMM(double, float); -GEMM(half, float); -GEMM(BFloat16, float); - -} // namespace row_major - -// TODO(anyone): the caller should not need to swap the params a and b manually, but all the current callsites are -// doing so. It is cumbersome and unintuitive. At the moment, this namespace only ease the porting from old direct -// rocblas_gemm* calls to tunable gemm calls. After all porting of all callsites, if there is no column_major usecase -// left, then we shall remove this namespace, finally. -namespace column_major { - -GEMM(double, double); -GEMM(float, float); -GEMM(half, half); -GEMM(BFloat16, BFloat16); -GEMM(double, float); -GEMM(half, float); -GEMM(BFloat16, float); - -} // namespace column_major - -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#ifndef _GEMM_H_KEEP_SIGNATURE_DEFINES -#undef GEMM -#endif diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh b/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh index c08a866fd5..f3d3bad0ae 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh +++ b/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh @@ -31,11 +31,6 @@ struct DataTypeAdaptor { using type = ck::half_t; }; -template <> -struct DataTypeAdaptor { - using type = ck::bhalf16_t; -}; - using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -55,12 +50,6 @@ auto GetCKGemmTypeStringAndOps() { auto type_string = impl->GetTypeString(); auto invoker = impl->MakeInvokerPointer(); auto ck_gemm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmParams* params) -> Status { - auto one = ToHipType::FromFloat(1.0f); - auto zero = ToHipType::FromFloat(0.0f); - TUNABLE_OP_RETURN_UNSUPPOTED_ARGUMENT_IF( - params->alpha != one || params->beta != zero, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0", params->Signature()); - auto nop = Nop{}; auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, params->m, params->n, params->k,