[ROCm] Add GemmFastGelu TunableOp (#13589)

### Description
<!-- Describe your changes. -->

1. Update the rules for GemmFastGelu fusion, MatMul input x should >=
two dimension, input weight should == two dimension.
2. Add GemmFastGelu fusion test.
3. Add GemmFastGelu TunableOp, only contains the original
implementation(Gemm + FastGelu).


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: peixuanzuo <peixuanzuo@linmif39a000004.zvflicr54joexhdgnhvmxrxygg.phxx.internal.cloudapp.net>
This commit is contained in:
PeixuanZuo 2022-11-22 12:58:01 +08:00 committed by GitHub
parent 45a895cdc3
commit 8f3c6ea0df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 808 additions and 150 deletions

View file

@ -3,18 +3,14 @@
#include "contrib_ops/rocm/bert/gemm_fast_gelu.h"
#include "contrib_ops/rocm/bert/fast_gelu_impl.h"
#include "contrib_ops/rocm/bert/transformer_common.h"
#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/rocm/math/matmul_impl.h"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/shared_inc/fpgeneric.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using onnxruntime::rocm::MatMulImpl;
using onnxruntime::rocm::ToHipType;
#define REGISTER_KERNEL_TYPED(T) \
@ -48,34 +44,27 @@ Status GemmFastGelu<T>::ComputeInternal(OpKernelContext* ctx) const {
MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(X->Shape(), W->Shape(), transa, transb, trans_batch_a, trans_batch_b, false));
auto gemm_buffer = GetScratchBuffer<T>(helper.OutputShape().Size());
Tensor* Y = ctx->Output(0, helper.OutputShape());
// Bail out early if the output is going to be empty
if (Y->Shape().Size() == 0)
return Status::OK();
const float alpha = 1.0f;
const float zero = 0.0f;
// gemmfastgelu only support alpha == 1 and beta == 0
const HipT alpha = ToHipType<T>::FromFloat(1.0f);
const HipT beta = ToHipType<T>::FromFloat(0.0f);
if (MatMulImpl<T>(this, helper, reinterpret_cast<const T*>(X->Data<T>()),
reinterpret_cast<const T*>(W->Data<T>()),
reinterpret_cast<T*>(gemm_buffer.get()),
X->Shape(), W->Shape(),
transa, transb, trans_batch_a, trans_batch_b, alpha, zero) != Status::OK()) {
return Status(common::ONNXRUNTIME, common::FAIL);
}
int64_t fast_gelu_input_length = Y->Shape().Size();
int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
return LaunchFastGeluKernel<HipT>(Stream(),
static_cast<int>(fast_gelu_input_length),
static_cast<int>(bias_length),
reinterpret_cast<HipT*>(gemm_buffer.get()),
(nullptr != bias) ? reinterpret_cast<const HipT*>(bias->Data<T>()) : nullptr,
reinterpret_cast<HipT*>(Y->MutableData<T>()),
false);
return LaunchGemmFastGeluKernel<HipT>(
IsTunableOpEnabled(),
Stream(), RocblasHandle(),
transa, transb,
static_cast<int64_t>(helper.M()), static_cast<int64_t>(helper.N()), static_cast<int64_t>(helper.K()),
alpha,
reinterpret_cast<const HipT*>(X->Data<T>()), static_cast<int64_t>(helper.Lda(transa)),
reinterpret_cast<const HipT*>(W->Data<T>()), static_cast<int64_t>(helper.Ldb(transb)),
(nullptr != bias) ? reinterpret_cast<const HipT*>(bias->Data<T>()) : nullptr,
beta,
reinterpret_cast<HipT*>(Y->MutableData<T>()), static_cast<int64_t>(helper.Ldc()));
}
} // namespace rocm

View file

