Refactor ke register to be decentralized (#15036)

So that we can remove all unnecessay header files
This commit is contained in:
cloudhan 2023-03-22 14:49:26 +08:00 committed by GitHub
parent 0086f7590d
commit 71b67ec1e2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 88 additions and 298 deletions

View file

@ -2,39 +2,43 @@
// Licensed under the MIT License.
#include <pybind11/pybind11.h>
#include <pybind11/embed.h>
#include <pybind11/numpy.h>
#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernels/vector_add.h"
#include "python/tools/kernel_explorer/kernels/rocm/fast_gelu.h"
#include "python/tools/kernel_explorer/kernels/rocm/gemm.h"
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h"
#include "python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.h"
#include "python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h"
#include "python/tools/kernel_explorer/kernels/rocm/softmax.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
namespace py = pybind11;
namespace onnxruntime {
PYBIND11_MODULE(_kernel_explorer, m) {
static py::module::module_def _kernel_explorer_module_def;
py::module GetKernelExplorerModule() {
static pybind11::module_ m = []() {
auto tmp = pybind11::module_::create_extension_module(
"_kernel_explorer", "", &_kernel_explorer_module_def);
tmp.dec_ref();
return tmp;
}();
return m;
}
PYBIND11_PLUGIN_IMPL(_kernel_explorer) {
PYBIND11_CHECK_PYTHON_VERSION;
PYBIND11_ENSURE_INTERNALS_READY;
return GetKernelExplorerModule().ptr();
}
KE_REGISTER(m) {
py::class_<DeviceArray>(m, "DeviceArray")
.def(py::init<py::array>())
.def("UpdateHostNumpyArray", &DeviceArray::UpdateHostNumpyArray);
InitVectorAdd(m);
#if USE_ROCM
InitFastGelu(m);
InitGemm(m);
InitSkipLayerNorm(m);
InitGemmFastGelu(m);
InitSoftmax(m);
InitGemmSoftmaxGemmPermute(m);
#endif
m.def("is_composable_kernel_available", []() {
#ifdef USE_COMPOSABLE_KERNEL
return true;
#else
return false;
return false;
#endif
});
}

View file

@ -3,6 +3,8 @@
#pragma once
#include <pybind11/pybind11.h>
#include "core/providers/shared_library/provider_api.h"
#ifdef USE_CUDA
#include <cuda_runtime_api.h>
@ -32,6 +34,8 @@ using TuningContextT = onnxruntime::rocm::tunable::RocmTuningContext;
#error "kernel explorer only supports CUDA or ROCM"
#endif
namespace onnxruntime {
/// Wrapping around Op and TunableOp
class IKernelExplorer {
public:
@ -78,3 +82,22 @@ class IKernelExplorer {
StreamT stream_{0};
int repeats_{100};
};
pybind11::module GetKernelExplorerModule();
class KernelExplorerInit {
public:
explicit KernelExplorerInit(void (*init_func)(pybind11::module module)) {
init_func(GetKernelExplorerModule());
}
};
#define KE_REGISTER_IMPL(unique_id, module_name) \
static void KeInitFunc##unique_id(pybind11::module module_name); \
static const KernelExplorerInit kKeInitializer##unique_id{KeInitFunc##unique_id}; \
void KeInitFunc##unique_id(pybind11::module module_name)
#define KE_REGISTER_(unique_id, module_name) KE_REGISTER_IMPL(unique_id, module_name)
#define KE_REGISTER(module_name) KE_REGISTER_(__COUNTER__, module_name)
} // namespace onnxruntime

View file

@ -110,7 +110,7 @@ class FastGeluTunable : public IKernelExplorer {
.def("Run", &name<type>::Run) \
.def("IsSupported", &name<type>::IsSupported);
void InitFastGelu(py::module m) {
KE_REGISTER(m) {
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, half);
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, float);
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, double);

View file

@ -1,14 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitFastGelu(py::module m);
}

View file

@ -1,15 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/rocm/gemm.h"
#include <pybind11/pybind11.h>
#include <type_traits>
#include "core/providers/rocm/tunable/gemm_common.h"
#include "python/tools/kernel_explorer/kernels/rocm/gemm_ck.h"
#include "python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.h"
#include "python/tools/kernel_explorer/kernels/rocm/gemm_tunable.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
using BlasOp = onnxruntime::rocm::tunable::blas::BlasOp;
@ -17,17 +13,13 @@ namespace py = pybind11;
namespace onnxruntime {
void InitGemm(py::module mod) {
KE_REGISTER(mod) {
auto blas_op = mod.def_submodule("blas_op");
py::enum_<BlasOp>(blas_op, "BlasOp")
.value("N", BlasOp::N, "Passthrough")
.value("T", BlasOp::T, "Transpose")
.export_values();
InitRocBlasGemm(mod);
InitComposableKernelGemm(mod);
InitTunableGemm(mod);
}
} // namespace onnxruntime

View file

@ -1,14 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitGemm(py::module mod);
} // namespace onnxruntime

View file

@ -1,8 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/rocm/gemm_ck.h"
#include <pybind11/stl.h>
#include <memory>
@ -208,15 +206,13 @@ class CKStridedBatchedGemm : public IKernelExplorer {
REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, Col, Row, "TN"); \
REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, Col, Col, "TT");
void InitComposableKernelGemm(py::module m) {
KE_REGISTER(m) {
REGISTER_CKGEMM_FOR_ALL_TRANSAB(float);
REGISTER_CKGEMM_FOR_ALL_TRANSAB(half);
REGISTER_CKSTRIDEDBATCHEDGEMM_FOR_ALL_TRANSAB(float);
REGISTER_CKSTRIDEDBATCHEDGEMM_FOR_ALL_TRANSAB(half);
}
#else
void InitComposableKernelGemm(py::module) {}
#endif // USE_COMPOSABLE_KERNEL
} // namespace onnxruntime

View file

@ -1,14 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitComposableKernelGemm(py::module mod);
} // namespace onnxruntime

View file

@ -1,22 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h"
#include <pybind11/pybind11.h>
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h"
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h"
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h"
namespace py = pybind11;
namespace onnxruntime {
void InitGemmFastGelu(py::module mod) {
InitGemmFastGeluUnfused(mod);
InitGemmFastGeluTunable(mod);
InitComposableKernelGemmFastGelu(mod);
}
} // namespace onnxruntime

View file

@ -1,14 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitGemmFastGelu(py::module mod);
} // namespace onnxruntime

