Change OpKernel class to be shared with shared providers (#6837)

In the previous shared providers there aren't many OpKernel classes, and the existing Provider_OpKernel wrapper was fine. With the opposibility of making Cuda a shared provider, having this need to be changed per OpKernel adds a lot of complexity.

It was fairly straightforward to make OpKernel work with shared providers with minimal changes.

In this change, the ONNX_OPERATOR_* macros can also be shared with the shared providers.
This commit is contained in:
Ryan Hill 2021-03-02 00:53:48 -08:00 committed by GitHub
parent 38796ad451
commit 0d0eb2c85c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 451 additions and 450 deletions

View file

@ -13,6 +13,7 @@
#include "core/common/common.h"
#include "core/common/exceptions.h"
#include "core/framework/endian.h"
#include "core/framework/float16.h"
#include "core/graph/onnx_protobuf.h"
struct OrtValue;
@ -51,93 +52,6 @@ class SequenceTensorTypeBase;
class NonTensorTypeBase;
class PrimitiveDataTypeBase;
// MLFloat16
struct MLFloat16 {
uint16_t val;
MLFloat16() : val(0) {}
explicit MLFloat16(uint16_t x) : val(x) {}
explicit MLFloat16(float f);
float ToFloat() const;
operator float() const {
return ToFloat();
}
};
inline bool operator==(const MLFloat16& left, const MLFloat16& right) {
return left.val == right.val;
}
inline bool operator!=(const MLFloat16& left, const MLFloat16& right) {
return left.val != right.val;
}
inline bool operator<(const MLFloat16& left, const MLFloat16& right) {
return left.val < right.val;
}
//BFloat16
struct BFloat16 {
uint16_t val{0};
explicit BFloat16() = default;
explicit BFloat16(uint16_t v) : val(v) {}
explicit BFloat16(float v) {
if (endian::native == endian::little) {
std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t));
} else {
std::memcpy(&val, &v, sizeof(uint16_t));
}
}
float ToFloat() const {
float result;
char* const first = reinterpret_cast<char*>(&result);
char* const second = first + sizeof(uint16_t);
if (endian::native == endian::little) {
std::memset(first, 0, sizeof(uint16_t));
std::memcpy(second, &val, sizeof(uint16_t));
} else {
std::memcpy(first, &val, sizeof(uint16_t));
std::memset(second, 0, sizeof(uint16_t));
}
return result;
}
operator float() const {
return ToFloat();
}
};
inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) {
auto src = blf;
auto d = flt;
for (; size != 0; ++src, ++d, --size) {
*d = src->ToFloat();
}
}
inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) {
auto src = flt;
auto d = blf;
for (; size != 0; ++src, ++d, --size) {
new (d) BFloat16(*src);
}
}
inline bool operator==(const BFloat16& left, const BFloat16& right) {
return left.val == right.val;
}
inline bool operator!=(const BFloat16& left, const BFloat16& right) {
return left.val != right.val;
}
inline bool operator<(const BFloat16& left, const BFloat16& right) {
return left.val < right.val;
}
// DataTypeImpl pointer as unique DataTypeImpl identifier.
using MLDataType = const DataTypeImpl*;
// be used with class MLValue

View file

@ -13,9 +13,11 @@
#include "boost/mp11.hpp"
#include "core/common/common.h"
#ifndef SHARED_PROVIDER
#include "core/common/type_list.h"
#include "core/framework/data_types.h"
#include "core/graph/onnx_protobuf.h"
#endif
namespace onnxruntime {
namespace utils {

View file

@ -3,7 +3,7 @@
#pragma once
#ifndef PROVIDER_BRIDGE_PROVIDER
#ifndef SHARED_PROVIDER
#include <unordered_map>
#include <unordered_set>

View file

@ -0,0 +1,97 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "endian.h"
namespace onnxruntime
{
// MLFloat16
struct MLFloat16 {
uint16_t val;
MLFloat16() : val(0) {}
explicit MLFloat16(uint16_t x) : val(x) {}
explicit MLFloat16(float f);
float ToFloat() const;
operator float() const {
return ToFloat();
}
};
inline bool operator==(const MLFloat16& left, const MLFloat16& right) {
return left.val == right.val;
}
inline bool operator!=(const MLFloat16& left, const MLFloat16& right) {
return left.val != right.val;
}
inline bool operator<(const MLFloat16& left, const MLFloat16& right) {
return left.val < right.val;
}
//BFloat16
struct BFloat16 {
uint16_t val{0};
explicit BFloat16() = default;
explicit BFloat16(uint16_t v) : val(v) {}
explicit BFloat16(float v) {
if (endian::native == endian::little) {
std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t));
} else {
std::memcpy(&val, &v, sizeof(uint16_t));
}
}
float ToFloat() const {
float result;
char* const first = reinterpret_cast<char*>(&result);
char* const second = first + sizeof(uint16_t);
if (endian::native == endian::little) {
std::memset(first, 0, sizeof(uint16_t));
std::memcpy(second, &val, sizeof(uint16_t));
} else {
std::memcpy(first, &val, sizeof(uint16_t));
std::memset(second, 0, sizeof(uint16_t));
}
return result;
}
operator float() const {
return ToFloat();
}
};
inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) {
auto src = blf;
auto d = flt;
for (; size != 0; ++src, ++d, --size) {
*d = src->ToFloat();
}
}
inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) {
auto src = flt;
auto d = blf;
for (; size != 0; ++src, ++d, --size) {
new (d) BFloat16(*src);
}
}
inline bool operator==(const BFloat16& left, const BFloat16& right) {
return left.val == right.val;
}
inline bool operator!=(const BFloat16& left, const BFloat16& right) {
return left.val != right.val;
}
inline bool operator<(const BFloat16& left, const BFloat16& right) {
return left.val < right.val;
}
}

View file

