From 71b67ec1e2ad74b5eb1875d44ed329746aca104f Mon Sep 17 00:00:00 2001 From: cloudhan Date: Wed, 22 Mar 2023 14:49:26 +0800 Subject: [PATCH] Refactor ke register to be decentralized (#15036) So that we can remove all unnecessay header files --- .../tools/kernel_explorer/kernel_explorer.cc | 40 ++++++++++--------- .../kernel_explorer_interface.h | 23 +++++++++++ ...rnel_explorer.pyi => _kernel_explorer.pyi} | 0 .../kernel_explorer/kernels/rocm/fast_gelu.cu | 2 +- .../kernel_explorer/kernels/rocm/fast_gelu.h | 14 ------- .../kernel_explorer/kernels/rocm/gemm.cc | 12 +----- .../tools/kernel_explorer/kernels/rocm/gemm.h | 14 ------- .../kernel_explorer/kernels/rocm/gemm_ck.cu | 6 +-- .../kernel_explorer/kernels/rocm/gemm_ck.h | 14 ------- .../kernels/rocm/gemm_fast_gelu.cc | 22 ---------- .../kernels/rocm/gemm_fast_gelu.h | 14 ------- .../kernels/rocm/gemm_fast_gelu_ck.cu | 6 +-- .../kernels/rocm/gemm_fast_gelu_ck.h | 14 ------- .../kernels/rocm/gemm_fast_gelu_tunable.cu | 6 +-- .../kernels/rocm/gemm_fast_gelu_tunable.h | 14 ------- .../kernels/rocm/gemm_fast_gelu_unfused.cu | 5 +-- .../kernels/rocm/gemm_fast_gelu_unfused.h | 14 ------- .../kernel_explorer/kernels/rocm/gemm_ke.h | 2 +- .../kernels/rocm/gemm_rocblas.cc | 4 +- .../kernels/rocm/gemm_rocblas.h | 14 ------- .../kernels/rocm/gemm_softmax_gemm_permute.cu | 4 +- .../kernels/rocm/gemm_softmax_gemm_permute.h | 14 ------- .../kernels/rocm/gemm_tunable.cu | 4 +- .../kernels/rocm/gemm_tunable.h | 14 ------- .../kernels/rocm/skip_layer_norm.cu | 33 +++++++-------- .../kernels/rocm/skip_layer_norm.h | 14 ------- .../kernel_explorer/kernels/rocm/softmax.cu | 16 ++++---- .../kernel_explorer/kernels/rocm/softmax.h | 14 ------- .../kernel_explorer/kernels/vector_add.cu | 19 +++++---- .../kernel_explorer/kernels/vector_add.h | 14 ------- 30 files changed, 88 insertions(+), 298 deletions(-) rename onnxruntime/python/tools/kernel_explorer/kernels/{kernel_explorer.pyi => _kernel_explorer.pyi} (100%) delete mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/fast_gelu.h delete mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm.h delete mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.h delete mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cc delete mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h delete mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h delete mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h delete mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h delete mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.h delete mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.h delete mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.h delete mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h delete mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.h delete mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/vector_add.h diff --git a/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc b/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc index 1a2b0c702b..fbcb55c781 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc +++ b/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc @@ -2,39 +2,43 @@ // Licensed under the MIT License. #include +#include #include #include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernels/vector_add.h" -#include "python/tools/kernel_explorer/kernels/rocm/fast_gelu.h" -#include "python/tools/kernel_explorer/kernels/rocm/gemm.h" -#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h" -#include "python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.h" -#include "python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h" -#include "python/tools/kernel_explorer/kernels/rocm/softmax.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" namespace py = pybind11; namespace onnxruntime { -PYBIND11_MODULE(_kernel_explorer, m) { +static py::module::module_def _kernel_explorer_module_def; + +py::module GetKernelExplorerModule() { + static pybind11::module_ m = []() { + auto tmp = pybind11::module_::create_extension_module( + "_kernel_explorer", "", &_kernel_explorer_module_def); + tmp.dec_ref(); + return tmp; + }(); + return m; +} + +PYBIND11_PLUGIN_IMPL(_kernel_explorer) { + PYBIND11_CHECK_PYTHON_VERSION; + PYBIND11_ENSURE_INTERNALS_READY; + return GetKernelExplorerModule().ptr(); +} + +KE_REGISTER(m) { py::class_(m, "DeviceArray") .def(py::init()) .def("UpdateHostNumpyArray", &DeviceArray::UpdateHostNumpyArray); - InitVectorAdd(m); -#if USE_ROCM - InitFastGelu(m); - InitGemm(m); - InitSkipLayerNorm(m); - InitGemmFastGelu(m); - InitSoftmax(m); - InitGemmSoftmaxGemmPermute(m); -#endif m.def("is_composable_kernel_available", []() { #ifdef USE_COMPOSABLE_KERNEL return true; #else - return false; + return false; #endif }); } diff --git a/onnxruntime/python/tools/kernel_explorer/kernel_explorer_interface.h b/onnxruntime/python/tools/kernel_explorer/kernel_explorer_interface.h index 49d932101f..2f3d08b1ff 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernel_explorer_interface.h +++ b/onnxruntime/python/tools/kernel_explorer/kernel_explorer_interface.h @@ -3,6 +3,8 @@ #pragma once +#include + #include "core/providers/shared_library/provider_api.h" #ifdef USE_CUDA #include @@ -32,6 +34,8 @@ using TuningContextT = onnxruntime::rocm::tunable::RocmTuningContext; #error "kernel explorer only supports CUDA or ROCM" #endif +namespace onnxruntime { + /// Wrapping around Op and TunableOp class IKernelExplorer { public: @@ -78,3 +82,22 @@ class IKernelExplorer { StreamT stream_{0}; int repeats_{100}; }; + +pybind11::module GetKernelExplorerModule(); + +class KernelExplorerInit { + public: + explicit KernelExplorerInit(void (*init_func)(pybind11::module module)) { + init_func(GetKernelExplorerModule()); + } +}; + +#define KE_REGISTER_IMPL(unique_id, module_name) \ + static void KeInitFunc##unique_id(pybind11::module module_name); \ + static const KernelExplorerInit kKeInitializer##unique_id{KeInitFunc##unique_id}; \ + void KeInitFunc##unique_id(pybind11::module module_name) + +#define KE_REGISTER_(unique_id, module_name) KE_REGISTER_IMPL(unique_id, module_name) +#define KE_REGISTER(module_name) KE_REGISTER_(__COUNTER__, module_name) + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.pyi b/onnxruntime/python/tools/kernel_explorer/kernels/_kernel_explorer.pyi similarity index 100% rename from onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.pyi rename to onnxruntime/python/tools/kernel_explorer/kernels/_kernel_explorer.pyi diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/fast_gelu.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/fast_gelu.cu index a6bad5f1d5..8f57806e13 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/fast_gelu.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/fast_gelu.cu @@ -110,7 +110,7 @@ class FastGeluTunable : public IKernelExplorer { .def("Run", &name::Run) \ .def("IsSupported", &name::IsSupported); -void InitFastGelu(py::module m) { +KE_REGISTER(m) { REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, half); REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, float); REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, double); diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/fast_gelu.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/fast_gelu.h deleted file mode 100644 index 563a883a40..0000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/fast_gelu.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -namespace py = pybind11; - -namespace onnxruntime { - -void InitFastGelu(py::module m); - -} diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm.cc b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm.cc index f30371c935..540964a149 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm.cc +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm.cc @@ -1,15 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "python/tools/kernel_explorer/kernels/rocm/gemm.h" - #include #include #include "core/providers/rocm/tunable/gemm_common.h" -#include "python/tools/kernel_explorer/kernels/rocm/gemm_ck.h" -#include "python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.h" -#include "python/tools/kernel_explorer/kernels/rocm/gemm_tunable.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" using BlasOp = onnxruntime::rocm::tunable::blas::BlasOp; @@ -17,17 +13,13 @@ namespace py = pybind11; namespace onnxruntime { -void InitGemm(py::module mod) { +KE_REGISTER(mod) { auto blas_op = mod.def_submodule("blas_op"); py::enum_(blas_op, "BlasOp") .value("N", BlasOp::N, "Passthrough") .value("T", BlasOp::T, "Transpose") .export_values(); - - InitRocBlasGemm(mod); - InitComposableKernelGemm(mod); - InitTunableGemm(mod); } } // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm.h deleted file mode 100644 index 3c6d0a5630..0000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -namespace py = pybind11; - -namespace onnxruntime { - -void InitGemm(py::module mod); - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.cu index f8b07484cb..a29f9bdd0f 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.cu @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "python/tools/kernel_explorer/kernels/rocm/gemm_ck.h" - #include #include @@ -208,15 +206,13 @@ class CKStridedBatchedGemm : public IKernelExplorer { REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, Col, Row, "TN"); \ REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, Col, Col, "TT"); -void InitComposableKernelGemm(py::module m) { +KE_REGISTER(m) { REGISTER_CKGEMM_FOR_ALL_TRANSAB(float); REGISTER_CKGEMM_FOR_ALL_TRANSAB(half); REGISTER_CKSTRIDEDBATCHEDGEMM_FOR_ALL_TRANSAB(float); REGISTER_CKSTRIDEDBATCHEDGEMM_FOR_ALL_TRANSAB(half); } -#else -void InitComposableKernelGemm(py::module) {} #endif // USE_COMPOSABLE_KERNEL } // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.h deleted file mode 100644 index ba0dfdcb64..0000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -namespace py = pybind11; - -namespace onnxruntime { - -void InitComposableKernelGemm(py::module mod); - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cc b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cc deleted file mode 100644 index f494af834d..0000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cc +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h" - -#include - -#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h" -#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h" -#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h" - -namespace py = pybind11; - -namespace onnxruntime { - -void InitGemmFastGelu(py::module mod) { - InitGemmFastGeluUnfused(mod); - InitGemmFastGeluTunable(mod); - InitComposableKernelGemmFastGelu(mod); -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h deleted file mode 100644 index 4a0cebe8cd..0000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -namespace py = pybind11; - -namespace onnxruntime { - -void InitGemmFastGelu(py::module mod); - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu index cfd39068ea..5f2e9fbd6c 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h" - #include #include @@ -120,12 +118,10 @@ class CKGemmFastGelu : public IKernelExplorer { REGISTER_OP(type, Col, Row, "TN"); \ REGISTER_OP(type, Col, Col, "TT"); -void InitComposableKernelGemmFastGelu(py::module m) { +KE_REGISTER(m) { REGISTER_OP_FOR_ALL_TRANSAB(float); REGISTER_OP_FOR_ALL_TRANSAB(half); } -#else -void InitComposableKernelGemmFastGelu(py::module) {} #endif // USE_COMPOSABLE_KERNEL } // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h deleted file mode 100644 index 13a22fae97..0000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -namespace py = pybind11; - -namespace onnxruntime { - -void InitComposableKernelGemmFastGelu(py::module mod); - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu index e184a529f4..9c66e43170 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h" - #include #include @@ -97,11 +95,9 @@ class GemmFastGeluTunable : public IKernelExplorer { REGISTER_OP(type, Col, Row, "TN"); \ REGISTER_OP(type, Col, Col, "TT"); -void InitGemmFastGeluTunable(py::module m) { +KE_REGISTER(m) { REGISTER_OP_FOR_ALL_TRANSAB(float); REGISTER_OP_FOR_ALL_TRANSAB(half); } -#undef REGISTER_OP - } // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h deleted file mode 100644 index f67b4950a4..0000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -namespace py = pybind11; - -namespace onnxruntime { - -void InitGemmFastGeluTunable(py::module mod); - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.cu index 8c39da12b4..c1b1cca2cc 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.cu @@ -1,6 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h" #include @@ -89,11 +88,9 @@ class GemmFastGeluUnfused : public IKernelExplorer { .def("ListOps", &GemmFastGeluUnfused::ListOps) \ .def("SelectOp", &GemmFastGeluUnfused::SelectOp); -void InitGemmFastGeluUnfused(py::module m) { +KE_REGISTER(m) { REGISTER_OP(float) REGISTER_OP(half) } -#undef REGISTER_OP - } // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h deleted file mode 100644 index 96ea8b8360..0000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -namespace py = pybind11; - -namespace onnxruntime { - -void InitGemmFastGeluUnfused(py::module mod); - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ke.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ke.h index 0d4ad8d7a4..7b20732d2c 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ke.h +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ke.h @@ -43,4 +43,4 @@ class IBatchedGemmKernelExplorer : public IKernelExplorer { std::shared_ptr dev_cs_; }; -} +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.cc b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.cc index f6346b9eaf..673e04621d 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.cc +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.h" - #include #include @@ -236,7 +234,7 @@ class RocBlasStridedBatchedGemm : public IKernelExplorer { DeviceArray&, int64_t, int64_t, \ int64_t>()) -void InitRocBlasGemm(py::module mod) { +KE_REGISTER(mod) { REGISTER_GEMM(float); REGISTER_GEMM(half); diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.h deleted file mode 100644 index 2d49783821..0000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -namespace py = pybind11; - -namespace onnxruntime { - -void InitRocBlasGemm(py::module mod); - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu index e39bf534bd..a18dfcb021 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.h" - #include "pybind11/stl.h" #include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" @@ -286,7 +284,7 @@ class GemmSoftmaxGemmPermuteTunable : public IGemmSoftmaxGemmPermuteKernelExplor #define REGISTER_TUNABLE(dtype) \ REGISTER_COMMON("GemmSoftmaxGemmPermuteTunable_" #dtype, GemmSoftmaxGemmPermuteTunable, dtype) -void InitGemmSoftmaxGemmPermute(py::module m) { +KE_REGISTER(m) { REGISTER_GENERIC(half); #ifdef USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.h deleted file mode 100644 index 67dc16ca31..0000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -namespace py = pybind11; - -namespace onnxruntime { - -void InitGemmSoftmaxGemmPermute(py::module mod); - -} diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.cu index 5233692395..ee39cf32a2 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.cu @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "python/tools/kernel_explorer/kernels/rocm/gemm_tunable.h" - #include #include @@ -256,7 +254,7 @@ class StridedBatchedGemmTunable : public IKernelExplorer { REGISTER_STRIDED_BATCHED_GEMM(dtype, Col, Row, "TN"); \ REGISTER_STRIDED_BATCHED_GEMM(dtype, Col, Col, "TT"); -void InitTunableGemm(py::module m) { +KE_REGISTER(m) { REGISTER_GEMM_FOR_ALL_TRANSAB(float); REGISTER_GEMM_FOR_ALL_TRANSAB(half); diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.h deleted file mode 100644 index 48aa2b792b..0000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -namespace py = pybind11; - -namespace onnxruntime { - -void InitTunableGemm(py::module mod); - -} diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu index 52f1f5c1b1..ac5ec602f8 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h" - #include #include @@ -20,8 +18,8 @@ class SkipLayerNormSmall : public IKernelExplorer { SkipLayerNormSmall(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias, float epsilon, int hidden_size, int element_count) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(skip_input_bias_add_output.ptr()), - static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(gamma.ptr()), + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(skip_input_bias_add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(bias.ptr()), epsilon, hidden_size, element_count) {} void Run() override { @@ -45,7 +43,7 @@ class SkipLayerNormRegular : public IKernelExplorer { DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias, float epsilon, int hidden_size, int element_count) : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(skip_input_bias_add_output.ptr()), - static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(gamma.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(bias.ptr()), epsilon, hidden_size, element_count) {} void Run() override { @@ -65,11 +63,11 @@ class SkipLayerNormRegular : public IKernelExplorer { template class SkipLayerNormStaticSelection : public IKernelExplorer { public: - SkipLayerNormStaticSelection(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, + SkipLayerNormStaticSelection(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias, float epsilon, int hidden_size, int element_count) : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(skip_input_bias_add_output.ptr()), - static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(gamma.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(bias.ptr()), epsilon, hidden_size, element_count) {} void Run() override { @@ -93,9 +91,8 @@ class SkipLayerNormTunable : public IKernelExplorer { DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias, float epsilon, int hidden_size, int element_count) : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(skip_input_bias_add_output.ptr()), - static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(gamma.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(bias.ptr()), epsilon, hidden_size, element_count) { - params_.TuningContext()->EnableTunableOp(); } @@ -113,14 +110,14 @@ class SkipLayerNormTunable : public IKernelExplorer { contrib::rocm::SkipLayerNormTunableOp op_{}; }; -#define REGISTER_OP(name, type, threads_per_block, vec_size) \ - py::class_>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \ - .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ - .def("Run", &name::Run) \ +#define REGISTER_OP(name, type, threads_per_block, vec_size) \ + py::class_>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run) \ .def("IsSupported", &name::IsSupported); #define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \ @@ -157,7 +154,7 @@ class SkipLayerNormTunable : public IKernelExplorer { .def("Run", &name::Run) \ .def("IsSupported", &name::IsSupported); -void InitSkipLayerNorm(py::module m) { +KE_REGISTER(m) { REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, half); REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, float); REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegular, half); diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h deleted file mode 100644 index 4dc72bc3da..0000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -namespace py = pybind11; - -namespace onnxruntime { - -void InitSkipLayerNorm(py::module m); - -} diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu index c514c65157..2c9980ef30 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "python/tools/kernel_explorer/kernels/rocm/softmax.h" - #include #include #include @@ -198,7 +196,7 @@ class CKSoftmax : public IKernelExplorer { .def("ListOps", &name::ListOps) \ .def("SelectOp", &name::SelectOp); -void InitSoftmax(py::module m) { +KE_REGISTER(m) { REGISTER_OP_FOR_ALL_VEC_SIZE(SoftmaxBlockwise, half); REGISTER_OP_FOR_ALL_VEC_SIZE(SoftmaxBlockwise, float); @@ -210,11 +208,13 @@ void InitSoftmax(py::module m) { REGISTER_OP_TYPED(SoftmaxTunable, half); REGISTER_OP_TYPED(SoftmaxTunable, float); - -#ifdef USE_COMPOSABLE_KERNEL - REGISTER_OP_TYPED(CKSoftmax, half); - REGISTER_OP_TYPED(CKSoftmax, float); -#endif // USE_COMPOSABLE_KERNEL } +#ifdef USE_COMPOSABLE_KERNEL +KE_REGISTER(m) { + REGISTER_OP_TYPED(CKSoftmax, half); + REGISTER_OP_TYPED(CKSoftmax, float); +} +#endif // USE_COMPOSABLE_KERNEL + } // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.h deleted file mode 100644 index 5ae71614e2..0000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -namespace py = pybind11; - -namespace onnxruntime { - -void InitSoftmax(py::module m); - -} diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.cu b/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.cu index 0f63274b59..ed10fc0a2c 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.cu @@ -2,7 +2,6 @@ // Licensed under the MIT License. // This file serve as a simple example for adding a tunable op to onnxruntime. -#include "python/tools/kernel_explorer/kernels/vector_add.h" #if USE_CUDA #include @@ -26,10 +25,10 @@ namespace py = pybind11; namespace onnxruntime { -//##################################################################################################################### -// In practice, VectorAddParam, VectorAddOp and VectorAddTunableOp should be tightly integrated to onnxruntime. -// We place them here purely for demo purpose. -//##################################################################################################################### +// ##################################################################################################################### +// In practice, VectorAddParam, VectorAddOp and VectorAddTunableOp should be tightly integrated to onnxruntime. +// We place them here purely for demo purpose. +// ##################################################################################################################### // Extend the OpParams so that all specializations have the same parameter passing interface template @@ -91,10 +90,10 @@ class VectorAddTunableOp : #undef ADD_OP -//##################################################################################################################### -// Following code just wraps our kernel implementation and expose them as python interface. This is the code that -// should be in the kernel_explorer directory. -//##################################################################################################################### +// ##################################################################################################################### +// Following code just wraps our kernel implementation and expose them as python interface. This is the code that +// should be in the kernel_explorer directory. +// ##################################################################################################################### template class VectorAdd : public IKernelExplorer { @@ -174,7 +173,7 @@ class VectorAddTunable : public IKernelExplorer { .def("Profile", &VectorAddTunable::Profile) \ .def("Run", &VectorAddTunable::Run); -void InitVectorAdd(py::module m) { +KE_REGISTER(m) { REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(VectorAdd, half); REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(VectorAdd, float); diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.h b/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.h deleted file mode 100644 index c895961d96..0000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -namespace py = pybind11; - -namespace onnxruntime { - -void InitVectorAdd(py::module m); - -}