onnxruntime/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc
PeixuanZuo 8f3c6ea0df
[ROCm] Add GemmFastGelu TunableOp (#13589)
### Description
<!-- Describe your changes. -->

1. Update the rules for GemmFastGelu fusion, MatMul input x should >=
two dimension, input weight should == two dimension.
2. Add GemmFastGelu fusion test.
3. Add GemmFastGelu TunableOp, only contains the original
implementation(Gemm + FastGelu).


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: peixuanzuo <peixuanzuo@linmif39a000004.zvflicr54joexhdgnhvmxrxygg.phxx.internal.cloudapp.net>
2022-11-22 12:58:01 +08:00

30 lines
931 B
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernels/vector_add.h"
#include "python/tools/kernel_explorer/kernels/rocm/fast_gelu.h"
#include "python/tools/kernel_explorer/kernels/rocm/gemm.h"
#include "python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h"
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h"
namespace py = pybind11;
namespace onnxruntime {
PYBIND11_MODULE(_kernel_explorer, m) {
py::class_<DeviceArray>(m, "DeviceArray")
.def(py::init<py::array>())
.def("UpdateHostNumpyArray", &DeviceArray::UpdateHostNumpyArray);
InitVectorAdd(m);
#if USE_ROCM
InitFastGelu(m);
InitGemm(m);
InitSkipLayerNorm(m);
InitGemmFastGelu(m);
#endif
}
} // namespace onnxruntime