@ -6,6 +6,9 @@
#include "core/framework/op_kernel.h"
namespace onnxruntime {
using KernelCreateMap = std::multimap<std::string, KernelCreateInfo>;
/**
* Each provider has a KernelRegistry. Often, the KernelRegistry only belongs to that specific provider.
*

View file

@ -3,10 +3,10 @@
#pragma once
#include <functional>
#include "boost/mp11.hpp"
#ifndef SHARED_PROVIDER
#include <functional>
#include "core/common/exceptions.h"
#include "core/common/logging/logging.h"
#include "core/common/status.h"
@ -21,29 +21,24 @@
#include "core/graph/graph_viewer.h"
#include "core/graph/onnx_protobuf.h"
#include "gsl/gsl"
namespace onnxruntime {
class OpKernelContext;
}
#endif
namespace onnxruntime {
class IExecutionFrame;
class OpKernelContext;
class OpKernelWrapper;
namespace concurrency {
class ThreadPool;
}
std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info);
class OpKernel {
public:
using DoneCallback = std::function<void()>;
explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_(info) {}
explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_(CopyOpKernelInfo(info)) {}
virtual ~OpKernel() = default;
const onnxruntime::Node& Node() const {
return op_kernel_info_.node();
}
const onnxruntime::KernelDef& KernelDef() const {
return op_kernel_info_.GetKernelDef();
}
const onnxruntime::Node& Node() const;
const onnxruntime::KernelDef& KernelDef() const;
virtual Status Compute(_Inout_ OpKernelContext* context) const ORT_MUST_USE_RESULT = 0;
@ -73,219 +68,14 @@ class OpKernel {
return Status::OK();
}
const OrtMemoryInfo& Allocator(int id, OrtMemType mem_type) const {
return op_kernel_info_.GetMemoryInfo(id, mem_type);
}
const OpKernelInfo& Info() const { return op_kernel_info_; }
const OrtMemoryInfo& Allocator(int id, OrtMemType mem_type) const;
const OpKernelInfo& Info() const { return *op_kernel_info_; }
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernel);
OpKernelInfo op_kernel_info_;
std::unique_ptr<OpKernelInfo> op_kernel_info_;
};
class OpKernelContext {
public:
using ArgMap = std::unordered_map<std::string, size_t>;
OpKernelContext(_Inout_ IExecutionFrame* frame, _In_ const OpKernel* kernel,
_In_opt_ concurrency::ThreadPool* threadpool, _In_ const logging::Logger& logger);
virtual ~OpKernelContext() = default;
/**
Return the number of inputs for a variadic argument.
@param arg_num The operator argument number.
@returns Number of inputs the argument has.
*/
int NumVariadicInputs(size_t arg_num) const;
MLDataType InputType(int index) const;
MLDataType OutputType(int index) const;
template <typename T>
const T* Input(int index) const {
const OrtValue* p_ml_value = GetInputMLValue(index);
ORT_TRY {
return p_ml_value ? &(p_ml_value->Get<T>()) : nullptr;
}
ORT_CATCH(const std::exception& /*e*/) {
ORT_THROW("Missing Input: " + kernel_->Node().InputDefs()[index]->Name());
}
}
// Fetch a required input, enforcing that it is present.
template <typename T>
const T& RequiredInput(int index) const {
const T* input_ptr = Input<T>(index);
ORT_ENFORCE(input_ptr, "Required input at index ", index, " is not present.");
return *input_ptr;
}
// Fetch output (non-tensor) with specified index.
template <typename T>
T* Output(int index) {
if (index < 0 || index >= OutputCount())
return nullptr;
OrtValue* p_ml_value = GetOrCreateOutputMLValue(index);
return p_ml_value ? p_ml_value->GetMutable<T>() : nullptr;
}
// In the case that memory allocation has not been done for an output tensor,
// The memory allocation will be done on-the-fly with given tensor shape.
// Return nullptr if the output is an unused optional output.
Tensor* Output(int index, const TensorShape& shape);
Tensor* Output(int index, const std::vector<int64_t>& shape);
Tensor* Output(int index, const std::initializer_list<int64_t>& shape);
// Fetch a required tensor output, enforcing that it is present.
Tensor& RequiredOutput(int index, const TensorShape& shape) {
Tensor* output_ptr = Output(index, shape);
ORT_ENFORCE(output_ptr, "Required output at index ", index, " is not present.");
return *output_ptr;
}
// Fetch a sparse-tensor output corresponding to the specified index.
// num_values must specify the number of non-zero values (commonly known as NNZ/nnz),
// and shape must specify the shape of the underlying dense-tensor.
// Memory allocation for the output may happen when this method is invoked,
// unless static optimization pre-allocates it.
SparseTensor* Output(int index, size_t num_values, const TensorShape& shape);
// Retrieve indexed shape obtained from memory planning before actual
// computation. If the indexed shape cannot be inferred, this function returns
// false.
bool TryGetInferredInputShape(int index, TensorShape& shape) const;
// Retrieve indexed shape obtained from memory planning before actual
// computation. If the indexed shape cannot be inferred, this function returns
// false.
bool TryGetInferredOutputShape(int index, TensorShape& shape) const;
const logging::Logger& Logger() const {
return *logger_;
}
// always >= 0
int InputCount() const {
return static_cast<int>(kernel_->Node().InputDefs().size());
}
// always >= 0
int ImplicitInputCount() const {
return static_cast<int>(kernel_->Node().ImplicitInputDefs().size());
}
// always >= 0
int OutputCount() const {
return static_cast<int>(kernel_->Node().OutputDefs().size());
}
/**
Return an allocator on device 0, with memtype of OrtMemTypeDefault.
@remarks Use SafeInt when calculating the size of memory to allocate using AllocatorPtr->Alloc.
*/
Status GetTempSpaceAllocator(AllocatorPtr* output) const ORT_MUST_USE_RESULT;
/**
Return the fence of current node's input.
@param index The index of the input.
@returns Point to the Fence of the input OrtValue.
It is null if the input OrtValue doesn't have fence or the input is optional.
*/
Fence_t InputFence(int index) const;
/**
Return the fence of current node's implicit input.
@param index The index of the implicit input.
@returns Point to the Fence of the implicit input OrtValue.
It is null if the input OrtValue doesn't have fence or the input is optional.
*/
Fence_t ImplicitInputFence(int index) const;
/**
Return the fence of current node's output identifed by index.
@param index The index of the output.
@returns Point to the Fence of the output OrtValue.
It is null if the output OrtValue doesn't have fence or the output is optional.
*/
Fence_t OutputFence(int index) const;
/**
Return the device id that current kernel runs on.
*/
int GetDeviceId() const {
return kernel_->Info().GetExecutionProvider()->GetDeviceId();
}
/**
Returns the opset domain of the underlying kernel
**/
const std::string& GetOpDomain() const;
/**
Returns the optype of the underlying kernel
**/
const std::string& GetOpType() const;
/**
Returns the node name of the underlying kernel
**/
const std::string& GetNodeName() const;
/**
Returns the intra-op threadpool, if available.
*/
_Ret_maybenull_ onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() const { return threadpool_; }
/**
Returns whether deterministic computation is preferred.
*/
virtual bool GetUseDeterministicCompute() const {
return true;
}
protected:
onnxruntime::NodeIndex GetNodeIndex() const;
const OrtValue* GetInputMLValue(int index) const;
const OrtValue* GetImplicitInputMLValue(int index) const;
OrtValue* GetOutputMLValue(int index);
// Creates the OrtValue* based on the shape, if it does not exist
// The parameter nnz is used only for sparse-tensors and indicates the
// number of non-zero values (the number of elements in the values buffer allocated).
OrtValue* OutputMLValue(int index, const TensorShape& shape, size_t nnz = 0);
private:
ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernelContext);
OrtValue* GetOrCreateOutputMLValue(int index);
int GetInputArgIndex(int index) const;
int GetImplicitInputArgIndex(int index) const;
int GetOutputArgIndex(int index) const;
IExecutionFrame* const execution_frame_;
const OpKernel* const kernel_;
concurrency::ThreadPool* const threadpool_;
const logging::Logger* const logger_;
// The argument starting index in ExecutionFrame.
int node_input_start_index_{-1};
int node_implicit_input_start_index_{-1};
int node_output_start_index_{-1};
};
// Fetching output tensor without shape is not allowed except when it already exists
template <>
inline Tensor* OpKernelContext::Output<Tensor>(int index) {
OrtValue* p_ml_value = GetOutputMLValue(index);
ORT_ENFORCE(p_ml_value, "Please fetch output tensor with specified shape.");
return p_ml_value->GetMutable<Tensor>();
}
using KernelCreateFn = std::function<OpKernel*(const OpKernelInfo& info)>;
using KernelCreatePtrFn = std::add_pointer<OpKernel*(const OpKernelInfo& info)>::type;
@ -306,8 +96,6 @@ struct KernelCreateInfo {
KernelCreateInfo() = default;
};
using KernelCreateMap = std::multimap<std::string, KernelCreateInfo>;
// Forward declarations for the non-specialized BuildKernelCreateInfo method.
template <typename T>
KernelCreateInfo BuildKernelCreateInfo();
@ -504,3 +292,7 @@ std::vector<MLDataType> BuildKernelDefConstraintsFromTypeList() {
}
} // namespace onnxruntime
#ifndef SHARED_PROVIDER
#include "core/framework/op_kernel_context.h"
#endif