@ -0,0 +1,79 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h"
#include <hip/hip_fp16.h>
#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h"
#include "core/providers/rocm/tunable/gemm_common.h"
using onnxruntime::rocm::tunable::blas::BlasOp;
namespace onnxruntime {
namespace contrib {
namespace rocm {
// See it as row-major
template <typename T>
Status LaunchGemmFastGeluKernel(bool tuning,
hipStream_t stream,
rocblas_handle handle,
bool transa,
bool transb,
int64_t m,
int64_t n,
int64_t k,
const T alpha,
const T* a,
int64_t lda,
const T* b,
int64_t ldb,
const T* bias,
const T beta,
T* c,
int64_t ldc) {
GemmFastGeluParams<T> params;
params.tuning = tuning;
params.stream = stream;
params.handle = handle;
params.opa = transa ? BlasOp::Trans : BlasOp::NonTrans;
params.opb = transb ? BlasOp::Trans : BlasOp::NonTrans;
params.m = m;
params.n = n;
params.k = k;
params.alpha = alpha;
params.a = a;
params.lda = lda;
params.b = b;
params.ldb = ldb;
params.bias = bias;
params.beta = beta;
params.c = c;
params.ldc = ldc;
if (tuning) {
static GemmFastGeluTunableOp<T> op;
op.EnableTuning();
return op(&params);
}
return GemmFastGeluUnfused(&params);
}
#define SPECIALIZED_IMPL(T) \
template Status LaunchGemmFastGeluKernel<T>(bool tuning, \
hipStream_t stream, rocblas_handle handle, \
bool transa, bool transb, \
int64_t m, int64_t n, int64_t k, const T alpha, \
const T* a, int64_t lda, const T* b, int64_t ldb, \
const T* bias, const T beta, T* c, int64_t ldc);
SPECIALIZED_IMPL(float)
SPECIALIZED_IMPL(half)
SPECIALIZED_IMPL(BFloat16)
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <hip/hip_runtime.h>
#include "core/common/common.h"
#include "core/providers/rocm/rocm_common.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
Status LaunchGemmFastGeluKernel(bool tuning,
hipStream_t stream,
rocblas_handle handle,
bool transa,
bool transb,
int64_t m,
int64_t n,
int64_t k,
const T alpha,
const T* a,
int64_t lda,
const T* b,
int64_t ldb,
const T* bias,
const T beta,
T* c,
int64_t ldc);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,81 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <hip/hip_runtime.h>
#include <memory>
#include <string>
#include "contrib_ops/rocm/bert/fast_gelu_impl.h"
#include "core/providers/rocm/tunable/gemm.h"
#include "core/providers/rocm/tunable/gemm_common.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"
using onnxruntime::rocm::tunable::blas::BlasOp;
using onnxruntime::rocm::tunable::blas::BlasOpToString;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
struct GemmFastGeluParams : onnxruntime::rocm::tunable::OpParams {
std::string Signature() const override {
return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k);
}
rocblas_handle handle;
BlasOp opa;
BlasOp opb;
int64_t m;
int64_t n;
int64_t k;
T alpha;
const T* a;
int64_t lda;
const T* b;
int64_t ldb;
const T* bias;
T beta;
T* c;
int64_t ldc;
bool tuning{false};
};
template <typename T>
Status GemmFastGeluUnfused(const GemmFastGeluParams<T>* params) {
namespace column_major = onnxruntime::rocm::tunable::blas::column_major;
if (column_major::Gemm(params->tuning, params->stream, params->handle,
params->opb, params->opa,
params->n, params->m, params->k,
params->alpha, params->b, params->ldb, params->a, params->lda,
params->beta, params->c, params->ldc) != Status::OK()) {
return Status(common::ONNXRUNTIME, common::FAIL, "GemmFastGelu call column_major::Gemm failed");
}
int64_t fast_gelu_input_length = params->m * params->n;
int64_t bias_length = (params->bias != nullptr) ? params->n : 0;
// inplace computation
return LaunchFastGeluKernel<T>(params->stream,
static_cast<int>(fast_gelu_input_length),
static_cast<int>(bias_length),
params->c,
params->bias,
params->c,
params->tuning);
}
template <typename T>
class GemmFastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp<GemmFastGeluParams<T>> {
public:
GemmFastGeluTunableOp() {
this->ops_.emplace_back(GemmFastGeluUnfused<T>);
this->SetDefaultId(0);
}
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -74,7 +74,7 @@ Status MatMul<T>::ComputeInternal(OpKernelContext* ctx) const {
reinterpret_cast<T*>(Y->MutableData<T>()),
left_X->Shape(), right_X->Shape(),
transa, transb, trans_batch_a_, trans_batch_b_, alpha_, 0.0f) != Status::OK()) {
return Status(common::ONNXRUNTIME, common::FAIL);
return Status(common::ONNXRUNTIME, common::FAIL, "MatMulImpl failed");
}
return Status::OK();
}

View file

@ -1,10 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// Modifications: Remove cudaDeviceProp in LaunchFastGeluKernel.
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/math/matmul_impl.h"
#include "core/providers/rocm/rocm_allocator.h"
@ -137,58 +133,18 @@ Status MatMulImpl(const RocmKernel* op, MatMulComputeHelper& helper,
return Status::OK();
}
template Status MatMulImpl<float>(const RocmKernel* op,
MatMulComputeHelper& helper,
const float* left_x_data,
const float* right_x_data,
float* output_y_data,
const TensorShape& left_shape,
const TensorShape& right_shape,
bool transa,
bool transb,
bool trans_batch_a,
bool trans_batch_b,
const float t_alpha,
const float t_zero);
template Status MatMulImpl<double>(const RocmKernel* op,
MatMulComputeHelper& helper,
const double* left_x_data,
const double* right_x_data,
double* output_y_data,
const TensorShape& left_shape,
const TensorShape& right_shape,
bool transa,
bool transb,
bool trans_batch_a,
bool trans_batch_b,
const float t_alpha,
const float t_zero);
template Status MatMulImpl<MLFloat16>(const RocmKernel* op,
MatMulComputeHelper& helper,
const MLFloat16* left_x_data,
const MLFloat16* right_x_data,
MLFloat16* output_y_data,
const TensorShape& left_shape,
const TensorShape& right_shape,
bool transa,
bool transb,
bool trans_batch_a,
bool trans_batch_b,
const float t_alpha,
const float t_zero);
template Status MatMulImpl<BFloat16>(const RocmKernel* op,
MatMulComputeHelper& helper,
const BFloat16* left_x_data,
const BFloat16* right_x_data,
BFloat16* output_y_data,
const TensorShape& left_shape,
const TensorShape& right_shape,
bool transa,
bool transb,
bool trans_batch_a,
bool trans_batch_b,
const float t_alpha,
const float t_zero);
#define SPECIALIZED_IMPL(T) \
template Status MatMulImpl<T>(const RocmKernel* op, MatMulComputeHelper& helper, \
const T* left_x_data, const T* right_x_data, T* output_y_data, \
const TensorShape& left_shape, const TensorShape& right_shape, \
bool transa, bool transb, \
bool trans_batch_a, bool trans_batch_b, \
const float t_alpha, const float t_zero);
SPECIALIZED_IMPL(float)
SPECIALIZED_IMPL(double)
SPECIALIZED_IMPL(MLFloat16)
SPECIALIZED_IMPL(BFloat16)
} // namespace rocm
} // namespace onnxruntime

