#pragma once #include "core/framework/op_kernel.h" #include "core/framework/func_api.h" #include "core/framework/op_kernel_context_internal.h" #include "core/graph/function.h" namespace onnxruntime { void* allocate_helper_func(void* allocator, size_t alignment, size_t size); void release_helper_func(void* allocator, void* p); // A kernel that wrapper the ComputeFunction call generated by execution provider when fuse the sub-graph class FunctionKernel : public OpKernel { public: explicit FunctionKernel(const OpKernelInfo& info, const NodeComputeInfo* compute) : OpKernel(info), compute_info_(compute) {} // The original design is we load the dll, find the entry point and wrapper it. // Here for quick prototype, we keep the entry pointer in the node. static Status Create(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr& out) { const NodeComputeInfo* compute; ORT_RETURN_IF_ERROR(func_mgr.GetFuncs(info.node().Name(), compute)); std::unique_ptr funckernel = std::make_unique(info, compute); funckernel->num_inputs_ = info.node().InputDefs().size(); funckernel->num_outputs_ = info.node().OutputDefs().size(); if (compute->create_state_func) { // TODO: we are only provide host allocate method in compute context. // Do we need to hold the ref-counting here? funckernel->host_allocator_ = info.GetAllocator(OrtMemType::OrtMemTypeDefault); ComputeContext context = {allocate_helper_func, release_helper_func, funckernel->host_allocator_.get(), info.node().Name().c_str()}; int ret = funckernel->compute_info_->create_state_func(&context, &funckernel->func_state_); if (ret != 0) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Create state function failed. Return value:", ret); } out = std::move(funckernel); return Status::OK(); } ~FunctionKernel() override { if (compute_info_->release_state_func && func_state_) { compute_info_->release_state_func(func_state_); } } virtual Status Compute(OpKernelContext* context) const override { auto* context_internal = static_cast(context); const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); if (api == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "API VERSION ", ORT_API_VERSION, " is invalid."); return compute_info_->compute_func(func_state_, api, reinterpret_cast(context_internal)); } private: const NodeComputeInfo* const compute_info_; FunctionState func_state_{nullptr}; size_t num_inputs_; size_t num_outputs_; AllocatorPtr host_allocator_; }; } // namespace onnxruntime