View file

@ -1,8 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
@ -120,12 +118,10 @@ class CKGemmFastGelu : public IKernelExplorer {
REGISTER_OP(type, Col, Row, "TN"); \
REGISTER_OP(type, Col, Col, "TT");
void InitComposableKernelGemmFastGelu(py::module m) {
KE_REGISTER(m) {
REGISTER_OP_FOR_ALL_TRANSAB(float);
REGISTER_OP_FOR_ALL_TRANSAB(half);
}
#else
void InitComposableKernelGemmFastGelu(py::module) {}
#endif // USE_COMPOSABLE_KERNEL
} // namespace onnxruntime

View file

@ -1,14 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitComposableKernelGemmFastGelu(py::module mod);
} // namespace onnxruntime

View file

@ -1,8 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h"
#include <pybind11/stl.h>
#include <string>
@ -97,11 +95,9 @@ class GemmFastGeluTunable : public IKernelExplorer {
REGISTER_OP(type, Col, Row, "TN"); \
REGISTER_OP(type, Col, Col, "TT");
void InitGemmFastGeluTunable(py::module m) {
KE_REGISTER(m) {
REGISTER_OP_FOR_ALL_TRANSAB(float);
REGISTER_OP_FOR_ALL_TRANSAB(half);
}
#undef REGISTER_OP
} // namespace onnxruntime

View file

@ -1,14 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitGemmFastGeluTunable(py::module mod);
} // namespace onnxruntime

View file