View file

@ -1,10 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// Modifications: Remove cudaDeviceProp in LaunchFastGeluKernel.
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/fpgeneric.h"

View file

@ -8,6 +8,7 @@
#include "python/tools/kernel_explorer/kernels/rocm/fast_gelu.h"
#include "python/tools/kernel_explorer/kernels/rocm/gemm.h"
#include "python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h"
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h"
namespace py = pybind11;
@ -22,6 +23,7 @@ PYBIND11_MODULE(_kernel_explorer, m) {
InitFastGelu(m);
InitGemm(m);
InitSkipLayerNorm(m);
InitGemmFastGelu(m);
#endif
}

View file

@ -0,0 +1,155 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import re
import sys
from itertools import product
import kernel_explorer as ke
import numpy as np
import pytest
from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix
def dtype_to_funcs(dtype):
type_map = {
"float16": list(filter(lambda x: re.search("GemmFastGelu.*_half", x), dir(ke))),
"float32": list(filter(lambda x: re.search("GemmFastGelu.*_float", x), dir(ke))),
}
return type_map[dtype]
def fast_gelu(x, bias):
x = x + bias
y = 0.5 * x * (1 + np.tanh(0.797885 * x + 0.035677 * x * x * x))
return y
# TODO The test method needs update.
def _test_gemmfastgelu(func, dtype: str, m: int, n: int, k: int, transa=False, transb=False):
assert dtype in ["float16", "float32"]
a_shape = (k, m) if transa else (m, k)
b_shape = (n, k) if transb else (k, n)
np.random.seed(0)
a = (np.random.rand(*a_shape)).astype(dtype).astype("float64")
b = (np.random.rand(*b_shape)).astype(dtype).astype("float64")
bias = (np.random.rand(n)).astype(dtype)
temp_c = (a.T if transa else a) @ (b.T if transb else b)
bound = get_gemm_bound(dtype, a, b, temp_c, transa, transb)
temp_c = temp_c.astype(dtype)
ref_c = fast_gelu(temp_c, bias)
a = a.astype(dtype)
b = b.astype(dtype)
my_c = np.zeros((m, n), dtype=dtype)
dev_a = ke.DeviceArray(a)
dev_b = ke.DeviceArray(b)
dev_bias = ke.DeviceArray(bias)
dev_c = ke.DeviceArray(my_c)
opa = ke.blas_op.T if transa else ke.blas_op.N
opb = ke.blas_op.T if transb else ke.blas_op.N
lda = a_shape[1]
ldb = b_shape[1]
alpha = 1.0
beta = 0.0
my_func = getattr(ke, func)
my_op = my_func(opa, opb, m, n, k, alpha, dev_a, lda, dev_b, ldb, dev_bias, beta, dev_c, n)
if my_op.IsSupported():
my_op.Run()
dev_c.UpdateHostNumpyArray()
print(
f"{func:<50} : dtype={dtype} {transab_to_suffix((transa, transb))} m={m:<5} n={n:<5} k={k:<5} bound: {bound}"
)
np.testing.assert_allclose(my_c, ref_c, rtol=max(bound, 1e-2))
dtypes = ["float16", "float32"]
all_transabs = list(product([True, False], repeat=2))
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False))
@pytest.mark.parametrize("transab", all_transabs)
def test_gemmfastgelu_bert_cases(dtype, size, transab):
for func in dtype_to_funcs(dtype):
_test_gemmfastgelu(func, dtype, *size, *transab)
def profile_gemmfastgelu_func(func, dtype: str, m: int, n: int, k: int, transa: bool, transb: bool):
a_shape = (k, m) if transa else (m, k)
b_shape = (n, k) if transb else (k, n)
np.random.seed(0)
a = (np.random.rand(*a_shape) * 2 - 1).astype(dtype)
b = (np.random.rand(*b_shape) * 2 - 1).astype(dtype)
my_c = np.zeros((m, n), dtype=dtype)
bias = np.random.rand(n).astype(dtype)
dev_a = ke.DeviceArray(a)
dev_b = ke.DeviceArray(b)
dev_bias = ke.DeviceArray(bias)
dev_c = ke.DeviceArray(my_c)
opa = ke.blas_op.T if transa else ke.blas_op.N
opb = ke.blas_op.T if transb else ke.blas_op.N
lda = a_shape[1]
ldb = b_shape[1]
alpha = 1.0
beta = 0.0
my_func = getattr(ke, func)
my_op = my_func(opa, opb, m, n, k, alpha, dev_a, lda, dev_b, ldb, dev_bias, beta, dev_c, n)
if my_op.IsSupported():
my_op.Run()
dev_c.UpdateHostNumpyArray()
time_ms = my_op.Profile()
time_us = time_ms * 1000
# only counts gemm tflops because fastgelu is low order term (7 * n).
tflops = (m * k * n * 2) / (time_ms * 1e-3) / 1e12
print(
f"{func:<50} {dtype} {transab_to_suffix((transa, transb))}",
f"m={m:<4} n={n:<4} k={k:<4} {time_us:>8.4f} us {tflops:>5.2f} tflops",
)
def profile_with_args(transa, transb, dtype, m, n, k):
for func in dtype_to_funcs(dtype):
profile_gemmfastgelu_func(func, dtype, m, n, k, transa, transb)
def profile():
for dtype in dtypes:
for m, n, k in get_gemm_bert_sizes(full=True):
profile_with_args(False, False, dtype, m, n, k)
print()
print()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
group = parser.add_argument_group("profile with args")
group.add_argument("transa", choices="NT")
group.add_argument("transb", choices="NT")
group.add_argument("dtype", choices=dtypes)
group.add_argument("m", type=int)
group.add_argument("n", type=int)
group.add_argument("k", type=int)
if len(sys.argv) == 1:
profile()
else:
args = parser.parse_args()
profile_with_args(args.transa == "T", args.transb == "T", args.dtype, args.m, args.n, args.k)