View file

@ -0,0 +1,212 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
namespace onnxruntime {
class IExecutionFrame;
namespace concurrency {
class ThreadPool;
}
class OpKernelContext {
public:
using ArgMap = std::unordered_map<std::string, size_t>;
OpKernelContext(_Inout_ IExecutionFrame* frame, _In_ const OpKernel* kernel,
_In_opt_ concurrency::ThreadPool* threadpool, _In_ const logging::Logger& logger);
virtual ~OpKernelContext() = default;
/**
Return the number of inputs for a variadic argument.
@param arg_num The operator argument number.
@returns Number of inputs the argument has.
*/
int NumVariadicInputs(size_t arg_num) const;
MLDataType InputType(int index) const;
MLDataType OutputType(int index) const;
template <typename T>
const T* Input(int index) const {
const OrtValue* p_ml_value = GetInputMLValue(index);
ORT_TRY {
return p_ml_value ? &(p_ml_value->Get<T>()) : nullptr;
}
ORT_CATCH(const std::exception& /*e*/) {
ORT_THROW("Missing Input: " + kernel_->Node().InputDefs()[index]->Name());
}
}
// Fetch a required input, enforcing that it is present.
template <typename T>
const T& RequiredInput(int index) const {
const T* input_ptr = Input<T>(index);
ORT_ENFORCE(input_ptr, "Required input at index ", index, " is not present.");
return *input_ptr;
}
// Fetch output (non-tensor) with specified index.
template <typename T>
T* Output(int index) {
if (index < 0 || index >= OutputCount())
return nullptr;
OrtValue* p_ml_value = GetOrCreateOutputMLValue(index);
return p_ml_value ? p_ml_value->GetMutable<T>() : nullptr;
}
// In the case that memory allocation has not been done for an output tensor,
// The memory allocation will be done on-the-fly with given tensor shape.
// Return nullptr if the output is an unused optional output.
Tensor* Output(int index, const TensorShape& shape);
Tensor* Output(int index, const std::vector<int64_t>& shape);
Tensor* Output(int index, const std::initializer_list<int64_t>& shape);
// Fetch a required tensor output, enforcing that it is present.
Tensor& RequiredOutput(int index, const TensorShape& shape) {
Tensor* output_ptr = Output(index, shape);
ORT_ENFORCE(output_ptr, "Required output at index ", index, " is not present.");
return *output_ptr;
}
// Fetch a sparse-tensor output corresponding to the specified index.
// num_values must specify the number of non-zero values (commonly known as NNZ/nnz),
// and shape must specify the shape of the underlying dense-tensor.
// Memory allocation for the output may happen when this method is invoked,
// unless static optimization pre-allocates it.
SparseTensor* Output(int index, size_t num_values, const TensorShape& shape);
// Retrieve indexed shape obtained from memory planning before actual
// computation. If the indexed shape cannot be inferred, this function returns
// false.
bool TryGetInferredInputShape(int index, TensorShape& shape) const;
// Retrieve indexed shape obtained from memory planning before actual
// computation. If the indexed shape cannot be inferred, this function returns
// false.
bool TryGetInferredOutputShape(int index, TensorShape& shape) const;
const logging::Logger& Logger() const {
return *logger_;
}
// always >= 0
int InputCount() const {
return static_cast<int>(kernel_->Node().InputDefs().size());
}
// always >= 0
int ImplicitInputCount() const {
return static_cast<int>(kernel_->Node().ImplicitInputDefs().size());
}
// always >= 0
int OutputCount() const {
return static_cast<int>(kernel_->Node().OutputDefs().size());
}
/**
Return an allocator on device 0, with memtype of OrtMemTypeDefault.
@remarks Use SafeInt when calculating the size of memory to allocate using AllocatorPtr->Alloc.
*/
Status GetTempSpaceAllocator(AllocatorPtr* output) const ORT_MUST_USE_RESULT;
/**
Return the fence of current node's input.
@param index The index of the input.
@returns Point to the Fence of the input OrtValue.
It is null if the input OrtValue doesn't have fence or the input is optional.
*/
Fence_t InputFence(int index) const;
/**
Return the fence of current node's implicit input.
@param index The index of the implicit input.
@returns Point to the Fence of the implicit input OrtValue.
It is null if the input OrtValue doesn't have fence or the input is optional.
*/
Fence_t ImplicitInputFence(int index) const;
/**
Return the fence of current node's output identifed by index.
@param index The index of the output.
@returns Point to the Fence of the output OrtValue.
It is null if the output OrtValue doesn't have fence or the output is optional.
*/
Fence_t OutputFence(int index) const;
/**
Return the device id that current kernel runs on.
*/
int GetDeviceId() const {
return kernel_->Info().GetExecutionProvider()->GetDeviceId();
}
/**
Returns the opset domain of the underlying kernel
**/
const std::string& GetOpDomain() const;
/**
Returns the optype of the underlying kernel
**/
const std::string& GetOpType() const;
/**
Returns the node name of the underlying kernel
**/
const std::string& GetNodeName() const;
/**
Returns the intra-op threadpool, if available.
*/
_Ret_maybenull_ onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() const { return threadpool_; }
/**
Returns whether deterministic computation is preferred.
*/
virtual bool GetUseDeterministicCompute() const {
return true;
}
protected:
onnxruntime::NodeIndex GetNodeIndex() const;
const OrtValue* GetInputMLValue(int index) const;
const OrtValue* GetImplicitInputMLValue(int index) const;
OrtValue* GetOutputMLValue(int index);
// Creates the OrtValue* based on the shape, if it does not exist
// The parameter nnz is used only for sparse-tensors and indicates the
// number of non-zero values (the number of elements in the values buffer allocated).
OrtValue* OutputMLValue(int index, const TensorShape& shape, size_t nnz = 0);
private:
ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernelContext);
OrtValue* GetOrCreateOutputMLValue(int index);
int GetInputArgIndex(int index) const;
int GetImplicitInputArgIndex(int index) const;
int GetOutputArgIndex(int index) const;
IExecutionFrame* const execution_frame_;
const OpKernel* const kernel_;
concurrency::ThreadPool* const threadpool_;
const logging::Logger* const logger_;
// The argument starting index in ExecutionFrame.
int node_input_start_index_{-1};
int node_implicit_input_start_index_{-1};
int node_output_start_index_{-1};
};
// Fetching output tensor without shape is not allowed except when it already exists
template <>
inline Tensor* OpKernelContext::Output<Tensor>(int index) {
OrtValue* p_ml_value = GetOutputMLValue(index);
ORT_ENFORCE(p_ml_value, "Please fetch output tensor with specified shape.");
return p_ml_value->GetMutable<Tensor>();
}
} // namespace onnxruntime