@ -1,6 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h"
#include <pybind11/stl.h>
@ -89,11 +88,9 @@ class GemmFastGeluUnfused : public IKernelExplorer {
.def("ListOps", &GemmFastGeluUnfused<type>::ListOps) \
.def("SelectOp", &GemmFastGeluUnfused<type>::SelectOp);
void InitGemmFastGeluUnfused(py::module m) {
KE_REGISTER(m) {
REGISTER_OP(float)
REGISTER_OP(half)
}
#undef REGISTER_OP
} // namespace onnxruntime

View file

@ -1,14 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitGemmFastGeluUnfused(py::module mod);
} // namespace onnxruntime

View file

@ -43,4 +43,4 @@ class IBatchedGemmKernelExplorer : public IKernelExplorer {
std::shared_ptr<T*> dev_cs_;
};
}
} // namespace onnxruntime

View file

@ -1,8 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
@ -236,7 +234,7 @@ class RocBlasStridedBatchedGemm : public IKernelExplorer {
DeviceArray&, int64_t, int64_t, \
int64_t>())
void InitRocBlasGemm(py::module mod) {
KE_REGISTER(mod) {
REGISTER_GEMM(float);
REGISTER_GEMM(half);

View file

@ -1,14 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitRocBlasGemm(py::module mod);
} // namespace onnxruntime

View file

@ -1,8 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.h"
#include "pybind11/stl.h"
#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh"
@ -286,7 +284,7 @@ class GemmSoftmaxGemmPermuteTunable : public IGemmSoftmaxGemmPermuteKernelExplor
#define REGISTER_TUNABLE(dtype) \
REGISTER_COMMON("GemmSoftmaxGemmPermuteTunable_" #dtype, GemmSoftmaxGemmPermuteTunable, dtype)
void InitGemmSoftmaxGemmPermute(py::module m) {
KE_REGISTER(m) {
REGISTER_GENERIC(half);
#ifdef USE_COMPOSABLE_KERNEL

View file

@ -1,14 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitGemmSoftmaxGemmPermute(py::module mod);
}

View file

@ -1,8 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/rocm/gemm_tunable.h"
#include <pybind11/stl.h>
#include <string>
@ -256,7 +254,7 @@ class StridedBatchedGemmTunable : public IKernelExplorer {
REGISTER_STRIDED_BATCHED_GEMM(dtype, Col, Row, "TN"); \
REGISTER_STRIDED_BATCHED_GEMM(dtype, Col, Col, "TT");
void InitTunableGemm(py::module m) {
KE_REGISTER(m) {
REGISTER_GEMM_FOR_ALL_TRANSAB(float);
REGISTER_GEMM_FOR_ALL_TRANSAB(half);

View file

@ -1,14 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitTunableGemm(py::module mod);
}

View file

@ -1,8 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h"
#include <hip/hip_fp16.h>
#include <pybind11/pybind11.h>
@ -20,8 +18,8 @@ class SkipLayerNormSmall : public IKernelExplorer {
SkipLayerNormSmall(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip,
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
float epsilon, int hidden_size, int element_count)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
void Run() override {
@ -45,7 +43,7 @@ class SkipLayerNormRegular : public IKernelExplorer {
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
float epsilon, int hidden_size, int element_count)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
void Run() override {
@ -65,11 +63,11 @@ class SkipLayerNormRegular : public IKernelExplorer {
template <typename T>
class SkipLayerNormStaticSelection : public IKernelExplorer {
public:
SkipLayerNormStaticSelection(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input,
SkipLayerNormStaticSelection(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input,
DeviceArray& skip, DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
float epsilon, int hidden_size, int element_count)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
void Run() override {
@ -93,9 +91,8 @@ class SkipLayerNormTunable : public IKernelExplorer {
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
float epsilon, int hidden_size, int element_count)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {
params_.TuningContext()->EnableTunableOp();
}
@ -113,14 +110,14 @@ class SkipLayerNormTunable : public IKernelExplorer {
contrib::rocm::SkipLayerNormTunableOp<T> op_{};
};
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
DeviceArray&, DeviceArray&, DeviceArray&, \
float, int, int>()) \
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
DeviceArray&, DeviceArray&, DeviceArray&, \
float, int, int>()) \
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
.def("IsSupported", &name<type, threads_per_block, vec_size>::IsSupported);
#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \
@ -157,7 +154,7 @@ class SkipLayerNormTunable : public IKernelExplorer {
.def("Run", &name<type>::Run) \
.def("IsSupported", &name<type>::IsSupported);
void InitSkipLayerNorm(py::module m) {
KE_REGISTER(m) {
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, half);
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, float);
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegular, half);

View file

@ -1,14 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitSkipLayerNorm(py::module m);
}

View file

@ -1,8 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/rocm/softmax.h"
#include <hip/hip_fp16.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
@ -198,7 +196,7 @@ class CKSoftmax : public IKernelExplorer {
.def("ListOps", &name<type>::ListOps) \
.def("SelectOp", &name<type>::SelectOp);
void InitSoftmax(py::module m) {
KE_REGISTER(m) {
REGISTER_OP_FOR_ALL_VEC_SIZE(SoftmaxBlockwise, half);
REGISTER_OP_FOR_ALL_VEC_SIZE(SoftmaxBlockwise, float);
@ -210,11 +208,13 @@ void InitSoftmax(py::module m) {
REGISTER_OP_TYPED(SoftmaxTunable, half);
REGISTER_OP_TYPED(SoftmaxTunable, float);
#ifdef USE_COMPOSABLE_KERNEL
REGISTER_OP_TYPED(CKSoftmax, half);
REGISTER_OP_TYPED(CKSoftmax, float);
#endif // USE_COMPOSABLE_KERNEL
}
#ifdef USE_COMPOSABLE_KERNEL
KE_REGISTER(m) {
REGISTER_OP_TYPED(CKSoftmax, half);
REGISTER_OP_TYPED(CKSoftmax, float);
}
#endif // USE_COMPOSABLE_KERNEL
} // namespace onnxruntime

View file

@ -1,14 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitSoftmax(py::module m);
}

View file

@ -2,7 +2,6 @@
// Licensed under the MIT License.
// This file serve as a simple example for adding a tunable op to onnxruntime.
#include "python/tools/kernel_explorer/kernels/vector_add.h"
#if USE_CUDA
#include <cuda_runtime_api.h>
@ -26,10 +25,10 @@ namespace py = pybind11;
namespace onnxruntime {
//#####################################################################################################################
// In practice, VectorAddParam, VectorAddOp and VectorAddTunableOp should be tightly integrated to onnxruntime.
// We place them here purely for demo purpose.
//#####################################################################################################################
// #####################################################################################################################
// In practice, VectorAddParam, VectorAddOp and VectorAddTunableOp should be tightly integrated to onnxruntime.
// We place them here purely for demo purpose.
// #####################################################################################################################
// Extend the OpParams so that all specializations have the same parameter passing interface
template <typename T>
@ -91,10 +90,10 @@ class VectorAddTunableOp :
#undef ADD_OP
//#####################################################################################################################
// Following code just wraps our kernel implementation and expose them as python interface. This is the code that
// should be in the kernel_explorer directory.
//#####################################################################################################################
// #####################################################################################################################
// Following code just wraps our kernel implementation and expose them as python interface. This is the code that
// should be in the kernel_explorer directory.
// #####################################################################################################################
template <typename T, int TPB, int Vec>
class VectorAdd : public IKernelExplorer {
@ -174,7 +173,7 @@ class VectorAddTunable : public IKernelExplorer {
.def("Profile", &VectorAddTunable<type>::Profile) \
.def("Run", &VectorAddTunable<type>::Run);
void InitVectorAdd(py::module m) {
KE_REGISTER(m) {
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(VectorAdd, half);
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(VectorAdd, float);

View file

@ -1,14 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitVectorAdd(py::module m);
}