mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
### Description
Run clang-format in CI. Formatted all c/c++, objective-c/c++ files.
Excluded
```
'onnxruntime/core/mlas/**',
'onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/**',
```
because they contain assembly or is data heavy
### Motivation and Context
Coding style consistency
62 lines
2.7 KiB
C++
62 lines
2.7 KiB
C++
#pragma once
|
|
#include "core/framework/op_kernel.h"
|
|
#include "core/framework/func_api.h"
|
|
#include "core/framework/op_kernel_context_internal.h"
|
|
#include "core/graph/function.h"
|
|
|
|
namespace onnxruntime {
|
|
|
|
void* allocate_helper_func(void* allocator, size_t alignment, size_t size);
|
|
|
|
void release_helper_func(void* allocator, void* p);
|
|
|
|
// A kernel that wrapper the ComputeFunction call generated by execution provider when fuse the sub-graph
|
|
class FunctionKernel : public OpKernel {
|
|
public:
|
|
explicit FunctionKernel(const OpKernelInfo& info, const NodeComputeInfo* compute) : OpKernel(info), compute_info_(compute) {}
|
|
|
|
// The original design is we load the dll, find the entry point and wrapper it.
|
|
// Here for quick prototype, we keep the entry pointer in the node.
|
|
static Status Create(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) {
|
|
const NodeComputeInfo* compute;
|
|
ORT_RETURN_IF_ERROR(func_mgr.GetFuncs(info.node().Name(), compute));
|
|
std::unique_ptr<FunctionKernel> funckernel = std::make_unique<FunctionKernel>(info, compute);
|
|
funckernel->num_inputs_ = info.node().InputDefs().size();
|
|
funckernel->num_outputs_ = info.node().OutputDefs().size();
|
|
|
|
if (compute->create_state_func) {
|
|
// TODO: we are only provide host allocate method in compute context.
|
|
// Do we need to hold the ref-counting here?
|
|
funckernel->host_allocator_ = info.GetAllocator(OrtMemType::OrtMemTypeDefault);
|
|
ComputeContext context = {allocate_helper_func, release_helper_func, funckernel->host_allocator_.get(),
|
|
info.node().Name().c_str()};
|
|
int ret = funckernel->compute_info_->create_state_func(&context, &funckernel->func_state_);
|
|
if (ret != 0)
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Create state function failed. Return value:", ret);
|
|
}
|
|
out = std::move(funckernel);
|
|
return Status::OK();
|
|
}
|
|
|
|
~FunctionKernel() override {
|
|
if (compute_info_->release_state_func && func_state_) {
|
|
compute_info_->release_state_func(func_state_);
|
|
}
|
|
}
|
|
|
|
virtual Status Compute(OpKernelContext* context) const override {
|
|
auto* context_internal = static_cast<OpKernelContextInternal*>(context);
|
|
const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
|
if (api == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "API VERSION ", ORT_API_VERSION, " is invalid.");
|
|
return compute_info_->compute_func(func_state_, api,
|
|
reinterpret_cast<OrtKernelContext*>(context_internal));
|
|
}
|
|
|
|
private:
|
|
const NodeComputeInfo* const compute_info_;
|
|
FunctionState func_state_{nullptr};
|
|
size_t num_inputs_;
|
|
size_t num_outputs_;
|
|
AllocatorPtr host_allocator_;
|
|
};
|
|
} // namespace onnxruntime
|