View file

@ -9,6 +9,7 @@ from itertools import product
import kernel_explorer as ke
import numpy as np
import pytest
from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix
def dtype_to_suffix(dtype):
@ -18,15 +19,6 @@ def dtype_to_suffix(dtype):
}[dtype]
def transab_to_suffix(transab):
return {
(True, True): "TT",
(True, False): "TN",
(False, True): "NT",
(False, False): "NN",
}[tuple(transab)]
def _test_gemm(func, dtype: str, m: int, n: int, k: int, transa=False, transb=False):
assert dtype in ["float32", "float16"]
@ -38,18 +30,7 @@ def _test_gemm(func, dtype: str, m: int, n: int, k: int, transa=False, transb=Fa
b = (np.random.rand(*b_shape) + 0.5).astype(dtype).astype("float64")
ref_c = (a.T if transa else a) @ (b.T if transb else b)
# The machine epsilon, unit roundoff, the smallest positive floating point number n such that the floating point
# number that represents 1 + n is greater than 1.
machine_eps = 2.0 ** -(24 if dtype == "float32" else 11)
# The following implements error bound 5.7 in paper I. C. Ipsen and H. Zhou, “Probabilistic error analysis for
# Inner Products,” SIAM Journal on Matrix Analysis and Applications, vol. 41, no. 4, pp. 17261741, 2020.
# NOTE: the bound is not tight for float16 when k is large
absa_mul_absb = np.abs(a.T if transa else a) @ np.abs(b.T if transb else b)
coeff = np.max(absa_mul_absb / np.abs(ref_c))
gamma_2k = (1.0 + machine_eps) ** (2 * k) - 1.0
bound_5_7 = coeff * np.sqrt(np.log(2 / 1e-10) * machine_eps * gamma_2k / 2)
bound = bound_5_7
bound = get_gemm_bound(dtype, a, b, ref_c, transa, transb)
a = a.astype(dtype)
b = b.astype(dtype)
@ -92,47 +73,17 @@ def _test_gemm(func, dtype: str, m: int, n: int, k: int, transa=False, transb=Fa
dtypes = ["float32", "float16"]
all_transabs = list(product([True, False], repeat=2))
all_basic_sizes = list(product([1, 3, 4, 16, 127, 128, 129, 133, 1024], repeat=3))
def get_bert_sizes(full=True):
bert_base_sizes = [
# m, n, k
(384, 768, 768),
(384, 768, 768 * 3),
(384, 768, 768 * 4),
(384, 768 * 4, 768),
(384, 1024, 1024),
(384, 1024, 1024 * 3),
(384, 1024, 1024 * 4),
(384, 1024 * 4, 1024),
]
# we then multiply m with the batch size
if full:
batch_sizes = [1, 64]
else:
batch_sizes = [1]
bert_sizes = []
for bsz in batch_sizes:
bert_sizes.extend([(m * bsz, n, k) for m, n, k in bert_base_sizes])
return bert_sizes
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("size", all_basic_sizes + get_bert_sizes(full=False))
@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=True) + get_gemm_bert_sizes(full=False))
@pytest.mark.parametrize("transab", all_transabs)
def test_rocblas_gemm_all_cases(dtype, size, transab):
_test_gemm(getattr(ke, "RocblasGemm_" + dtype_to_suffix(dtype)), dtype, *size, *transab)
# ck has various impls to be tested, use the full basic cases will result too many cases to test.
# So we use a reduced combination here.
reduced_basic_sizes = list(product([1, 4, 127, 133], [3, 16, 128], [3, 129, 1024]))
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("size", reduced_basic_sizes + get_bert_sizes(full=False))
@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False))
@pytest.mark.parametrize("transab", all_transabs)
def test_ck_gemm_bert_cases(dtype, size, transab):
wrapper_name = "CKGemm_{}_{}".format(dtype_to_suffix(dtype), transab_to_suffix(transab))
@ -141,7 +92,7 @@ def test_ck_gemm_bert_cases(dtype, size, transab):
# Tunable is basically wrapped around of rocblas and ck gemm, so no need for full tests
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("size", reduced_basic_sizes + get_bert_sizes(full=False))
@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False))
@pytest.mark.parametrize("transab", all_transabs)
def test_gemm_tunable_bert_cases(dtype, size, transab):
wrapper_name = "GemmTunable_{}_{}".format(dtype_to_suffix(dtype), transab_to_suffix(transab))
@ -192,7 +143,7 @@ def profile_with_args(transa, transb, dtype, m, n, k):
def profile():
for dtype in dtypes:
for m, n, k in get_bert_sizes(full=True):
for m, n, k in get_gemm_bert_sizes(full=True):
profile_with_args(False, False, dtype, m, n, k)
print()
print()