View file

@ -9,6 +9,22 @@
using namespace ::onnxruntime::common;
namespace onnxruntime {
std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info) {
return onnxruntime::make_unique<OpKernelInfo>(info);
}
const onnxruntime::Node& OpKernel::Node() const {
return op_kernel_info_->node();
}
const onnxruntime::KernelDef& OpKernel::KernelDef() const {
return op_kernel_info_->GetKernelDef();
}
const OrtMemoryInfo& OpKernel::Allocator(int id, OrtMemType mem_type) const {
return op_kernel_info_->GetMemoryInfo(id, mem_type);
}
OpKernelContext::OpKernelContext(_Inout_ IExecutionFrame* frame, _In_ const OpKernel* kernel,
_In_opt_ concurrency::ThreadPool* threadpool, _In_ const logging::Logger& logger)
: execution_frame_(frame), kernel_(kernel), threadpool_(threadpool), logger_(&logger) {

View file

@ -26,7 +26,7 @@
#endif
namespace onnxruntime {
// These Provider types are really just internal types, so we #define PROVIDER_BRIDGE_ORT so that only these definitions are seen by provider_interfaces.h
// These Provider types are really just internal types, since we don't include provider_api.h only these definitions are seen by provider_interfaces.h
// Users of provider_interfaces.h (through provider_api.h) will see the wrappers that call through the provider shared interface which is implemented by this file
using Provider_int64s = google::protobuf::RepeatedField<int64_t>;
using Provider_AttributeProto = ONNX_NAMESPACE::AttributeProto;
@ -46,7 +46,6 @@ using Provider_ValueInfoProtos = google::protobuf::RepeatedPtrField<ONNX_NAMESPA
using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef;
} // namespace onnxruntime
#define PROVIDER_BRIDGE_ORT
#include "core/common/cpuid_info.h"
#include "onnx/common/stl_backports.h"
#include "core/common/logging/logging.h"
@ -123,19 +122,6 @@ struct Node__EdgeIterator_Impl : Node__EdgeIterator {
Node::EdgeConstIterator v_;
};
struct OpKernel_Translator : OpKernel {
OpKernel_Translator(const OpKernelInfo& info, Provider_OpKernel* p) : OpKernel{info}, p_{p} {
}
Status Compute(OpKernelContext* context) const override {
return p_->Compute(context, *reinterpret_cast<const Provider_OpKernel_Base*>(static_cast<const OpKernel*>(this)));
}
std::unique_ptr<Provider_OpKernel> p_;
ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernel_Translator);
};
struct ProviderHostImpl : ProviderHost {
ProviderHostImpl() {
DataTypeImpl_GetType_Tensor = &DataTypeImpl::GetType<Tensor>;
@ -184,9 +170,13 @@ struct ProviderHostImpl : ProviderHost {
return const_cast<logging::Logger*>(&logging::LoggingManager::DefaultLogger());
}
const std::vector<MLDataType>& DataTypeImpl_AllFixedSizeTensorTypes() override {
return DataTypeImpl::AllFixedSizeTensorTypes();
}
int32_t PrimitiveDataTypeBase__GetDataType(const PrimitiveDataTypeBase* p) override { return p->GetDataType(); }
const char* DataTypeImpl__ToString(MLDataType type) override { return DataTypeImpl::ToString(type); }
const std::vector<MLDataType>& DataTypeImpl__AllFixedSizeTensorTypes() override { return DataTypeImpl::AllFixedSizeTensorTypes(); }
const std::vector<MLDataType>& DataTypeImpl__AllTensorTypes() override { return DataTypeImpl::AllTensorTypes(); }
size_t DataTypeImpl__Size(const DataTypeImpl* p) override { return p->Size(); }
const PrimitiveDataTypeBase* DataTypeImpl__AsPrimitiveDataType(const DataTypeImpl* p) override { return p->AsPrimitiveDataType(); }
void* HeapAllocate(size_t size) override { return new uint8_t[size]; }
void HeapFree(void* p) override { delete[] reinterpret_cast<uint8_t*>(p); }
@ -398,6 +388,7 @@ struct ProviderHostImpl : ProviderHost {
// KernelDef
void KernelDef__operator_delete(KernelDef* p) override { delete p; }
int KernelDef__ExecQueueId(const KernelDef* p) override { return p->ExecQueueId(); }
// KernelDefBuilder
std::unique_ptr<KernelDefBuilder> KernelDefBuilder__construct() override { return onnxruntime::make_unique<KernelDefBuilder>(); }
@ -418,13 +409,7 @@ struct ProviderHostImpl : ProviderHost {
// KernelRegistry
std::shared_ptr<KernelRegistry> KernelRegistry__construct() override { return std::make_shared<KernelRegistry>(); }
void KernelRegistry__operator_delete(KernelRegistry* p) override { delete p; }
Status KernelRegistry__Register(KernelRegistry* p, Provider_KernelCreateInfo&& create_info) override {
KernelCreateInfo info_real(std::move(create_info.kernel_def),
[kernel_create_func = create_info.kernel_create_func](const OpKernelInfo& info) -> OpKernel* {
return new OpKernel_Translator(info, kernel_create_func(info));
});
return p->Register(std::move(info_real));
}
Status KernelRegistry__Register(KernelRegistry* p, KernelCreateInfo&& create_info) override { return p->Register(std::move(create_info)); }
// Function
const Graph& Function__Body(const Function* p) override { return p->Body(); }
@ -543,23 +528,21 @@ struct ProviderHostImpl : ProviderHost {
const std::vector<NodeIndex>& GraphViewer__GetNodesInTopologicalOrder(const GraphViewer* p) override { return p->GetNodesInTopologicalOrder(); }
const std::vector<const NodeArg*>& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept override { return p->GetInputsIncludingInitializers(); }
//Path
//Gets a string representation of the path.
// Path
PathString Path__ToPathString(const Path* p) noexcept override { return p->ToPathString(); }
// Provider_OpKernel_Base
const OpKernelInfo& Provider_OpKernel_Base__GetInfo(const Provider_OpKernel_Base* p) override { return reinterpret_cast<const OpKernel*>(p)->Info(); }
// OpKernelContext
const Tensor* OpKernelContext__Input_Tensor(const OpKernelContext* p, int index) override { return p->Input<Tensor>(index); }
Tensor* OpKernelContext__Output(OpKernelContext* p, int index, const TensorShape& shape) override { return p->Output(index, shape); }
// OpKernelInfo
std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info) override { return onnxruntime::CopyOpKernelInfo(info); }
void OpKernelInfo__operator_delete(OpKernelInfo* p) override { delete p; }
Status OpKernelInfo__GetAttr_int64(const OpKernelInfo* p, const std::string& name, int64_t* value) override { return p->GetAttr(name, value); }
Status OpKernelInfo__GetAttr_float(const OpKernelInfo* p, const std::string& name, float* value) override { return p->GetAttr(name, value); }
const DataTransferManager& OpKernelInfo__GetDataTransferManager(const OpKernelInfo* p) noexcept override { return p->GetDataTransferManager(); }
int OpKernelInfo__GetKernelDef_ExecQueueId(const OpKernelInfo* p) noexcept override { return p->GetKernelDef().ExecQueueId(); }
const KernelDef& OpKernelInfo__GetKernelDef(const OpKernelInfo* p) override { return p->GetKernelDef(); }
// Tensor
float* Tensor__MutableData_float(Tensor* p) override { return p->MutableData<float>(); }
@ -570,7 +553,7 @@ struct ProviderHostImpl : ProviderHost {
const TensorShape& Tensor__Shape(const Tensor* p) override { return p->Shape(); }
size_t Tensor__SizeInBytes(const Tensor* p) override { return p->SizeInBytes(); }
const OrtMemoryInfo& Tensor__Location(const Tensor* p) override { return p->Location(); }
const OrtMemoryInfo& Tensor__Location(const Tensor* p) override { return p->Location(); }
// AllocatorManager
void AllocatorManager__InsertAllocator(AllocatorManager* p, AllocatorPtr allocator) override { p->InsertAllocator(allocator); }

View file

@ -42,7 +42,7 @@ namespace ort_dnnl {
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDnnlExecutionProvider, kOnnxDomain, 7, Gemm);
Status RegisterDNNLKernels(KernelRegistry& kernel_registry) {
static const Provider_BuildKernelCreateInfoFn function_table[] = {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kDnnlExecutionProvider, kOnnxDomain, 7, Gemm)>,
};

View file

@ -6,6 +6,6 @@
namespace onnxruntime {
namespace ort_dnnl {
template <typename T>
Provider_KernelCreateInfo BuildKernelCreateInfo();
KernelCreateInfo BuildKernelCreateInfo();
}
} // namespace onnxruntime

View file

@ -92,7 +92,7 @@ template <typename T>
using ConstEigenMatrixMapRowMajor = Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
template <>
Status Gemm<float>::Compute(OpKernelContext* ctx, const Provider_OpKernel_Base&) const {
Status Gemm<float>::Compute(OpKernelContext* ctx) const {
const auto X = ctx->Input<Tensor>(0);
const auto W = ctx->Input<Tensor>(1);
const auto B = ctx->Input<Tensor>(2);

View file

@ -6,9 +6,9 @@
namespace onnxruntime {
namespace ort_dnnl {
template <typename T>
class Gemm final : public Provider_OpKernel {
class Gemm final : public OpKernel {
public:
Gemm(const OpKernelInfo& info) {
Gemm(const OpKernelInfo& info) : OpKernel(info) {
int64_t temp;
ORT_ENFORCE(info.GetAttr<int64_t>("transA", &temp).IsOK());
trans_A_ = (temp != 0);
@ -20,7 +20,7 @@ class Gemm final : public Provider_OpKernel {
ORT_ENFORCE(info.GetAttr<float>("beta", &beta_).IsOK());
}
Status Compute(OpKernelContext* context, const Provider_OpKernel_Base& base) const override;
Status Compute(OpKernelContext* context) const override;
private:
bool trans_A_;

View file

@ -7,7 +7,7 @@
// switching providers to be runnable as shared libraries. The interfaces will become more tightly integrated into the core code.
#pragma once
#define PROVIDER_BRIDGE_PROVIDER 1
#define SHARED_PROVIDER 1
#include <vector>
#include <string>
@ -16,9 +16,11 @@
#include "onnx/common/stl_backports.h"
#include "core/common/common.h"
#include "core/common/const_pointer_container.h"
#include "core/common/type_list.h"
#include "core/common/logging/severity.h"
#include "core/framework/allocator.h"
#include "core/framework/allocatormgr.h"
#include "core/framework/float16.h"
#include "core/framework/tensor_shape.h"
#include "core/providers/providers.h"
#include "core/common/path_string.h"
@ -67,6 +69,7 @@ struct DataTransferManager;
struct IDataTransfer;
struct IndexedSubGraph;
struct IndexedSubGraph_MetaDef;
struct KernelCreateInfo;
struct KernelDef;
struct KernelDefBuilder;
struct KernelRegistry;
@ -80,7 +83,11 @@ struct NodeArg;
struct NodeAttributes;
struct OpKernelContext;
struct OpKernelInfo;
struct PrimitiveDataTypeBase;
struct Tensor;
class DataTypeImpl;
using MLDataType = const DataTypeImpl*;
} // namespace onnxruntime
namespace ONNX_NAMESPACE {
@ -147,6 +154,8 @@ enum OperatorStatus : int {
#include "core/framework/execution_provider.h"
#include "provider_interfaces.h"
#include "core/framework/op_kernel.h"
#include "core/framework/data_types_internal.h"
namespace onnxruntime {
@ -185,18 +194,6 @@ enum CUDAStreamType : int {
kTotalCudaStreams,
};
class DataTypeImpl {
public:
virtual ~DataTypeImpl() = default;
template <typename T>
static MLDataType GetType();
template <typename elemT>
static MLDataType GetTensorType();
static const std::vector<MLDataType>& AllFixedSizeTensorTypes();
};
template <typename T>
using IAllocatorUniquePtr = std::unique_ptr<T, std::function<void(T*)>>;
@ -234,23 +231,6 @@ constexpr T roundUpPow2(T a) {
} // namespace onnxruntime
#define ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name) \
provider##_##name##_##domain##_ver##ver
#define ONNX_OPERATOR_KERNEL_EX(name, domain, ver, provider, builder, ...) \
class ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name); \
template <> \
Provider_KernelCreateInfo \
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name)>() { \
return Provider_KernelCreateInfo( \
builder.SetName(#name) \
.SetDomain(domain) \
.SinceVersion(ver) \
.Provider(provider) \
.Build(), \
static_cast<Provider_KernelCreatePtrFn>([](const OpKernelInfo& info) -> Provider_OpKernel* { return new __VA_ARGS__(info); })); \
}
#define CREATE_MESSAGE(logger, severity, category, datatype) \
::onnxruntime::logging::Capture::Create(logger, ::onnxruntime::logging::Severity::k##severity, category, datatype, ORT_WHERE)

View file

@ -63,10 +63,6 @@ MLDataType DataTypeImpl::GetTensorType<float>() {
return g_host->DataTypeImpl_GetTensorType_float();
}
const std::vector<MLDataType>& DataTypeImpl::AllFixedSizeTensorTypes() {
return g_host->DataTypeImpl_AllFixedSizeTensorTypes();
}
TensorShape::TensorShape(const int64_t* dimension_sizes, size_t dimension_count)
: std::vector<int64_t>(dimension_count) {
for (size_t i = 0; i < dimension_count; ++i) {
@ -221,4 +217,8 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status,
return g_host->LogRuntimeError(session_id, status, file, function, line);
}
std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info) {
return g_host->CopyOpKernelInfo(info);
}
} // namespace onnxruntime

View file

@ -27,8 +27,6 @@ using DataType = const std::string*;
namespace onnxruntime {
// These types don't directly map to internal types
struct Provider_KernelCreateInfo;
struct Provider_OpKernel_Base;
struct ProviderHost;
class TensorShape;
@ -66,19 +64,6 @@ struct Provider_TensorShapeProto_Dimension_Iterator {
virtual const Provider_TensorShapeProto_Dimension& operator*() = 0;
};
class DataTypeImpl;
using MLDataType = const DataTypeImpl*;
struct Provider_OpKernel {
Provider_OpKernel() {}
virtual ~Provider_OpKernel() = default;
virtual Status Compute(OpKernelContext* context, const Provider_OpKernel_Base& base) const = 0;
Provider_OpKernel(const Provider_OpKernel&) = delete;
void operator=(const Provider_OpKernel&) = delete;
};
using NodeIndex = size_t;
using Provider_NodeArgInfo = Provider_ValueInfoProto;
// We can't just reinterpret_cast this one, since it's an unordered_map of object BY VALUE (can't do anything by value on the real types)
@ -142,10 +127,19 @@ struct ProviderHost {
virtual std::string GetEnvironmentVar(const std::string& var_name) = 0;
// PrimitiveDataTypeBase
virtual int32_t PrimitiveDataTypeBase__GetDataType(const PrimitiveDataTypeBase* p) = 0;
// DataTypeImpl
MLDataType (*DataTypeImpl_GetType_Tensor)();
MLDataType (*DataTypeImpl_GetType_float)();
MLDataType (*DataTypeImpl_GetTensorType_float)();
virtual const std::vector<MLDataType>& DataTypeImpl_AllFixedSizeTensorTypes() = 0;
virtual const char* DataTypeImpl__ToString(MLDataType type) = 0;
virtual const std::vector<MLDataType>& DataTypeImpl__AllFixedSizeTensorTypes() = 0;
virtual const std::vector<MLDataType>& DataTypeImpl__AllTensorTypes() = 0;
virtual size_t DataTypeImpl__Size(const DataTypeImpl* p) = 0;
virtual const PrimitiveDataTypeBase* DataTypeImpl__AsPrimitiveDataType(const DataTypeImpl* p) = 0;
virtual void* HeapAllocate(size_t size) = 0;
virtual void HeapFree(void*) = 0;
@ -335,6 +329,7 @@ struct ProviderHost {
// KernelDef
virtual void KernelDef__operator_delete(KernelDef* p) = 0;
virtual int KernelDef__ExecQueueId(const KernelDef* p) = 0;
// KernelDefBuilder
virtual std::unique_ptr<KernelDefBuilder> KernelDefBuilder__construct() = 0;
@ -355,7 +350,7 @@ struct ProviderHost {
// KernelRegistry
virtual std::shared_ptr<KernelRegistry> KernelRegistry__construct() = 0;
virtual void KernelRegistry__operator_delete(KernelRegistry* p) = 0;
virtual Status KernelRegistry__Register(KernelRegistry* p, Provider_KernelCreateInfo&& create_info) = 0;
virtual Status KernelRegistry__Register(KernelRegistry* p, KernelCreateInfo&& create_info) = 0;
// Function
virtual const Graph& Function__Body(const Function* p) = 0;
@ -465,19 +460,18 @@ struct ProviderHost {
// Path
virtual PathString Path__ToPathString(const Path* p) noexcept = 0;
// Provider_OpKernel_Base
virtual const OpKernelInfo& Provider_OpKernel_Base__GetInfo(const Provider_OpKernel_Base* p) = 0;
// OpKernelContext
virtual const Tensor* OpKernelContext__Input_Tensor(const OpKernelContext* p, int index) = 0;
virtual Tensor* OpKernelContext__Output(OpKernelContext* p, int index, const TensorShape& shape) = 0;
// OpKernelInfo
virtual std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info) = 0;
virtual void OpKernelInfo__operator_delete(OpKernelInfo* p) = 0;
virtual Status OpKernelInfo__GetAttr_int64(const OpKernelInfo* p, const std::string& name, int64_t* value) = 0;
virtual Status OpKernelInfo__GetAttr_float(const OpKernelInfo* p, const std::string& name, float* value) = 0;
virtual const DataTransferManager& OpKernelInfo__GetDataTransferManager(const OpKernelInfo* p) noexcept = 0;
virtual int OpKernelInfo__GetKernelDef_ExecQueueId(const OpKernelInfo* p) noexcept = 0;
virtual const KernelDef& OpKernelInfo__GetKernelDef(const OpKernelInfo* p) = 0;
// Tensor
virtual float* Tensor__MutableData_float(Tensor* p) = 0;
@ -497,7 +491,7 @@ struct ProviderHost {
extern ProviderHost* g_host;
#ifndef PROVIDER_BRIDGE_ORT
#ifdef SHARED_PROVIDER
struct CPUIDInfo {
static const CPUIDInfo& GetCPUIDInfo() { return g_host->CPUIDInfo__GetCPUIDInfo(); }
@ -767,32 +761,17 @@ struct IndexedSubGraph {
struct KernelDef {
static void operator delete(void* p) { g_host->KernelDef__operator_delete(reinterpret_cast<KernelDef*>(p)); }
int ExecQueueId() const { return g_host->KernelDef__ExecQueueId(this); }
KernelDef() = delete;
KernelDef(const KernelDef*) = delete;
void operator=(const KernelDef&) = delete;
};
#endif
using Provider_KernelCreateFn = std::function<Provider_OpKernel*(const OpKernelInfo& info)>;
using Provider_KernelCreatePtrFn = std::add_pointer<Provider_OpKernel*(const OpKernelInfo& info)>::type;
using BuildKernelCreateInfoFn = KernelCreateInfo (*)();
struct Provider_KernelCreateInfo {
std::unique_ptr<KernelDef> kernel_def; // Owned and stored in the global kernel registry.
Provider_KernelCreateFn kernel_create_func;
Provider_KernelCreateInfo(std::unique_ptr<KernelDef> definition,
Provider_KernelCreateFn create_func)
: kernel_def(std::move(definition)),
kernel_create_func(create_func) {}
Provider_KernelCreateInfo(Provider_KernelCreateInfo&& other) noexcept
: kernel_def(std::move(other.kernel_def)),
kernel_create_func(std::move(other.kernel_create_func)) {}
};
using Provider_BuildKernelCreateInfoFn = Provider_KernelCreateInfo (*)();
#ifndef PROVIDER_BRIDGE_ORT
#ifdef SHARED_PROVIDER
struct KernelDefBuilder {
static std::unique_ptr<KernelDefBuilder> Create() { return g_host->KernelDefBuilder__construct(); }
static void operator delete(void* p) { g_host->KernelDefBuilder__operator_delete(reinterpret_cast<KernelDefBuilder*>(p)); }
@ -845,13 +824,38 @@ struct KernelRegistry {
static std::shared_ptr<KernelRegistry> Create() { return g_host->KernelRegistry__construct(); }
static void operator delete(void* p) { g_host->KernelRegistry__operator_delete(reinterpret_cast<KernelRegistry*>(p)); }
Status Register(Provider_KernelCreateInfo&& create_info) { return g_host->KernelRegistry__Register(this, std::move(create_info)); }
Status Register(KernelCreateInfo&& create_info) { return g_host->KernelRegistry__Register(this, std::move(create_info)); }
KernelRegistry() = delete;
KernelRegistry(const KernelRegistry&) = delete;
void operator=(const KernelRegistry&) = delete;
};
struct PrimitiveDataTypeBase {
int32_t GetDataType() const { return g_host->PrimitiveDataTypeBase__GetDataType(this); }
PROVIDER_DISALLOW_ALL(PrimitiveDataTypeBase)
};
class DataTypeImpl {
public:
size_t Size() const { return g_host->DataTypeImpl__Size(this); }
template <typename T>
static MLDataType GetType();
template <typename elemT>
static MLDataType GetTensorType();
static const std::vector<MLDataType>& AllFixedSizeTensorTypes() { return g_host->DataTypeImpl__AllFixedSizeTensorTypes(); }
static const std::vector<MLDataType>& AllTensorTypes() { return g_host->DataTypeImpl__AllTensorTypes(); }
const PrimitiveDataTypeBase* AsPrimitiveDataType() const { return g_host->DataTypeImpl__AsPrimitiveDataType(this); }
static const char* ToString(MLDataType type) { return g_host->DataTypeImpl__ToString(type); }
PROVIDER_DISALLOW_ALL(DataTypeImpl)
};
struct Function {
const Graph& Body() const { return g_host->Function__Body(this); }
@ -1023,13 +1027,7 @@ struct Path {
#endif
struct Provider_OpKernel_Base {
const OpKernelInfo& GetInfo() const { return g_host->Provider_OpKernel_Base__GetInfo(this); }
PROVIDER_DISALLOW_ALL(Provider_OpKernel_Base)
};
#ifndef PROVIDER_BRIDGE_ORT
#ifdef SHARED_PROVIDER
struct OpKernelContext {
const Tensor* Input_Tensor(int index) const { return g_host->OpKernelContext__Input_Tensor(this, index); }
@ -1047,6 +1045,8 @@ inline const Tensor* OpKernelContext::Input<Tensor>(int index) const {
}
struct OpKernelInfo {
static void operator delete(void* p) { g_host->OpKernelInfo__operator_delete(reinterpret_cast<OpKernelInfo*>(p)); }
template <typename T>
Status GetAttr(const std::string& name, T* value) const;
@ -1054,9 +1054,11 @@ struct OpKernelInfo {
Status GetAttr(const std::string& name, float* value) const { return g_host->OpKernelInfo__GetAttr_float(this, name, value); }
const DataTransferManager& GetDataTransferManager() const noexcept { return g_host->OpKernelInfo__GetDataTransferManager(this); }
int GetKernelDef_ExecQueueId() const noexcept { return g_host->OpKernelInfo__GetKernelDef_ExecQueueId(this); }
const KernelDef& GetKernelDef() const { return g_host->OpKernelInfo__GetKernelDef(this); }
PROVIDER_DISALLOW_ALL(OpKernelInfo)
OpKernelInfo() = delete;
OpKernelInfo(const OpKernelInfo&) = delete;
void operator=(const OpKernelInfo&) = delete;
};
template <>

View file

@ -307,20 +307,20 @@ bool CudaCall<cudaError, true>(cudaError retCode, const char* exprString, const
return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg);
}
class Memcpy final : public Provider_OpKernel {
class Memcpy final : public OpKernel {
public:
Memcpy(const OpKernelInfo&) {}
Memcpy(const OpKernelInfo& info) : OpKernel(info) {}
Status Compute(OpKernelContext* ctx, const Provider_OpKernel_Base& base) const override {
Status Compute(OpKernelContext* ctx) const override {
const auto* X = ctx->Input<Tensor>(0);
Tensor* Y = ctx->Output(0, X->Shape());
Status retval = base.GetInfo().GetDataTransferManager().CopyTensor(*X, *Y, base.GetInfo().GetKernelDef_ExecQueueId());
Status retval = Info().GetDataTransferManager().CopyTensor(*X, *Y, Info().GetKernelDef().ExecQueueId());
return retval;
}
};
template <typename T>
Provider_KernelCreateInfo BuildKernelCreateInfo();
KernelCreateInfo BuildKernelCreateInfo();
ONNX_OPERATOR_KERNEL_EX(
MemcpyFromHost,
@ -348,7 +348,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kTensorrtExecutionProvider, kOnnxDomain, 1
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kTensorrtExecutionProvider, kOnnxDomain, 1, MemcpyToHost);
static Status RegisterTensorrtKernels(KernelRegistry& kernel_registry) {
static const Provider_BuildKernelCreateInfoFn function_table[] = {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kTensorrtExecutionProvider, kOnnxDomain, 1, MemcpyFromHost)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kTensorrtExecutionProvider, kOnnxDomain, 1, MemcpyToHost)>,
};
@ -463,12 +463,12 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
if (engine_decryption_enable_) {
std::string engine_decryption_lib_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDecryptionLibPath);
LIBTYPE handle = OPENLIB(engine_decryption_lib_path.c_str());
if (handle == nullptr) {
ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not open shared library from " + engine_decryption_lib_path);
}
engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt");
LIBTYPE handle = OPENLIB(engine_decryption_lib_path.c_str());
if (handle == nullptr) {
ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not open shared library from " + engine_decryption_lib_path);
}
engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt");
}
}
@ -1289,7 +1289,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
}
trt_context = trt_state->context->get();
} else if (trt_state->engine_decryption_enable && !engine_file && profile_file) {
} else if (trt_state->engine_decryption_enable && !engine_file && profile_file) {
shape_ranges = DeserializeProfile(profile_file);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
// Decrypt engine