mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
### Description <!-- Describe your changes. --> 1. Update the rules for GemmFastGelu fusion, MatMul input x should >= two dimension, input weight should == two dimension. 2. Add GemmFastGelu fusion test. 3. Add GemmFastGelu TunableOp, only contains the original implementation(Gemm + FastGelu). ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Co-authored-by: peixuanzuo <peixuanzuo@linmif39a000004.zvflicr54joexhdgnhvmxrxygg.phxx.internal.cloudapp.net>
30 lines
931 B
C++
30 lines
931 B
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include <pybind11/pybind11.h>
|
|
#include <pybind11/numpy.h>
|
|
#include "python/tools/kernel_explorer/device_array.h"
|
|
#include "python/tools/kernel_explorer/kernels/vector_add.h"
|
|
#include "python/tools/kernel_explorer/kernels/rocm/fast_gelu.h"
|
|
#include "python/tools/kernel_explorer/kernels/rocm/gemm.h"
|
|
#include "python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h"
|
|
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h"
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace onnxruntime {
|
|
|
|
PYBIND11_MODULE(_kernel_explorer, m) {
|
|
py::class_<DeviceArray>(m, "DeviceArray")
|
|
.def(py::init<py::array>())
|
|
.def("UpdateHostNumpyArray", &DeviceArray::UpdateHostNumpyArray);
|
|
InitVectorAdd(m);
|
|
#if USE_ROCM
|
|
InitFastGelu(m);
|
|
InitGemm(m);
|
|
InitSkipLayerNorm(m);
|
|
InitGemmFastGelu(m);
|
|
#endif
|
|
}
|
|
|
|
} // namespace onnxruntime
|