View file

@ -0,0 +1,146 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h"
#include <hip/hip_fp16.h>
#include <pybind11/pybind11.h>
#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h"
#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
using onnxruntime::rocm::tunable::blas::BlasOp;
namespace py = pybind11;
namespace onnxruntime {
template <typename T>
class GemmFastGeluUnfused : public IKernelExplorer {
public:
GemmFastGeluUnfused(BlasOp opa, BlasOp opb,
int64_t m, int64_t n, int64_t k,
double alpha,
DeviceArray& a, int64_t lda,
DeviceArray& b, int64_t ldb,
DeviceArray& bias,
double beta,
DeviceArray& c, int64_t ldc) : params_{} {
ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_));
params_.tuning = true;
params_.stream = Stream();
params_.handle = rocblas_handle_;
params_.opa = opa;
params_.opb = opb;
params_.m = m;
params_.n = n;
params_.k = k;
params_.alpha = alpha;
params_.a = static_cast<T*>(a.ptr());
params_.lda = lda;
params_.b = static_cast<T*>(b.ptr());
params_.ldb = ldb;
params_.bias = static_cast<T*>(bias.ptr());
params_.beta = beta;
params_.c = static_cast<T*>(c.ptr());
params_.ldc = ldc;
}
~GemmFastGeluUnfused() {
ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_));
rocblas_handle_ = nullptr;
}
void Run() override {
ORT_THROW_IF_ERROR((contrib::rocm::GemmFastGeluUnfused<T>(&params_)));
}
bool IsSupported() {
Status status = contrib::rocm::GemmFastGeluUnfused<T>(&params_);
return status.IsOK();
}
private:
using ParamsT = contrib::rocm::GemmFastGeluParams<T>;
ParamsT params_{};
rocblas_handle rocblas_handle_;
};
template <typename T>
class GemmFastGeluTunableOp : public IKernelExplorer {
public:
GemmFastGeluTunableOp(BlasOp opa, BlasOp opb,
int64_t m, int64_t n, int64_t k,
double alpha,
DeviceArray& a, int64_t lda,
DeviceArray& b, int64_t ldb,
DeviceArray& bias,
double beta,
DeviceArray& c, int64_t ldc) : params_{} {
ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_));
params_.tuning = true;
params_.stream = Stream();
params_.handle = rocblas_handle_;
params_.opa = opa;
params_.opb = opb;
params_.m = m;
params_.n = n;
params_.k = k;
params_.alpha = alpha;
params_.a = static_cast<T*>(a.ptr());
params_.lda = lda;
params_.b = static_cast<T*>(b.ptr());
params_.ldb = ldb;
params_.bias = static_cast<T*>(bias.ptr());
params_.beta = beta;
params_.c = static_cast<T*>(c.ptr());
params_.ldc = ldc;
op_.EnableTuning();
}
~GemmFastGeluTunableOp() {
ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_));
rocblas_handle_ = nullptr;
}
void Run() override {
ORT_THROW_IF_ERROR((op_(&params_)));
}
bool IsSupported() {
Status status = op_(&params_);
return status.IsOK();
}
private:
using ParamsT = contrib::rocm::GemmFastGeluParams<T>;
ParamsT params_{};
rocblas_handle rocblas_handle_;
contrib::rocm::GemmFastGeluTunableOp<T> op_{};
};
#define REGISTER_OP(name, type) \
py::class_<name<type>>(m, #name "_" #type) \
.def(py::init<BlasOp, BlasOp, int64_t, int64_t, int64_t, \
double, \
DeviceArray&, int64_t, \
DeviceArray&, int64_t, \
DeviceArray&, \
double, \
DeviceArray&, int64_t>()) \
.def("SetRepeats", &name<type>::SetRepeats) \
.def("Run", &name<type>::Run) \
.def("Profile", &name<type>::Profile) \
.def("IsSupported", &name<type>::IsSupported);
void InitGemmFastGelu(py::module m) {
REGISTER_OP(GemmFastGeluUnfused, float)
REGISTER_OP(GemmFastGeluUnfused, half)
REGISTER_OP(GemmFastGeluTunableOp, float)
REGISTER_OP(GemmFastGeluTunableOp, half)
}
} // namespace onnxruntime

