mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Change OpKernel class to be shared with shared providers (#6837)
In the previous shared providers there aren't many OpKernel classes, and the existing Provider_OpKernel wrapper was fine. With the opposibility of making Cuda a shared provider, having this need to be changed per OpKernel adds a lot of complexity. It was fairly straightforward to make OpKernel work with shared providers with minimal changes. In this change, the ONNX_OPERATOR_* macros can also be shared with the shared providers.
This commit is contained in:
parent
38796ad451
commit
0d0eb2c85c
17 changed files with 451 additions and 450 deletions
|
|
@ -13,6 +13,7 @@
|
|||
#include "core/common/common.h"
|
||||
#include "core/common/exceptions.h"
|
||||
#include "core/framework/endian.h"
|
||||
#include "core/framework/float16.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
|
||||
struct OrtValue;
|
||||
|
|
@ -51,93 +52,6 @@ class SequenceTensorTypeBase;
|
|||
class NonTensorTypeBase;
|
||||
class PrimitiveDataTypeBase;
|
||||
|
||||
// MLFloat16
|
||||
struct MLFloat16 {
|
||||
uint16_t val;
|
||||
|
||||
MLFloat16() : val(0) {}
|
||||
explicit MLFloat16(uint16_t x) : val(x) {}
|
||||
explicit MLFloat16(float f);
|
||||
|
||||
float ToFloat() const;
|
||||
|
||||
operator float() const {
|
||||
return ToFloat();
|
||||
}
|
||||
};
|
||||
|
||||
inline bool operator==(const MLFloat16& left, const MLFloat16& right) {
|
||||
return left.val == right.val;
|
||||
}
|
||||
|
||||
inline bool operator!=(const MLFloat16& left, const MLFloat16& right) {
|
||||
return left.val != right.val;
|
||||
}
|
||||
|
||||
inline bool operator<(const MLFloat16& left, const MLFloat16& right) {
|
||||
return left.val < right.val;
|
||||
}
|
||||
|
||||
//BFloat16
|
||||
struct BFloat16 {
|
||||
uint16_t val{0};
|
||||
explicit BFloat16() = default;
|
||||
explicit BFloat16(uint16_t v) : val(v) {}
|
||||
explicit BFloat16(float v) {
|
||||
if (endian::native == endian::little) {
|
||||
std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t));
|
||||
} else {
|
||||
std::memcpy(&val, &v, sizeof(uint16_t));
|
||||
}
|
||||
}
|
||||
|
||||
float ToFloat() const {
|
||||
float result;
|
||||
char* const first = reinterpret_cast<char*>(&result);
|
||||
char* const second = first + sizeof(uint16_t);
|
||||
if (endian::native == endian::little) {
|
||||
std::memset(first, 0, sizeof(uint16_t));
|
||||
std::memcpy(second, &val, sizeof(uint16_t));
|
||||
} else {
|
||||
std::memcpy(first, &val, sizeof(uint16_t));
|
||||
std::memset(second, 0, sizeof(uint16_t));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
operator float() const {
|
||||
return ToFloat();
|
||||
}
|
||||
};
|
||||
|
||||
inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) {
|
||||
auto src = blf;
|
||||
auto d = flt;
|
||||
for (; size != 0; ++src, ++d, --size) {
|
||||
*d = src->ToFloat();
|
||||
}
|
||||
}
|
||||
|
||||
inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) {
|
||||
auto src = flt;
|
||||
auto d = blf;
|
||||
for (; size != 0; ++src, ++d, --size) {
|
||||
new (d) BFloat16(*src);
|
||||
}
|
||||
}
|
||||
|
||||
inline bool operator==(const BFloat16& left, const BFloat16& right) {
|
||||
return left.val == right.val;
|
||||
}
|
||||
|
||||
inline bool operator!=(const BFloat16& left, const BFloat16& right) {
|
||||
return left.val != right.val;
|
||||
}
|
||||
|
||||
inline bool operator<(const BFloat16& left, const BFloat16& right) {
|
||||
return left.val < right.val;
|
||||
}
|
||||
|
||||
// DataTypeImpl pointer as unique DataTypeImpl identifier.
|
||||
using MLDataType = const DataTypeImpl*;
|
||||
// be used with class MLValue
|
||||
|
|
|
|||
|
|
@ -13,9 +13,11 @@
|
|||
#include "boost/mp11.hpp"
|
||||
|
||||
#include "core/common/common.h"
|
||||
#ifndef SHARED_PROVIDER
|
||||
#include "core/common/type_list.h"
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace utils {
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#ifndef PROVIDER_BRIDGE_PROVIDER
|
||||
#ifndef SHARED_PROVIDER
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
|
|
|
|||
97
include/onnxruntime/core/framework/float16.h
Normal file
97
include/onnxruntime/core/framework/float16.h
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
#include "endian.h"
|
||||
|
||||
namespace onnxruntime
|
||||
{
|
||||
|
||||
// MLFloat16
|
||||
struct MLFloat16 {
|
||||
uint16_t val;
|
||||
|
||||
MLFloat16() : val(0) {}
|
||||
explicit MLFloat16(uint16_t x) : val(x) {}
|
||||
explicit MLFloat16(float f);
|
||||
|
||||
float ToFloat() const;
|
||||
|
||||
operator float() const {
|
||||
return ToFloat();
|
||||
}
|
||||
};
|
||||
|
||||
inline bool operator==(const MLFloat16& left, const MLFloat16& right) {
|
||||
return left.val == right.val;
|
||||
}
|
||||
|
||||
inline bool operator!=(const MLFloat16& left, const MLFloat16& right) {
|
||||
return left.val != right.val;
|
||||
}
|
||||
|
||||
inline bool operator<(const MLFloat16& left, const MLFloat16& right) {
|
||||
return left.val < right.val;
|
||||
}
|
||||
|
||||
//BFloat16
|
||||
struct BFloat16 {
|
||||
uint16_t val{0};
|
||||
explicit BFloat16() = default;
|
||||
explicit BFloat16(uint16_t v) : val(v) {}
|
||||
explicit BFloat16(float v) {
|
||||
if (endian::native == endian::little) {
|
||||
std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t));
|
||||
} else {
|
||||
std::memcpy(&val, &v, sizeof(uint16_t));
|
||||
}
|
||||
}
|
||||
|
||||
float ToFloat() const {
|
||||
float result;
|
||||
char* const first = reinterpret_cast<char*>(&result);
|
||||
char* const second = first + sizeof(uint16_t);
|
||||
if (endian::native == endian::little) {
|
||||
std::memset(first, 0, sizeof(uint16_t));
|
||||
std::memcpy(second, &val, sizeof(uint16_t));
|
||||
} else {
|
||||
std::memcpy(first, &val, sizeof(uint16_t));
|
||||
std::memset(second, 0, sizeof(uint16_t));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
operator float() const {
|
||||
return ToFloat();
|
||||
}
|
||||
};
|
||||
|
||||
inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) {
|
||||
auto src = blf;
|
||||
auto d = flt;
|
||||
for (; size != 0; ++src, ++d, --size) {
|
||||
*d = src->ToFloat();
|
||||
}
|
||||
}
|
||||
|
||||
inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) {
|
||||
auto src = flt;
|
||||
auto d = blf;
|
||||
for (; size != 0; ++src, ++d, --size) {
|
||||
new (d) BFloat16(*src);
|
||||
}
|
||||
}
|
||||
|
||||
inline bool operator==(const BFloat16& left, const BFloat16& right) {
|
||||
return left.val == right.val;
|
||||
}
|
||||
|
||||
inline bool operator!=(const BFloat16& left, const BFloat16& right) {
|
||||
return left.val != right.val;
|
||||
}
|
||||
|
||||
inline bool operator<(const BFloat16& left, const BFloat16& right) {
|
||||
return left.val < right.val;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -6,6 +6,9 @@
|
|||
#include "core/framework/op_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
using KernelCreateMap = std::multimap<std::string, KernelCreateInfo>;
|
||||
|
||||
/**
|
||||
* Each provider has a KernelRegistry. Often, the KernelRegistry only belongs to that specific provider.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -3,10 +3,10 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "boost/mp11.hpp"
|
||||
|
||||
#ifndef SHARED_PROVIDER
|
||||
#include <functional>
|
||||
#include "core/common/exceptions.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/common/status.h"
|
||||
|
|
@ -21,29 +21,24 @@
|
|||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "gsl/gsl"
|
||||
namespace onnxruntime {
|
||||
class OpKernelContext;
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
class IExecutionFrame;
|
||||
class OpKernelContext;
|
||||
class OpKernelWrapper;
|
||||
namespace concurrency {
|
||||
class ThreadPool;
|
||||
}
|
||||
|
||||
std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info);
|
||||
|
||||
class OpKernel {
|
||||
public:
|
||||
using DoneCallback = std::function<void()>;
|
||||
|
||||
explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_(info) {}
|
||||
explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_(CopyOpKernelInfo(info)) {}
|
||||
virtual ~OpKernel() = default;
|
||||
|
||||
const onnxruntime::Node& Node() const {
|
||||
return op_kernel_info_.node();
|
||||
}
|
||||
|
||||
const onnxruntime::KernelDef& KernelDef() const {
|
||||
return op_kernel_info_.GetKernelDef();
|
||||
}
|
||||
const onnxruntime::Node& Node() const;
|
||||
const onnxruntime::KernelDef& KernelDef() const;
|
||||
|
||||
virtual Status Compute(_Inout_ OpKernelContext* context) const ORT_MUST_USE_RESULT = 0;
|
||||
|
||||
|
|
@ -73,219 +68,14 @@ class OpKernel {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
const OrtMemoryInfo& Allocator(int id, OrtMemType mem_type) const {
|
||||
return op_kernel_info_.GetMemoryInfo(id, mem_type);
|
||||
}
|
||||
|
||||
const OpKernelInfo& Info() const { return op_kernel_info_; }
|
||||
const OrtMemoryInfo& Allocator(int id, OrtMemType mem_type) const;
|
||||
const OpKernelInfo& Info() const { return *op_kernel_info_; }
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernel);
|
||||
OpKernelInfo op_kernel_info_;
|
||||
std::unique_ptr<OpKernelInfo> op_kernel_info_;
|
||||
};
|
||||
|
||||
class OpKernelContext {
|
||||
public:
|
||||
using ArgMap = std::unordered_map<std::string, size_t>;
|
||||
|
||||
OpKernelContext(_Inout_ IExecutionFrame* frame, _In_ const OpKernel* kernel,
|
||||
_In_opt_ concurrency::ThreadPool* threadpool, _In_ const logging::Logger& logger);
|
||||
|
||||
virtual ~OpKernelContext() = default;
|
||||
|
||||
/**
|
||||
Return the number of inputs for a variadic argument.
|
||||
@param arg_num The operator argument number.
|
||||
@returns Number of inputs the argument has.
|
||||
*/
|
||||
int NumVariadicInputs(size_t arg_num) const;
|
||||
|
||||
MLDataType InputType(int index) const;
|
||||
MLDataType OutputType(int index) const;
|
||||
|
||||
template <typename T>
|
||||
const T* Input(int index) const {
|
||||
const OrtValue* p_ml_value = GetInputMLValue(index);
|
||||
ORT_TRY {
|
||||
return p_ml_value ? &(p_ml_value->Get<T>()) : nullptr;
|
||||
}
|
||||
ORT_CATCH(const std::exception& /*e*/) {
|
||||
ORT_THROW("Missing Input: " + kernel_->Node().InputDefs()[index]->Name());
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch a required input, enforcing that it is present.
|
||||
template <typename T>
|
||||
const T& RequiredInput(int index) const {
|
||||
const T* input_ptr = Input<T>(index);
|
||||
ORT_ENFORCE(input_ptr, "Required input at index ", index, " is not present.");
|
||||
return *input_ptr;
|
||||
}
|
||||
|
||||
// Fetch output (non-tensor) with specified index.
|
||||
template <typename T>
|
||||
T* Output(int index) {
|
||||
if (index < 0 || index >= OutputCount())
|
||||
return nullptr;
|
||||
|
||||
OrtValue* p_ml_value = GetOrCreateOutputMLValue(index);
|
||||
return p_ml_value ? p_ml_value->GetMutable<T>() : nullptr;
|
||||
}
|
||||
|
||||
// In the case that memory allocation has not been done for an output tensor,
|
||||
// The memory allocation will be done on-the-fly with given tensor shape.
|
||||
// Return nullptr if the output is an unused optional output.
|
||||
Tensor* Output(int index, const TensorShape& shape);
|
||||
Tensor* Output(int index, const std::vector<int64_t>& shape);
|
||||
Tensor* Output(int index, const std::initializer_list<int64_t>& shape);
|
||||
|
||||
// Fetch a required tensor output, enforcing that it is present.
|
||||
Tensor& RequiredOutput(int index, const TensorShape& shape) {
|
||||
Tensor* output_ptr = Output(index, shape);
|
||||
ORT_ENFORCE(output_ptr, "Required output at index ", index, " is not present.");
|
||||
return *output_ptr;
|
||||
}
|
||||
|
||||
// Fetch a sparse-tensor output corresponding to the specified index.
|
||||
// num_values must specify the number of non-zero values (commonly known as NNZ/nnz),
|
||||
// and shape must specify the shape of the underlying dense-tensor.
|
||||
// Memory allocation for the output may happen when this method is invoked,
|
||||
// unless static optimization pre-allocates it.
|
||||
SparseTensor* Output(int index, size_t num_values, const TensorShape& shape);
|
||||
|
||||
// Retrieve indexed shape obtained from memory planning before actual
|
||||
// computation. If the indexed shape cannot be inferred, this function returns
|
||||
// false.
|
||||
bool TryGetInferredInputShape(int index, TensorShape& shape) const;
|
||||
|
||||
// Retrieve indexed shape obtained from memory planning before actual
|
||||
// computation. If the indexed shape cannot be inferred, this function returns
|
||||
// false.
|
||||
bool TryGetInferredOutputShape(int index, TensorShape& shape) const;
|
||||
|
||||
const logging::Logger& Logger() const {
|
||||
return *logger_;
|
||||
}
|
||||
|
||||
// always >= 0
|
||||
int InputCount() const {
|
||||
return static_cast<int>(kernel_->Node().InputDefs().size());
|
||||
}
|
||||
|
||||
// always >= 0
|
||||
int ImplicitInputCount() const {
|
||||
return static_cast<int>(kernel_->Node().ImplicitInputDefs().size());
|
||||
}
|
||||
|
||||
// always >= 0
|
||||
int OutputCount() const {
|
||||
return static_cast<int>(kernel_->Node().OutputDefs().size());
|
||||
}
|
||||
|
||||
/**
|
||||
Return an allocator on device 0, with memtype of OrtMemTypeDefault.
|
||||
@remarks Use SafeInt when calculating the size of memory to allocate using AllocatorPtr->Alloc.
|
||||
*/
|
||||
Status GetTempSpaceAllocator(AllocatorPtr* output) const ORT_MUST_USE_RESULT;
|
||||
|
||||
/**
|
||||
Return the fence of current node's input.
|
||||
@param index The index of the input.
|
||||
@returns Point to the Fence of the input OrtValue.
|
||||
It is null if the input OrtValue doesn't have fence or the input is optional.
|
||||
*/
|
||||
Fence_t InputFence(int index) const;
|
||||
|
||||
/**
|
||||
Return the fence of current node's implicit input.
|
||||
@param index The index of the implicit input.
|
||||
@returns Point to the Fence of the implicit input OrtValue.
|
||||
It is null if the input OrtValue doesn't have fence or the input is optional.
|
||||
*/
|
||||
Fence_t ImplicitInputFence(int index) const;
|
||||
|
||||
/**
|
||||
Return the fence of current node's output identifed by index.
|
||||
@param index The index of the output.
|
||||
@returns Point to the Fence of the output OrtValue.
|
||||
It is null if the output OrtValue doesn't have fence or the output is optional.
|
||||
*/
|
||||
Fence_t OutputFence(int index) const;
|
||||
|
||||
/**
|
||||
Return the device id that current kernel runs on.
|
||||
*/
|
||||
int GetDeviceId() const {
|
||||
return kernel_->Info().GetExecutionProvider()->GetDeviceId();
|
||||
}
|
||||
|
||||
/**
|
||||
Returns the opset domain of the underlying kernel
|
||||
**/
|
||||
const std::string& GetOpDomain() const;
|
||||
|
||||
/**
|
||||
Returns the optype of the underlying kernel
|
||||
**/
|
||||
const std::string& GetOpType() const;
|
||||
|
||||
/**
|
||||
Returns the node name of the underlying kernel
|
||||
**/
|
||||
const std::string& GetNodeName() const;
|
||||
|
||||
/**
|
||||
Returns the intra-op threadpool, if available.
|
||||
*/
|
||||
_Ret_maybenull_ onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() const { return threadpool_; }
|
||||
|
||||
/**
|
||||
Returns whether deterministic computation is preferred.
|
||||
*/
|
||||
virtual bool GetUseDeterministicCompute() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
onnxruntime::NodeIndex GetNodeIndex() const;
|
||||
|
||||
const OrtValue* GetInputMLValue(int index) const;
|
||||
const OrtValue* GetImplicitInputMLValue(int index) const;
|
||||
OrtValue* GetOutputMLValue(int index);
|
||||
|
||||
// Creates the OrtValue* based on the shape, if it does not exist
|
||||
// The parameter nnz is used only for sparse-tensors and indicates the
|
||||
// number of non-zero values (the number of elements in the values buffer allocated).
|
||||
OrtValue* OutputMLValue(int index, const TensorShape& shape, size_t nnz = 0);
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernelContext);
|
||||
|
||||
OrtValue* GetOrCreateOutputMLValue(int index);
|
||||
|
||||
int GetInputArgIndex(int index) const;
|
||||
int GetImplicitInputArgIndex(int index) const;
|
||||
int GetOutputArgIndex(int index) const;
|
||||
|
||||
IExecutionFrame* const execution_frame_;
|
||||
const OpKernel* const kernel_;
|
||||
concurrency::ThreadPool* const threadpool_;
|
||||
const logging::Logger* const logger_;
|
||||
|
||||
// The argument starting index in ExecutionFrame.
|
||||
int node_input_start_index_{-1};
|
||||
int node_implicit_input_start_index_{-1};
|
||||
int node_output_start_index_{-1};
|
||||
};
|
||||
|
||||
// Fetching output tensor without shape is not allowed except when it already exists
|
||||
template <>
|
||||
inline Tensor* OpKernelContext::Output<Tensor>(int index) {
|
||||
OrtValue* p_ml_value = GetOutputMLValue(index);
|
||||
ORT_ENFORCE(p_ml_value, "Please fetch output tensor with specified shape.");
|
||||
return p_ml_value->GetMutable<Tensor>();
|
||||
}
|
||||
|
||||
using KernelCreateFn = std::function<OpKernel*(const OpKernelInfo& info)>;
|
||||
using KernelCreatePtrFn = std::add_pointer<OpKernel*(const OpKernelInfo& info)>::type;
|
||||
|
||||
|
|
@ -306,8 +96,6 @@ struct KernelCreateInfo {
|
|||
KernelCreateInfo() = default;
|
||||
};
|
||||
|
||||
using KernelCreateMap = std::multimap<std::string, KernelCreateInfo>;
|
||||
|
||||
// Forward declarations for the non-specialized BuildKernelCreateInfo method.
|
||||
template <typename T>
|
||||
KernelCreateInfo BuildKernelCreateInfo();
|
||||
|
|
@ -504,3 +292,7 @@ std::vector<MLDataType> BuildKernelDefConstraintsFromTypeList() {
|
|||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
||||
#ifndef SHARED_PROVIDER
|
||||
#include "core/framework/op_kernel_context.h"
|
||||
#endif
|
||||
|
|
|
|||
212
include/onnxruntime/core/framework/op_kernel_context.h
Normal file
212
include/onnxruntime/core/framework/op_kernel_context.h
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
namespace onnxruntime {
|
||||
class IExecutionFrame;
|
||||
namespace concurrency {
|
||||
class ThreadPool;
|
||||
}
|
||||
|
||||
class OpKernelContext {
|
||||
public:
|
||||
using ArgMap = std::unordered_map<std::string, size_t>;
|
||||
|
||||
OpKernelContext(_Inout_ IExecutionFrame* frame, _In_ const OpKernel* kernel,
|
||||
_In_opt_ concurrency::ThreadPool* threadpool, _In_ const logging::Logger& logger);
|
||||
|
||||
virtual ~OpKernelContext() = default;
|
||||
|
||||
/**
|
||||
Return the number of inputs for a variadic argument.
|
||||
@param arg_num The operator argument number.
|
||||
@returns Number of inputs the argument has.
|
||||
*/
|
||||
int NumVariadicInputs(size_t arg_num) const;
|
||||
|
||||
MLDataType InputType(int index) const;
|
||||
MLDataType OutputType(int index) const;
|
||||
|
||||
template <typename T>
|
||||
const T* Input(int index) const {
|
||||
const OrtValue* p_ml_value = GetInputMLValue(index);
|
||||
ORT_TRY {
|
||||
return p_ml_value ? &(p_ml_value->Get<T>()) : nullptr;
|
||||
}
|
||||
ORT_CATCH(const std::exception& /*e*/) {
|
||||
ORT_THROW("Missing Input: " + kernel_->Node().InputDefs()[index]->Name());
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch a required input, enforcing that it is present.
|
||||
template <typename T>
|
||||
const T& RequiredInput(int index) const {
|
||||
const T* input_ptr = Input<T>(index);
|
||||
ORT_ENFORCE(input_ptr, "Required input at index ", index, " is not present.");
|
||||
return *input_ptr;
|
||||
}
|
||||
|
||||
// Fetch output (non-tensor) with specified index.
|
||||
template <typename T>
|
||||
T* Output(int index) {
|
||||
if (index < 0 || index >= OutputCount())
|
||||
return nullptr;
|
||||
|
||||
OrtValue* p_ml_value = GetOrCreateOutputMLValue(index);
|
||||
return p_ml_value ? p_ml_value->GetMutable<T>() : nullptr;
|
||||
}
|
||||
|
||||
// In the case that memory allocation has not been done for an output tensor,
|
||||
// The memory allocation will be done on-the-fly with given tensor shape.
|
||||
// Return nullptr if the output is an unused optional output.
|
||||
Tensor* Output(int index, const TensorShape& shape);
|
||||
Tensor* Output(int index, const std::vector<int64_t>& shape);
|
||||
Tensor* Output(int index, const std::initializer_list<int64_t>& shape);
|
||||
|
||||
// Fetch a required tensor output, enforcing that it is present.
|
||||
Tensor& RequiredOutput(int index, const TensorShape& shape) {
|
||||
Tensor* output_ptr = Output(index, shape);
|
||||
ORT_ENFORCE(output_ptr, "Required output at index ", index, " is not present.");
|
||||
return *output_ptr;
|
||||
}
|
||||
|
||||
// Fetch a sparse-tensor output corresponding to the specified index.
|
||||
// num_values must specify the number of non-zero values (commonly known as NNZ/nnz),
|
||||
// and shape must specify the shape of the underlying dense-tensor.
|
||||
// Memory allocation for the output may happen when this method is invoked,
|
||||
// unless static optimization pre-allocates it.
|
||||
SparseTensor* Output(int index, size_t num_values, const TensorShape& shape);
|
||||
|
||||
// Retrieve indexed shape obtained from memory planning before actual
|
||||
// computation. If the indexed shape cannot be inferred, this function returns
|
||||
// false.
|
||||
bool TryGetInferredInputShape(int index, TensorShape& shape) const;
|
||||
|
||||
// Retrieve indexed shape obtained from memory planning before actual
|
||||
// computation. If the indexed shape cannot be inferred, this function returns
|
||||
// false.
|
||||
bool TryGetInferredOutputShape(int index, TensorShape& shape) const;
|
||||
|
||||
const logging::Logger& Logger() const {
|
||||
return *logger_;
|
||||
}
|
||||
|
||||
// always >= 0
|
||||
int InputCount() const {
|
||||
return static_cast<int>(kernel_->Node().InputDefs().size());
|
||||
}
|
||||
|
||||
// always >= 0
|
||||
int ImplicitInputCount() const {
|
||||
return static_cast<int>(kernel_->Node().ImplicitInputDefs().size());
|
||||
}
|
||||
|
||||
// always >= 0
|
||||
int OutputCount() const {
|
||||
return static_cast<int>(kernel_->Node().OutputDefs().size());
|
||||
}
|
||||
|
||||
/**
|
||||
Return an allocator on device 0, with memtype of OrtMemTypeDefault.
|
||||
@remarks Use SafeInt when calculating the size of memory to allocate using AllocatorPtr->Alloc.
|
||||
*/
|
||||
Status GetTempSpaceAllocator(AllocatorPtr* output) const ORT_MUST_USE_RESULT;
|
||||
|
||||
/**
|
||||
Return the fence of current node's input.
|
||||
@param index The index of the input.
|
||||
@returns Point to the Fence of the input OrtValue.
|
||||
It is null if the input OrtValue doesn't have fence or the input is optional.
|
||||
*/
|
||||
Fence_t InputFence(int index) const;
|
||||
|
||||
/**
|
||||
Return the fence of current node's implicit input.
|
||||
@param index The index of the implicit input.
|
||||
@returns Point to the Fence of the implicit input OrtValue.
|
||||
It is null if the input OrtValue doesn't have fence or the input is optional.
|
||||
*/
|
||||
Fence_t ImplicitInputFence(int index) const;
|
||||
|
||||
/**
|
||||
Return the fence of current node's output identifed by index.
|
||||
@param index The index of the output.
|
||||
@returns Point to the Fence of the output OrtValue.
|
||||
It is null if the output OrtValue doesn't have fence or the output is optional.
|
||||
*/
|
||||
Fence_t OutputFence(int index) const;
|
||||
|
||||
/**
|
||||
Return the device id that current kernel runs on.
|
||||
*/
|
||||
int GetDeviceId() const {
|
||||
return kernel_->Info().GetExecutionProvider()->GetDeviceId();
|
||||
}
|
||||
|
||||
/**
|
||||
Returns the opset domain of the underlying kernel
|
||||
**/
|
||||
const std::string& GetOpDomain() const;
|
||||
|
||||
/**
|
||||
Returns the optype of the underlying kernel
|
||||
**/
|
||||
const std::string& GetOpType() const;
|
||||
|
||||
/**
|
||||
Returns the node name of the underlying kernel
|
||||
**/
|
||||
const std::string& GetNodeName() const;
|
||||
|
||||
/**
|
||||
Returns the intra-op threadpool, if available.
|
||||
*/
|
||||
_Ret_maybenull_ onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() const { return threadpool_; }
|
||||
|
||||
/**
|
||||
Returns whether deterministic computation is preferred.
|
||||
*/
|
||||
virtual bool GetUseDeterministicCompute() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
onnxruntime::NodeIndex GetNodeIndex() const;
|
||||
|
||||
const OrtValue* GetInputMLValue(int index) const;
|
||||
const OrtValue* GetImplicitInputMLValue(int index) const;
|
||||
OrtValue* GetOutputMLValue(int index);
|
||||
|
||||
// Creates the OrtValue* based on the shape, if it does not exist
|
||||
// The parameter nnz is used only for sparse-tensors and indicates the
|
||||
// number of non-zero values (the number of elements in the values buffer allocated).
|
||||
OrtValue* OutputMLValue(int index, const TensorShape& shape, size_t nnz = 0);
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernelContext);
|
||||
|
||||
OrtValue* GetOrCreateOutputMLValue(int index);
|
||||
|
||||
int GetInputArgIndex(int index) const;
|
||||
int GetImplicitInputArgIndex(int index) const;
|
||||
int GetOutputArgIndex(int index) const;
|
||||
|
||||
IExecutionFrame* const execution_frame_;
|
||||
const OpKernel* const kernel_;
|
||||
concurrency::ThreadPool* const threadpool_;
|
||||
const logging::Logger* const logger_;
|
||||
|
||||
// The argument starting index in ExecutionFrame.
|
||||
int node_input_start_index_{-1};
|
||||
int node_implicit_input_start_index_{-1};
|
||||
int node_output_start_index_{-1};
|
||||
};
|
||||
|
||||
// Fetching output tensor without shape is not allowed except when it already exists
|
||||
template <>
|
||||
inline Tensor* OpKernelContext::Output<Tensor>(int index) {
|
||||
OrtValue* p_ml_value = GetOutputMLValue(index);
|
||||
ORT_ENFORCE(p_ml_value, "Please fetch output tensor with specified shape.");
|
||||
return p_ml_value->GetMutable<Tensor>();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -9,6 +9,22 @@
|
|||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info) {
|
||||
return onnxruntime::make_unique<OpKernelInfo>(info);
|
||||
}
|
||||
|
||||
const onnxruntime::Node& OpKernel::Node() const {
|
||||
return op_kernel_info_->node();
|
||||
}
|
||||
|
||||
const onnxruntime::KernelDef& OpKernel::KernelDef() const {
|
||||
return op_kernel_info_->GetKernelDef();
|
||||
}
|
||||
|
||||
const OrtMemoryInfo& OpKernel::Allocator(int id, OrtMemType mem_type) const {
|
||||
return op_kernel_info_->GetMemoryInfo(id, mem_type);
|
||||
}
|
||||
|
||||
OpKernelContext::OpKernelContext(_Inout_ IExecutionFrame* frame, _In_ const OpKernel* kernel,
|
||||
_In_opt_ concurrency::ThreadPool* threadpool, _In_ const logging::Logger& logger)
|
||||
: execution_frame_(frame), kernel_(kernel), threadpool_(threadpool), logger_(&logger) {
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
// These Provider types are really just internal types, so we #define PROVIDER_BRIDGE_ORT so that only these definitions are seen by provider_interfaces.h
|
||||
// These Provider types are really just internal types, since we don't include provider_api.h only these definitions are seen by provider_interfaces.h
|
||||
// Users of provider_interfaces.h (through provider_api.h) will see the wrappers that call through the provider shared interface which is implemented by this file
|
||||
using Provider_int64s = google::protobuf::RepeatedField<int64_t>;
|
||||
using Provider_AttributeProto = ONNX_NAMESPACE::AttributeProto;
|
||||
|
|
@ -46,7 +46,6 @@ using Provider_ValueInfoProtos = google::protobuf::RepeatedPtrField<ONNX_NAMESPA
|
|||
using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef;
|
||||
} // namespace onnxruntime
|
||||
|
||||
#define PROVIDER_BRIDGE_ORT
|
||||
#include "core/common/cpuid_info.h"
|
||||
#include "onnx/common/stl_backports.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
|
|
@ -123,19 +122,6 @@ struct Node__EdgeIterator_Impl : Node__EdgeIterator {
|
|||
Node::EdgeConstIterator v_;
|
||||
};
|
||||
|
||||
struct OpKernel_Translator : OpKernel {
|
||||
OpKernel_Translator(const OpKernelInfo& info, Provider_OpKernel* p) : OpKernel{info}, p_{p} {
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
return p_->Compute(context, *reinterpret_cast<const Provider_OpKernel_Base*>(static_cast<const OpKernel*>(this)));
|
||||
}
|
||||
|
||||
std::unique_ptr<Provider_OpKernel> p_;
|
||||
|
||||
ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernel_Translator);
|
||||
};
|
||||
|
||||
struct ProviderHostImpl : ProviderHost {
|
||||
ProviderHostImpl() {
|
||||
DataTypeImpl_GetType_Tensor = &DataTypeImpl::GetType<Tensor>;
|
||||
|
|
@ -184,9 +170,13 @@ struct ProviderHostImpl : ProviderHost {
|
|||
return const_cast<logging::Logger*>(&logging::LoggingManager::DefaultLogger());
|
||||
}
|
||||
|
||||
const std::vector<MLDataType>& DataTypeImpl_AllFixedSizeTensorTypes() override {
|
||||
return DataTypeImpl::AllFixedSizeTensorTypes();
|
||||
}
|
||||
int32_t PrimitiveDataTypeBase__GetDataType(const PrimitiveDataTypeBase* p) override { return p->GetDataType(); }
|
||||
|
||||
const char* DataTypeImpl__ToString(MLDataType type) override { return DataTypeImpl::ToString(type); }
|
||||
const std::vector<MLDataType>& DataTypeImpl__AllFixedSizeTensorTypes() override { return DataTypeImpl::AllFixedSizeTensorTypes(); }
|
||||
const std::vector<MLDataType>& DataTypeImpl__AllTensorTypes() override { return DataTypeImpl::AllTensorTypes(); }
|
||||
size_t DataTypeImpl__Size(const DataTypeImpl* p) override { return p->Size(); }
|
||||
const PrimitiveDataTypeBase* DataTypeImpl__AsPrimitiveDataType(const DataTypeImpl* p) override { return p->AsPrimitiveDataType(); }
|
||||
|
||||
void* HeapAllocate(size_t size) override { return new uint8_t[size]; }
|
||||
void HeapFree(void* p) override { delete[] reinterpret_cast<uint8_t*>(p); }
|
||||
|
|
@ -398,6 +388,7 @@ struct ProviderHostImpl : ProviderHost {
|
|||
|
||||
// KernelDef
|
||||
void KernelDef__operator_delete(KernelDef* p) override { delete p; }
|
||||
int KernelDef__ExecQueueId(const KernelDef* p) override { return p->ExecQueueId(); }
|
||||
|
||||
// KernelDefBuilder
|
||||
std::unique_ptr<KernelDefBuilder> KernelDefBuilder__construct() override { return onnxruntime::make_unique<KernelDefBuilder>(); }
|
||||
|
|
@ -418,13 +409,7 @@ struct ProviderHostImpl : ProviderHost {
|
|||
// KernelRegistry
|
||||
std::shared_ptr<KernelRegistry> KernelRegistry__construct() override { return std::make_shared<KernelRegistry>(); }
|
||||
void KernelRegistry__operator_delete(KernelRegistry* p) override { delete p; }
|
||||
Status KernelRegistry__Register(KernelRegistry* p, Provider_KernelCreateInfo&& create_info) override {
|
||||
KernelCreateInfo info_real(std::move(create_info.kernel_def),
|
||||
[kernel_create_func = create_info.kernel_create_func](const OpKernelInfo& info) -> OpKernel* {
|
||||
return new OpKernel_Translator(info, kernel_create_func(info));
|
||||
});
|
||||
return p->Register(std::move(info_real));
|
||||
}
|
||||
Status KernelRegistry__Register(KernelRegistry* p, KernelCreateInfo&& create_info) override { return p->Register(std::move(create_info)); }
|
||||
|
||||
// Function
|
||||
const Graph& Function__Body(const Function* p) override { return p->Body(); }
|
||||
|
|
@ -543,23 +528,21 @@ struct ProviderHostImpl : ProviderHost {
|
|||
const std::vector<NodeIndex>& GraphViewer__GetNodesInTopologicalOrder(const GraphViewer* p) override { return p->GetNodesInTopologicalOrder(); }
|
||||
const std::vector<const NodeArg*>& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept override { return p->GetInputsIncludingInitializers(); }
|
||||
|
||||
//Path
|
||||
//Gets a string representation of the path.
|
||||
// Path
|
||||
PathString Path__ToPathString(const Path* p) noexcept override { return p->ToPathString(); }
|
||||
|
||||
// Provider_OpKernel_Base
|
||||
const OpKernelInfo& Provider_OpKernel_Base__GetInfo(const Provider_OpKernel_Base* p) override { return reinterpret_cast<const OpKernel*>(p)->Info(); }
|
||||
|
||||
// OpKernelContext
|
||||
const Tensor* OpKernelContext__Input_Tensor(const OpKernelContext* p, int index) override { return p->Input<Tensor>(index); }
|
||||
Tensor* OpKernelContext__Output(OpKernelContext* p, int index, const TensorShape& shape) override { return p->Output(index, shape); }
|
||||
|
||||
// OpKernelInfo
|
||||
std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info) override { return onnxruntime::CopyOpKernelInfo(info); }
|
||||
void OpKernelInfo__operator_delete(OpKernelInfo* p) override { delete p; }
|
||||
Status OpKernelInfo__GetAttr_int64(const OpKernelInfo* p, const std::string& name, int64_t* value) override { return p->GetAttr(name, value); }
|
||||
Status OpKernelInfo__GetAttr_float(const OpKernelInfo* p, const std::string& name, float* value) override { return p->GetAttr(name, value); }
|
||||
|
||||
const DataTransferManager& OpKernelInfo__GetDataTransferManager(const OpKernelInfo* p) noexcept override { return p->GetDataTransferManager(); }
|
||||
int OpKernelInfo__GetKernelDef_ExecQueueId(const OpKernelInfo* p) noexcept override { return p->GetKernelDef().ExecQueueId(); }
|
||||
const KernelDef& OpKernelInfo__GetKernelDef(const OpKernelInfo* p) override { return p->GetKernelDef(); }
|
||||
|
||||
// Tensor
|
||||
float* Tensor__MutableData_float(Tensor* p) override { return p->MutableData<float>(); }
|
||||
|
|
@ -570,7 +553,7 @@ struct ProviderHostImpl : ProviderHost {
|
|||
|
||||
const TensorShape& Tensor__Shape(const Tensor* p) override { return p->Shape(); }
|
||||
size_t Tensor__SizeInBytes(const Tensor* p) override { return p->SizeInBytes(); }
|
||||
const OrtMemoryInfo& Tensor__Location(const Tensor* p) override { return p->Location(); }
|
||||
const OrtMemoryInfo& Tensor__Location(const Tensor* p) override { return p->Location(); }
|
||||
|
||||
// AllocatorManager
|
||||
void AllocatorManager__InsertAllocator(AllocatorManager* p, AllocatorPtr allocator) override { p->InsertAllocator(allocator); }
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ namespace ort_dnnl {
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDnnlExecutionProvider, kOnnxDomain, 7, Gemm);
|
||||
|
||||
Status RegisterDNNLKernels(KernelRegistry& kernel_registry) {
|
||||
static const Provider_BuildKernelCreateInfoFn function_table[] = {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kDnnlExecutionProvider, kOnnxDomain, 7, Gemm)>,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,6 @@
|
|||
namespace onnxruntime {
|
||||
namespace ort_dnnl {
|
||||
template <typename T>
|
||||
Provider_KernelCreateInfo BuildKernelCreateInfo();
|
||||
KernelCreateInfo BuildKernelCreateInfo();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ template <typename T>
|
|||
using ConstEigenMatrixMapRowMajor = Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
|
||||
|
||||
template <>
|
||||
Status Gemm<float>::Compute(OpKernelContext* ctx, const Provider_OpKernel_Base&) const {
|
||||
Status Gemm<float>::Compute(OpKernelContext* ctx) const {
|
||||
const auto X = ctx->Input<Tensor>(0);
|
||||
const auto W = ctx->Input<Tensor>(1);
|
||||
const auto B = ctx->Input<Tensor>(2);
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@
|
|||
namespace onnxruntime {
|
||||
namespace ort_dnnl {
|
||||
template <typename T>
|
||||
class Gemm final : public Provider_OpKernel {
|
||||
class Gemm final : public OpKernel {
|
||||
public:
|
||||
Gemm(const OpKernelInfo& info) {
|
||||
Gemm(const OpKernelInfo& info) : OpKernel(info) {
|
||||
int64_t temp;
|
||||
ORT_ENFORCE(info.GetAttr<int64_t>("transA", &temp).IsOK());
|
||||
trans_A_ = (temp != 0);
|
||||
|
|
@ -20,7 +20,7 @@ class Gemm final : public Provider_OpKernel {
|
|||
ORT_ENFORCE(info.GetAttr<float>("beta", &beta_).IsOK());
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context, const Provider_OpKernel_Base& base) const override;
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
bool trans_A_;
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
// switching providers to be runnable as shared libraries. The interfaces will become more tightly integrated into the core code.
|
||||
|
||||
#pragma once
|
||||
#define PROVIDER_BRIDGE_PROVIDER 1
|
||||
#define SHARED_PROVIDER 1
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
|
@ -16,9 +16,11 @@
|
|||
#include "onnx/common/stl_backports.h"
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/const_pointer_container.h"
|
||||
#include "core/common/type_list.h"
|
||||
#include "core/common/logging/severity.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/allocatormgr.h"
|
||||
#include "core/framework/float16.h"
|
||||
#include "core/framework/tensor_shape.h"
|
||||
#include "core/providers/providers.h"
|
||||
#include "core/common/path_string.h"
|
||||
|
|
@ -67,6 +69,7 @@ struct DataTransferManager;
|
|||
struct IDataTransfer;
|
||||
struct IndexedSubGraph;
|
||||
struct IndexedSubGraph_MetaDef;
|
||||
struct KernelCreateInfo;
|
||||
struct KernelDef;
|
||||
struct KernelDefBuilder;
|
||||
struct KernelRegistry;
|
||||
|
|
@ -80,7 +83,11 @@ struct NodeArg;
|
|||
struct NodeAttributes;
|
||||
struct OpKernelContext;
|
||||
struct OpKernelInfo;
|
||||
struct PrimitiveDataTypeBase;
|
||||
struct Tensor;
|
||||
|
||||
class DataTypeImpl;
|
||||
using MLDataType = const DataTypeImpl*;
|
||||
} // namespace onnxruntime
|
||||
|
||||
namespace ONNX_NAMESPACE {
|
||||
|
|
@ -147,6 +154,8 @@ enum OperatorStatus : int {
|
|||
|
||||
#include "core/framework/execution_provider.h"
|
||||
#include "provider_interfaces.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/data_types_internal.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -185,18 +194,6 @@ enum CUDAStreamType : int {
|
|||
kTotalCudaStreams,
|
||||
};
|
||||
|
||||
class DataTypeImpl {
|
||||
public:
|
||||
virtual ~DataTypeImpl() = default;
|
||||
|
||||
template <typename T>
|
||||
static MLDataType GetType();
|
||||
template <typename elemT>
|
||||
static MLDataType GetTensorType();
|
||||
|
||||
static const std::vector<MLDataType>& AllFixedSizeTensorTypes();
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using IAllocatorUniquePtr = std::unique_ptr<T, std::function<void(T*)>>;
|
||||
|
||||
|
|
@ -234,23 +231,6 @@ constexpr T roundUpPow2(T a) {
|
|||
|
||||
} // namespace onnxruntime
|
||||
|
||||
#define ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name) \
|
||||
provider##_##name##_##domain##_ver##ver
|
||||
|
||||
#define ONNX_OPERATOR_KERNEL_EX(name, domain, ver, provider, builder, ...) \
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name); \
|
||||
template <> \
|
||||
Provider_KernelCreateInfo \
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name)>() { \
|
||||
return Provider_KernelCreateInfo( \
|
||||
builder.SetName(#name) \
|
||||
.SetDomain(domain) \
|
||||
.SinceVersion(ver) \
|
||||
.Provider(provider) \
|
||||
.Build(), \
|
||||
static_cast<Provider_KernelCreatePtrFn>([](const OpKernelInfo& info) -> Provider_OpKernel* { return new __VA_ARGS__(info); })); \
|
||||
}
|
||||
|
||||
#define CREATE_MESSAGE(logger, severity, category, datatype) \
|
||||
::onnxruntime::logging::Capture::Create(logger, ::onnxruntime::logging::Severity::k##severity, category, datatype, ORT_WHERE)
|
||||
|
||||
|
|
|
|||
|
|
@ -63,10 +63,6 @@ MLDataType DataTypeImpl::GetTensorType<float>() {
|
|||
return g_host->DataTypeImpl_GetTensorType_float();
|
||||
}
|
||||
|
||||
const std::vector<MLDataType>& DataTypeImpl::AllFixedSizeTensorTypes() {
|
||||
return g_host->DataTypeImpl_AllFixedSizeTensorTypes();
|
||||
}
|
||||
|
||||
TensorShape::TensorShape(const int64_t* dimension_sizes, size_t dimension_count)
|
||||
: std::vector<int64_t>(dimension_count) {
|
||||
for (size_t i = 0; i < dimension_count; ++i) {
|
||||
|
|
@ -221,4 +217,8 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status,
|
|||
return g_host->LogRuntimeError(session_id, status, file, function, line);
|
||||
}
|
||||
|
||||
std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info) {
|
||||
return g_host->CopyOpKernelInfo(info);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -27,8 +27,6 @@ using DataType = const std::string*;
|
|||
|
||||
namespace onnxruntime {
|
||||
// These types don't directly map to internal types
|
||||
struct Provider_KernelCreateInfo;
|
||||
struct Provider_OpKernel_Base;
|
||||
struct ProviderHost;
|
||||
|
||||
class TensorShape;
|
||||
|
|
@ -66,19 +64,6 @@ struct Provider_TensorShapeProto_Dimension_Iterator {
|
|||
virtual const Provider_TensorShapeProto_Dimension& operator*() = 0;
|
||||
};
|
||||
|
||||
class DataTypeImpl;
|
||||
using MLDataType = const DataTypeImpl*;
|
||||
|
||||
struct Provider_OpKernel {
|
||||
Provider_OpKernel() {}
|
||||
virtual ~Provider_OpKernel() = default;
|
||||
|
||||
virtual Status Compute(OpKernelContext* context, const Provider_OpKernel_Base& base) const = 0;
|
||||
|
||||
Provider_OpKernel(const Provider_OpKernel&) = delete;
|
||||
void operator=(const Provider_OpKernel&) = delete;
|
||||
};
|
||||
|
||||
using NodeIndex = size_t;
|
||||
using Provider_NodeArgInfo = Provider_ValueInfoProto;
|
||||
// We can't just reinterpret_cast this one, since it's an unordered_map of object BY VALUE (can't do anything by value on the real types)
|
||||
|
|
@ -142,10 +127,19 @@ struct ProviderHost {
|
|||
|
||||
virtual std::string GetEnvironmentVar(const std::string& var_name) = 0;
|
||||
|
||||
// PrimitiveDataTypeBase
|
||||
virtual int32_t PrimitiveDataTypeBase__GetDataType(const PrimitiveDataTypeBase* p) = 0;
|
||||
|
||||
// DataTypeImpl
|
||||
MLDataType (*DataTypeImpl_GetType_Tensor)();
|
||||
MLDataType (*DataTypeImpl_GetType_float)();
|
||||
MLDataType (*DataTypeImpl_GetTensorType_float)();
|
||||
virtual const std::vector<MLDataType>& DataTypeImpl_AllFixedSizeTensorTypes() = 0;
|
||||
|
||||
virtual const char* DataTypeImpl__ToString(MLDataType type) = 0;
|
||||
virtual const std::vector<MLDataType>& DataTypeImpl__AllFixedSizeTensorTypes() = 0;
|
||||
virtual const std::vector<MLDataType>& DataTypeImpl__AllTensorTypes() = 0;
|
||||
virtual size_t DataTypeImpl__Size(const DataTypeImpl* p) = 0;
|
||||
virtual const PrimitiveDataTypeBase* DataTypeImpl__AsPrimitiveDataType(const DataTypeImpl* p) = 0;
|
||||
|
||||
virtual void* HeapAllocate(size_t size) = 0;
|
||||
virtual void HeapFree(void*) = 0;
|
||||
|
|
@ -335,6 +329,7 @@ struct ProviderHost {
|
|||
|
||||
// KernelDef
|
||||
virtual void KernelDef__operator_delete(KernelDef* p) = 0;
|
||||
virtual int KernelDef__ExecQueueId(const KernelDef* p) = 0;
|
||||
|
||||
// KernelDefBuilder
|
||||
virtual std::unique_ptr<KernelDefBuilder> KernelDefBuilder__construct() = 0;
|
||||
|
|
@ -355,7 +350,7 @@ struct ProviderHost {
|
|||
// KernelRegistry
|
||||
virtual std::shared_ptr<KernelRegistry> KernelRegistry__construct() = 0;
|
||||
virtual void KernelRegistry__operator_delete(KernelRegistry* p) = 0;
|
||||
virtual Status KernelRegistry__Register(KernelRegistry* p, Provider_KernelCreateInfo&& create_info) = 0;
|
||||
virtual Status KernelRegistry__Register(KernelRegistry* p, KernelCreateInfo&& create_info) = 0;
|
||||
|
||||
// Function
|
||||
virtual const Graph& Function__Body(const Function* p) = 0;
|
||||
|
|
@ -465,19 +460,18 @@ struct ProviderHost {
|
|||
// Path
|
||||
virtual PathString Path__ToPathString(const Path* p) noexcept = 0;
|
||||
|
||||
// Provider_OpKernel_Base
|
||||
virtual const OpKernelInfo& Provider_OpKernel_Base__GetInfo(const Provider_OpKernel_Base* p) = 0;
|
||||
|
||||
// OpKernelContext
|
||||
virtual const Tensor* OpKernelContext__Input_Tensor(const OpKernelContext* p, int index) = 0;
|
||||
virtual Tensor* OpKernelContext__Output(OpKernelContext* p, int index, const TensorShape& shape) = 0;
|
||||
|
||||
// OpKernelInfo
|
||||
virtual std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info) = 0;
|
||||
virtual void OpKernelInfo__operator_delete(OpKernelInfo* p) = 0;
|
||||
virtual Status OpKernelInfo__GetAttr_int64(const OpKernelInfo* p, const std::string& name, int64_t* value) = 0;
|
||||
virtual Status OpKernelInfo__GetAttr_float(const OpKernelInfo* p, const std::string& name, float* value) = 0;
|
||||
|
||||
virtual const DataTransferManager& OpKernelInfo__GetDataTransferManager(const OpKernelInfo* p) noexcept = 0;
|
||||
virtual int OpKernelInfo__GetKernelDef_ExecQueueId(const OpKernelInfo* p) noexcept = 0;
|
||||
virtual const KernelDef& OpKernelInfo__GetKernelDef(const OpKernelInfo* p) = 0;
|
||||
|
||||
// Tensor
|
||||
virtual float* Tensor__MutableData_float(Tensor* p) = 0;
|
||||
|
|
@ -497,7 +491,7 @@ struct ProviderHost {
|
|||
|
||||
extern ProviderHost* g_host;
|
||||
|
||||
#ifndef PROVIDER_BRIDGE_ORT
|
||||
#ifdef SHARED_PROVIDER
|
||||
|
||||
struct CPUIDInfo {
|
||||
static const CPUIDInfo& GetCPUIDInfo() { return g_host->CPUIDInfo__GetCPUIDInfo(); }
|
||||
|
|
@ -767,32 +761,17 @@ struct IndexedSubGraph {
|
|||
struct KernelDef {
|
||||
static void operator delete(void* p) { g_host->KernelDef__operator_delete(reinterpret_cast<KernelDef*>(p)); }
|
||||
|
||||
int ExecQueueId() const { return g_host->KernelDef__ExecQueueId(this); }
|
||||
|
||||
KernelDef() = delete;
|
||||
KernelDef(const KernelDef*) = delete;
|
||||
void operator=(const KernelDef&) = delete;
|
||||
};
|
||||
#endif
|
||||
|
||||
using Provider_KernelCreateFn = std::function<Provider_OpKernel*(const OpKernelInfo& info)>;
|
||||
using Provider_KernelCreatePtrFn = std::add_pointer<Provider_OpKernel*(const OpKernelInfo& info)>::type;
|
||||
using BuildKernelCreateInfoFn = KernelCreateInfo (*)();
|
||||
|
||||
struct Provider_KernelCreateInfo {
|
||||
std::unique_ptr<KernelDef> kernel_def; // Owned and stored in the global kernel registry.
|
||||
Provider_KernelCreateFn kernel_create_func;
|
||||
|
||||
Provider_KernelCreateInfo(std::unique_ptr<KernelDef> definition,
|
||||
Provider_KernelCreateFn create_func)
|
||||
: kernel_def(std::move(definition)),
|
||||
kernel_create_func(create_func) {}
|
||||
|
||||
Provider_KernelCreateInfo(Provider_KernelCreateInfo&& other) noexcept
|
||||
: kernel_def(std::move(other.kernel_def)),
|
||||
kernel_create_func(std::move(other.kernel_create_func)) {}
|
||||
};
|
||||
|
||||
using Provider_BuildKernelCreateInfoFn = Provider_KernelCreateInfo (*)();
|
||||
|
||||
#ifndef PROVIDER_BRIDGE_ORT
|
||||
#ifdef SHARED_PROVIDER
|
||||
struct KernelDefBuilder {
|
||||
static std::unique_ptr<KernelDefBuilder> Create() { return g_host->KernelDefBuilder__construct(); }
|
||||
static void operator delete(void* p) { g_host->KernelDefBuilder__operator_delete(reinterpret_cast<KernelDefBuilder*>(p)); }
|
||||
|
|
@ -845,13 +824,38 @@ struct KernelRegistry {
|
|||
static std::shared_ptr<KernelRegistry> Create() { return g_host->KernelRegistry__construct(); }
|
||||
static void operator delete(void* p) { g_host->KernelRegistry__operator_delete(reinterpret_cast<KernelRegistry*>(p)); }
|
||||
|
||||
Status Register(Provider_KernelCreateInfo&& create_info) { return g_host->KernelRegistry__Register(this, std::move(create_info)); }
|
||||
Status Register(KernelCreateInfo&& create_info) { return g_host->KernelRegistry__Register(this, std::move(create_info)); }
|
||||
|
||||
KernelRegistry() = delete;
|
||||
KernelRegistry(const KernelRegistry&) = delete;
|
||||
void operator=(const KernelRegistry&) = delete;
|
||||
};
|
||||
|
||||
struct PrimitiveDataTypeBase {
|
||||
int32_t GetDataType() const { return g_host->PrimitiveDataTypeBase__GetDataType(this); }
|
||||
|
||||
PROVIDER_DISALLOW_ALL(PrimitiveDataTypeBase)
|
||||
};
|
||||
|
||||
class DataTypeImpl {
|
||||
public:
|
||||
size_t Size() const { return g_host->DataTypeImpl__Size(this); }
|
||||
|
||||
template <typename T>
|
||||
static MLDataType GetType();
|
||||
template <typename elemT>
|
||||
static MLDataType GetTensorType();
|
||||
|
||||
static const std::vector<MLDataType>& AllFixedSizeTensorTypes() { return g_host->DataTypeImpl__AllFixedSizeTensorTypes(); }
|
||||
static const std::vector<MLDataType>& AllTensorTypes() { return g_host->DataTypeImpl__AllTensorTypes(); }
|
||||
|
||||
const PrimitiveDataTypeBase* AsPrimitiveDataType() const { return g_host->DataTypeImpl__AsPrimitiveDataType(this); }
|
||||
|
||||
static const char* ToString(MLDataType type) { return g_host->DataTypeImpl__ToString(type); }
|
||||
|
||||
PROVIDER_DISALLOW_ALL(DataTypeImpl)
|
||||
};
|
||||
|
||||
struct Function {
|
||||
const Graph& Body() const { return g_host->Function__Body(this); }
|
||||
|
||||
|
|
@ -1023,13 +1027,7 @@ struct Path {
|
|||
|
||||
#endif
|
||||
|
||||
struct Provider_OpKernel_Base {
|
||||
const OpKernelInfo& GetInfo() const { return g_host->Provider_OpKernel_Base__GetInfo(this); }
|
||||
|
||||
PROVIDER_DISALLOW_ALL(Provider_OpKernel_Base)
|
||||
};
|
||||
|
||||
#ifndef PROVIDER_BRIDGE_ORT
|
||||
#ifdef SHARED_PROVIDER
|
||||
struct OpKernelContext {
|
||||
const Tensor* Input_Tensor(int index) const { return g_host->OpKernelContext__Input_Tensor(this, index); }
|
||||
|
||||
|
|
@ -1047,6 +1045,8 @@ inline const Tensor* OpKernelContext::Input<Tensor>(int index) const {
|
|||
}
|
||||
|
||||
struct OpKernelInfo {
|
||||
static void operator delete(void* p) { g_host->OpKernelInfo__operator_delete(reinterpret_cast<OpKernelInfo*>(p)); }
|
||||
|
||||
template <typename T>
|
||||
Status GetAttr(const std::string& name, T* value) const;
|
||||
|
||||
|
|
@ -1054,9 +1054,11 @@ struct OpKernelInfo {
|
|||
Status GetAttr(const std::string& name, float* value) const { return g_host->OpKernelInfo__GetAttr_float(this, name, value); }
|
||||
|
||||
const DataTransferManager& GetDataTransferManager() const noexcept { return g_host->OpKernelInfo__GetDataTransferManager(this); }
|
||||
int GetKernelDef_ExecQueueId() const noexcept { return g_host->OpKernelInfo__GetKernelDef_ExecQueueId(this); }
|
||||
const KernelDef& GetKernelDef() const { return g_host->OpKernelInfo__GetKernelDef(this); }
|
||||
|
||||
PROVIDER_DISALLOW_ALL(OpKernelInfo)
|
||||
OpKernelInfo() = delete;
|
||||
OpKernelInfo(const OpKernelInfo&) = delete;
|
||||
void operator=(const OpKernelInfo&) = delete;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
|
|
|||
|
|
@ -307,20 +307,20 @@ bool CudaCall<cudaError, true>(cudaError retCode, const char* exprString, const
|
|||
return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg);
|
||||
}
|
||||
|
||||
class Memcpy final : public Provider_OpKernel {
|
||||
class Memcpy final : public OpKernel {
|
||||
public:
|
||||
Memcpy(const OpKernelInfo&) {}
|
||||
Memcpy(const OpKernelInfo& info) : OpKernel(info) {}
|
||||
|
||||
Status Compute(OpKernelContext* ctx, const Provider_OpKernel_Base& base) const override {
|
||||
Status Compute(OpKernelContext* ctx) const override {
|
||||
const auto* X = ctx->Input<Tensor>(0);
|
||||
Tensor* Y = ctx->Output(0, X->Shape());
|
||||
Status retval = base.GetInfo().GetDataTransferManager().CopyTensor(*X, *Y, base.GetInfo().GetKernelDef_ExecQueueId());
|
||||
Status retval = Info().GetDataTransferManager().CopyTensor(*X, *Y, Info().GetKernelDef().ExecQueueId());
|
||||
return retval;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Provider_KernelCreateInfo BuildKernelCreateInfo();
|
||||
KernelCreateInfo BuildKernelCreateInfo();
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
MemcpyFromHost,
|
||||
|
|
@ -348,7 +348,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kTensorrtExecutionProvider, kOnnxDomain, 1
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kTensorrtExecutionProvider, kOnnxDomain, 1, MemcpyToHost);
|
||||
|
||||
static Status RegisterTensorrtKernels(KernelRegistry& kernel_registry) {
|
||||
static const Provider_BuildKernelCreateInfoFn function_table[] = {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kTensorrtExecutionProvider, kOnnxDomain, 1, MemcpyFromHost)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kTensorrtExecutionProvider, kOnnxDomain, 1, MemcpyToHost)>,
|
||||
};
|
||||
|
|
@ -463,12 +463,12 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
|
|||
|
||||
if (engine_decryption_enable_) {
|
||||
std::string engine_decryption_lib_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDecryptionLibPath);
|
||||
LIBTYPE handle = OPENLIB(engine_decryption_lib_path.c_str());
|
||||
if (handle == nullptr) {
|
||||
ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not open shared library from " + engine_decryption_lib_path);
|
||||
}
|
||||
engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt");
|
||||
LIBTYPE handle = OPENLIB(engine_decryption_lib_path.c_str());
|
||||
if (handle == nullptr) {
|
||||
ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not open shared library from " + engine_decryption_lib_path);
|
||||
}
|
||||
engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1289,7 +1289,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
|
||||
}
|
||||
trt_context = trt_state->context->get();
|
||||
} else if (trt_state->engine_decryption_enable && !engine_file && profile_file) {
|
||||
} else if (trt_state->engine_decryption_enable && !engine_file && profile_file) {
|
||||
shape_ranges = DeserializeProfile(profile_file);
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
|
||||
// Decrypt engine
|
||||
|
|
|
|||
Loading…
Reference in a new issue