mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Refactor ke register to be decentralized (#15036)
So that we can remove all unnecessay header files
This commit is contained in:
parent
0086f7590d
commit
71b67ec1e2
30 changed files with 88 additions and 298 deletions
|
|
@ -2,39 +2,43 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/embed.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include "python/tools/kernel_explorer/device_array.h"
|
||||
#include "python/tools/kernel_explorer/kernels/vector_add.h"
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/fast_gelu.h"
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm.h"
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h"
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.h"
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h"
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/softmax.h"
|
||||
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
PYBIND11_MODULE(_kernel_explorer, m) {
|
||||
static py::module::module_def _kernel_explorer_module_def;
|
||||
|
||||
py::module GetKernelExplorerModule() {
|
||||
static pybind11::module_ m = []() {
|
||||
auto tmp = pybind11::module_::create_extension_module(
|
||||
"_kernel_explorer", "", &_kernel_explorer_module_def);
|
||||
tmp.dec_ref();
|
||||
return tmp;
|
||||
}();
|
||||
return m;
|
||||
}
|
||||
|
||||
PYBIND11_PLUGIN_IMPL(_kernel_explorer) {
|
||||
PYBIND11_CHECK_PYTHON_VERSION;
|
||||
PYBIND11_ENSURE_INTERNALS_READY;
|
||||
return GetKernelExplorerModule().ptr();
|
||||
}
|
||||
|
||||
KE_REGISTER(m) {
|
||||
py::class_<DeviceArray>(m, "DeviceArray")
|
||||
.def(py::init<py::array>())
|
||||
.def("UpdateHostNumpyArray", &DeviceArray::UpdateHostNumpyArray);
|
||||
InitVectorAdd(m);
|
||||
#if USE_ROCM
|
||||
InitFastGelu(m);
|
||||
InitGemm(m);
|
||||
InitSkipLayerNorm(m);
|
||||
InitGemmFastGelu(m);
|
||||
InitSoftmax(m);
|
||||
InitGemmSoftmaxGemmPermute(m);
|
||||
#endif
|
||||
|
||||
m.def("is_composable_kernel_available", []() {
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
return false;
|
||||
#endif
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#ifdef USE_CUDA
|
||||
#include <cuda_runtime_api.h>
|
||||
|
|
@ -32,6 +34,8 @@ using TuningContextT = onnxruntime::rocm::tunable::RocmTuningContext;
|
|||
#error "kernel explorer only supports CUDA or ROCM"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/// Wrapping around Op and TunableOp
|
||||
class IKernelExplorer {
|
||||
public:
|
||||
|
|
@ -78,3 +82,22 @@ class IKernelExplorer {
|
|||
StreamT stream_{0};
|
||||
int repeats_{100};
|
||||
};
|
||||
|
||||
pybind11::module GetKernelExplorerModule();
|
||||
|
||||
class KernelExplorerInit {
|
||||
public:
|
||||
explicit KernelExplorerInit(void (*init_func)(pybind11::module module)) {
|
||||
init_func(GetKernelExplorerModule());
|
||||
}
|
||||
};
|
||||
|
||||
#define KE_REGISTER_IMPL(unique_id, module_name) \
|
||||
static void KeInitFunc##unique_id(pybind11::module module_name); \
|
||||
static const KernelExplorerInit kKeInitializer##unique_id{KeInitFunc##unique_id}; \
|
||||
void KeInitFunc##unique_id(pybind11::module module_name)
|
||||
|
||||
#define KE_REGISTER_(unique_id, module_name) KE_REGISTER_IMPL(unique_id, module_name)
|
||||
#define KE_REGISTER(module_name) KE_REGISTER_(__COUNTER__, module_name)
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ class FastGeluTunable : public IKernelExplorer {
|
|||
.def("Run", &name<type>::Run) \
|
||||
.def("IsSupported", &name<type>::IsSupported);
|
||||
|
||||
void InitFastGelu(py::module m) {
|
||||
KE_REGISTER(m) {
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, half);
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, float);
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, double);
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void InitFastGelu(py::module m);
|
||||
|
||||
}
|
||||
|
|
@ -1,15 +1,11 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include "core/providers/rocm/tunable/gemm_common.h"
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_ck.h"
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.h"
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_tunable.h"
|
||||
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
|
||||
|
||||
using BlasOp = onnxruntime::rocm::tunable::blas::BlasOp;
|
||||
|
||||
|
|
@ -17,17 +13,13 @@ namespace py = pybind11;
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
void InitGemm(py::module mod) {
|
||||
KE_REGISTER(mod) {
|
||||
auto blas_op = mod.def_submodule("blas_op");
|
||||
|
||||
py::enum_<BlasOp>(blas_op, "BlasOp")
|
||||
.value("N", BlasOp::N, "Passthrough")
|
||||
.value("T", BlasOp::T, "Transpose")
|
||||
.export_values();
|
||||
|
||||
InitRocBlasGemm(mod);
|
||||
InitComposableKernelGemm(mod);
|
||||
InitTunableGemm(mod);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void InitGemm(py::module mod);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,8 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_ck.h"
|
||||
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <memory>
|
||||
|
|
@ -208,15 +206,13 @@ class CKStridedBatchedGemm : public IKernelExplorer {
|
|||
REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, Col, Row, "TN"); \
|
||||
REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, Col, Col, "TT");
|
||||
|
||||
void InitComposableKernelGemm(py::module m) {
|
||||
KE_REGISTER(m) {
|
||||
REGISTER_CKGEMM_FOR_ALL_TRANSAB(float);
|
||||
REGISTER_CKGEMM_FOR_ALL_TRANSAB(half);
|
||||
|
||||
REGISTER_CKSTRIDEDBATCHEDGEMM_FOR_ALL_TRANSAB(float);
|
||||
REGISTER_CKSTRIDEDBATCHEDGEMM_FOR_ALL_TRANSAB(half);
|
||||
}
|
||||
#else
|
||||
void InitComposableKernelGemm(py::module) {}
|
||||
#endif // USE_COMPOSABLE_KERNEL
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void InitComposableKernelGemm(py::module mod);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h"
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h"
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void InitGemmFastGelu(py::module mod) {
|
||||
InitGemmFastGeluUnfused(mod);
|
||||
InitGemmFastGeluTunable(mod);
|
||||
InitComposableKernelGemmFastGelu(mod);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void InitGemmFastGelu(py::module mod);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,8 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
|
|
@ -120,12 +118,10 @@ class CKGemmFastGelu : public IKernelExplorer {
|
|||
REGISTER_OP(type, Col, Row, "TN"); \
|
||||
REGISTER_OP(type, Col, Col, "TT");
|
||||
|
||||
void InitComposableKernelGemmFastGelu(py::module m) {
|
||||
KE_REGISTER(m) {
|
||||
REGISTER_OP_FOR_ALL_TRANSAB(float);
|
||||
REGISTER_OP_FOR_ALL_TRANSAB(half);
|
||||
}
|
||||
#else
|
||||
void InitComposableKernelGemmFastGelu(py::module) {}
|
||||
#endif // USE_COMPOSABLE_KERNEL
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void InitComposableKernelGemmFastGelu(py::module mod);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,8 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h"
|
||||
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <string>
|
||||
|
|
@ -97,11 +95,9 @@ class GemmFastGeluTunable : public IKernelExplorer {
|
|||
REGISTER_OP(type, Col, Row, "TN"); \
|
||||
REGISTER_OP(type, Col, Col, "TT");
|
||||
|
||||
void InitGemmFastGeluTunable(py::module m) {
|
||||
KE_REGISTER(m) {
|
||||
REGISTER_OP_FOR_ALL_TRANSAB(float);
|
||||
REGISTER_OP_FOR_ALL_TRANSAB(half);
|
||||
}
|
||||
|
||||
#undef REGISTER_OP
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void InitGemmFastGeluTunable(py::module mod);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h"
|
||||
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
|
|
@ -89,11 +88,9 @@ class GemmFastGeluUnfused : public IKernelExplorer {
|
|||
.def("ListOps", &GemmFastGeluUnfused<type>::ListOps) \
|
||||
.def("SelectOp", &GemmFastGeluUnfused<type>::SelectOp);
|
||||
|
||||
void InitGemmFastGeluUnfused(py::module m) {
|
||||
KE_REGISTER(m) {
|
||||
REGISTER_OP(float)
|
||||
REGISTER_OP(half)
|
||||
}
|
||||
|
||||
#undef REGISTER_OP
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void InitGemmFastGeluUnfused(py::module mod);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -43,4 +43,4 @@ class IBatchedGemmKernelExplorer : public IKernelExplorer {
|
|||
std::shared_ptr<T*> dev_cs_;
|
||||
};
|
||||
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
|
|
@ -236,7 +234,7 @@ class RocBlasStridedBatchedGemm : public IKernelExplorer {
|
|||
DeviceArray&, int64_t, int64_t, \
|
||||
int64_t>())
|
||||
|
||||
void InitRocBlasGemm(py::module mod) {
|
||||
KE_REGISTER(mod) {
|
||||
REGISTER_GEMM(float);
|
||||
REGISTER_GEMM(half);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void InitRocBlasGemm(py::module mod);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,8 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.h"
|
||||
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh"
|
||||
|
|
@ -286,7 +284,7 @@ class GemmSoftmaxGemmPermuteTunable : public IGemmSoftmaxGemmPermuteKernelExplor
|
|||
#define REGISTER_TUNABLE(dtype) \
|
||||
REGISTER_COMMON("GemmSoftmaxGemmPermuteTunable_" #dtype, GemmSoftmaxGemmPermuteTunable, dtype)
|
||||
|
||||
void InitGemmSoftmaxGemmPermute(py::module m) {
|
||||
KE_REGISTER(m) {
|
||||
REGISTER_GENERIC(half);
|
||||
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void InitGemmSoftmaxGemmPermute(py::module mod);
|
||||
|
||||
}
|
||||
|
|
@ -1,8 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/gemm_tunable.h"
|
||||
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <string>
|
||||
|
|
@ -256,7 +254,7 @@ class StridedBatchedGemmTunable : public IKernelExplorer {
|
|||
REGISTER_STRIDED_BATCHED_GEMM(dtype, Col, Row, "TN"); \
|
||||
REGISTER_STRIDED_BATCHED_GEMM(dtype, Col, Col, "TT");
|
||||
|
||||
void InitTunableGemm(py::module m) {
|
||||
KE_REGISTER(m) {
|
||||
REGISTER_GEMM_FOR_ALL_TRANSAB(float);
|
||||
REGISTER_GEMM_FOR_ALL_TRANSAB(half);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void InitTunableGemm(py::module mod);
|
||||
|
||||
}
|
||||
|
|
@ -1,8 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h"
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
|
|
@ -20,8 +18,8 @@ class SkipLayerNormSmall : public IKernelExplorer {
|
|||
SkipLayerNormSmall(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip,
|
||||
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
|
||||
float epsilon, int hidden_size, int element_count)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
|
||||
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
|
||||
|
||||
void Run() override {
|
||||
|
|
@ -45,7 +43,7 @@ class SkipLayerNormRegular : public IKernelExplorer {
|
|||
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
|
||||
float epsilon, int hidden_size, int element_count)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
|
||||
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
|
||||
|
||||
void Run() override {
|
||||
|
|
@ -65,11 +63,11 @@ class SkipLayerNormRegular : public IKernelExplorer {
|
|||
template <typename T>
|
||||
class SkipLayerNormStaticSelection : public IKernelExplorer {
|
||||
public:
|
||||
SkipLayerNormStaticSelection(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input,
|
||||
SkipLayerNormStaticSelection(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input,
|
||||
DeviceArray& skip, DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
|
||||
float epsilon, int hidden_size, int element_count)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
|
||||
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
|
||||
|
||||
void Run() override {
|
||||
|
|
@ -93,9 +91,8 @@ class SkipLayerNormTunable : public IKernelExplorer {
|
|||
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
|
||||
float epsilon, int hidden_size, int element_count)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
|
||||
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {
|
||||
|
||||
params_.TuningContext()->EnableTunableOp();
|
||||
}
|
||||
|
||||
|
|
@ -113,14 +110,14 @@ class SkipLayerNormTunable : public IKernelExplorer {
|
|||
contrib::rocm::SkipLayerNormTunableOp<T> op_{};
|
||||
};
|
||||
|
||||
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
|
||||
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
float, int, int>()) \
|
||||
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
|
||||
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
|
||||
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
|
||||
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
|
||||
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
float, int, int>()) \
|
||||
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
|
||||
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
|
||||
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
|
||||
.def("IsSupported", &name<type, threads_per_block, vec_size>::IsSupported);
|
||||
|
||||
#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \
|
||||
|
|
@ -157,7 +154,7 @@ class SkipLayerNormTunable : public IKernelExplorer {
|
|||
.def("Run", &name<type>::Run) \
|
||||
.def("IsSupported", &name<type>::IsSupported);
|
||||
|
||||
void InitSkipLayerNorm(py::module m) {
|
||||
KE_REGISTER(m) {
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, half);
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, float);
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegular, half);
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void InitSkipLayerNorm(py::module m);
|
||||
|
||||
}
|
||||
|
|
@ -1,8 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "python/tools/kernel_explorer/kernels/rocm/softmax.h"
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
|
@ -198,7 +196,7 @@ class CKSoftmax : public IKernelExplorer {
|
|||
.def("ListOps", &name<type>::ListOps) \
|
||||
.def("SelectOp", &name<type>::SelectOp);
|
||||
|
||||
void InitSoftmax(py::module m) {
|
||||
KE_REGISTER(m) {
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(SoftmaxBlockwise, half);
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(SoftmaxBlockwise, float);
|
||||
|
||||
|
|
@ -210,11 +208,13 @@ void InitSoftmax(py::module m) {
|
|||
|
||||
REGISTER_OP_TYPED(SoftmaxTunable, half);
|
||||
REGISTER_OP_TYPED(SoftmaxTunable, float);
|
||||
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
REGISTER_OP_TYPED(CKSoftmax, half);
|
||||
REGISTER_OP_TYPED(CKSoftmax, float);
|
||||
#endif // USE_COMPOSABLE_KERNEL
|
||||
}
|
||||
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
KE_REGISTER(m) {
|
||||
REGISTER_OP_TYPED(CKSoftmax, half);
|
||||
REGISTER_OP_TYPED(CKSoftmax, float);
|
||||
}
|
||||
#endif // USE_COMPOSABLE_KERNEL
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void InitSoftmax(py::module m);
|
||||
|
||||
}
|
||||
|
|
@ -2,7 +2,6 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
// This file serve as a simple example for adding a tunable op to onnxruntime.
|
||||
#include "python/tools/kernel_explorer/kernels/vector_add.h"
|
||||
|
||||
#if USE_CUDA
|
||||
#include <cuda_runtime_api.h>
|
||||
|
|
@ -26,10 +25,10 @@ namespace py = pybind11;
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
//#####################################################################################################################
|
||||
// In practice, VectorAddParam, VectorAddOp and VectorAddTunableOp should be tightly integrated to onnxruntime.
|
||||
// We place them here purely for demo purpose.
|
||||
//#####################################################################################################################
|
||||
// #####################################################################################################################
|
||||
// In practice, VectorAddParam, VectorAddOp and VectorAddTunableOp should be tightly integrated to onnxruntime.
|
||||
// We place them here purely for demo purpose.
|
||||
// #####################################################################################################################
|
||||
|
||||
// Extend the OpParams so that all specializations have the same parameter passing interface
|
||||
template <typename T>
|
||||
|
|
@ -91,10 +90,10 @@ class VectorAddTunableOp :
|
|||
|
||||
#undef ADD_OP
|
||||
|
||||
//#####################################################################################################################
|
||||
// Following code just wraps our kernel implementation and expose them as python interface. This is the code that
|
||||
// should be in the kernel_explorer directory.
|
||||
//#####################################################################################################################
|
||||
// #####################################################################################################################
|
||||
// Following code just wraps our kernel implementation and expose them as python interface. This is the code that
|
||||
// should be in the kernel_explorer directory.
|
||||
// #####################################################################################################################
|
||||
|
||||
template <typename T, int TPB, int Vec>
|
||||
class VectorAdd : public IKernelExplorer {
|
||||
|
|
@ -174,7 +173,7 @@ class VectorAddTunable : public IKernelExplorer {
|
|||
.def("Profile", &VectorAddTunable<type>::Profile) \
|
||||
.def("Run", &VectorAddTunable<type>::Run);
|
||||
|
||||
void InitVectorAdd(py::module m) {
|
||||
KE_REGISTER(m) {
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(VectorAdd, half);
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(VectorAdd, float);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void InitVectorAdd(py::module m);
|
||||
|
||||
}
|
||||
Loading…
Reference in a new issue