View file

@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitGemmFastGelu(py::module mod);
} // namespace onnxruntime

View file

@ -0,0 +1,68 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from itertools import product
import numpy as np
def transab_to_suffix(transab):
return {
(True, True): "TT",
(True, False): "TN",
(False, True): "NT",
(False, False): "NN",
}[tuple(transab)]
def get_gemm_bound(dtype: str, a: np.ndarray, b: np.ndarray, c: np.ndarray, transa: bool, transb: bool):
k = b.shape[1] if transb else b.shape[0]
# The machine epsilon, unit roundoff, the smallest positive floating point number n such that the floating point
# number that represents 1 + n is greater than 1.
machine_eps = 2.0 ** -(24 if dtype == "float32" else 11)
# The following implements error bound 5.7 in paper I. C. Ipsen and H. Zhou, “Probabilistic error analysis for
# Inner Products,” SIAM Journal on Matrix Analysis and Applications, vol. 41, no. 4, pp. 17261741, 2020.
# NOTE: the bound is not tight for float16 when k is large
absa_mul_absb = np.abs(a.T if transa else a) @ np.abs(b.T if transb else b)
coeff = np.max(absa_mul_absb / np.abs(c))
gamma_2k = (1.0 + machine_eps) ** (2 * k) - 1.0
bound_5_7 = coeff * np.sqrt(np.log(2 / 1e-10) * machine_eps * gamma_2k / 2)
bound = bound_5_7
return bound
def get_gemm_bert_sizes(full=True):
bert_base_sizes = [
# m, n, k
(384, 768, 768),
(384, 768, 768 * 3),
(384, 768, 768 * 4),
(384, 768 * 4, 768),
(384, 1024, 1024),
(384, 1024, 1024 * 3),
(384, 1024, 1024 * 4),
(384, 1024 * 4, 1024),
]
# we then multiply m with the batch size
if full:
batch_sizes = [1, 64]
else:
batch_sizes = [1]
bert_sizes = []
for bsz in batch_sizes:
bert_sizes.extend([(m * bsz, n, k) for m, n, k in bert_base_sizes])
return bert_sizes
def get_gemm_basic_sizes(full=True):
if full:
return list(product([1, 3, 4, 16, 127, 128, 129, 133, 1024], repeat=3))
# ck has various impls to be tested, use the full basic cases will result too many cases to test.
# So we use a reduced combination here.
return list(product([1, 4, 127, 133], [3, 16, 128], [3, 129, 1024]))

