mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
[ROCm] Add GemmFastGelu TunableOp (#13589)
### Description <!-- Describe your changes. --> 1. Update the rules for GemmFastGelu fusion, MatMul input x should >= two dimension, input weight should == two dimension. 2. Add GemmFastGelu fusion test. 3. Add GemmFastGelu TunableOp, only contains the original implementation(Gemm + FastGelu). ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Co-authored-by: peixuanzuo <peixuanzuo@linmif39a000004.zvflicr54joexhdgnhvmxrxygg.phxx.internal.cloudapp.net>
This commit is contained in:
parent
45a895cdc3
commit
8f3c6ea0df
17 changed files with 808 additions and 150 deletions
|
|
@ -3,18 +3,14 @@
|
|||
|
||||
#include "contrib_ops/rocm/bert/gemm_fast_gelu.h"
|
||||
|
||||
#include "contrib_ops/rocm/bert/fast_gelu_impl.h"
|
||||
#include "contrib_ops/rocm/bert/transformer_common.h"
|
||||
#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h"
|
||||
#include "core/providers/cpu/math/matmul_helper.h"
|
||||
#include "core/providers/rocm/math/matmul_impl.h"
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "core/providers/rocm/shared_inc/fpgeneric.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
using onnxruntime::rocm::MatMulImpl;
|
||||
using onnxruntime::rocm::ToHipType;
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
|
|
@ -48,34 +44,27 @@ Status GemmFastGelu<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
MatMulComputeHelper helper;
|
||||
ORT_RETURN_IF_ERROR(helper.Compute(X->Shape(), W->Shape(), transa, transb, trans_batch_a, trans_batch_b, false));
|
||||
|
||||
auto gemm_buffer = GetScratchBuffer<T>(helper.OutputShape().Size());
|
||||
Tensor* Y = ctx->Output(0, helper.OutputShape());
|
||||
|
||||
// Bail out early if the output is going to be empty
|
||||
if (Y->Shape().Size() == 0)
|
||||
return Status::OK();
|
||||
|
||||
const float alpha = 1.0f;
|
||||
const float zero = 0.0f;
|
||||
// gemmfastgelu only support alpha == 1 and beta == 0
|
||||
const HipT alpha = ToHipType<T>::FromFloat(1.0f);
|
||||
const HipT beta = ToHipType<T>::FromFloat(0.0f);
|
||||
|
||||
if (MatMulImpl<T>(this, helper, reinterpret_cast<const T*>(X->Data<T>()),
|
||||
reinterpret_cast<const T*>(W->Data<T>()),
|
||||
reinterpret_cast<T*>(gemm_buffer.get()),
|
||||
X->Shape(), W->Shape(),
|
||||
transa, transb, trans_batch_a, trans_batch_b, alpha, zero) != Status::OK()) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL);
|
||||
}
|
||||
|
||||
int64_t fast_gelu_input_length = Y->Shape().Size();
|
||||
int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
|
||||
|
||||
return LaunchFastGeluKernel<HipT>(Stream(),
|
||||
static_cast<int>(fast_gelu_input_length),
|
||||
static_cast<int>(bias_length),
|
||||
reinterpret_cast<HipT*>(gemm_buffer.get()),
|
||||
(nullptr != bias) ? reinterpret_cast<const HipT*>(bias->Data<T>()) : nullptr,
|
||||
reinterpret_cast<HipT*>(Y->MutableData<T>()),
|
||||
false);
|
||||
return LaunchGemmFastGeluKernel<HipT>(
|
||||
IsTunableOpEnabled(),
|
||||
Stream(), RocblasHandle(),
|
||||
transa, transb,
|
||||
static_cast<int64_t>(helper.M()), static_cast<int64_t>(helper.N()), static_cast<int64_t>(helper.K()),
|
||||
alpha,
|
||||
reinterpret_cast<const HipT*>(X->Data<T>()), static_cast<int64_t>(helper.Lda(transa)),
|
||||
reinterpret_cast<const HipT*>(W->Data<T>()), static_cast<int64_t>(helper.Ldb(transb)),
|
||||
(nullptr != bias) ? reinterpret_cast<const HipT*>(bias->Data<T>()) : nullptr,
|
||||
beta,
|
||||
reinterpret_cast<HipT*>(Y->MutableData<T>()), static_cast<int64_t>(helper.Ldc()));
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
|
|
|||
79
onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu
Normal file
79
onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h"
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h"
|
||||
#include "core/providers/rocm/tunable/gemm_common.h"
|
||||
|
||||
using onnxruntime::rocm::tunable::blas::BlasOp;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
// See it as row-major
|
||||
template <typename T>
|
||||
Status LaunchGemmFastGeluKernel(bool tuning,
|
||||
hipStream_t stream,
|
||||
rocblas_handle handle,
|
||||
bool transa,
|
||||
bool transb,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t k,
|
||||
const T alpha,
|
||||
const T* a,
|
||||
int64_t lda,
|
||||
const T* b,
|
||||
int64_t ldb,
|
||||
const T* bias,
|
||||
const T beta,
|
||||
T* c,
|
||||
int64_t ldc) {
|
||||
GemmFastGeluParams<T> params;
|
||||
params.tuning = tuning;
|
||||
params.stream = stream;
|
||||
params.handle = handle;
|
||||
params.opa = transa ? BlasOp::Trans : BlasOp::NonTrans;
|
||||
params.opb = transb ? BlasOp::Trans : BlasOp::NonTrans;
|
||||
|
||||
params.m = m;
|
||||
params.n = n;
|
||||
params.k = k;
|
||||
params.alpha = alpha;
|
||||
params.a = a;
|
||||
params.lda = lda;
|
||||
params.b = b;
|
||||
params.ldb = ldb;
|
||||
params.bias = bias;
|
||||
params.beta = beta;
|
||||
params.c = c;
|
||||
params.ldc = ldc;
|
||||
|
||||
if (tuning) {
|
||||
static GemmFastGeluTunableOp<T> op;
|
||||
op.EnableTuning();
|
||||
return op(¶ms);
|
||||
}
|
||||
|
||||
return GemmFastGeluUnfused(¶ms);
|
||||
}
|
||||
|
||||
#define SPECIALIZED_IMPL(T) \
|
||||
template Status LaunchGemmFastGeluKernel<T>(bool tuning, \
|
||||
hipStream_t stream, rocblas_handle handle, \
|
||||
bool transa, bool transb, \
|
||||
int64_t m, int64_t n, int64_t k, const T alpha, \
|
||||
const T* a, int64_t lda, const T* b, int64_t ldb, \
|
||||
const T* bias, const T beta, T* c, int64_t ldc);
|
||||
|
||||
SPECIALIZED_IMPL(float)
|
||||
SPECIALIZED_IMPL(half)
|
||||
SPECIALIZED_IMPL(BFloat16)
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
36
onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h
Normal file
36
onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
template <typename T>
|
||||
Status LaunchGemmFastGeluKernel(bool tuning,
|
||||
hipStream_t stream,
|
||||
rocblas_handle handle,
|
||||
bool transa,
|
||||
bool transb,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t k,
|
||||
const T alpha,
|
||||
const T* a,
|
||||
int64_t lda,
|
||||
const T* b,
|
||||
int64_t ldb,
|
||||
const T* bias,
|
||||
const T beta,
|
||||
T* c,
|
||||
int64_t ldc);
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "contrib_ops/rocm/bert/fast_gelu_impl.h"
|
||||
#include "core/providers/rocm/tunable/gemm.h"
|
||||
#include "core/providers/rocm/tunable/gemm_common.h"
|
||||
#include "core/providers/rocm/tunable/rocm_tunable.h"
|
||||
|
||||
using onnxruntime::rocm::tunable::blas::BlasOp;
|
||||
using onnxruntime::rocm::tunable::blas::BlasOpToString;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
template <typename T>
|
||||
struct GemmFastGeluParams : onnxruntime::rocm::tunable::OpParams {
|
||||
std::string Signature() const override {
|
||||
return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k);
|
||||
}
|
||||
rocblas_handle handle;
|
||||
BlasOp opa;
|
||||
BlasOp opb;
|
||||
int64_t m;
|
||||
int64_t n;
|
||||
int64_t k;
|
||||
T alpha;
|
||||
const T* a;
|
||||
int64_t lda;
|
||||
const T* b;
|
||||
int64_t ldb;
|
||||
const T* bias;
|
||||
T beta;
|
||||
T* c;
|
||||
int64_t ldc;
|
||||
bool tuning{false};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status GemmFastGeluUnfused(const GemmFastGeluParams<T>* params) {
|
||||
namespace column_major = onnxruntime::rocm::tunable::blas::column_major;
|
||||
if (column_major::Gemm(params->tuning, params->stream, params->handle,
|
||||
params->opb, params->opa,
|
||||
params->n, params->m, params->k,
|
||||
params->alpha, params->b, params->ldb, params->a, params->lda,
|
||||
params->beta, params->c, params->ldc) != Status::OK()) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "GemmFastGelu call column_major::Gemm failed");
|
||||
}
|
||||
|
||||
int64_t fast_gelu_input_length = params->m * params->n;
|
||||
int64_t bias_length = (params->bias != nullptr) ? params->n : 0;
|
||||
|
||||
// inplace computation
|
||||
return LaunchFastGeluKernel<T>(params->stream,
|
||||
static_cast<int>(fast_gelu_input_length),
|
||||
static_cast<int>(bias_length),
|
||||
params->c,
|
||||
params->bias,
|
||||
params->c,
|
||||
params->tuning);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class GemmFastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp<GemmFastGeluParams<T>> {
|
||||
public:
|
||||
GemmFastGeluTunableOp() {
|
||||
this->ops_.emplace_back(GemmFastGeluUnfused<T>);
|
||||
|
||||
this->SetDefaultId(0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -74,7 +74,7 @@ Status MatMul<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
reinterpret_cast<T*>(Y->MutableData<T>()),
|
||||
left_X->Shape(), right_X->Shape(),
|
||||
transa, transb, trans_batch_a_, trans_batch_b_, alpha_, 0.0f) != Status::OK()) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL);
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "MatMulImpl failed");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// Modifications: Remove cudaDeviceProp in LaunchFastGeluKernel.
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/rocm/math/matmul_impl.h"
|
||||
|
||||
#include "core/providers/rocm/rocm_allocator.h"
|
||||
|
|
@ -137,58 +133,18 @@ Status MatMulImpl(const RocmKernel* op, MatMulComputeHelper& helper,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template Status MatMulImpl<float>(const RocmKernel* op,
|
||||
MatMulComputeHelper& helper,
|
||||
const float* left_x_data,
|
||||
const float* right_x_data,
|
||||
float* output_y_data,
|
||||
const TensorShape& left_shape,
|
||||
const TensorShape& right_shape,
|
||||
bool transa,
|
||||
bool transb,
|
||||
bool trans_batch_a,
|
||||
bool trans_batch_b,
|
||||
const float t_alpha,
|
||||
const float t_zero);
|
||||
template Status MatMulImpl<double>(const RocmKernel* op,
|
||||
MatMulComputeHelper& helper,
|
||||
const double* left_x_data,
|
||||
const double* right_x_data,
|
||||
double* output_y_data,
|
||||
const TensorShape& left_shape,
|
||||
const TensorShape& right_shape,
|
||||
bool transa,
|
||||
bool transb,
|
||||
bool trans_batch_a,
|
||||
bool trans_batch_b,
|
||||
const float t_alpha,
|
||||
const float t_zero);
|
||||
template Status MatMulImpl<MLFloat16>(const RocmKernel* op,
|
||||
MatMulComputeHelper& helper,
|
||||
const MLFloat16* left_x_data,
|
||||
const MLFloat16* right_x_data,
|
||||
MLFloat16* output_y_data,
|
||||
const TensorShape& left_shape,
|
||||
const TensorShape& right_shape,
|
||||
bool transa,
|
||||
bool transb,
|
||||
bool trans_batch_a,
|
||||
bool trans_batch_b,
|
||||
const float t_alpha,
|
||||
const float t_zero);
|
||||
template Status MatMulImpl<BFloat16>(const RocmKernel* op,
|
||||
MatMulComputeHelper& helper,
|
||||
const BFloat16* left_x_data,
|
||||
const BFloat16* right_x_data,
|
||||
BFloat16* output_y_data,
|
||||
const TensorShape& left_shape,
|
||||
const TensorShape& right_shape,
|
||||
bool transa,
|
||||
bool transb,
|
||||
bool trans_batch_a,
|
||||
bool trans_batch_b,
|
||||
const float t_alpha,
|
||||
const float t_zero);
|
||||
#define SPECIALIZED_IMPL(T) \
|
||||
template Status MatMulImpl<T>(const RocmKernel* op, MatMulComputeHelper& helper, \
|
||||
const T* left_x_data, const T* right_x_data, T* output_y_data, \
|
||||
const TensorShape& left_shape, const TensorShape& right_shape, \
|
||||
bool transa, bool transb, \
|
||||
bool trans_batch_a, bool trans_batch_b, \
|
||||
const float t_alpha, const float t_zero);
|
||||
|
||||
SPECIALIZED_IMPL(float)
|
||||
SPECIALIZED_IMPL(double)
|
||||
SPECIALIZED_IMPL(MLFloat16)
|
||||
SPECIALIZED_IMPL(BFloat16)
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,10 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// Modifications: Remove cudaDeviceProp in LaunchFastGeluKernel.
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/rocm/shared_inc/fpgeneric.h"
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include "python/tools/kernel_explorer/kernels/rocm/fast_gelu.h"
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm.h"
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h"
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
|
@ -22,6 +23,7 @@ PYBIND11_MODULE(_kernel_explorer, m) {
|
|||
InitFastGelu(m);
|
||||
InitGemm(m);
|
||||
InitSkipLayerNorm(m);
|
||||
InitGemmFastGelu(m);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,155 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import re
|
||||
import sys
|
||||
from itertools import product
|
||||
|
||||
import kernel_explorer as ke
|
||||
import numpy as np
|
||||
import pytest
|
||||
from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix
|
||||
|
||||
|
||||
def dtype_to_funcs(dtype):
|
||||
type_map = {
|
||||
"float16": list(filter(lambda x: re.search("GemmFastGelu.*_half", x), dir(ke))),
|
||||
"float32": list(filter(lambda x: re.search("GemmFastGelu.*_float", x), dir(ke))),
|
||||
}
|
||||
return type_map[dtype]
|
||||
|
||||
|
||||
def fast_gelu(x, bias):
|
||||
x = x + bias
|
||||
y = 0.5 * x * (1 + np.tanh(0.797885 * x + 0.035677 * x * x * x))
|
||||
return y
|
||||
|
||||
|
||||
# TODO The test method needs update.
|
||||
def _test_gemmfastgelu(func, dtype: str, m: int, n: int, k: int, transa=False, transb=False):
|
||||
assert dtype in ["float16", "float32"]
|
||||
|
||||
a_shape = (k, m) if transa else (m, k)
|
||||
b_shape = (n, k) if transb else (k, n)
|
||||
|
||||
np.random.seed(0)
|
||||
a = (np.random.rand(*a_shape)).astype(dtype).astype("float64")
|
||||
b = (np.random.rand(*b_shape)).astype(dtype).astype("float64")
|
||||
bias = (np.random.rand(n)).astype(dtype)
|
||||
temp_c = (a.T if transa else a) @ (b.T if transb else b)
|
||||
|
||||
bound = get_gemm_bound(dtype, a, b, temp_c, transa, transb)
|
||||
|
||||
temp_c = temp_c.astype(dtype)
|
||||
ref_c = fast_gelu(temp_c, bias)
|
||||
|
||||
a = a.astype(dtype)
|
||||
b = b.astype(dtype)
|
||||
|
||||
my_c = np.zeros((m, n), dtype=dtype)
|
||||
dev_a = ke.DeviceArray(a)
|
||||
dev_b = ke.DeviceArray(b)
|
||||
dev_bias = ke.DeviceArray(bias)
|
||||
dev_c = ke.DeviceArray(my_c)
|
||||
|
||||
opa = ke.blas_op.T if transa else ke.blas_op.N
|
||||
opb = ke.blas_op.T if transb else ke.blas_op.N
|
||||
lda = a_shape[1]
|
||||
ldb = b_shape[1]
|
||||
alpha = 1.0
|
||||
beta = 0.0
|
||||
my_func = getattr(ke, func)
|
||||
my_op = my_func(opa, opb, m, n, k, alpha, dev_a, lda, dev_b, ldb, dev_bias, beta, dev_c, n)
|
||||
|
||||
if my_op.IsSupported():
|
||||
my_op.Run()
|
||||
dev_c.UpdateHostNumpyArray()
|
||||
|
||||
print(
|
||||
f"{func:<50} : dtype={dtype} {transab_to_suffix((transa, transb))} m={m:<5} n={n:<5} k={k:<5} bound: {bound}"
|
||||
)
|
||||
|
||||
np.testing.assert_allclose(my_c, ref_c, rtol=max(bound, 1e-2))
|
||||
|
||||
|
||||
dtypes = ["float16", "float32"]
|
||||
all_transabs = list(product([True, False], repeat=2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False))
|
||||
@pytest.mark.parametrize("transab", all_transabs)
|
||||
def test_gemmfastgelu_bert_cases(dtype, size, transab):
|
||||
for func in dtype_to_funcs(dtype):
|
||||
_test_gemmfastgelu(func, dtype, *size, *transab)
|
||||
|
||||
|
||||
def profile_gemmfastgelu_func(func, dtype: str, m: int, n: int, k: int, transa: bool, transb: bool):
|
||||
a_shape = (k, m) if transa else (m, k)
|
||||
b_shape = (n, k) if transb else (k, n)
|
||||
|
||||
np.random.seed(0)
|
||||
a = (np.random.rand(*a_shape) * 2 - 1).astype(dtype)
|
||||
b = (np.random.rand(*b_shape) * 2 - 1).astype(dtype)
|
||||
my_c = np.zeros((m, n), dtype=dtype)
|
||||
bias = np.random.rand(n).astype(dtype)
|
||||
|
||||
dev_a = ke.DeviceArray(a)
|
||||
dev_b = ke.DeviceArray(b)
|
||||
dev_bias = ke.DeviceArray(bias)
|
||||
dev_c = ke.DeviceArray(my_c)
|
||||
|
||||
opa = ke.blas_op.T if transa else ke.blas_op.N
|
||||
opb = ke.blas_op.T if transb else ke.blas_op.N
|
||||
lda = a_shape[1]
|
||||
ldb = b_shape[1]
|
||||
alpha = 1.0
|
||||
beta = 0.0
|
||||
my_func = getattr(ke, func)
|
||||
my_op = my_func(opa, opb, m, n, k, alpha, dev_a, lda, dev_b, ldb, dev_bias, beta, dev_c, n)
|
||||
|
||||
if my_op.IsSupported():
|
||||
my_op.Run()
|
||||
dev_c.UpdateHostNumpyArray()
|
||||
|
||||
time_ms = my_op.Profile()
|
||||
time_us = time_ms * 1000
|
||||
# only counts gemm tflops because fastgelu is low order term (7 * n).
|
||||
tflops = (m * k * n * 2) / (time_ms * 1e-3) / 1e12
|
||||
print(
|
||||
f"{func:<50} {dtype} {transab_to_suffix((transa, transb))}",
|
||||
f"m={m:<4} n={n:<4} k={k:<4} {time_us:>8.4f} us {tflops:>5.2f} tflops",
|
||||
)
|
||||
|
||||
|
||||
def profile_with_args(transa, transb, dtype, m, n, k):
|
||||
for func in dtype_to_funcs(dtype):
|
||||
profile_gemmfastgelu_func(func, dtype, m, n, k, transa, transb)
|
||||
|
||||
|
||||
def profile():
|
||||
for dtype in dtypes:
|
||||
for m, n, k in get_gemm_bert_sizes(full=True):
|
||||
profile_with_args(False, False, dtype, m, n, k)
|
||||
print()
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
group = parser.add_argument_group("profile with args")
|
||||
group.add_argument("transa", choices="NT")
|
||||
group.add_argument("transb", choices="NT")
|
||||
group.add_argument("dtype", choices=dtypes)
|
||||
group.add_argument("m", type=int)
|
||||
group.add_argument("n", type=int)
|
||||
group.add_argument("k", type=int)
|
||||
if len(sys.argv) == 1:
|
||||
profile()
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
profile_with_args(args.transa == "T", args.transb == "T", args.dtype, args.m, args.n, args.k)
|
||||
|
|
@ -9,6 +9,7 @@ from itertools import product
|
|||
import kernel_explorer as ke
|
||||
import numpy as np
|
||||
import pytest
|
||||
from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix
|
||||
|
||||
|
||||
def dtype_to_suffix(dtype):
|
||||
|
|
@ -18,15 +19,6 @@ def dtype_to_suffix(dtype):
|
|||
}[dtype]
|
||||
|
||||
|
||||
def transab_to_suffix(transab):
|
||||
return {
|
||||
(True, True): "TT",
|
||||
(True, False): "TN",
|
||||
(False, True): "NT",
|
||||
(False, False): "NN",
|
||||
}[tuple(transab)]
|
||||
|
||||
|
||||
def _test_gemm(func, dtype: str, m: int, n: int, k: int, transa=False, transb=False):
|
||||
assert dtype in ["float32", "float16"]
|
||||
|
||||
|
|
@ -38,18 +30,7 @@ def _test_gemm(func, dtype: str, m: int, n: int, k: int, transa=False, transb=Fa
|
|||
b = (np.random.rand(*b_shape) + 0.5).astype(dtype).astype("float64")
|
||||
ref_c = (a.T if transa else a) @ (b.T if transb else b)
|
||||
|
||||
# The machine epsilon, unit roundoff, the smallest positive floating point number n such that the floating point
|
||||
# number that represents 1 + n is greater than 1.
|
||||
machine_eps = 2.0 ** -(24 if dtype == "float32" else 11)
|
||||
|
||||
# The following implements error bound 5.7 in paper I. C. Ipsen and H. Zhou, “Probabilistic error analysis for
|
||||
# Inner Products,” SIAM Journal on Matrix Analysis and Applications, vol. 41, no. 4, pp. 1726–1741, 2020.
|
||||
# NOTE: the bound is not tight for float16 when k is large
|
||||
absa_mul_absb = np.abs(a.T if transa else a) @ np.abs(b.T if transb else b)
|
||||
coeff = np.max(absa_mul_absb / np.abs(ref_c))
|
||||
gamma_2k = (1.0 + machine_eps) ** (2 * k) - 1.0
|
||||
bound_5_7 = coeff * np.sqrt(np.log(2 / 1e-10) * machine_eps * gamma_2k / 2)
|
||||
bound = bound_5_7
|
||||
bound = get_gemm_bound(dtype, a, b, ref_c, transa, transb)
|
||||
|
||||
a = a.astype(dtype)
|
||||
b = b.astype(dtype)
|
||||
|
|
@ -92,47 +73,17 @@ def _test_gemm(func, dtype: str, m: int, n: int, k: int, transa=False, transb=Fa
|
|||
|
||||
dtypes = ["float32", "float16"]
|
||||
all_transabs = list(product([True, False], repeat=2))
|
||||
all_basic_sizes = list(product([1, 3, 4, 16, 127, 128, 129, 133, 1024], repeat=3))
|
||||
|
||||
|
||||
def get_bert_sizes(full=True):
|
||||
bert_base_sizes = [
|
||||
# m, n, k
|
||||
(384, 768, 768),
|
||||
(384, 768, 768 * 3),
|
||||
(384, 768, 768 * 4),
|
||||
(384, 768 * 4, 768),
|
||||
(384, 1024, 1024),
|
||||
(384, 1024, 1024 * 3),
|
||||
(384, 1024, 1024 * 4),
|
||||
(384, 1024 * 4, 1024),
|
||||
]
|
||||
|
||||
# we then multiply m with the batch size
|
||||
if full:
|
||||
batch_sizes = [1, 64]
|
||||
else:
|
||||
batch_sizes = [1]
|
||||
bert_sizes = []
|
||||
for bsz in batch_sizes:
|
||||
bert_sizes.extend([(m * bsz, n, k) for m, n, k in bert_base_sizes])
|
||||
return bert_sizes
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
@pytest.mark.parametrize("size", all_basic_sizes + get_bert_sizes(full=False))
|
||||
@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=True) + get_gemm_bert_sizes(full=False))
|
||||
@pytest.mark.parametrize("transab", all_transabs)
|
||||
def test_rocblas_gemm_all_cases(dtype, size, transab):
|
||||
_test_gemm(getattr(ke, "RocblasGemm_" + dtype_to_suffix(dtype)), dtype, *size, *transab)
|
||||
|
||||
|
||||
# ck has various impls to be tested, use the full basic cases will result too many cases to test.
|
||||
# So we use a reduced combination here.
|
||||
reduced_basic_sizes = list(product([1, 4, 127, 133], [3, 16, 128], [3, 129, 1024]))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
@pytest.mark.parametrize("size", reduced_basic_sizes + get_bert_sizes(full=False))
|
||||
@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False))
|
||||
@pytest.mark.parametrize("transab", all_transabs)
|
||||
def test_ck_gemm_bert_cases(dtype, size, transab):
|
||||
wrapper_name = "CKGemm_{}_{}".format(dtype_to_suffix(dtype), transab_to_suffix(transab))
|
||||
|
|
@ -141,7 +92,7 @@ def test_ck_gemm_bert_cases(dtype, size, transab):
|
|||
|
||||
# Tunable is basically wrapped around of rocblas and ck gemm, so no need for full tests
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
@pytest.mark.parametrize("size", reduced_basic_sizes + get_bert_sizes(full=False))
|
||||
@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False))
|
||||
@pytest.mark.parametrize("transab", all_transabs)
|
||||
def test_gemm_tunable_bert_cases(dtype, size, transab):
|
||||
wrapper_name = "GemmTunable_{}_{}".format(dtype_to_suffix(dtype), transab_to_suffix(transab))
|
||||
|
|
@ -192,7 +143,7 @@ def profile_with_args(transa, transb, dtype, m, n, k):
|
|||
|
||||
def profile():
|
||||
for dtype in dtypes:
|
||||
for m, n, k in get_bert_sizes(full=True):
|
||||
for m, n, k in get_gemm_bert_sizes(full=True):
|
||||
profile_with_args(False, False, dtype, m, n, k)
|
||||
print()
|
||||
print()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,146 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h"
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h"
|
||||
#include "python/tools/kernel_explorer/device_array.h"
|
||||
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
|
||||
|
||||
using onnxruntime::rocm::tunable::blas::BlasOp;
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
template <typename T>
|
||||
class GemmFastGeluUnfused : public IKernelExplorer {
|
||||
public:
|
||||
GemmFastGeluUnfused(BlasOp opa, BlasOp opb,
|
||||
int64_t m, int64_t n, int64_t k,
|
||||
double alpha,
|
||||
DeviceArray& a, int64_t lda,
|
||||
DeviceArray& b, int64_t ldb,
|
||||
DeviceArray& bias,
|
||||
double beta,
|
||||
DeviceArray& c, int64_t ldc) : params_{} {
|
||||
ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_));
|
||||
params_.tuning = true;
|
||||
params_.stream = Stream();
|
||||
params_.handle = rocblas_handle_;
|
||||
params_.opa = opa;
|
||||
params_.opb = opb;
|
||||
params_.m = m;
|
||||
params_.n = n;
|
||||
params_.k = k;
|
||||
params_.alpha = alpha;
|
||||
params_.a = static_cast<T*>(a.ptr());
|
||||
params_.lda = lda;
|
||||
params_.b = static_cast<T*>(b.ptr());
|
||||
params_.ldb = ldb;
|
||||
params_.bias = static_cast<T*>(bias.ptr());
|
||||
params_.beta = beta;
|
||||
params_.c = static_cast<T*>(c.ptr());
|
||||
params_.ldc = ldc;
|
||||
}
|
||||
|
||||
~GemmFastGeluUnfused() {
|
||||
ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_));
|
||||
rocblas_handle_ = nullptr;
|
||||
}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR((contrib::rocm::GemmFastGeluUnfused<T>(¶ms_)));
|
||||
}
|
||||
|
||||
bool IsSupported() {
|
||||
Status status = contrib::rocm::GemmFastGeluUnfused<T>(¶ms_);
|
||||
return status.IsOK();
|
||||
}
|
||||
|
||||
private:
|
||||
using ParamsT = contrib::rocm::GemmFastGeluParams<T>;
|
||||
ParamsT params_{};
|
||||
rocblas_handle rocblas_handle_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class GemmFastGeluTunableOp : public IKernelExplorer {
|
||||
public:
|
||||
GemmFastGeluTunableOp(BlasOp opa, BlasOp opb,
|
||||
int64_t m, int64_t n, int64_t k,
|
||||
double alpha,
|
||||
DeviceArray& a, int64_t lda,
|
||||
DeviceArray& b, int64_t ldb,
|
||||
DeviceArray& bias,
|
||||
double beta,
|
||||
DeviceArray& c, int64_t ldc) : params_{} {
|
||||
ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_));
|
||||
params_.tuning = true;
|
||||
params_.stream = Stream();
|
||||
params_.handle = rocblas_handle_;
|
||||
params_.opa = opa;
|
||||
params_.opb = opb;
|
||||
params_.m = m;
|
||||
params_.n = n;
|
||||
params_.k = k;
|
||||
params_.alpha = alpha;
|
||||
params_.a = static_cast<T*>(a.ptr());
|
||||
params_.lda = lda;
|
||||
params_.b = static_cast<T*>(b.ptr());
|
||||
params_.ldb = ldb;
|
||||
params_.bias = static_cast<T*>(bias.ptr());
|
||||
params_.beta = beta;
|
||||
params_.c = static_cast<T*>(c.ptr());
|
||||
params_.ldc = ldc;
|
||||
|
||||
op_.EnableTuning();
|
||||
}
|
||||
|
||||
~GemmFastGeluTunableOp() {
|
||||
ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_));
|
||||
rocblas_handle_ = nullptr;
|
||||
}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR((op_(¶ms_)));
|
||||
}
|
||||
|
||||
bool IsSupported() {
|
||||
Status status = op_(¶ms_);
|
||||
return status.IsOK();
|
||||
}
|
||||
|
||||
private:
|
||||
using ParamsT = contrib::rocm::GemmFastGeluParams<T>;
|
||||
ParamsT params_{};
|
||||
rocblas_handle rocblas_handle_;
|
||||
contrib::rocm::GemmFastGeluTunableOp<T> op_{};
|
||||
};
|
||||
|
||||
#define REGISTER_OP(name, type) \
|
||||
py::class_<name<type>>(m, #name "_" #type) \
|
||||
.def(py::init<BlasOp, BlasOp, int64_t, int64_t, int64_t, \
|
||||
double, \
|
||||
DeviceArray&, int64_t, \
|
||||
DeviceArray&, int64_t, \
|
||||
DeviceArray&, \
|
||||
double, \
|
||||
DeviceArray&, int64_t>()) \
|
||||
.def("SetRepeats", &name<type>::SetRepeats) \
|
||||
.def("Run", &name<type>::Run) \
|
||||
.def("Profile", &name<type>::Profile) \
|
||||
.def("IsSupported", &name<type>::IsSupported);
|
||||
|
||||
void InitGemmFastGelu(py::module m) {
|
||||
REGISTER_OP(GemmFastGeluUnfused, float)
|
||||
REGISTER_OP(GemmFastGeluUnfused, half)
|
||||
|
||||
REGISTER_OP(GemmFastGeluTunableOp, float)
|
||||
REGISTER_OP(GemmFastGeluTunableOp, half)
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void InitGemmFastGelu(py::module mod);
|
||||
|
||||
} // namespace onnxruntime
|
||||
68
onnxruntime/python/tools/kernel_explorer/kernels/utils.py
Normal file
68
onnxruntime/python/tools/kernel_explorer/kernels/utils.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def transab_to_suffix(transab):
|
||||
return {
|
||||
(True, True): "TT",
|
||||
(True, False): "TN",
|
||||
(False, True): "NT",
|
||||
(False, False): "NN",
|
||||
}[tuple(transab)]
|
||||
|
||||
|
||||
def get_gemm_bound(dtype: str, a: np.ndarray, b: np.ndarray, c: np.ndarray, transa: bool, transb: bool):
|
||||
k = b.shape[1] if transb else b.shape[0]
|
||||
# The machine epsilon, unit roundoff, the smallest positive floating point number n such that the floating point
|
||||
# number that represents 1 + n is greater than 1.
|
||||
machine_eps = 2.0 ** -(24 if dtype == "float32" else 11)
|
||||
|
||||
# The following implements error bound 5.7 in paper I. C. Ipsen and H. Zhou, “Probabilistic error analysis for
|
||||
# Inner Products,” SIAM Journal on Matrix Analysis and Applications, vol. 41, no. 4, pp. 1726–1741, 2020.
|
||||
# NOTE: the bound is not tight for float16 when k is large
|
||||
absa_mul_absb = np.abs(a.T if transa else a) @ np.abs(b.T if transb else b)
|
||||
coeff = np.max(absa_mul_absb / np.abs(c))
|
||||
gamma_2k = (1.0 + machine_eps) ** (2 * k) - 1.0
|
||||
bound_5_7 = coeff * np.sqrt(np.log(2 / 1e-10) * machine_eps * gamma_2k / 2)
|
||||
bound = bound_5_7
|
||||
|
||||
return bound
|
||||
|
||||
|
||||
def get_gemm_bert_sizes(full=True):
|
||||
bert_base_sizes = [
|
||||
# m, n, k
|
||||
(384, 768, 768),
|
||||
(384, 768, 768 * 3),
|
||||
(384, 768, 768 * 4),
|
||||
(384, 768 * 4, 768),
|
||||
(384, 1024, 1024),
|
||||
(384, 1024, 1024 * 3),
|
||||
(384, 1024, 1024 * 4),
|
||||
(384, 1024 * 4, 1024),
|
||||
]
|
||||
|
||||
# we then multiply m with the batch size
|
||||
if full:
|
||||
batch_sizes = [1, 64]
|
||||
else:
|
||||
batch_sizes = [1]
|
||||
bert_sizes = []
|
||||
for bsz in batch_sizes:
|
||||
bert_sizes.extend([(m * bsz, n, k) for m, n, k in bert_base_sizes])
|
||||
return bert_sizes
|
||||
|
||||
|
||||
def get_gemm_basic_sizes(full=True):
|
||||
if full:
|
||||
return list(product([1, 3, 4, 16, 127, 128, 129, 133, 1024], repeat=3))
|
||||
|
||||
# ck has various impls to be tested, use the full basic cases will result too many cases to test.
|
||||
# So we use a reduced combination here.
|
||||
return list(product([1, 4, 127, 133], [3, 16, 128], [3, 129, 1024]))
|
||||
|
|
@ -4,10 +4,11 @@
|
|||
# --------------------------------------------------------------------------
|
||||
|
||||
from logging import getLogger
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from fusion_base import Fusion
|
||||
from fusion_utils import NumpyHelper
|
||||
from onnx import helper
|
||||
from onnx import NodeProto, TensorProto, helper
|
||||
from onnx_model import OnnxModel
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
@ -16,8 +17,35 @@ logger = getLogger(__name__)
|
|||
class FusionGemmFastGelu(Fusion):
|
||||
def __init__(self, model: OnnxModel):
|
||||
super().__init__(model, "GemmFastGelu", "FastGelu", "GemmFastGelu")
|
||||
self.shape_infer = None
|
||||
self.shape_infer_done = False
|
||||
|
||||
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
||||
def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[int, None]:
|
||||
if tensor_proto.type.tensor_type.HasField("shape"):
|
||||
return len(tensor_proto.type.tensor_type.shape.dim)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_dimensions(self, input_name: str) -> Union[int, None]:
|
||||
graph_input = self.model.find_graph_input(input_name)
|
||||
if graph_input:
|
||||
return self.get_dimensions_from_tensor_proto(graph_input)
|
||||
|
||||
if not self.shape_infer_done:
|
||||
self.shape_infer = self.model.infer_runtime_shape({}, update=True)
|
||||
self.shape_infer_done = True
|
||||
|
||||
if self.shape_infer is not None:
|
||||
return self.get_dimensions_from_tensor_proto(self.shape_infer.known_vi_[input_name])
|
||||
|
||||
return None
|
||||
|
||||
def fuse(
|
||||
self,
|
||||
node: NodeProto,
|
||||
input_name_to_nodes: Dict[str, List[NodeProto]],
|
||||
output_name_to_node: Dict[str, NodeProto],
|
||||
):
|
||||
"""
|
||||
This pattern is from PyTorch bert model
|
||||
Fuse MatMul with FastGelu into one node:
|
||||
|
|
@ -34,20 +62,24 @@ class FusionGemmFastGelu(Fusion):
|
|||
return
|
||||
matmul = match_nodes[0]
|
||||
|
||||
weight = None
|
||||
# matmul weight should be two dimension
|
||||
# matmul input X should >= two dimension, input weight should be two dimension
|
||||
weight_index = -1
|
||||
x_dims = 0
|
||||
weight = None
|
||||
|
||||
for i, input in enumerate(matmul.input):
|
||||
initializer = self.model.get_initializer(input)
|
||||
if initializer is None:
|
||||
continue
|
||||
weight_index = i
|
||||
weight = NumpyHelper.to_array(initializer)
|
||||
break
|
||||
x_dims = self.get_dimensions(matmul.input[i])
|
||||
else:
|
||||
weight_index = i
|
||||
weight = NumpyHelper.to_array(initializer)
|
||||
if weight is None:
|
||||
return
|
||||
if len(weight.shape) != 2:
|
||||
return
|
||||
if x_dims < len(weight.shape):
|
||||
return
|
||||
|
||||
# bias weight should be one dimension
|
||||
bias_index = -1
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
153
onnxruntime/test/python/transformers/test_gemmfastgelu_fusion.py
Normal file
153
onnxruntime/test/python/transformers/test_gemmfastgelu_fusion.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
from onnx import TensorProto, helper
|
||||
from parity_utilities import find_transformers_source
|
||||
|
||||
if find_transformers_source():
|
||||
from fusion_options import FusionOptions
|
||||
from onnx_model import OnnxModel
|
||||
from optimizer import optimize_model
|
||||
else:
|
||||
from onnxruntime.transformers.fusion_options import FusionOptions
|
||||
from onnxruntime.transformers.onnx_model import OnnxModel
|
||||
from onnxruntime.transformers.optimizer import optimize_model
|
||||
|
||||
|
||||
onnxdomain = onnx.OperatorSetIdProto()
|
||||
onnxdomain.version = 12
|
||||
# The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
|
||||
onnxdomain.domain = ""
|
||||
msdomain = onnx.OperatorSetIdProto()
|
||||
msdomain.version = 1
|
||||
msdomain.domain = "com.microsoft"
|
||||
opsets = [onnxdomain, msdomain]
|
||||
|
||||
|
||||
def float_tensor(name: str, shape: List[int], random=False):
|
||||
low = 0.0
|
||||
high = 1.0
|
||||
total_elements = 1
|
||||
for x in shape:
|
||||
total_elements *= x
|
||||
weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements
|
||||
return helper.make_tensor(name, TensorProto.FLOAT, shape, weights)
|
||||
|
||||
|
||||
def create_MatMul_FastGelu_withoutBias(batch_size, m, n, k):
|
||||
# MatMul + FastGelu
|
||||
nodes = [
|
||||
helper.make_node("MatMul", ["input", "matmul_weight"], ["fastgelu_input"], "matmul"),
|
||||
]
|
||||
fastgelu_node = helper.make_node("FastGelu", ["fastgelu_input"], ["output"], "fastgelu")
|
||||
fastgelu_node.domain = "com.microsoft"
|
||||
nodes.append(fastgelu_node)
|
||||
|
||||
initializers = [float_tensor("matmul_weight", [k, n])] # initializers
|
||||
|
||||
graph = helper.make_graph(
|
||||
[node for node in nodes if node],
|
||||
"GemmFastGeluNoBiasModel", # name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info(
|
||||
"input",
|
||||
TensorProto.FLOAT,
|
||||
[batch_size, m, k],
|
||||
)
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info(
|
||||
"output",
|
||||
TensorProto.FLOAT,
|
||||
[batch_size, m, n],
|
||||
),
|
||||
],
|
||||
initializers,
|
||||
)
|
||||
|
||||
return helper.make_model(graph)
|
||||
|
||||
|
||||
def create_MatMul_FastGelu_withBias(batch_size, m, n, k):
|
||||
# MatMul + FastGelu
|
||||
nodes = [
|
||||
helper.make_node("MatMul", ["input", "matmul_weight"], ["fastgelu_input"], "matmul"),
|
||||
]
|
||||
fastgelu_node = helper.make_node("FastGelu", ["fastgelu_input", "fastgelu_bias"], ["output"], "fastgelu")
|
||||
fastgelu_node.domain = "com.microsoft"
|
||||
nodes.append(fastgelu_node)
|
||||
|
||||
initializers = [float_tensor("matmul_weight", [k, n]), float_tensor("fastgelu_bias", [n])] # initializers
|
||||
|
||||
graph = helper.make_graph(
|
||||
[node for node in nodes if node],
|
||||
"GemmFastGeluWithBiasModel", # name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info(
|
||||
"input",
|
||||
TensorProto.FLOAT,
|
||||
[batch_size, m, k],
|
||||
)
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info(
|
||||
"output",
|
||||
TensorProto.FLOAT,
|
||||
[batch_size, m, n],
|
||||
),
|
||||
],
|
||||
initializers,
|
||||
)
|
||||
|
||||
return helper.make_model(graph, opset_imports=opsets)
|
||||
|
||||
|
||||
class TestFusion(unittest.TestCase):
|
||||
def verify_fusion(self, optimized_model, expected_model_filename):
|
||||
optimized_model.topological_sort()
|
||||
|
||||
expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model_filename)
|
||||
expected_model = OnnxModel(onnx.load(expected_model_path))
|
||||
expected_model.topological_sort()
|
||||
|
||||
self.assertEqual(str(optimized_model.model.graph), str(expected_model.model.graph))
|
||||
|
||||
def test_gemmfastgelu_fusion_withoutbias(self):
|
||||
model = create_MatMul_FastGelu_withoutBias(32, 128, 64, 1024)
|
||||
dir = "."
|
||||
model_path = os.path.join(dir, "gemmfastgelu_nobias.onnx")
|
||||
onnx.save(model, model_path)
|
||||
|
||||
fusion_opt = FusionOptions("bert")
|
||||
fusion_opt.enable_gemm_fast_gelu = True
|
||||
optimized_model = optimize_model(input=model_path, optimization_options=fusion_opt)
|
||||
os.remove(model_path)
|
||||
|
||||
self.verify_fusion(optimized_model, "gemmfastgelu_nobias_opt.onnx")
|
||||
|
||||
def test_gemmfastgelu_fusion_withbias(self):
|
||||
model = create_MatMul_FastGelu_withBias(32, 128, 64, 1024)
|
||||
dir = "."
|
||||
model_path = os.path.join(dir, "gemmfastgelu_withbias.onnx")
|
||||
onnx.save(model, model_path)
|
||||
|
||||
fusion_opt = FusionOptions("bert")
|
||||
fusion_opt.enable_gemm_fast_gelu = True
|
||||
optimized_model = optimize_model(input=model_path, optimization_options=fusion_opt)
|
||||
|
||||
os.remove(model_path)
|
||||
|
||||
self.verify_fusion(optimized_model, "gemmfastgelu_withbias_opt.onnx")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in a new issue