onnxruntime/onnxruntime/core/framework/func_kernel.h
Justin Chu cf19c3697d
Run clang-format in CI (#15524)
### Description

Run clang-format in CI. Formatted all c/c++, objective-c/c++ files.

Excluded

```
    'onnxruntime/core/mlas/**',
    'onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/**',
```

because they contain assembly or is data heavy


### Motivation and Context

Coding style consistency
2023-04-18 09:26:58 -07:00

62 lines
2.7 KiB
C++

#pragma once
#include "core/framework/op_kernel.h"
#include "core/framework/func_api.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/graph/function.h"
namespace onnxruntime {
void* allocate_helper_func(void* allocator, size_t alignment, size_t size);
void release_helper_func(void* allocator, void* p);
// A kernel that wrapper the ComputeFunction call generated by execution provider when fuse the sub-graph
class FunctionKernel : public OpKernel {
public:
explicit FunctionKernel(const OpKernelInfo& info, const NodeComputeInfo* compute) : OpKernel(info), compute_info_(compute) {}
// The original design is we load the dll, find the entry point and wrapper it.
// Here for quick prototype, we keep the entry pointer in the node.
static Status Create(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) {
const NodeComputeInfo* compute;
ORT_RETURN_IF_ERROR(func_mgr.GetFuncs(info.node().Name(), compute));
std::unique_ptr<FunctionKernel> funckernel = std::make_unique<FunctionKernel>(info, compute);
funckernel->num_inputs_ = info.node().InputDefs().size();
funckernel->num_outputs_ = info.node().OutputDefs().size();
if (compute->create_state_func) {
// TODO: we are only provide host allocate method in compute context.
// Do we need to hold the ref-counting here?
funckernel->host_allocator_ = info.GetAllocator(OrtMemType::OrtMemTypeDefault);
ComputeContext context = {allocate_helper_func, release_helper_func, funckernel->host_allocator_.get(),
info.node().Name().c_str()};
int ret = funckernel->compute_info_->create_state_func(&context, &funckernel->func_state_);
if (ret != 0)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Create state function failed. Return value:", ret);
}
out = std::move(funckernel);
return Status::OK();
}
~FunctionKernel() override {
if (compute_info_->release_state_func && func_state_) {
compute_info_->release_state_func(func_state_);
}
}
virtual Status Compute(OpKernelContext* context) const override {
auto* context_internal = static_cast<OpKernelContextInternal*>(context);
const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
if (api == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "API VERSION ", ORT_API_VERSION, " is invalid.");
return compute_info_->compute_func(func_state_, api,
reinterpret_cast<OrtKernelContext*>(context_internal));
}
private:
const NodeComputeInfo* const compute_info_;
FunctionState func_state_{nullptr};
size_t num_inputs_;
size_t num_outputs_;
AllocatorPtr host_allocator_;
};
} // namespace onnxruntime