diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc index d19fd2377b..486272ad6b 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc @@ -3,18 +3,14 @@ #include "contrib_ops/rocm/bert/gemm_fast_gelu.h" -#include "contrib_ops/rocm/bert/fast_gelu_impl.h" -#include "contrib_ops/rocm/bert/transformer_common.h" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" #include "core/providers/cpu/math/matmul_helper.h" -#include "core/providers/rocm/math/matmul_impl.h" #include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/shared_inc/fpgeneric.h" namespace onnxruntime { namespace contrib { namespace rocm { -using onnxruntime::rocm::MatMulImpl; using onnxruntime::rocm::ToHipType; #define REGISTER_KERNEL_TYPED(T) \ @@ -48,34 +44,27 @@ Status GemmFastGelu::ComputeInternal(OpKernelContext* ctx) const { MatMulComputeHelper helper; ORT_RETURN_IF_ERROR(helper.Compute(X->Shape(), W->Shape(), transa, transb, trans_batch_a, trans_batch_b, false)); - auto gemm_buffer = GetScratchBuffer(helper.OutputShape().Size()); Tensor* Y = ctx->Output(0, helper.OutputShape()); // Bail out early if the output is going to be empty if (Y->Shape().Size() == 0) return Status::OK(); - const float alpha = 1.0f; - const float zero = 0.0f; + // gemmfastgelu only support alpha == 1 and beta == 0 + const HipT alpha = ToHipType::FromFloat(1.0f); + const HipT beta = ToHipType::FromFloat(0.0f); - if (MatMulImpl(this, helper, reinterpret_cast(X->Data()), - reinterpret_cast(W->Data()), - reinterpret_cast(gemm_buffer.get()), - X->Shape(), W->Shape(), - transa, transb, trans_batch_a, trans_batch_b, alpha, zero) != Status::OK()) { - return Status(common::ONNXRUNTIME, common::FAIL); - } - - int64_t fast_gelu_input_length = Y->Shape().Size(); - int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size(); - - return LaunchFastGeluKernel(Stream(), - static_cast(fast_gelu_input_length), - static_cast(bias_length), - reinterpret_cast(gemm_buffer.get()), - (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, - reinterpret_cast(Y->MutableData()), - false); + return LaunchGemmFastGeluKernel( + IsTunableOpEnabled(), + Stream(), RocblasHandle(), + transa, transb, + static_cast(helper.M()), static_cast(helper.N()), static_cast(helper.K()), + alpha, + reinterpret_cast(X->Data()), static_cast(helper.Lda(transa)), + reinterpret_cast(W->Data()), static_cast(helper.Ldb(transb)), + (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, + beta, + reinterpret_cast(Y->MutableData()), static_cast(helper.Ldc())); } } // namespace rocm diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu new file mode 100644 index 0000000000..1317a2ccbd --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" + +#include + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h" +#include "core/providers/rocm/tunable/gemm_common.h" + +using onnxruntime::rocm::tunable::blas::BlasOp; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +// See it as row-major +template +Status LaunchGemmFastGeluKernel(bool tuning, + hipStream_t stream, + rocblas_handle handle, + bool transa, + bool transb, + int64_t m, + int64_t n, + int64_t k, + const T alpha, + const T* a, + int64_t lda, + const T* b, + int64_t ldb, + const T* bias, + const T beta, + T* c, + int64_t ldc) { + GemmFastGeluParams params; + params.tuning = tuning; + params.stream = stream; + params.handle = handle; + params.opa = transa ? BlasOp::Trans : BlasOp::NonTrans; + params.opb = transb ? BlasOp::Trans : BlasOp::NonTrans; + + params.m = m; + params.n = n; + params.k = k; + params.alpha = alpha; + params.a = a; + params.lda = lda; + params.b = b; + params.ldb = ldb; + params.bias = bias; + params.beta = beta; + params.c = c; + params.ldc = ldc; + + if (tuning) { + static GemmFastGeluTunableOp op; + op.EnableTuning(); + return op(¶ms); + } + + return GemmFastGeluUnfused(¶ms); +} + +#define SPECIALIZED_IMPL(T) \ + template Status LaunchGemmFastGeluKernel(bool tuning, \ + hipStream_t stream, rocblas_handle handle, \ + bool transa, bool transb, \ + int64_t m, int64_t n, int64_t k, const T alpha, \ + const T* a, int64_t lda, const T* b, int64_t ldb, \ + const T* bias, const T beta, T* c, int64_t ldc); + +SPECIALIZED_IMPL(float) +SPECIALIZED_IMPL(half) +SPECIALIZED_IMPL(BFloat16) + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h new file mode 100644 index 0000000000..765a0c96a9 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/common/common.h" +#include "core/providers/rocm/rocm_common.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +Status LaunchGemmFastGeluKernel(bool tuning, + hipStream_t stream, + rocblas_handle handle, + bool transa, + bool transb, + int64_t m, + int64_t n, + int64_t k, + const T alpha, + const T* a, + int64_t lda, + const T* b, + int64_t ldb, + const T* bias, + const T beta, + T* c, + int64_t ldc); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h new file mode 100644 index 0000000000..2ce6040b5c --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "contrib_ops/rocm/bert/fast_gelu_impl.h" +#include "core/providers/rocm/tunable/gemm.h" +#include "core/providers/rocm/tunable/gemm_common.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +using onnxruntime::rocm::tunable::blas::BlasOp; +using onnxruntime::rocm::tunable::blas::BlasOpToString; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +struct GemmFastGeluParams : onnxruntime::rocm::tunable::OpParams { + std::string Signature() const override { + return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k); + } + rocblas_handle handle; + BlasOp opa; + BlasOp opb; + int64_t m; + int64_t n; + int64_t k; + T alpha; + const T* a; + int64_t lda; + const T* b; + int64_t ldb; + const T* bias; + T beta; + T* c; + int64_t ldc; + bool tuning{false}; +}; + +template +Status GemmFastGeluUnfused(const GemmFastGeluParams* params) { + namespace column_major = onnxruntime::rocm::tunable::blas::column_major; + if (column_major::Gemm(params->tuning, params->stream, params->handle, + params->opb, params->opa, + params->n, params->m, params->k, + params->alpha, params->b, params->ldb, params->a, params->lda, + params->beta, params->c, params->ldc) != Status::OK()) { + return Status(common::ONNXRUNTIME, common::FAIL, "GemmFastGelu call column_major::Gemm failed"); + } + + int64_t fast_gelu_input_length = params->m * params->n; + int64_t bias_length = (params->bias != nullptr) ? params->n : 0; + + // inplace computation + return LaunchFastGeluKernel(params->stream, + static_cast(fast_gelu_input_length), + static_cast(bias_length), + params->c, + params->bias, + params->c, + params->tuning); +} + +template +class GemmFastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp> { + public: + GemmFastGeluTunableOp() { + this->ops_.emplace_back(GemmFastGeluUnfused); + + this->SetDefaultId(0); + } +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/matmul.cc b/onnxruntime/core/providers/rocm/math/matmul.cc index f9c299ac11..d8c183b72f 100644 --- a/onnxruntime/core/providers/rocm/math/matmul.cc +++ b/onnxruntime/core/providers/rocm/math/matmul.cc @@ -74,7 +74,7 @@ Status MatMul::ComputeInternal(OpKernelContext* ctx) const { reinterpret_cast(Y->MutableData()), left_X->Shape(), right_X->Shape(), transa, transb, trans_batch_a_, trans_batch_b_, alpha_, 0.0f) != Status::OK()) { - return Status(common::ONNXRUNTIME, common::FAIL); + return Status(common::ONNXRUNTIME, common::FAIL, "MatMulImpl failed"); } return Status::OK(); } diff --git a/onnxruntime/core/providers/rocm/math/matmul_impl.cc b/onnxruntime/core/providers/rocm/math/matmul_impl.cc index 6c88949e8f..3e08898adc 100644 --- a/onnxruntime/core/providers/rocm/math/matmul_impl.cc +++ b/onnxruntime/core/providers/rocm/math/matmul_impl.cc @@ -1,10 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// Modifications: Remove cudaDeviceProp in LaunchFastGeluKernel. -// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. - #include "core/providers/rocm/math/matmul_impl.h" #include "core/providers/rocm/rocm_allocator.h" @@ -137,58 +133,18 @@ Status MatMulImpl(const RocmKernel* op, MatMulComputeHelper& helper, return Status::OK(); } -template Status MatMulImpl(const RocmKernel* op, - MatMulComputeHelper& helper, - const float* left_x_data, - const float* right_x_data, - float* output_y_data, - const TensorShape& left_shape, - const TensorShape& right_shape, - bool transa, - bool transb, - bool trans_batch_a, - bool trans_batch_b, - const float t_alpha, - const float t_zero); -template Status MatMulImpl(const RocmKernel* op, - MatMulComputeHelper& helper, - const double* left_x_data, - const double* right_x_data, - double* output_y_data, - const TensorShape& left_shape, - const TensorShape& right_shape, - bool transa, - bool transb, - bool trans_batch_a, - bool trans_batch_b, - const float t_alpha, - const float t_zero); -template Status MatMulImpl(const RocmKernel* op, - MatMulComputeHelper& helper, - const MLFloat16* left_x_data, - const MLFloat16* right_x_data, - MLFloat16* output_y_data, - const TensorShape& left_shape, - const TensorShape& right_shape, - bool transa, - bool transb, - bool trans_batch_a, - bool trans_batch_b, - const float t_alpha, - const float t_zero); -template Status MatMulImpl(const RocmKernel* op, - MatMulComputeHelper& helper, - const BFloat16* left_x_data, - const BFloat16* right_x_data, - BFloat16* output_y_data, - const TensorShape& left_shape, - const TensorShape& right_shape, - bool transa, - bool transb, - bool trans_batch_a, - bool trans_batch_b, - const float t_alpha, - const float t_zero); +#define SPECIALIZED_IMPL(T) \ + template Status MatMulImpl(const RocmKernel* op, MatMulComputeHelper& helper, \ + const T* left_x_data, const T* right_x_data, T* output_y_data, \ + const TensorShape& left_shape, const TensorShape& right_shape, \ + bool transa, bool transb, \ + bool trans_batch_a, bool trans_batch_b, \ + const float t_alpha, const float t_zero); + +SPECIALIZED_IMPL(float) +SPECIALIZED_IMPL(double) +SPECIALIZED_IMPL(MLFloat16) +SPECIALIZED_IMPL(BFloat16) } // namespace rocm } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/matmul_impl.h b/onnxruntime/core/providers/rocm/math/matmul_impl.h index f22f3b8eac..01754189d4 100644 --- a/onnxruntime/core/providers/rocm/math/matmul_impl.h +++ b/onnxruntime/core/providers/rocm/math/matmul_impl.h @@ -1,10 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// Modifications: Remove cudaDeviceProp in LaunchFastGeluKernel. -// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. - #pragma once #include "core/providers/rocm/shared_inc/fpgeneric.h" diff --git a/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc b/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc index 07d0502c6f..ac638a0229 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc +++ b/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc @@ -8,6 +8,7 @@ #include "python/tools/kernel_explorer/kernels/rocm/fast_gelu.h" #include "python/tools/kernel_explorer/kernels/rocm/gemm.h" #include "python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h" +#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h" namespace py = pybind11; @@ -22,6 +23,7 @@ PYBIND11_MODULE(_kernel_explorer, m) { InitFastGelu(m); InitGemm(m); InitSkipLayerNorm(m); + InitGemmFastGelu(m); #endif } diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py new file mode 100644 index 0000000000..6a21da480b --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py @@ -0,0 +1,155 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import re +import sys +from itertools import product + +import kernel_explorer as ke +import numpy as np +import pytest +from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix + + +def dtype_to_funcs(dtype): + type_map = { + "float16": list(filter(lambda x: re.search("GemmFastGelu.*_half", x), dir(ke))), + "float32": list(filter(lambda x: re.search("GemmFastGelu.*_float", x), dir(ke))), + } + return type_map[dtype] + + +def fast_gelu(x, bias): + x = x + bias + y = 0.5 * x * (1 + np.tanh(0.797885 * x + 0.035677 * x * x * x)) + return y + + +# TODO The test method needs update. +def _test_gemmfastgelu(func, dtype: str, m: int, n: int, k: int, transa=False, transb=False): + assert dtype in ["float16", "float32"] + + a_shape = (k, m) if transa else (m, k) + b_shape = (n, k) if transb else (k, n) + + np.random.seed(0) + a = (np.random.rand(*a_shape)).astype(dtype).astype("float64") + b = (np.random.rand(*b_shape)).astype(dtype).astype("float64") + bias = (np.random.rand(n)).astype(dtype) + temp_c = (a.T if transa else a) @ (b.T if transb else b) + + bound = get_gemm_bound(dtype, a, b, temp_c, transa, transb) + + temp_c = temp_c.astype(dtype) + ref_c = fast_gelu(temp_c, bias) + + a = a.astype(dtype) + b = b.astype(dtype) + + my_c = np.zeros((m, n), dtype=dtype) + dev_a = ke.DeviceArray(a) + dev_b = ke.DeviceArray(b) + dev_bias = ke.DeviceArray(bias) + dev_c = ke.DeviceArray(my_c) + + opa = ke.blas_op.T if transa else ke.blas_op.N + opb = ke.blas_op.T if transb else ke.blas_op.N + lda = a_shape[1] + ldb = b_shape[1] + alpha = 1.0 + beta = 0.0 + my_func = getattr(ke, func) + my_op = my_func(opa, opb, m, n, k, alpha, dev_a, lda, dev_b, ldb, dev_bias, beta, dev_c, n) + + if my_op.IsSupported(): + my_op.Run() + dev_c.UpdateHostNumpyArray() + + print( + f"{func:<50} : dtype={dtype} {transab_to_suffix((transa, transb))} m={m:<5} n={n:<5} k={k:<5} bound: {bound}" + ) + + np.testing.assert_allclose(my_c, ref_c, rtol=max(bound, 1e-2)) + + +dtypes = ["float16", "float32"] +all_transabs = list(product([True, False], repeat=2)) + + +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) +@pytest.mark.parametrize("transab", all_transabs) +def test_gemmfastgelu_bert_cases(dtype, size, transab): + for func in dtype_to_funcs(dtype): + _test_gemmfastgelu(func, dtype, *size, *transab) + + +def profile_gemmfastgelu_func(func, dtype: str, m: int, n: int, k: int, transa: bool, transb: bool): + a_shape = (k, m) if transa else (m, k) + b_shape = (n, k) if transb else (k, n) + + np.random.seed(0) + a = (np.random.rand(*a_shape) * 2 - 1).astype(dtype) + b = (np.random.rand(*b_shape) * 2 - 1).astype(dtype) + my_c = np.zeros((m, n), dtype=dtype) + bias = np.random.rand(n).astype(dtype) + + dev_a = ke.DeviceArray(a) + dev_b = ke.DeviceArray(b) + dev_bias = ke.DeviceArray(bias) + dev_c = ke.DeviceArray(my_c) + + opa = ke.blas_op.T if transa else ke.blas_op.N + opb = ke.blas_op.T if transb else ke.blas_op.N + lda = a_shape[1] + ldb = b_shape[1] + alpha = 1.0 + beta = 0.0 + my_func = getattr(ke, func) + my_op = my_func(opa, opb, m, n, k, alpha, dev_a, lda, dev_b, ldb, dev_bias, beta, dev_c, n) + + if my_op.IsSupported(): + my_op.Run() + dev_c.UpdateHostNumpyArray() + + time_ms = my_op.Profile() + time_us = time_ms * 1000 + # only counts gemm tflops because fastgelu is low order term (7 * n). + tflops = (m * k * n * 2) / (time_ms * 1e-3) / 1e12 + print( + f"{func:<50} {dtype} {transab_to_suffix((transa, transb))}", + f"m={m:<4} n={n:<4} k={k:<4} {time_us:>8.4f} us {tflops:>5.2f} tflops", + ) + + +def profile_with_args(transa, transb, dtype, m, n, k): + for func in dtype_to_funcs(dtype): + profile_gemmfastgelu_func(func, dtype, m, n, k, transa, transb) + + +def profile(): + for dtype in dtypes: + for m, n, k in get_gemm_bert_sizes(full=True): + profile_with_args(False, False, dtype, m, n, k) + print() + print() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + group = parser.add_argument_group("profile with args") + group.add_argument("transa", choices="NT") + group.add_argument("transb", choices="NT") + group.add_argument("dtype", choices=dtypes) + group.add_argument("m", type=int) + group.add_argument("n", type=int) + group.add_argument("k", type=int) + if len(sys.argv) == 1: + profile() + else: + args = parser.parse_args() + profile_with_args(args.transa == "T", args.transb == "T", args.dtype, args.m, args.n, args.k) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py index e03c54b32d..f5a1b84ade 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py @@ -9,6 +9,7 @@ from itertools import product import kernel_explorer as ke import numpy as np import pytest +from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix def dtype_to_suffix(dtype): @@ -18,15 +19,6 @@ def dtype_to_suffix(dtype): }[dtype] -def transab_to_suffix(transab): - return { - (True, True): "TT", - (True, False): "TN", - (False, True): "NT", - (False, False): "NN", - }[tuple(transab)] - - def _test_gemm(func, dtype: str, m: int, n: int, k: int, transa=False, transb=False): assert dtype in ["float32", "float16"] @@ -38,18 +30,7 @@ def _test_gemm(func, dtype: str, m: int, n: int, k: int, transa=False, transb=Fa b = (np.random.rand(*b_shape) + 0.5).astype(dtype).astype("float64") ref_c = (a.T if transa else a) @ (b.T if transb else b) - # The machine epsilon, unit roundoff, the smallest positive floating point number n such that the floating point - # number that represents 1 + n is greater than 1. - machine_eps = 2.0 ** -(24 if dtype == "float32" else 11) - - # The following implements error bound 5.7 in paper I. C. Ipsen and H. Zhou, “Probabilistic error analysis for - # Inner Products,” SIAM Journal on Matrix Analysis and Applications, vol. 41, no. 4, pp. 1726–1741, 2020. - # NOTE: the bound is not tight for float16 when k is large - absa_mul_absb = np.abs(a.T if transa else a) @ np.abs(b.T if transb else b) - coeff = np.max(absa_mul_absb / np.abs(ref_c)) - gamma_2k = (1.0 + machine_eps) ** (2 * k) - 1.0 - bound_5_7 = coeff * np.sqrt(np.log(2 / 1e-10) * machine_eps * gamma_2k / 2) - bound = bound_5_7 + bound = get_gemm_bound(dtype, a, b, ref_c, transa, transb) a = a.astype(dtype) b = b.astype(dtype) @@ -92,47 +73,17 @@ def _test_gemm(func, dtype: str, m: int, n: int, k: int, transa=False, transb=Fa dtypes = ["float32", "float16"] all_transabs = list(product([True, False], repeat=2)) -all_basic_sizes = list(product([1, 3, 4, 16, 127, 128, 129, 133, 1024], repeat=3)) - - -def get_bert_sizes(full=True): - bert_base_sizes = [ - # m, n, k - (384, 768, 768), - (384, 768, 768 * 3), - (384, 768, 768 * 4), - (384, 768 * 4, 768), - (384, 1024, 1024), - (384, 1024, 1024 * 3), - (384, 1024, 1024 * 4), - (384, 1024 * 4, 1024), - ] - - # we then multiply m with the batch size - if full: - batch_sizes = [1, 64] - else: - batch_sizes = [1] - bert_sizes = [] - for bsz in batch_sizes: - bert_sizes.extend([(m * bsz, n, k) for m, n, k in bert_base_sizes]) - return bert_sizes @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("size", all_basic_sizes + get_bert_sizes(full=False)) +@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=True) + get_gemm_bert_sizes(full=False)) @pytest.mark.parametrize("transab", all_transabs) def test_rocblas_gemm_all_cases(dtype, size, transab): _test_gemm(getattr(ke, "RocblasGemm_" + dtype_to_suffix(dtype)), dtype, *size, *transab) -# ck has various impls to be tested, use the full basic cases will result too many cases to test. -# So we use a reduced combination here. -reduced_basic_sizes = list(product([1, 4, 127, 133], [3, 16, 128], [3, 129, 1024])) - - @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("size", reduced_basic_sizes + get_bert_sizes(full=False)) +@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) @pytest.mark.parametrize("transab", all_transabs) def test_ck_gemm_bert_cases(dtype, size, transab): wrapper_name = "CKGemm_{}_{}".format(dtype_to_suffix(dtype), transab_to_suffix(transab)) @@ -141,7 +92,7 @@ def test_ck_gemm_bert_cases(dtype, size, transab): # Tunable is basically wrapped around of rocblas and ck gemm, so no need for full tests @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("size", reduced_basic_sizes + get_bert_sizes(full=False)) +@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) @pytest.mark.parametrize("transab", all_transabs) def test_gemm_tunable_bert_cases(dtype, size, transab): wrapper_name = "GemmTunable_{}_{}".format(dtype_to_suffix(dtype), transab_to_suffix(transab)) @@ -192,7 +143,7 @@ def profile_with_args(transa, transb, dtype, m, n, k): def profile(): for dtype in dtypes: - for m, n, k in get_bert_sizes(full=True): + for m, n, k in get_gemm_bert_sizes(full=True): profile_with_args(False, False, dtype, m, n, k) print() print() diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cu new file mode 100644 index 0000000000..37b7230257 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cu @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h" + +#include +#include + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h" +#include "python/tools/kernel_explorer/device_array.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" + +using onnxruntime::rocm::tunable::blas::BlasOp; + +namespace py = pybind11; + +namespace onnxruntime { + +template +class GemmFastGeluUnfused : public IKernelExplorer { + public: + GemmFastGeluUnfused(BlasOp opa, BlasOp opb, + int64_t m, int64_t n, int64_t k, + double alpha, + DeviceArray& a, int64_t lda, + DeviceArray& b, int64_t ldb, + DeviceArray& bias, + double beta, + DeviceArray& c, int64_t ldc) : params_{} { + ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); + params_.tuning = true; + params_.stream = Stream(); + params_.handle = rocblas_handle_; + params_.opa = opa; + params_.opb = opb; + params_.m = m; + params_.n = n; + params_.k = k; + params_.alpha = alpha; + params_.a = static_cast(a.ptr()); + params_.lda = lda; + params_.b = static_cast(b.ptr()); + params_.ldb = ldb; + params_.bias = static_cast(bias.ptr()); + params_.beta = beta; + params_.c = static_cast(c.ptr()); + params_.ldc = ldc; + } + + ~GemmFastGeluUnfused() { + ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); + rocblas_handle_ = nullptr; + } + + void Run() override { + ORT_THROW_IF_ERROR((contrib::rocm::GemmFastGeluUnfused(¶ms_))); + } + + bool IsSupported() { + Status status = contrib::rocm::GemmFastGeluUnfused(¶ms_); + return status.IsOK(); + } + + private: + using ParamsT = contrib::rocm::GemmFastGeluParams; + ParamsT params_{}; + rocblas_handle rocblas_handle_; +}; + +template +class GemmFastGeluTunableOp : public IKernelExplorer { + public: + GemmFastGeluTunableOp(BlasOp opa, BlasOp opb, + int64_t m, int64_t n, int64_t k, + double alpha, + DeviceArray& a, int64_t lda, + DeviceArray& b, int64_t ldb, + DeviceArray& bias, + double beta, + DeviceArray& c, int64_t ldc) : params_{} { + ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); + params_.tuning = true; + params_.stream = Stream(); + params_.handle = rocblas_handle_; + params_.opa = opa; + params_.opb = opb; + params_.m = m; + params_.n = n; + params_.k = k; + params_.alpha = alpha; + params_.a = static_cast(a.ptr()); + params_.lda = lda; + params_.b = static_cast(b.ptr()); + params_.ldb = ldb; + params_.bias = static_cast(bias.ptr()); + params_.beta = beta; + params_.c = static_cast(c.ptr()); + params_.ldc = ldc; + + op_.EnableTuning(); + } + + ~GemmFastGeluTunableOp() { + ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); + rocblas_handle_ = nullptr; + } + + void Run() override { + ORT_THROW_IF_ERROR((op_(¶ms_))); + } + + bool IsSupported() { + Status status = op_(¶ms_); + return status.IsOK(); + } + + private: + using ParamsT = contrib::rocm::GemmFastGeluParams; + ParamsT params_{}; + rocblas_handle rocblas_handle_; + contrib::rocm::GemmFastGeluTunableOp op_{}; +}; + +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Run", &name::Run) \ + .def("Profile", &name::Profile) \ + .def("IsSupported", &name::IsSupported); + +void InitGemmFastGelu(py::module m) { + REGISTER_OP(GemmFastGeluUnfused, float) + REGISTER_OP(GemmFastGeluUnfused, half) + + REGISTER_OP(GemmFastGeluTunableOp, float) + REGISTER_OP(GemmFastGeluTunableOp, half) +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h new file mode 100644 index 0000000000..4a0cebe8cd --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace py = pybind11; + +namespace onnxruntime { + +void InitGemmFastGelu(py::module mod); + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/utils.py b/onnxruntime/python/tools/kernel_explorer/kernels/utils.py new file mode 100644 index 0000000000..aabe747dee --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/utils.py @@ -0,0 +1,68 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from itertools import product + +import numpy as np + + +def transab_to_suffix(transab): + return { + (True, True): "TT", + (True, False): "TN", + (False, True): "NT", + (False, False): "NN", + }[tuple(transab)] + + +def get_gemm_bound(dtype: str, a: np.ndarray, b: np.ndarray, c: np.ndarray, transa: bool, transb: bool): + k = b.shape[1] if transb else b.shape[0] + # The machine epsilon, unit roundoff, the smallest positive floating point number n such that the floating point + # number that represents 1 + n is greater than 1. + machine_eps = 2.0 ** -(24 if dtype == "float32" else 11) + + # The following implements error bound 5.7 in paper I. C. Ipsen and H. Zhou, “Probabilistic error analysis for + # Inner Products,” SIAM Journal on Matrix Analysis and Applications, vol. 41, no. 4, pp. 1726–1741, 2020. + # NOTE: the bound is not tight for float16 when k is large + absa_mul_absb = np.abs(a.T if transa else a) @ np.abs(b.T if transb else b) + coeff = np.max(absa_mul_absb / np.abs(c)) + gamma_2k = (1.0 + machine_eps) ** (2 * k) - 1.0 + bound_5_7 = coeff * np.sqrt(np.log(2 / 1e-10) * machine_eps * gamma_2k / 2) + bound = bound_5_7 + + return bound + + +def get_gemm_bert_sizes(full=True): + bert_base_sizes = [ + # m, n, k + (384, 768, 768), + (384, 768, 768 * 3), + (384, 768, 768 * 4), + (384, 768 * 4, 768), + (384, 1024, 1024), + (384, 1024, 1024 * 3), + (384, 1024, 1024 * 4), + (384, 1024 * 4, 1024), + ] + + # we then multiply m with the batch size + if full: + batch_sizes = [1, 64] + else: + batch_sizes = [1] + bert_sizes = [] + for bsz in batch_sizes: + bert_sizes.extend([(m * bsz, n, k) for m, n, k in bert_base_sizes]) + return bert_sizes + + +def get_gemm_basic_sizes(full=True): + if full: + return list(product([1, 3, 4, 16, 127, 128, 129, 133, 1024], repeat=3)) + + # ck has various impls to be tested, use the full basic cases will result too many cases to test. + # So we use a reduced combination here. + return list(product([1, 4, 127, 133], [3, 16, 128], [3, 129, 1024])) diff --git a/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py b/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py index e946cc1f52..f1d803a3cc 100644 --- a/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py @@ -4,10 +4,11 @@ # -------------------------------------------------------------------------- from logging import getLogger +from typing import Dict, List, Union from fusion_base import Fusion from fusion_utils import NumpyHelper -from onnx import helper +from onnx import NodeProto, TensorProto, helper from onnx_model import OnnxModel logger = getLogger(__name__) @@ -16,8 +17,35 @@ logger = getLogger(__name__) class FusionGemmFastGelu(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "GemmFastGelu", "FastGelu", "GemmFastGelu") + self.shape_infer = None + self.shape_infer_done = False - def fuse(self, node, input_name_to_nodes, output_name_to_node): + def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[int, None]: + if tensor_proto.type.tensor_type.HasField("shape"): + return len(tensor_proto.type.tensor_type.shape.dim) + else: + return None + + def get_dimensions(self, input_name: str) -> Union[int, None]: + graph_input = self.model.find_graph_input(input_name) + if graph_input: + return self.get_dimensions_from_tensor_proto(graph_input) + + if not self.shape_infer_done: + self.shape_infer = self.model.infer_runtime_shape({}, update=True) + self.shape_infer_done = True + + if self.shape_infer is not None: + return self.get_dimensions_from_tensor_proto(self.shape_infer.known_vi_[input_name]) + + return None + + def fuse( + self, + node: NodeProto, + input_name_to_nodes: Dict[str, List[NodeProto]], + output_name_to_node: Dict[str, NodeProto], + ): """ This pattern is from PyTorch bert model Fuse MatMul with FastGelu into one node: @@ -34,20 +62,24 @@ class FusionGemmFastGelu(Fusion): return matmul = match_nodes[0] - weight = None - # matmul weight should be two dimension + # matmul input X should >= two dimension, input weight should be two dimension weight_index = -1 + x_dims = 0 + weight = None + for i, input in enumerate(matmul.input): initializer = self.model.get_initializer(input) if initializer is None: - continue - weight_index = i - weight = NumpyHelper.to_array(initializer) - break + x_dims = self.get_dimensions(matmul.input[i]) + else: + weight_index = i + weight = NumpyHelper.to_array(initializer) if weight is None: return if len(weight.shape) != 2: return + if x_dims < len(weight.shape): + return # bias weight should be one dimension bias_index = -1 diff --git a/onnxruntime/test/python/transformers/test_data/models/gemmfastgelu_nobias_opt.onnx b/onnxruntime/test/python/transformers/test_data/models/gemmfastgelu_nobias_opt.onnx new file mode 100644 index 0000000000..c493f8665d Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/gemmfastgelu_nobias_opt.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/gemmfastgelu_withbias_opt.onnx b/onnxruntime/test/python/transformers/test_data/models/gemmfastgelu_withbias_opt.onnx new file mode 100644 index 0000000000..344cc747bc Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/gemmfastgelu_withbias_opt.onnx differ diff --git a/onnxruntime/test/python/transformers/test_gemmfastgelu_fusion.py b/onnxruntime/test/python/transformers/test_gemmfastgelu_fusion.py new file mode 100644 index 0000000000..0c5d12c905 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_gemmfastgelu_fusion.py @@ -0,0 +1,153 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import unittest +from typing import List + +import numpy as np +import onnx +from onnx import TensorProto, helper +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from fusion_options import FusionOptions + from onnx_model import OnnxModel + from optimizer import optimize_model +else: + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_model + + +onnxdomain = onnx.OperatorSetIdProto() +onnxdomain.version = 12 +# The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" +msdomain = onnx.OperatorSetIdProto() +msdomain.version = 1 +msdomain.domain = "com.microsoft" +opsets = [onnxdomain, msdomain] + + +def float_tensor(name: str, shape: List[int], random=False): + low = 0.0 + high = 1.0 + total_elements = 1 + for x in shape: + total_elements *= x + weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements + return helper.make_tensor(name, TensorProto.FLOAT, shape, weights) + + +def create_MatMul_FastGelu_withoutBias(batch_size, m, n, k): + # MatMul + FastGelu + nodes = [ + helper.make_node("MatMul", ["input", "matmul_weight"], ["fastgelu_input"], "matmul"), + ] + fastgelu_node = helper.make_node("FastGelu", ["fastgelu_input"], ["output"], "fastgelu") + fastgelu_node.domain = "com.microsoft" + nodes.append(fastgelu_node) + + initializers = [float_tensor("matmul_weight", [k, n])] # initializers + + graph = helper.make_graph( + [node for node in nodes if node], + "GemmFastGeluNoBiasModel", # name + [ # inputs + helper.make_tensor_value_info( + "input", + TensorProto.FLOAT, + [batch_size, m, k], + ) + ], + [ # outputs + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT, + [batch_size, m, n], + ), + ], + initializers, + ) + + return helper.make_model(graph) + + +def create_MatMul_FastGelu_withBias(batch_size, m, n, k): + # MatMul + FastGelu + nodes = [ + helper.make_node("MatMul", ["input", "matmul_weight"], ["fastgelu_input"], "matmul"), + ] + fastgelu_node = helper.make_node("FastGelu", ["fastgelu_input", "fastgelu_bias"], ["output"], "fastgelu") + fastgelu_node.domain = "com.microsoft" + nodes.append(fastgelu_node) + + initializers = [float_tensor("matmul_weight", [k, n]), float_tensor("fastgelu_bias", [n])] # initializers + + graph = helper.make_graph( + [node for node in nodes if node], + "GemmFastGeluWithBiasModel", # name + [ # inputs + helper.make_tensor_value_info( + "input", + TensorProto.FLOAT, + [batch_size, m, k], + ) + ], + [ # outputs + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT, + [batch_size, m, n], + ), + ], + initializers, + ) + + return helper.make_model(graph, opset_imports=opsets) + + +class TestFusion(unittest.TestCase): + def verify_fusion(self, optimized_model, expected_model_filename): + optimized_model.topological_sort() + + expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model_filename) + expected_model = OnnxModel(onnx.load(expected_model_path)) + expected_model.topological_sort() + + self.assertEqual(str(optimized_model.model.graph), str(expected_model.model.graph)) + + def test_gemmfastgelu_fusion_withoutbias(self): + model = create_MatMul_FastGelu_withoutBias(32, 128, 64, 1024) + dir = "." + model_path = os.path.join(dir, "gemmfastgelu_nobias.onnx") + onnx.save(model, model_path) + + fusion_opt = FusionOptions("bert") + fusion_opt.enable_gemm_fast_gelu = True + optimized_model = optimize_model(input=model_path, optimization_options=fusion_opt) + os.remove(model_path) + + self.verify_fusion(optimized_model, "gemmfastgelu_nobias_opt.onnx") + + def test_gemmfastgelu_fusion_withbias(self): + model = create_MatMul_FastGelu_withBias(32, 128, 64, 1024) + dir = "." + model_path = os.path.join(dir, "gemmfastgelu_withbias.onnx") + onnx.save(model, model_path) + + fusion_opt = FusionOptions("bert") + fusion_opt.enable_gemm_fast_gelu = True + optimized_model = optimize_model(input=model_path, optimization_options=fusion_opt) + + os.remove(model_path) + + self.verify_fusion(optimized_model, "gemmfastgelu_withbias_opt.onnx") + + +if __name__ == "__main__": + unittest.main()