diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index c71376d316..1157493651 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -13,6 +13,7 @@ #include "core/common/common.h" #include "core/common/exceptions.h" #include "core/framework/endian.h" +#include "core/framework/float16.h" #include "core/graph/onnx_protobuf.h" struct OrtValue; @@ -51,93 +52,6 @@ class SequenceTensorTypeBase; class NonTensorTypeBase; class PrimitiveDataTypeBase; -// MLFloat16 -struct MLFloat16 { - uint16_t val; - - MLFloat16() : val(0) {} - explicit MLFloat16(uint16_t x) : val(x) {} - explicit MLFloat16(float f); - - float ToFloat() const; - - operator float() const { - return ToFloat(); - } -}; - -inline bool operator==(const MLFloat16& left, const MLFloat16& right) { - return left.val == right.val; -} - -inline bool operator!=(const MLFloat16& left, const MLFloat16& right) { - return left.val != right.val; -} - -inline bool operator<(const MLFloat16& left, const MLFloat16& right) { - return left.val < right.val; -} - -//BFloat16 -struct BFloat16 { - uint16_t val{0}; - explicit BFloat16() = default; - explicit BFloat16(uint16_t v) : val(v) {} - explicit BFloat16(float v) { - if (endian::native == endian::little) { - std::memcpy(&val, reinterpret_cast(&v) + sizeof(uint16_t), sizeof(uint16_t)); - } else { - std::memcpy(&val, &v, sizeof(uint16_t)); - } - } - - float ToFloat() const { - float result; - char* const first = reinterpret_cast(&result); - char* const second = first + sizeof(uint16_t); - if (endian::native == endian::little) { - std::memset(first, 0, sizeof(uint16_t)); - std::memcpy(second, &val, sizeof(uint16_t)); - } else { - std::memcpy(first, &val, sizeof(uint16_t)); - std::memset(second, 0, sizeof(uint16_t)); - } - return result; - } - - operator float() const { - return ToFloat(); - } -}; - -inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) { - auto src = blf; - auto d = flt; - for (; size != 0; ++src, ++d, --size) { - *d = src->ToFloat(); - } -} - -inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) { - auto src = flt; - auto d = blf; - for (; size != 0; ++src, ++d, --size) { - new (d) BFloat16(*src); - } -} - -inline bool operator==(const BFloat16& left, const BFloat16& right) { - return left.val == right.val; -} - -inline bool operator!=(const BFloat16& left, const BFloat16& right) { - return left.val != right.val; -} - -inline bool operator<(const BFloat16& left, const BFloat16& right) { - return left.val < right.val; -} - // DataTypeImpl pointer as unique DataTypeImpl identifier. using MLDataType = const DataTypeImpl*; // be used with class MLValue diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index c1ce7581d2..83dc5c2210 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -13,9 +13,11 @@ #include "boost/mp11.hpp" #include "core/common/common.h" +#ifndef SHARED_PROVIDER #include "core/common/type_list.h" #include "core/framework/data_types.h" #include "core/graph/onnx_protobuf.h" +#endif namespace onnxruntime { namespace utils { diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index a2454997bc..02c0622b41 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -3,7 +3,7 @@ #pragma once -#ifndef PROVIDER_BRIDGE_PROVIDER +#ifndef SHARED_PROVIDER #include #include diff --git a/include/onnxruntime/core/framework/float16.h b/include/onnxruntime/core/framework/float16.h new file mode 100644 index 0000000000..9f846fdc85 --- /dev/null +++ b/include/onnxruntime/core/framework/float16.h @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "endian.h" + +namespace onnxruntime +{ + +// MLFloat16 +struct MLFloat16 { + uint16_t val; + + MLFloat16() : val(0) {} + explicit MLFloat16(uint16_t x) : val(x) {} + explicit MLFloat16(float f); + + float ToFloat() const; + + operator float() const { + return ToFloat(); + } +}; + +inline bool operator==(const MLFloat16& left, const MLFloat16& right) { + return left.val == right.val; +} + +inline bool operator!=(const MLFloat16& left, const MLFloat16& right) { + return left.val != right.val; +} + +inline bool operator<(const MLFloat16& left, const MLFloat16& right) { + return left.val < right.val; +} + +//BFloat16 +struct BFloat16 { + uint16_t val{0}; + explicit BFloat16() = default; + explicit BFloat16(uint16_t v) : val(v) {} + explicit BFloat16(float v) { + if (endian::native == endian::little) { + std::memcpy(&val, reinterpret_cast(&v) + sizeof(uint16_t), sizeof(uint16_t)); + } else { + std::memcpy(&val, &v, sizeof(uint16_t)); + } + } + + float ToFloat() const { + float result; + char* const first = reinterpret_cast(&result); + char* const second = first + sizeof(uint16_t); + if (endian::native == endian::little) { + std::memset(first, 0, sizeof(uint16_t)); + std::memcpy(second, &val, sizeof(uint16_t)); + } else { + std::memcpy(first, &val, sizeof(uint16_t)); + std::memset(second, 0, sizeof(uint16_t)); + } + return result; + } + + operator float() const { + return ToFloat(); + } +}; + +inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) { + auto src = blf; + auto d = flt; + for (; size != 0; ++src, ++d, --size) { + *d = src->ToFloat(); + } +} + +inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) { + auto src = flt; + auto d = blf; + for (; size != 0; ++src, ++d, --size) { + new (d) BFloat16(*src); + } +} + +inline bool operator==(const BFloat16& left, const BFloat16& right) { + return left.val == right.val; +} + +inline bool operator!=(const BFloat16& left, const BFloat16& right) { + return left.val != right.val; +} + +inline bool operator<(const BFloat16& left, const BFloat16& right) { + return left.val < right.val; +} + +} \ No newline at end of file diff --git a/include/onnxruntime/core/framework/kernel_registry.h b/include/onnxruntime/core/framework/kernel_registry.h index 0542183d6d..33dbb76c5e 100644 --- a/include/onnxruntime/core/framework/kernel_registry.h +++ b/include/onnxruntime/core/framework/kernel_registry.h @@ -6,6 +6,9 @@ #include "core/framework/op_kernel.h" namespace onnxruntime { + +using KernelCreateMap = std::multimap; + /** * Each provider has a KernelRegistry. Often, the KernelRegistry only belongs to that specific provider. * diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index afd215df44..fdce63cad0 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -3,10 +3,10 @@ #pragma once -#include - #include "boost/mp11.hpp" +#ifndef SHARED_PROVIDER +#include #include "core/common/exceptions.h" #include "core/common/logging/logging.h" #include "core/common/status.h" @@ -21,29 +21,24 @@ #include "core/graph/graph_viewer.h" #include "core/graph/onnx_protobuf.h" #include "gsl/gsl" +namespace onnxruntime { +class OpKernelContext; +} +#endif namespace onnxruntime { -class IExecutionFrame; -class OpKernelContext; -class OpKernelWrapper; -namespace concurrency { -class ThreadPool; -} + +std::unique_ptr CopyOpKernelInfo(const OpKernelInfo& info); class OpKernel { public: using DoneCallback = std::function; - explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_(info) {} + explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_(CopyOpKernelInfo(info)) {} virtual ~OpKernel() = default; - const onnxruntime::Node& Node() const { - return op_kernel_info_.node(); - } - - const onnxruntime::KernelDef& KernelDef() const { - return op_kernel_info_.GetKernelDef(); - } + const onnxruntime::Node& Node() const; + const onnxruntime::KernelDef& KernelDef() const; virtual Status Compute(_Inout_ OpKernelContext* context) const ORT_MUST_USE_RESULT = 0; @@ -73,219 +68,14 @@ class OpKernel { return Status::OK(); } - const OrtMemoryInfo& Allocator(int id, OrtMemType mem_type) const { - return op_kernel_info_.GetMemoryInfo(id, mem_type); - } - - const OpKernelInfo& Info() const { return op_kernel_info_; } + const OrtMemoryInfo& Allocator(int id, OrtMemType mem_type) const; + const OpKernelInfo& Info() const { return *op_kernel_info_; } private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernel); - OpKernelInfo op_kernel_info_; + std::unique_ptr op_kernel_info_; }; -class OpKernelContext { - public: - using ArgMap = std::unordered_map; - - OpKernelContext(_Inout_ IExecutionFrame* frame, _In_ const OpKernel* kernel, - _In_opt_ concurrency::ThreadPool* threadpool, _In_ const logging::Logger& logger); - - virtual ~OpKernelContext() = default; - - /** - Return the number of inputs for a variadic argument. - @param arg_num The operator argument number. - @returns Number of inputs the argument has. - */ - int NumVariadicInputs(size_t arg_num) const; - - MLDataType InputType(int index) const; - MLDataType OutputType(int index) const; - - template - const T* Input(int index) const { - const OrtValue* p_ml_value = GetInputMLValue(index); - ORT_TRY { - return p_ml_value ? &(p_ml_value->Get()) : nullptr; - } - ORT_CATCH(const std::exception& /*e*/) { - ORT_THROW("Missing Input: " + kernel_->Node().InputDefs()[index]->Name()); - } - } - - // Fetch a required input, enforcing that it is present. - template - const T& RequiredInput(int index) const { - const T* input_ptr = Input(index); - ORT_ENFORCE(input_ptr, "Required input at index ", index, " is not present."); - return *input_ptr; - } - - // Fetch output (non-tensor) with specified index. - template - T* Output(int index) { - if (index < 0 || index >= OutputCount()) - return nullptr; - - OrtValue* p_ml_value = GetOrCreateOutputMLValue(index); - return p_ml_value ? p_ml_value->GetMutable() : nullptr; - } - - // In the case that memory allocation has not been done for an output tensor, - // The memory allocation will be done on-the-fly with given tensor shape. - // Return nullptr if the output is an unused optional output. - Tensor* Output(int index, const TensorShape& shape); - Tensor* Output(int index, const std::vector& shape); - Tensor* Output(int index, const std::initializer_list& shape); - - // Fetch a required tensor output, enforcing that it is present. - Tensor& RequiredOutput(int index, const TensorShape& shape) { - Tensor* output_ptr = Output(index, shape); - ORT_ENFORCE(output_ptr, "Required output at index ", index, " is not present."); - return *output_ptr; - } - - // Fetch a sparse-tensor output corresponding to the specified index. - // num_values must specify the number of non-zero values (commonly known as NNZ/nnz), - // and shape must specify the shape of the underlying dense-tensor. - // Memory allocation for the output may happen when this method is invoked, - // unless static optimization pre-allocates it. - SparseTensor* Output(int index, size_t num_values, const TensorShape& shape); - - // Retrieve indexed shape obtained from memory planning before actual - // computation. If the indexed shape cannot be inferred, this function returns - // false. - bool TryGetInferredInputShape(int index, TensorShape& shape) const; - - // Retrieve indexed shape obtained from memory planning before actual - // computation. If the indexed shape cannot be inferred, this function returns - // false. - bool TryGetInferredOutputShape(int index, TensorShape& shape) const; - - const logging::Logger& Logger() const { - return *logger_; - } - - // always >= 0 - int InputCount() const { - return static_cast(kernel_->Node().InputDefs().size()); - } - - // always >= 0 - int ImplicitInputCount() const { - return static_cast(kernel_->Node().ImplicitInputDefs().size()); - } - - // always >= 0 - int OutputCount() const { - return static_cast(kernel_->Node().OutputDefs().size()); - } - - /** - Return an allocator on device 0, with memtype of OrtMemTypeDefault. - @remarks Use SafeInt when calculating the size of memory to allocate using AllocatorPtr->Alloc. - */ - Status GetTempSpaceAllocator(AllocatorPtr* output) const ORT_MUST_USE_RESULT; - - /** - Return the fence of current node's input. - @param index The index of the input. - @returns Point to the Fence of the input OrtValue. - It is null if the input OrtValue doesn't have fence or the input is optional. - */ - Fence_t InputFence(int index) const; - - /** - Return the fence of current node's implicit input. - @param index The index of the implicit input. - @returns Point to the Fence of the implicit input OrtValue. - It is null if the input OrtValue doesn't have fence or the input is optional. - */ - Fence_t ImplicitInputFence(int index) const; - - /** - Return the fence of current node's output identifed by index. - @param index The index of the output. - @returns Point to the Fence of the output OrtValue. - It is null if the output OrtValue doesn't have fence or the output is optional. - */ - Fence_t OutputFence(int index) const; - - /** - Return the device id that current kernel runs on. - */ - int GetDeviceId() const { - return kernel_->Info().GetExecutionProvider()->GetDeviceId(); - } - - /** - Returns the opset domain of the underlying kernel - **/ - const std::string& GetOpDomain() const; - - /** - Returns the optype of the underlying kernel - **/ - const std::string& GetOpType() const; - - /** - Returns the node name of the underlying kernel - **/ - const std::string& GetNodeName() const; - - /** - Returns the intra-op threadpool, if available. - */ - _Ret_maybenull_ onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() const { return threadpool_; } - - /** - Returns whether deterministic computation is preferred. - */ - virtual bool GetUseDeterministicCompute() const { - return true; - } - - protected: - onnxruntime::NodeIndex GetNodeIndex() const; - - const OrtValue* GetInputMLValue(int index) const; - const OrtValue* GetImplicitInputMLValue(int index) const; - OrtValue* GetOutputMLValue(int index); - - // Creates the OrtValue* based on the shape, if it does not exist - // The parameter nnz is used only for sparse-tensors and indicates the - // number of non-zero values (the number of elements in the values buffer allocated). - OrtValue* OutputMLValue(int index, const TensorShape& shape, size_t nnz = 0); - - private: - ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernelContext); - - OrtValue* GetOrCreateOutputMLValue(int index); - - int GetInputArgIndex(int index) const; - int GetImplicitInputArgIndex(int index) const; - int GetOutputArgIndex(int index) const; - - IExecutionFrame* const execution_frame_; - const OpKernel* const kernel_; - concurrency::ThreadPool* const threadpool_; - const logging::Logger* const logger_; - - // The argument starting index in ExecutionFrame. - int node_input_start_index_{-1}; - int node_implicit_input_start_index_{-1}; - int node_output_start_index_{-1}; -}; - -// Fetching output tensor without shape is not allowed except when it already exists -template <> -inline Tensor* OpKernelContext::Output(int index) { - OrtValue* p_ml_value = GetOutputMLValue(index); - ORT_ENFORCE(p_ml_value, "Please fetch output tensor with specified shape."); - return p_ml_value->GetMutable(); -} - using KernelCreateFn = std::function; using KernelCreatePtrFn = std::add_pointer::type; @@ -306,8 +96,6 @@ struct KernelCreateInfo { KernelCreateInfo() = default; }; -using KernelCreateMap = std::multimap; - // Forward declarations for the non-specialized BuildKernelCreateInfo method. template KernelCreateInfo BuildKernelCreateInfo(); @@ -504,3 +292,7 @@ std::vector BuildKernelDefConstraintsFromTypeList() { } } // namespace onnxruntime + +#ifndef SHARED_PROVIDER +#include "core/framework/op_kernel_context.h" +#endif diff --git a/include/onnxruntime/core/framework/op_kernel_context.h b/include/onnxruntime/core/framework/op_kernel_context.h new file mode 100644 index 0000000000..833f6f9185 --- /dev/null +++ b/include/onnxruntime/core/framework/op_kernel_context.h @@ -0,0 +1,212 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace onnxruntime { +class IExecutionFrame; +namespace concurrency { +class ThreadPool; +} + +class OpKernelContext { + public: + using ArgMap = std::unordered_map; + + OpKernelContext(_Inout_ IExecutionFrame* frame, _In_ const OpKernel* kernel, + _In_opt_ concurrency::ThreadPool* threadpool, _In_ const logging::Logger& logger); + + virtual ~OpKernelContext() = default; + + /** + Return the number of inputs for a variadic argument. + @param arg_num The operator argument number. + @returns Number of inputs the argument has. + */ + int NumVariadicInputs(size_t arg_num) const; + + MLDataType InputType(int index) const; + MLDataType OutputType(int index) const; + + template + const T* Input(int index) const { + const OrtValue* p_ml_value = GetInputMLValue(index); + ORT_TRY { + return p_ml_value ? &(p_ml_value->Get()) : nullptr; + } + ORT_CATCH(const std::exception& /*e*/) { + ORT_THROW("Missing Input: " + kernel_->Node().InputDefs()[index]->Name()); + } + } + + // Fetch a required input, enforcing that it is present. + template + const T& RequiredInput(int index) const { + const T* input_ptr = Input(index); + ORT_ENFORCE(input_ptr, "Required input at index ", index, " is not present."); + return *input_ptr; + } + + // Fetch output (non-tensor) with specified index. + template + T* Output(int index) { + if (index < 0 || index >= OutputCount()) + return nullptr; + + OrtValue* p_ml_value = GetOrCreateOutputMLValue(index); + return p_ml_value ? p_ml_value->GetMutable() : nullptr; + } + + // In the case that memory allocation has not been done for an output tensor, + // The memory allocation will be done on-the-fly with given tensor shape. + // Return nullptr if the output is an unused optional output. + Tensor* Output(int index, const TensorShape& shape); + Tensor* Output(int index, const std::vector& shape); + Tensor* Output(int index, const std::initializer_list& shape); + + // Fetch a required tensor output, enforcing that it is present. + Tensor& RequiredOutput(int index, const TensorShape& shape) { + Tensor* output_ptr = Output(index, shape); + ORT_ENFORCE(output_ptr, "Required output at index ", index, " is not present."); + return *output_ptr; + } + + // Fetch a sparse-tensor output corresponding to the specified index. + // num_values must specify the number of non-zero values (commonly known as NNZ/nnz), + // and shape must specify the shape of the underlying dense-tensor. + // Memory allocation for the output may happen when this method is invoked, + // unless static optimization pre-allocates it. + SparseTensor* Output(int index, size_t num_values, const TensorShape& shape); + + // Retrieve indexed shape obtained from memory planning before actual + // computation. If the indexed shape cannot be inferred, this function returns + // false. + bool TryGetInferredInputShape(int index, TensorShape& shape) const; + + // Retrieve indexed shape obtained from memory planning before actual + // computation. If the indexed shape cannot be inferred, this function returns + // false. + bool TryGetInferredOutputShape(int index, TensorShape& shape) const; + + const logging::Logger& Logger() const { + return *logger_; + } + + // always >= 0 + int InputCount() const { + return static_cast(kernel_->Node().InputDefs().size()); + } + + // always >= 0 + int ImplicitInputCount() const { + return static_cast(kernel_->Node().ImplicitInputDefs().size()); + } + + // always >= 0 + int OutputCount() const { + return static_cast(kernel_->Node().OutputDefs().size()); + } + + /** + Return an allocator on device 0, with memtype of OrtMemTypeDefault. + @remarks Use SafeInt when calculating the size of memory to allocate using AllocatorPtr->Alloc. + */ + Status GetTempSpaceAllocator(AllocatorPtr* output) const ORT_MUST_USE_RESULT; + + /** + Return the fence of current node's input. + @param index The index of the input. + @returns Point to the Fence of the input OrtValue. + It is null if the input OrtValue doesn't have fence or the input is optional. + */ + Fence_t InputFence(int index) const; + + /** + Return the fence of current node's implicit input. + @param index The index of the implicit input. + @returns Point to the Fence of the implicit input OrtValue. + It is null if the input OrtValue doesn't have fence or the input is optional. + */ + Fence_t ImplicitInputFence(int index) const; + + /** + Return the fence of current node's output identifed by index. + @param index The index of the output. + @returns Point to the Fence of the output OrtValue. + It is null if the output OrtValue doesn't have fence or the output is optional. + */ + Fence_t OutputFence(int index) const; + + /** + Return the device id that current kernel runs on. + */ + int GetDeviceId() const { + return kernel_->Info().GetExecutionProvider()->GetDeviceId(); + } + + /** + Returns the opset domain of the underlying kernel + **/ + const std::string& GetOpDomain() const; + + /** + Returns the optype of the underlying kernel + **/ + const std::string& GetOpType() const; + + /** + Returns the node name of the underlying kernel + **/ + const std::string& GetNodeName() const; + + /** + Returns the intra-op threadpool, if available. + */ + _Ret_maybenull_ onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() const { return threadpool_; } + + /** + Returns whether deterministic computation is preferred. + */ + virtual bool GetUseDeterministicCompute() const { + return true; + } + + protected: + onnxruntime::NodeIndex GetNodeIndex() const; + + const OrtValue* GetInputMLValue(int index) const; + const OrtValue* GetImplicitInputMLValue(int index) const; + OrtValue* GetOutputMLValue(int index); + + // Creates the OrtValue* based on the shape, if it does not exist + // The parameter nnz is used only for sparse-tensors and indicates the + // number of non-zero values (the number of elements in the values buffer allocated). + OrtValue* OutputMLValue(int index, const TensorShape& shape, size_t nnz = 0); + + private: + ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernelContext); + + OrtValue* GetOrCreateOutputMLValue(int index); + + int GetInputArgIndex(int index) const; + int GetImplicitInputArgIndex(int index) const; + int GetOutputArgIndex(int index) const; + + IExecutionFrame* const execution_frame_; + const OpKernel* const kernel_; + concurrency::ThreadPool* const threadpool_; + const logging::Logger* const logger_; + + // The argument starting index in ExecutionFrame. + int node_input_start_index_{-1}; + int node_implicit_input_start_index_{-1}; + int node_output_start_index_{-1}; +}; + +// Fetching output tensor without shape is not allowed except when it already exists +template <> +inline Tensor* OpKernelContext::Output(int index) { + OrtValue* p_ml_value = GetOutputMLValue(index); + ORT_ENFORCE(p_ml_value, "Please fetch output tensor with specified shape."); + return p_ml_value->GetMutable(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/op_kernel.cc b/onnxruntime/core/framework/op_kernel.cc index 6c7ddf1abf..fca87d371c 100644 --- a/onnxruntime/core/framework/op_kernel.cc +++ b/onnxruntime/core/framework/op_kernel.cc @@ -9,6 +9,22 @@ using namespace ::onnxruntime::common; namespace onnxruntime { +std::unique_ptr CopyOpKernelInfo(const OpKernelInfo& info) { + return onnxruntime::make_unique(info); +} + +const onnxruntime::Node& OpKernel::Node() const { + return op_kernel_info_->node(); +} + +const onnxruntime::KernelDef& OpKernel::KernelDef() const { + return op_kernel_info_->GetKernelDef(); +} + +const OrtMemoryInfo& OpKernel::Allocator(int id, OrtMemType mem_type) const { + return op_kernel_info_->GetMemoryInfo(id, mem_type); +} + OpKernelContext::OpKernelContext(_Inout_ IExecutionFrame* frame, _In_ const OpKernel* kernel, _In_opt_ concurrency::ThreadPool* threadpool, _In_ const logging::Logger& logger) : execution_frame_(frame), kernel_(kernel), threadpool_(threadpool), logger_(&logger) { diff --git a/onnxruntime/core/framework/provider_bridge_ort.cc b/onnxruntime/core/framework/provider_bridge_ort.cc index a12b5989d7..14135ccf5e 100644 --- a/onnxruntime/core/framework/provider_bridge_ort.cc +++ b/onnxruntime/core/framework/provider_bridge_ort.cc @@ -26,7 +26,7 @@ #endif namespace onnxruntime { -// These Provider types are really just internal types, so we #define PROVIDER_BRIDGE_ORT so that only these definitions are seen by provider_interfaces.h +// These Provider types are really just internal types, since we don't include provider_api.h only these definitions are seen by provider_interfaces.h // Users of provider_interfaces.h (through provider_api.h) will see the wrappers that call through the provider shared interface which is implemented by this file using Provider_int64s = google::protobuf::RepeatedField; using Provider_AttributeProto = ONNX_NAMESPACE::AttributeProto; @@ -46,7 +46,6 @@ using Provider_ValueInfoProtos = google::protobuf::RepeatedPtrFieldCompute(context, *reinterpret_cast(static_cast(this))); - } - - std::unique_ptr p_; - - ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernel_Translator); -}; - struct ProviderHostImpl : ProviderHost { ProviderHostImpl() { DataTypeImpl_GetType_Tensor = &DataTypeImpl::GetType; @@ -184,9 +170,13 @@ struct ProviderHostImpl : ProviderHost { return const_cast(&logging::LoggingManager::DefaultLogger()); } - const std::vector& DataTypeImpl_AllFixedSizeTensorTypes() override { - return DataTypeImpl::AllFixedSizeTensorTypes(); - } + int32_t PrimitiveDataTypeBase__GetDataType(const PrimitiveDataTypeBase* p) override { return p->GetDataType(); } + + const char* DataTypeImpl__ToString(MLDataType type) override { return DataTypeImpl::ToString(type); } + const std::vector& DataTypeImpl__AllFixedSizeTensorTypes() override { return DataTypeImpl::AllFixedSizeTensorTypes(); } + const std::vector& DataTypeImpl__AllTensorTypes() override { return DataTypeImpl::AllTensorTypes(); } + size_t DataTypeImpl__Size(const DataTypeImpl* p) override { return p->Size(); } + const PrimitiveDataTypeBase* DataTypeImpl__AsPrimitiveDataType(const DataTypeImpl* p) override { return p->AsPrimitiveDataType(); } void* HeapAllocate(size_t size) override { return new uint8_t[size]; } void HeapFree(void* p) override { delete[] reinterpret_cast(p); } @@ -398,6 +388,7 @@ struct ProviderHostImpl : ProviderHost { // KernelDef void KernelDef__operator_delete(KernelDef* p) override { delete p; } + int KernelDef__ExecQueueId(const KernelDef* p) override { return p->ExecQueueId(); } // KernelDefBuilder std::unique_ptr KernelDefBuilder__construct() override { return onnxruntime::make_unique(); } @@ -418,13 +409,7 @@ struct ProviderHostImpl : ProviderHost { // KernelRegistry std::shared_ptr KernelRegistry__construct() override { return std::make_shared(); } void KernelRegistry__operator_delete(KernelRegistry* p) override { delete p; } - Status KernelRegistry__Register(KernelRegistry* p, Provider_KernelCreateInfo&& create_info) override { - KernelCreateInfo info_real(std::move(create_info.kernel_def), - [kernel_create_func = create_info.kernel_create_func](const OpKernelInfo& info) -> OpKernel* { - return new OpKernel_Translator(info, kernel_create_func(info)); - }); - return p->Register(std::move(info_real)); - } + Status KernelRegistry__Register(KernelRegistry* p, KernelCreateInfo&& create_info) override { return p->Register(std::move(create_info)); } // Function const Graph& Function__Body(const Function* p) override { return p->Body(); } @@ -543,23 +528,21 @@ struct ProviderHostImpl : ProviderHost { const std::vector& GraphViewer__GetNodesInTopologicalOrder(const GraphViewer* p) override { return p->GetNodesInTopologicalOrder(); } const std::vector& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept override { return p->GetInputsIncludingInitializers(); } - //Path - //Gets a string representation of the path. + // Path PathString Path__ToPathString(const Path* p) noexcept override { return p->ToPathString(); } - // Provider_OpKernel_Base - const OpKernelInfo& Provider_OpKernel_Base__GetInfo(const Provider_OpKernel_Base* p) override { return reinterpret_cast(p)->Info(); } - // OpKernelContext const Tensor* OpKernelContext__Input_Tensor(const OpKernelContext* p, int index) override { return p->Input(index); } Tensor* OpKernelContext__Output(OpKernelContext* p, int index, const TensorShape& shape) override { return p->Output(index, shape); } // OpKernelInfo + std::unique_ptr CopyOpKernelInfo(const OpKernelInfo& info) override { return onnxruntime::CopyOpKernelInfo(info); } + void OpKernelInfo__operator_delete(OpKernelInfo* p) override { delete p; } Status OpKernelInfo__GetAttr_int64(const OpKernelInfo* p, const std::string& name, int64_t* value) override { return p->GetAttr(name, value); } Status OpKernelInfo__GetAttr_float(const OpKernelInfo* p, const std::string& name, float* value) override { return p->GetAttr(name, value); } const DataTransferManager& OpKernelInfo__GetDataTransferManager(const OpKernelInfo* p) noexcept override { return p->GetDataTransferManager(); } - int OpKernelInfo__GetKernelDef_ExecQueueId(const OpKernelInfo* p) noexcept override { return p->GetKernelDef().ExecQueueId(); } + const KernelDef& OpKernelInfo__GetKernelDef(const OpKernelInfo* p) override { return p->GetKernelDef(); } // Tensor float* Tensor__MutableData_float(Tensor* p) override { return p->MutableData(); } @@ -570,7 +553,7 @@ struct ProviderHostImpl : ProviderHost { const TensorShape& Tensor__Shape(const Tensor* p) override { return p->Shape(); } size_t Tensor__SizeInBytes(const Tensor* p) override { return p->SizeInBytes(); } - const OrtMemoryInfo& Tensor__Location(const Tensor* p) override { return p->Location(); } + const OrtMemoryInfo& Tensor__Location(const Tensor* p) override { return p->Location(); } // AllocatorManager void AllocatorManager__InsertAllocator(AllocatorManager* p, AllocatorPtr allocator) override { p->InsertAllocator(allocator); } diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc index e2fa3f7801..6dca36d182 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc @@ -42,7 +42,7 @@ namespace ort_dnnl { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDnnlExecutionProvider, kOnnxDomain, 7, Gemm); Status RegisterDNNLKernels(KernelRegistry& kernel_registry) { - static const Provider_BuildKernelCreateInfoFn function_table[] = { + static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, }; diff --git a/onnxruntime/core/providers/dnnl/dnnl_fwd.h b/onnxruntime/core/providers/dnnl/dnnl_fwd.h index d0a42737fa..0541e1b372 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_fwd.h +++ b/onnxruntime/core/providers/dnnl/dnnl_fwd.h @@ -6,6 +6,6 @@ namespace onnxruntime { namespace ort_dnnl { template -Provider_KernelCreateInfo BuildKernelCreateInfo(); +KernelCreateInfo BuildKernelCreateInfo(); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/math/gemm.cc b/onnxruntime/core/providers/dnnl/math/gemm.cc index 9b95062473..1fd9621a2a 100644 --- a/onnxruntime/core/providers/dnnl/math/gemm.cc +++ b/onnxruntime/core/providers/dnnl/math/gemm.cc @@ -92,7 +92,7 @@ template using ConstEigenMatrixMapRowMajor = Eigen::Map>; template <> -Status Gemm::Compute(OpKernelContext* ctx, const Provider_OpKernel_Base&) const { +Status Gemm::Compute(OpKernelContext* ctx) const { const auto X = ctx->Input(0); const auto W = ctx->Input(1); const auto B = ctx->Input(2); diff --git a/onnxruntime/core/providers/dnnl/math/gemm.h b/onnxruntime/core/providers/dnnl/math/gemm.h index c3dd36d38c..6e51cb6c94 100644 --- a/onnxruntime/core/providers/dnnl/math/gemm.h +++ b/onnxruntime/core/providers/dnnl/math/gemm.h @@ -6,9 +6,9 @@ namespace onnxruntime { namespace ort_dnnl { template -class Gemm final : public Provider_OpKernel { +class Gemm final : public OpKernel { public: - Gemm(const OpKernelInfo& info) { + Gemm(const OpKernelInfo& info) : OpKernel(info) { int64_t temp; ORT_ENFORCE(info.GetAttr("transA", &temp).IsOK()); trans_A_ = (temp != 0); @@ -20,7 +20,7 @@ class Gemm final : public Provider_OpKernel { ORT_ENFORCE(info.GetAttr("beta", &beta_).IsOK()); } - Status Compute(OpKernelContext* context, const Provider_OpKernel_Base& base) const override; + Status Compute(OpKernelContext* context) const override; private: bool trans_A_; diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 2313eb02b8..f3e291fa53 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -7,7 +7,7 @@ // switching providers to be runnable as shared libraries. The interfaces will become more tightly integrated into the core code. #pragma once -#define PROVIDER_BRIDGE_PROVIDER 1 +#define SHARED_PROVIDER 1 #include #include @@ -16,9 +16,11 @@ #include "onnx/common/stl_backports.h" #include "core/common/common.h" #include "core/common/const_pointer_container.h" +#include "core/common/type_list.h" #include "core/common/logging/severity.h" #include "core/framework/allocator.h" #include "core/framework/allocatormgr.h" +#include "core/framework/float16.h" #include "core/framework/tensor_shape.h" #include "core/providers/providers.h" #include "core/common/path_string.h" @@ -67,6 +69,7 @@ struct DataTransferManager; struct IDataTransfer; struct IndexedSubGraph; struct IndexedSubGraph_MetaDef; +struct KernelCreateInfo; struct KernelDef; struct KernelDefBuilder; struct KernelRegistry; @@ -80,7 +83,11 @@ struct NodeArg; struct NodeAttributes; struct OpKernelContext; struct OpKernelInfo; +struct PrimitiveDataTypeBase; struct Tensor; + +class DataTypeImpl; +using MLDataType = const DataTypeImpl*; } // namespace onnxruntime namespace ONNX_NAMESPACE { @@ -147,6 +154,8 @@ enum OperatorStatus : int { #include "core/framework/execution_provider.h" #include "provider_interfaces.h" +#include "core/framework/op_kernel.h" +#include "core/framework/data_types_internal.h" namespace onnxruntime { @@ -185,18 +194,6 @@ enum CUDAStreamType : int { kTotalCudaStreams, }; -class DataTypeImpl { - public: - virtual ~DataTypeImpl() = default; - - template - static MLDataType GetType(); - template - static MLDataType GetTensorType(); - - static const std::vector& AllFixedSizeTensorTypes(); -}; - template using IAllocatorUniquePtr = std::unique_ptr>; @@ -234,23 +231,6 @@ constexpr T roundUpPow2(T a) { } // namespace onnxruntime -#define ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name) \ - provider##_##name##_##domain##_ver##ver - -#define ONNX_OPERATOR_KERNEL_EX(name, domain, ver, provider, builder, ...) \ - class ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name); \ - template <> \ - Provider_KernelCreateInfo \ - BuildKernelCreateInfo() { \ - return Provider_KernelCreateInfo( \ - builder.SetName(#name) \ - .SetDomain(domain) \ - .SinceVersion(ver) \ - .Provider(provider) \ - .Build(), \ - static_cast([](const OpKernelInfo& info) -> Provider_OpKernel* { return new __VA_ARGS__(info); })); \ - } - #define CREATE_MESSAGE(logger, severity, category, datatype) \ ::onnxruntime::logging::Capture::Create(logger, ::onnxruntime::logging::Severity::k##severity, category, datatype, ORT_WHERE) diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index fbb061472c..51b77aba09 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -63,10 +63,6 @@ MLDataType DataTypeImpl::GetTensorType() { return g_host->DataTypeImpl_GetTensorType_float(); } -const std::vector& DataTypeImpl::AllFixedSizeTensorTypes() { - return g_host->DataTypeImpl_AllFixedSizeTensorTypes(); -} - TensorShape::TensorShape(const int64_t* dimension_sizes, size_t dimension_count) : std::vector(dimension_count) { for (size_t i = 0; i < dimension_count; ++i) { @@ -221,4 +217,8 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, return g_host->LogRuntimeError(session_id, status, file, function, line); } +std::unique_ptr CopyOpKernelInfo(const OpKernelInfo& info) { + return g_host->CopyOpKernelInfo(info); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 474d133420..80d3298757 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -27,8 +27,6 @@ using DataType = const std::string*; namespace onnxruntime { // These types don't directly map to internal types -struct Provider_KernelCreateInfo; -struct Provider_OpKernel_Base; struct ProviderHost; class TensorShape; @@ -66,19 +64,6 @@ struct Provider_TensorShapeProto_Dimension_Iterator { virtual const Provider_TensorShapeProto_Dimension& operator*() = 0; }; -class DataTypeImpl; -using MLDataType = const DataTypeImpl*; - -struct Provider_OpKernel { - Provider_OpKernel() {} - virtual ~Provider_OpKernel() = default; - - virtual Status Compute(OpKernelContext* context, const Provider_OpKernel_Base& base) const = 0; - - Provider_OpKernel(const Provider_OpKernel&) = delete; - void operator=(const Provider_OpKernel&) = delete; -}; - using NodeIndex = size_t; using Provider_NodeArgInfo = Provider_ValueInfoProto; // We can't just reinterpret_cast this one, since it's an unordered_map of object BY VALUE (can't do anything by value on the real types) @@ -142,10 +127,19 @@ struct ProviderHost { virtual std::string GetEnvironmentVar(const std::string& var_name) = 0; + // PrimitiveDataTypeBase + virtual int32_t PrimitiveDataTypeBase__GetDataType(const PrimitiveDataTypeBase* p) = 0; + + // DataTypeImpl MLDataType (*DataTypeImpl_GetType_Tensor)(); MLDataType (*DataTypeImpl_GetType_float)(); MLDataType (*DataTypeImpl_GetTensorType_float)(); - virtual const std::vector& DataTypeImpl_AllFixedSizeTensorTypes() = 0; + + virtual const char* DataTypeImpl__ToString(MLDataType type) = 0; + virtual const std::vector& DataTypeImpl__AllFixedSizeTensorTypes() = 0; + virtual const std::vector& DataTypeImpl__AllTensorTypes() = 0; + virtual size_t DataTypeImpl__Size(const DataTypeImpl* p) = 0; + virtual const PrimitiveDataTypeBase* DataTypeImpl__AsPrimitiveDataType(const DataTypeImpl* p) = 0; virtual void* HeapAllocate(size_t size) = 0; virtual void HeapFree(void*) = 0; @@ -335,6 +329,7 @@ struct ProviderHost { // KernelDef virtual void KernelDef__operator_delete(KernelDef* p) = 0; + virtual int KernelDef__ExecQueueId(const KernelDef* p) = 0; // KernelDefBuilder virtual std::unique_ptr KernelDefBuilder__construct() = 0; @@ -355,7 +350,7 @@ struct ProviderHost { // KernelRegistry virtual std::shared_ptr KernelRegistry__construct() = 0; virtual void KernelRegistry__operator_delete(KernelRegistry* p) = 0; - virtual Status KernelRegistry__Register(KernelRegistry* p, Provider_KernelCreateInfo&& create_info) = 0; + virtual Status KernelRegistry__Register(KernelRegistry* p, KernelCreateInfo&& create_info) = 0; // Function virtual const Graph& Function__Body(const Function* p) = 0; @@ -465,19 +460,18 @@ struct ProviderHost { // Path virtual PathString Path__ToPathString(const Path* p) noexcept = 0; - // Provider_OpKernel_Base - virtual const OpKernelInfo& Provider_OpKernel_Base__GetInfo(const Provider_OpKernel_Base* p) = 0; - // OpKernelContext virtual const Tensor* OpKernelContext__Input_Tensor(const OpKernelContext* p, int index) = 0; virtual Tensor* OpKernelContext__Output(OpKernelContext* p, int index, const TensorShape& shape) = 0; // OpKernelInfo + virtual std::unique_ptr CopyOpKernelInfo(const OpKernelInfo& info) = 0; + virtual void OpKernelInfo__operator_delete(OpKernelInfo* p) = 0; virtual Status OpKernelInfo__GetAttr_int64(const OpKernelInfo* p, const std::string& name, int64_t* value) = 0; virtual Status OpKernelInfo__GetAttr_float(const OpKernelInfo* p, const std::string& name, float* value) = 0; virtual const DataTransferManager& OpKernelInfo__GetDataTransferManager(const OpKernelInfo* p) noexcept = 0; - virtual int OpKernelInfo__GetKernelDef_ExecQueueId(const OpKernelInfo* p) noexcept = 0; + virtual const KernelDef& OpKernelInfo__GetKernelDef(const OpKernelInfo* p) = 0; // Tensor virtual float* Tensor__MutableData_float(Tensor* p) = 0; @@ -497,7 +491,7 @@ struct ProviderHost { extern ProviderHost* g_host; -#ifndef PROVIDER_BRIDGE_ORT +#ifdef SHARED_PROVIDER struct CPUIDInfo { static const CPUIDInfo& GetCPUIDInfo() { return g_host->CPUIDInfo__GetCPUIDInfo(); } @@ -767,32 +761,17 @@ struct IndexedSubGraph { struct KernelDef { static void operator delete(void* p) { g_host->KernelDef__operator_delete(reinterpret_cast(p)); } + int ExecQueueId() const { return g_host->KernelDef__ExecQueueId(this); } + KernelDef() = delete; KernelDef(const KernelDef*) = delete; void operator=(const KernelDef&) = delete; }; #endif -using Provider_KernelCreateFn = std::function; -using Provider_KernelCreatePtrFn = std::add_pointer::type; +using BuildKernelCreateInfoFn = KernelCreateInfo (*)(); -struct Provider_KernelCreateInfo { - std::unique_ptr kernel_def; // Owned and stored in the global kernel registry. - Provider_KernelCreateFn kernel_create_func; - - Provider_KernelCreateInfo(std::unique_ptr definition, - Provider_KernelCreateFn create_func) - : kernel_def(std::move(definition)), - kernel_create_func(create_func) {} - - Provider_KernelCreateInfo(Provider_KernelCreateInfo&& other) noexcept - : kernel_def(std::move(other.kernel_def)), - kernel_create_func(std::move(other.kernel_create_func)) {} -}; - -using Provider_BuildKernelCreateInfoFn = Provider_KernelCreateInfo (*)(); - -#ifndef PROVIDER_BRIDGE_ORT +#ifdef SHARED_PROVIDER struct KernelDefBuilder { static std::unique_ptr Create() { return g_host->KernelDefBuilder__construct(); } static void operator delete(void* p) { g_host->KernelDefBuilder__operator_delete(reinterpret_cast(p)); } @@ -845,13 +824,38 @@ struct KernelRegistry { static std::shared_ptr Create() { return g_host->KernelRegistry__construct(); } static void operator delete(void* p) { g_host->KernelRegistry__operator_delete(reinterpret_cast(p)); } - Status Register(Provider_KernelCreateInfo&& create_info) { return g_host->KernelRegistry__Register(this, std::move(create_info)); } + Status Register(KernelCreateInfo&& create_info) { return g_host->KernelRegistry__Register(this, std::move(create_info)); } KernelRegistry() = delete; KernelRegistry(const KernelRegistry&) = delete; void operator=(const KernelRegistry&) = delete; }; +struct PrimitiveDataTypeBase { + int32_t GetDataType() const { return g_host->PrimitiveDataTypeBase__GetDataType(this); } + + PROVIDER_DISALLOW_ALL(PrimitiveDataTypeBase) +}; + +class DataTypeImpl { + public: + size_t Size() const { return g_host->DataTypeImpl__Size(this); } + + template + static MLDataType GetType(); + template + static MLDataType GetTensorType(); + + static const std::vector& AllFixedSizeTensorTypes() { return g_host->DataTypeImpl__AllFixedSizeTensorTypes(); } + static const std::vector& AllTensorTypes() { return g_host->DataTypeImpl__AllTensorTypes(); } + + const PrimitiveDataTypeBase* AsPrimitiveDataType() const { return g_host->DataTypeImpl__AsPrimitiveDataType(this); } + + static const char* ToString(MLDataType type) { return g_host->DataTypeImpl__ToString(type); } + + PROVIDER_DISALLOW_ALL(DataTypeImpl) +}; + struct Function { const Graph& Body() const { return g_host->Function__Body(this); } @@ -1023,13 +1027,7 @@ struct Path { #endif -struct Provider_OpKernel_Base { - const OpKernelInfo& GetInfo() const { return g_host->Provider_OpKernel_Base__GetInfo(this); } - - PROVIDER_DISALLOW_ALL(Provider_OpKernel_Base) -}; - -#ifndef PROVIDER_BRIDGE_ORT +#ifdef SHARED_PROVIDER struct OpKernelContext { const Tensor* Input_Tensor(int index) const { return g_host->OpKernelContext__Input_Tensor(this, index); } @@ -1047,6 +1045,8 @@ inline const Tensor* OpKernelContext::Input(int index) const { } struct OpKernelInfo { + static void operator delete(void* p) { g_host->OpKernelInfo__operator_delete(reinterpret_cast(p)); } + template Status GetAttr(const std::string& name, T* value) const; @@ -1054,9 +1054,11 @@ struct OpKernelInfo { Status GetAttr(const std::string& name, float* value) const { return g_host->OpKernelInfo__GetAttr_float(this, name, value); } const DataTransferManager& GetDataTransferManager() const noexcept { return g_host->OpKernelInfo__GetDataTransferManager(this); } - int GetKernelDef_ExecQueueId() const noexcept { return g_host->OpKernelInfo__GetKernelDef_ExecQueueId(this); } + const KernelDef& GetKernelDef() const { return g_host->OpKernelInfo__GetKernelDef(this); } - PROVIDER_DISALLOW_ALL(OpKernelInfo) + OpKernelInfo() = delete; + OpKernelInfo(const OpKernelInfo&) = delete; + void operator=(const OpKernelInfo&) = delete; }; template <> diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 6faae63c49..3546609994 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -307,20 +307,20 @@ bool CudaCall(cudaError retCode, const char* exprString, const return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg); } -class Memcpy final : public Provider_OpKernel { +class Memcpy final : public OpKernel { public: - Memcpy(const OpKernelInfo&) {} + Memcpy(const OpKernelInfo& info) : OpKernel(info) {} - Status Compute(OpKernelContext* ctx, const Provider_OpKernel_Base& base) const override { + Status Compute(OpKernelContext* ctx) const override { const auto* X = ctx->Input(0); Tensor* Y = ctx->Output(0, X->Shape()); - Status retval = base.GetInfo().GetDataTransferManager().CopyTensor(*X, *Y, base.GetInfo().GetKernelDef_ExecQueueId()); + Status retval = Info().GetDataTransferManager().CopyTensor(*X, *Y, Info().GetKernelDef().ExecQueueId()); return retval; } }; template -Provider_KernelCreateInfo BuildKernelCreateInfo(); +KernelCreateInfo BuildKernelCreateInfo(); ONNX_OPERATOR_KERNEL_EX( MemcpyFromHost, @@ -348,7 +348,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kTensorrtExecutionProvider, kOnnxDomain, 1 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kTensorrtExecutionProvider, kOnnxDomain, 1, MemcpyToHost); static Status RegisterTensorrtKernels(KernelRegistry& kernel_registry) { - static const Provider_BuildKernelCreateInfoFn function_table[] = { + static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, BuildKernelCreateInfo, }; @@ -463,12 +463,12 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv if (engine_decryption_enable_) { std::string engine_decryption_lib_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDecryptionLibPath); - LIBTYPE handle = OPENLIB(engine_decryption_lib_path.c_str()); - if (handle == nullptr) { - ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not open shared library from " + engine_decryption_lib_path); - } - engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt"); + LIBTYPE handle = OPENLIB(engine_decryption_lib_path.c_str()); + if (handle == nullptr) { + ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not open shared library from " + engine_decryption_lib_path); + } + engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt"); } } @@ -1289,7 +1289,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); } trt_context = trt_state->context->get(); - } else if (trt_state->engine_decryption_enable && !engine_file && profile_file) { + } else if (trt_state->engine_decryption_enable && !engine_file && profile_file) { shape_ranges = DeserializeProfile(profile_file); LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; // Decrypt engine