View file

@ -4,10 +4,11 @@
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Dict, List, Union
from fusion_base import Fusion
from fusion_utils import NumpyHelper
from onnx import helper
from onnx import NodeProto, TensorProto, helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
@ -16,8 +17,35 @@ logger = getLogger(__name__)
class FusionGemmFastGelu(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "GemmFastGelu", "FastGelu", "GemmFastGelu")
self.shape_infer = None
self.shape_infer_done = False
def fuse(self, node, input_name_to_nodes, output_name_to_node):
def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[int, None]:
if tensor_proto.type.tensor_type.HasField("shape"):
return len(tensor_proto.type.tensor_type.shape.dim)
else:
return None
def get_dimensions(self, input_name: str) -> Union[int, None]:
graph_input = self.model.find_graph_input(input_name)
if graph_input:
return self.get_dimensions_from_tensor_proto(graph_input)
if not self.shape_infer_done:
self.shape_infer = self.model.infer_runtime_shape({}, update=True)
self.shape_infer_done = True
if self.shape_infer is not None:
return self.get_dimensions_from_tensor_proto(self.shape_infer.known_vi_[input_name])
return None
def fuse(
self,
node: NodeProto,
input_name_to_nodes: Dict[str, List[NodeProto]],
output_name_to_node: Dict[str, NodeProto],
):
"""
This pattern is from PyTorch bert model
Fuse MatMul with FastGelu into one node:
@ -34,20 +62,24 @@ class FusionGemmFastGelu(Fusion):
return
matmul = match_nodes[0]
weight = None
# matmul weight should be two dimension
# matmul input X should >= two dimension, input weight should be two dimension
weight_index = -1
x_dims = 0
weight = None
for i, input in enumerate(matmul.input):
initializer = self.model.get_initializer(input)
if initializer is None:
continue
weight_index = i
weight = NumpyHelper.to_array(initializer)
break
x_dims = self.get_dimensions(matmul.input[i])
else:
weight_index = i
weight = NumpyHelper.to_array(initializer)
if weight is None:
return
if len(weight.shape) != 2:
return
if x_dims < len(weight.shape):
return
# bias weight should be one dimension
bias_index = -1

View file

@ -0,0 +1,153 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import os
import unittest
from typing import List
import numpy as np
import onnx
from onnx import TensorProto, helper
from parity_utilities import find_transformers_source
if find_transformers_source():
from fusion_options import FusionOptions
from onnx_model import OnnxModel
from optimizer import optimize_model
else:
from onnxruntime.transformers.fusion_options import FusionOptions
from onnxruntime.transformers.onnx_model import OnnxModel
from onnxruntime.transformers.optimizer import optimize_model
onnxdomain = onnx.OperatorSetIdProto()
onnxdomain.version = 12
# The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
onnxdomain.domain = ""
msdomain = onnx.OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = "com.microsoft"
opsets = [onnxdomain, msdomain]
def float_tensor(name: str, shape: List[int], random=False):
low = 0.0
high = 1.0
total_elements = 1
for x in shape:
total_elements *= x
weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements
return helper.make_tensor(name, TensorProto.FLOAT, shape, weights)
def create_MatMul_FastGelu_withoutBias(batch_size, m, n, k):
# MatMul + FastGelu
nodes = [
helper.make_node("MatMul", ["input", "matmul_weight"], ["fastgelu_input"], "matmul"),
]
fastgelu_node = helper.make_node("FastGelu", ["fastgelu_input"], ["output"], "fastgelu")
fastgelu_node.domain = "com.microsoft"
nodes.append(fastgelu_node)
initializers = [float_tensor("matmul_weight", [k, n])] # initializers
graph = helper.make_graph(
[node for node in nodes if node],
"GemmFastGeluNoBiasModel", # name
[ # inputs
helper.make_tensor_value_info(
"input",
TensorProto.FLOAT,
[batch_size, m, k],
)
],
[ # outputs
helper.make_tensor_value_info(
"output",
TensorProto.FLOAT,
[batch_size, m, n],
),
],
initializers,
)
return helper.make_model(graph)
def create_MatMul_FastGelu_withBias(batch_size, m, n, k):
# MatMul + FastGelu
nodes = [
helper.make_node("MatMul", ["input", "matmul_weight"], ["fastgelu_input"], "matmul"),
]
fastgelu_node = helper.make_node("FastGelu", ["fastgelu_input", "fastgelu_bias"], ["output"], "fastgelu")
fastgelu_node.domain = "com.microsoft"
nodes.append(fastgelu_node)
initializers = [float_tensor("matmul_weight", [k, n]), float_tensor("fastgelu_bias", [n])] # initializers
graph = helper.make_graph(
[node for node in nodes if node],
"GemmFastGeluWithBiasModel", # name
[ # inputs
helper.make_tensor_value_info(
"input",
TensorProto.FLOAT,
[batch_size, m, k],
)
],
[ # outputs
helper.make_tensor_value_info(
"output",
TensorProto.FLOAT,
[batch_size, m, n],
),
],
initializers,
)
return helper.make_model(graph, opset_imports=opsets)
class TestFusion(unittest.TestCase):
def verify_fusion(self, optimized_model, expected_model_filename):
optimized_model.topological_sort()
expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model_filename)
expected_model = OnnxModel(onnx.load(expected_model_path))
expected_model.topological_sort()
self.assertEqual(str(optimized_model.model.graph), str(expected_model.model.graph))
def test_gemmfastgelu_fusion_withoutbias(self):
model = create_MatMul_FastGelu_withoutBias(32, 128, 64, 1024)
dir = "."
model_path = os.path.join(dir, "gemmfastgelu_nobias.onnx")
onnx.save(model, model_path)
fusion_opt = FusionOptions("bert")
fusion_opt.enable_gemm_fast_gelu = True
optimized_model = optimize_model(input=model_path, optimization_options=fusion_opt)
os.remove(model_path)
self.verify_fusion(optimized_model, "gemmfastgelu_nobias_opt.onnx")
def test_gemmfastgelu_fusion_withbias(self):
model = create_MatMul_FastGelu_withBias(32, 128, 64, 1024)
dir = "."
model_path = os.path.join(dir, "gemmfastgelu_withbias.onnx")
onnx.save(model, model_path)
fusion_opt = FusionOptions("bert")
fusion_opt.enable_gemm_fast_gelu = True
optimized_model = optimize_model(input=model_path, optimization_options=fusion_opt)
os.remove(model_path)
self.verify_fusion(optimized_model, "gemmfastgelu_withbias_opt.onnx")
if __name__ == "__main__":
unittest.main()