mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Add support for custom ops to minimal build. (#6228)
* Add support for custom ops to minimal build. Cost is only ~8KB so including in base minimal build.
This commit is contained in:
parent
6507b4f818
commit
e1dc268e45
22 changed files with 470 additions and 216 deletions
|
|
@ -10,8 +10,6 @@ file(GLOB_RECURSE onnxruntime_framework_srcs CONFIGURE_DEPENDS
|
|||
if (onnxruntime_MINIMAL_BUILD)
|
||||
set(onnxruntime_framework_src_exclude
|
||||
"${ONNXRUNTIME_ROOT}/core/framework/provider_bridge_ort.cc"
|
||||
"${ONNXRUNTIME_INCLUDE_DIR}/core/framework/customregistry.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/framework/customregistry.cc"
|
||||
"${ONNXRUNTIME_ROOT}/core/framework/fallback_cpu_capability.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/framework/fallback_cpu_capability.cc"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -322,7 +322,10 @@ set (onnxruntime_shared_lib_test_SRC
|
|||
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_run_options.cc
|
||||
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_allocator.cc
|
||||
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_nontensor_types.cc
|
||||
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_loading.cc)
|
||||
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_loading.cc
|
||||
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_ort_format_models.cc
|
||||
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/utils.h
|
||||
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/utils.cc)
|
||||
|
||||
if (NOT onnxruntime_MINIMAL_BUILD)
|
||||
list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_inference.cc)
|
||||
|
|
|
|||
|
|
@ -5,11 +5,14 @@
|
|||
|
||||
#include "core/common/status.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/graph/schema_registry.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/kernel_def_builder.h"
|
||||
#include "core/framework/kernel_registry.h"
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
#include "core/graph/schema_registry.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
|
|
@ -17,9 +20,14 @@ namespace onnxruntime {
|
|||
*/
|
||||
class CustomRegistry final {
|
||||
public:
|
||||
CustomRegistry() :
|
||||
kernel_registry_(std::make_shared<KernelRegistry>()),
|
||||
opschema_registry_(std::make_shared<onnxruntime::OnnxRuntimeOpSchemaRegistry>()) {}
|
||||
CustomRegistry()
|
||||
: kernel_registry_(std::make_shared<KernelRegistry>())
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
,
|
||||
opschema_registry_(std::make_shared<onnxruntime::OnnxRuntimeOpSchemaRegistry>())
|
||||
#endif
|
||||
{
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a kernel definition together with kernel factory method to this session.
|
||||
|
|
@ -32,18 +40,21 @@ class CustomRegistry final {
|
|||
|
||||
common::Status RegisterCustomKernel(KernelCreateInfo&);
|
||||
|
||||
common::Status RegisterOpSet(std::vector<ONNX_NAMESPACE::OpSchema>& schemas, const std::string& domain,
|
||||
int baseline_opset_version, int opset_version);
|
||||
|
||||
const std::shared_ptr<KernelRegistry>& GetKernelRegistry();
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
common::Status RegisterOpSet(std::vector<ONNX_NAMESPACE::OpSchema>& schemas, const std::string& domain,
|
||||
int baseline_opset_version, int opset_version);
|
||||
|
||||
const std::shared_ptr<onnxruntime::OnnxRuntimeOpSchemaRegistry>& GetOpschemaRegistry();
|
||||
#endif
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomRegistry);
|
||||
std::shared_ptr<KernelRegistry> kernel_registry_;
|
||||
std::shared_ptr<onnxruntime::OnnxRuntimeOpSchemaRegistry> opschema_registry_;
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
std::shared_ptr<onnxruntime::OnnxRuntimeOpSchemaRegistry> opschema_registry_;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -654,7 +654,8 @@ class Graph {
|
|||
|
||||
/** Return true if "node_arg" is a input or an initializer. Otherwise, returns false. */
|
||||
bool IsInputsIncludingInitializers(const NodeArg* node_arg) const noexcept {
|
||||
return std::find(graph_inputs_including_initializers_.begin(), graph_inputs_including_initializers_.end(), node_arg) != graph_inputs_including_initializers_.end();
|
||||
return std::find(graph_inputs_including_initializers_.begin(),
|
||||
graph_inputs_including_initializers_.end(), node_arg) != graph_inputs_including_initializers_.end();
|
||||
}
|
||||
|
||||
/** Gets the Graph inputs that are initializers
|
||||
|
|
@ -1085,6 +1086,9 @@ class Graph {
|
|||
static common::Status LoadFromOrtFormat(
|
||||
const onnxruntime::experimental::fbs::Graph& fbs_graph, const Model& owning_model,
|
||||
const std::unordered_map<std::string, int>& domain_to_version,
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
|
||||
#endif
|
||||
const logging::Logger& logger, std::unique_ptr<Graph>& graph);
|
||||
|
||||
// deserialize a subgraph
|
||||
|
|
@ -1104,6 +1108,9 @@ class Graph {
|
|||
// Create empty Graph instance to re-create from ORT format serialized data.
|
||||
Graph(const Model& owning_model,
|
||||
const std::unordered_map<std::string, int>& domain_to_version,
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
|
||||
#endif
|
||||
Graph* parent_graph, const Node* parent_node,
|
||||
const logging::Logger& logger);
|
||||
|
||||
|
|
|
|||
|
|
@ -12,21 +12,22 @@ common::Status CustomRegistry::RegisterCustomKernel(KernelCreateInfo& create_inf
|
|||
return kernel_registry_->Register(std::move(create_info));
|
||||
}
|
||||
|
||||
const std::shared_ptr<KernelRegistry>& CustomRegistry::GetKernelRegistry() {
|
||||
return kernel_registry_;
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
common::Status CustomRegistry::RegisterOpSet(
|
||||
std::vector<ONNX_NAMESPACE::OpSchema>& schemas,
|
||||
const std::string& domain,
|
||||
int baseline_opset_version,
|
||||
int opset_version) {
|
||||
|
||||
return opschema_registry_->RegisterOpSet(schemas, domain, baseline_opset_version, opset_version);
|
||||
}
|
||||
|
||||
const std::shared_ptr<KernelRegistry>& CustomRegistry::GetKernelRegistry() {
|
||||
return kernel_registry_;
|
||||
}
|
||||
|
||||
const std::shared_ptr<onnxruntime::OnnxRuntimeOpSchemaRegistry>& CustomRegistry::GetOpschemaRegistry() {
|
||||
return opschema_registry_;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -77,19 +77,22 @@ void IExecutionProvider::InsertAllocator(AllocatorPtr allocator) {
|
|||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& /*fused_node*/,
|
||||
std::vector<NodeComputeInfo>& /*node_compute_funcs*/) {
|
||||
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
|
||||
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED,
|
||||
"IExecutionProvider::Compile with fused Node is not implemented by " + type_);
|
||||
}
|
||||
|
||||
common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& /*fused_node*/,
|
||||
std::string& /*dll_path*/) {
|
||||
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
|
||||
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED,
|
||||
"IExecutionProvider::Compile with fused Node and dll path is not implemented by " + type_);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
common::Status IExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& /*fused_nodes_and_graphs*/,
|
||||
std::vector<NodeComputeInfo>& /*node_compute_funcs*/) {
|
||||
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
|
||||
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED,
|
||||
"IExecutionProvider::Compile with FusedNodeAndGraph is not implemented by " + type_);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -417,6 +417,10 @@ static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr,
|
|||
std::vector<std::unique_ptr<ComputeCapability>> capabilities =
|
||||
current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(type));
|
||||
|
||||
if (capabilities.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// storage for the GraphViewer for each IndexedSubGraph
|
||||
std::vector<std::unique_ptr<GraphViewer>> viewers;
|
||||
viewers.reserve(capabilities.size());
|
||||
|
|
|
|||
|
|
@ -45,14 +45,12 @@ Status KernelRegistryManager::RegisterKernels(const ExecutionProviders& executio
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
void KernelRegistryManager::RegisterKernelRegistry(std::shared_ptr<KernelRegistry> kernel_registry) {
|
||||
if (nullptr == kernel_registry) {
|
||||
return;
|
||||
}
|
||||
custom_kernel_registries_.push_front(kernel_registry);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type) {
|
||||
|
|
@ -86,14 +84,12 @@ Status KernelRegistryManager::SearchKernelRegistry(const onnxruntime::Node& node
|
|||
return Status(ONNXRUNTIME, FAIL, create_error_message("The node is not placed on any Execution Provider. "));
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
for (auto& registry : custom_kernel_registries_) {
|
||||
status = registry->TryFindKernel(node, std::string(), kernel_def_hash, kernel_create_info);
|
||||
if (status.IsOK()) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
KernelRegistry* p = nullptr;
|
||||
auto iter = provider_type_to_registry_.find(ptype);
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ class KernelRegistryManager {
|
|||
// Register kernels from providers
|
||||
Status RegisterKernels(const ExecutionProviders& execution_providers) ORT_MUST_USE_RESULT;
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
// The registry passed in this function has highest priority than anything already in this KernelRegistryManager,
|
||||
// and anything registered from RegisterKernels
|
||||
// For example, if you do:
|
||||
|
|
@ -42,6 +41,7 @@ class KernelRegistryManager {
|
|||
// Then B > A > providers
|
||||
void RegisterKernelRegistry(std::shared_ptr<KernelRegistry> kernel_registry);
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
/**
|
||||
* Search kernel registry by provider type.
|
||||
* @param type provider type string
|
||||
|
|
|
|||
|
|
@ -3619,13 +3619,19 @@ std::ostream& operator<<(std::ostream& out, const Graph& graph) {
|
|||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
Status Graph::LoadFromOrtFormat(
|
||||
const onnxruntime::experimental::fbs::Graph& fbs_graph,
|
||||
const Model& owning_model,
|
||||
const std::unordered_map<std::string, int>& domain_to_version,
|
||||
const logging::Logger& logger, std::unique_ptr<Graph>& graph) {
|
||||
Status Graph::LoadFromOrtFormat(const onnxruntime::experimental::fbs::Graph& fbs_graph,
|
||||
const Model& owning_model,
|
||||
const std::unordered_map<std::string, int>& domain_to_version,
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
|
||||
#endif
|
||||
const logging::Logger& logger, std::unique_ptr<Graph>& graph) {
|
||||
// can't use make_unique as we're calling a private ctor
|
||||
graph.reset(new Graph(owning_model, domain_to_version, nullptr, nullptr, logger));
|
||||
graph.reset(new Graph(owning_model, domain_to_version,
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
schema_registry,
|
||||
#endif
|
||||
nullptr, nullptr, logger));
|
||||
|
||||
ORT_RETURN_IF_ERROR(graph->LoadFromOrtFormat(fbs_graph));
|
||||
|
||||
|
|
@ -3636,8 +3642,6 @@ Status Graph::LoadFromOrtFormat(
|
|||
// and in InferenceSession::Initialize skip partitioning and running optimizers.
|
||||
graph->SetGraphResolveNeeded();
|
||||
ORT_RETURN_IF_ERROR(graph->Resolve());
|
||||
#else
|
||||
// probably nothing required here. validate with model that has nested subgraphs.
|
||||
#endif
|
||||
|
||||
return Status::OK();
|
||||
|
|
@ -3648,7 +3652,11 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::experimental::fbs::Graph& fbs
|
|||
const logging::Logger& logger, std::unique_ptr<Graph>& graph) {
|
||||
// can't use make_unique as we're calling a private ctor
|
||||
graph.reset(new Graph(parent_graph.owning_model_,
|
||||
parent_graph.domain_to_version_, &parent_graph, &parent_node,
|
||||
parent_graph.domain_to_version_,
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
parent_graph.schema_registry_,
|
||||
#endif
|
||||
&parent_graph, &parent_node,
|
||||
logger));
|
||||
|
||||
return graph->LoadFromOrtFormat(fbs_graph);
|
||||
|
|
@ -3656,12 +3664,15 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::experimental::fbs::Graph& fbs
|
|||
|
||||
Graph::Graph(const Model& owning_model,
|
||||
const std::unordered_map<std::string, int>& domain_to_version,
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
|
||||
#endif
|
||||
Graph* parent_graph, const Node* parent_node,
|
||||
const logging::Logger& logger)
|
||||
: owning_model_(owning_model),
|
||||
graph_proto_(&deserialized_proto_data_),
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
schema_registry_(std::make_shared<SchemaRegistryManager>()),
|
||||
schema_registry_(schema_registry),
|
||||
#endif
|
||||
domain_to_version_(domain_to_version),
|
||||
ir_version_(owning_model.IrVersion()),
|
||||
|
|
|
|||
|
|
@ -93,7 +93,8 @@ Model::Model(const ModelProto& model_proto, const PathString& model_path,
|
|||
: Model(ModelProto(model_proto), model_path, local_registries, logger) {
|
||||
}
|
||||
|
||||
Model::Model(ModelProto&& model_proto, const PathString& model_path, const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
Model::Model(ModelProto&& model_proto, const PathString& model_path,
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
const logging::Logger& logger)
|
||||
: model_path_(Path::Parse(model_path)) {
|
||||
if (!utils::HasGraph(model_proto)) {
|
||||
|
|
@ -618,6 +619,9 @@ Model::Model() : model_path_{} {
|
|||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
#endif
|
||||
const logging::Logger& logger,
|
||||
std::unique_ptr<Model>& model) {
|
||||
model.reset(new Model());
|
||||
|
|
@ -632,6 +636,13 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,
|
|||
}
|
||||
model->model_proto_.set_model_version(fbs_model.model_version());
|
||||
model->model_proto_.set_ir_version(fbs_model.ir_version());
|
||||
|
||||
auto schema_registry = std::make_shared<SchemaRegistryManager>();
|
||||
if (local_registries != nullptr) {
|
||||
for (const auto& schema_collection : *local_registries) {
|
||||
schema_registry->RegisterRegistry(schema_collection);
|
||||
}
|
||||
}
|
||||
#else
|
||||
experimental::utils::LoadStringFromOrtFormat(model->producer_name_, fbs_model.producer_name());
|
||||
experimental::utils::LoadStringFromOrtFormat(model->producer_version_, fbs_model.producer_version());
|
||||
|
|
@ -648,10 +659,14 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,
|
|||
auto fbs_graph = fbs_model.graph();
|
||||
ORT_RETURN_IF(nullptr == fbs_graph, "Graph is null. Invalid ORT format model.");
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version, schema_registry, logger,
|
||||
model->graph_));
|
||||
#else
|
||||
ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version, logger, model->graph_));
|
||||
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
#endif // defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -239,6 +239,9 @@ class Model {
|
|||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
static common::Status LoadFromOrtFormat(const onnxruntime::experimental::fbs::Model& fbs_model,
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
#endif
|
||||
const logging::Logger& logger,
|
||||
std::unique_ptr<Model>& model);
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -4,16 +4,18 @@
|
|||
#ifdef _WIN32
|
||||
#pragma warning(disable : 4267)
|
||||
#endif
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "core/session/inference_session.h"
|
||||
#include "core/session/ort_apis.h"
|
||||
|
||||
#include "core/framework/customregistry.h"
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/framework/op_kernel_info.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
#include "core/framework/error_code_helper.h"
|
||||
#include "core/framework/tensor_type_and_shape.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "core/session/inference_session.h"
|
||||
#include "core/session/ort_apis.h"
|
||||
|
||||
ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(const onnxruntime::DataTypeImpl* cpp_type);
|
||||
// ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(const onnxruntime::DataTypeImpl* cpp_type);
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) {
|
||||
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<float>(name, out);
|
||||
|
|
@ -67,23 +69,22 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_string, _In_ const OrtKernel
|
|||
return onnxruntime::ToOrtStatus(status);
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
#include "core/framework/customregistry.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
struct CustomOpKernel : OpKernel {
|
||||
CustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) {
|
||||
if (op_.version > ORT_API_VERSION)
|
||||
if (op_.version > ORT_API_VERSION) {
|
||||
ORT_THROW("Unsupported version '" + std::to_string(op_.version) + "' in custom op '" + op.GetName(&op));
|
||||
op_kernel_ = op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version), reinterpret_cast<const OrtKernelInfo*>(&info));
|
||||
}
|
||||
|
||||
op_kernel_ = op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version),
|
||||
reinterpret_cast<const OrtKernelInfo*>(&info));
|
||||
}
|
||||
|
||||
~CustomOpKernel() override { op_.KernelDestroy(op_kernel_); }
|
||||
|
||||
Status Compute(OpKernelContext* ctx) const override {
|
||||
auto* ictx = static_cast<OpKernelContextInternal*>(ctx);
|
||||
op_.KernelCompute(op_kernel_, reinterpret_cast<OrtKernelContext*>(ictx));
|
||||
op_.KernelCompute(op_kernel_, reinterpret_cast<OrtKernelContext*>(ctx));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -94,12 +95,17 @@ struct CustomOpKernel : OpKernel {
|
|||
void* op_kernel_;
|
||||
};
|
||||
|
||||
common::Status CreateCustomRegistry(const std::vector<OrtCustomOpDomain*>& op_domains, std::shared_ptr<CustomRegistry>& output) {
|
||||
common::Status CreateCustomRegistry(const std::vector<OrtCustomOpDomain*>& op_domains,
|
||||
std::shared_ptr<CustomRegistry>& output) {
|
||||
output = std::make_shared<CustomRegistry>();
|
||||
for (auto& domain : op_domains) {
|
||||
|
||||
for (const auto& domain : op_domains) {
|
||||
// Create an OpSchema for each op and register them
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
// Domain is not empty - add it to the DomainToVersion ONNX map
|
||||
// If domain is empty, it is assumed to be part of the ONNX domain
|
||||
if (domain->domain_[0]) {
|
||||
if (!domain->domain_.empty()) {
|
||||
// Add it to the DomainToVersion ONNX map if it doesn't already exist
|
||||
// For example, two sessions using the same session_options should not add the same custom op domain to the version map twice
|
||||
auto& domain_to_version_range_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance();
|
||||
|
|
@ -111,24 +117,25 @@ common::Status CreateCustomRegistry(const std::vector<OrtCustomOpDomain*>& op_do
|
|||
}
|
||||
|
||||
std::vector<ONNX_NAMESPACE::OpSchema> schemas_list;
|
||||
|
||||
for (auto& op : domain->custom_ops_) {
|
||||
ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "unknown", 0);
|
||||
for (const auto* op : domain->custom_ops_) {
|
||||
ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "custom op registered at runtime", 0);
|
||||
|
||||
auto input_count = op->GetInputTypeCount(op);
|
||||
for (size_t i = 0; i < input_count; i++) {
|
||||
auto type = op->GetInputType(op, i);
|
||||
schema.Input(i, "A", "Description",
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type ? "T" :
|
||||
DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type)));
|
||||
schema.Input(i, "Input" + std::to_string(i), "",
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type
|
||||
? "T"
|
||||
: DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type)));
|
||||
}
|
||||
|
||||
auto output_count = op->GetOutputTypeCount(op);
|
||||
for (size_t i = 0; i < output_count; i++) {
|
||||
auto type = op->GetOutputType(op, i);
|
||||
schema.Output(i, "A", "Description",
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type ? "T":
|
||||
DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type)));
|
||||
schema.Output(i, "Output" + std::to_string(i), "",
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type
|
||||
? "T"
|
||||
: DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type)));
|
||||
}
|
||||
|
||||
schema.TypeConstraint("T", DataTypeImpl::ToString(DataTypeImpl::AllTensorTypes()), "all types");
|
||||
|
|
@ -136,30 +143,38 @@ common::Status CreateCustomRegistry(const std::vector<OrtCustomOpDomain*>& op_do
|
|||
schema.SinceVersion(1);
|
||||
schema.AllowUncheckedAttributes();
|
||||
schemas_list.push_back(schema);
|
||||
|
||||
KernelDefBuilder def_builder;
|
||||
def_builder.SetName(op->GetName(op))
|
||||
.SetDomain(domain->domain_)
|
||||
.SinceVersion(1)
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes());
|
||||
|
||||
if (const char* provider_type = op->GetExecutionProviderType(op))
|
||||
def_builder.Provider(provider_type);
|
||||
else
|
||||
def_builder.Provider(onnxruntime::kCpuExecutionProvider);
|
||||
KernelCreateFn kernel_create_fn = [&op](const OpKernelInfo& info) -> OpKernel* { return new CustomOpKernel(info, *op); };
|
||||
KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn);
|
||||
output->RegisterCustomKernel(create_info);
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(output->RegisterOpSet(schemas_list,
|
||||
domain->domain_,
|
||||
1 /* baseline opset version */,
|
||||
1000 /* opset version */));
|
||||
|
||||
#endif
|
||||
// create the KernelDef for each op and register it
|
||||
for (const auto* op : domain->custom_ops_) {
|
||||
KernelDefBuilder def_builder;
|
||||
def_builder.SetName(op->GetName(op))
|
||||
.SetDomain(domain->domain_)
|
||||
.SinceVersion(1)
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes());
|
||||
|
||||
if (const char* provider_type = op->GetExecutionProviderType(op)) {
|
||||
def_builder.Provider(provider_type);
|
||||
} else {
|
||||
def_builder.Provider(onnxruntime::kCpuExecutionProvider);
|
||||
}
|
||||
|
||||
KernelCreateFn kernel_create_fn = [op](const OpKernelInfo& info) -> OpKernel* {
|
||||
return new CustomOpKernel(info, *op);
|
||||
};
|
||||
|
||||
KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn);
|
||||
ORT_RETURN_IF_ERROR(output->RegisterCustomKernel(create_info));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
#include "core/common/denormal.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/framework/allocatormgr.h"
|
||||
#include "core/framework/customregistry.h"
|
||||
#include "core/framework/error_code_helper.h"
|
||||
#include "core/framework/execution_frame.h"
|
||||
#include "core/framework/feeds_fetches_manager.h"
|
||||
|
|
@ -43,6 +44,7 @@
|
|||
#ifdef USE_DML // TODO: This is necessary for the workaround in TransformGraph
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h"
|
||||
#endif
|
||||
#include "core/session/custom_ops.h"
|
||||
#include "core/session/environment.h"
|
||||
#include "core/session/IOBinding.h"
|
||||
#include "core/session/inference_session_utils.h"
|
||||
|
|
@ -50,10 +52,6 @@
|
|||
#include "core/util/protobuf_parsing_utils.h"
|
||||
#include "core/util/thread_utils.h"
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
#include "core/framework/customregistry.h"
|
||||
#include "core/session/custom_ops.h"
|
||||
#endif
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::experimental;
|
||||
|
|
@ -468,6 +466,7 @@ common::Status InferenceSession::AddCustomTransformerList(const std::vector<std:
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
common::Status InferenceSession::AddCustomOpDomains(const std::vector<OrtCustomOpDomain*>& op_domains) {
|
||||
std::shared_ptr<CustomRegistry> custom_registry;
|
||||
|
|
@ -486,10 +485,13 @@ common::Status InferenceSession::RegisterCustomRegistry(std::shared_ptr<CustomRe
|
|||
// Insert session-level customized kernel registry.
|
||||
kernel_registry_manager_.RegisterKernelRegistry(custom_registry->GetKernelRegistry());
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
custom_schema_registries_.push_back(custom_registry->GetOpschemaRegistry());
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
common::Status InferenceSession::SaveToOrtFormat(const std::basic_string<ORTCHAR_T>& filepath) const {
|
||||
ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ort format only supports little-edian machines");
|
||||
|
||||
|
|
@ -1016,7 +1018,15 @@ Status InferenceSession::LoadOrtModel(std::function<Status()> load_ort_format_mo
|
|||
|
||||
// need to go from unique_ptr to shared_ptr when moving into model_
|
||||
std::unique_ptr<Model> tmp_model;
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model,
|
||||
HasLocalSchema() ? &custom_schema_registries_ : nullptr,
|
||||
*session_logger_, tmp_model));
|
||||
|
||||
#else
|
||||
ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model, *session_logger_, tmp_model));
|
||||
#endif
|
||||
|
||||
ORT_RETURN_IF_ERROR(SaveModelMetadata(*tmp_model));
|
||||
model_ = std::move(tmp_model);
|
||||
|
||||
|
|
|
|||
|
|
@ -184,6 +184,7 @@ class InferenceSession {
|
|||
*/
|
||||
common::Status AddCustomTransformerList(const std::vector<std::string>& transformers_to_enable) ORT_MUST_USE_RESULT;
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
/**
|
||||
* Add custom ops. This API is not thread safe.
|
||||
*/
|
||||
|
|
@ -200,8 +201,6 @@ class InferenceSession {
|
|||
*/
|
||||
common::Status RegisterCustomRegistry(std::shared_ptr<CustomRegistry> custom_registry) ORT_MUST_USE_RESULT;
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
/**
|
||||
* Load an ONNX or ORT format model.
|
||||
*
|
||||
|
|
@ -582,11 +581,11 @@ class InferenceSession {
|
|||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
std::list<std::shared_ptr<onnxruntime::IOnnxRuntimeOpSchemaCollection>> custom_schema_registries_;
|
||||
#endif
|
||||
|
||||
//CustomRegistry objects own the corresponding KernelRegistry and OnnxRuntimeOpSchemaRegistry objects.
|
||||
//So its lifetime should be same as its constituents. This vector is to extend the lifetime of the owner.
|
||||
std::vector<std::shared_ptr<CustomRegistry>> custom_registries_;
|
||||
#endif
|
||||
|
||||
ModelMetadata model_metadata_;
|
||||
std::unordered_set<std::string> required_inputs_;
|
||||
|
|
|
|||
|
|
@ -439,13 +439,11 @@ static ORT_STATUS_PTR CreateSessionAndLoadModel(_In_ const OrtSessionOptions* op
|
|||
env->GetEnvironment());
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
// Add custom domains
|
||||
Status status;
|
||||
if (options && !options->custom_op_domains_.empty()) {
|
||||
ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(options->custom_op_domains_));
|
||||
}
|
||||
#endif
|
||||
|
||||
// Finish load
|
||||
if (load_config_from_model) {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
// if we can't load an ORT format model we can't really test anything
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
#include "core/common/make_unique.h"
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
|
|
@ -18,6 +19,8 @@
|
|||
#include "flatbuffers/idl.h"
|
||||
#include "flatbuffers/util.h"
|
||||
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace std;
|
||||
|
|
@ -237,16 +240,16 @@ static void DumpOrtModelAsJson(const std::string& model_uri) {
|
|||
std::ofstream(model_uri + ".json") << json;
|
||||
}
|
||||
*/
|
||||
/* The full build was causing the following error because the graph node array has some empyt (blank) node at some indices for certain ORT designs
|
||||
onnx runtime exception : Satisfied, but should not be : node == nullptr
|
||||
session_state.cc : 814 onnxruntime::SessionState::LoadFromOrtFormatCan't find node with index 4. Invalid ORT format model.
|
||||
The bug was due to loading an ORT format model in a full build, allowing optimizers to run, but trying to use the saved kernel information.
|
||||
As the optimizer removed some leaving gaps in the graph node vector.
|
||||
The build has been fixed in InferenceSession code. The following test case to catch this error.
|
||||
*/
|
||||
|
||||
/*
|
||||
Validate we don't run optimizers on an ORT format model in a full build. The optimizers will remove nodes,
|
||||
which will create a mismatch with the saved kernel information and result in a runtime error.
|
||||
We could take steps to handle this scenario in a full build, but for consistency we choose to not run optimizers
|
||||
on any ORT format model.
|
||||
*/
|
||||
TEST(OrtModelOnlyTests, ValidateOrtFormatModelDoesNotRunOptimizersInFullBuild) {
|
||||
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("mnist.onnx.ort");
|
||||
// we have tests that use a pre-created minst.onnx.ort, so make the naming for the unit test generated file clearer
|
||||
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("testdata/mnist.onnx.test_output.ort");
|
||||
SaveAndCompareModels("testdata/mnist.onnx", ort_file);
|
||||
|
||||
// DumpOrtModelAsJson(ToMBString(ort_file));
|
||||
|
|
@ -257,8 +260,8 @@ TEST(OrtModelOnlyTests, ValidateOrtFormatModelDoesNotRunOptimizersInFullBuild) {
|
|||
test_info.configs.push_back(std::make_pair(kOrtSessionOptionsConfigLoadModelFormat, "ORT"));
|
||||
|
||||
OrtValue ml_value;
|
||||
vector<float> data(28*28, 0.0);
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1,1,28,28}, data,
|
||||
vector<float> data(28 * 28, 0.0);
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1, 1, 28, 28}, data,
|
||||
&ml_value);
|
||||
test_info.inputs.insert(std::make_pair("Input3", ml_value));
|
||||
|
||||
|
|
@ -267,14 +270,13 @@ TEST(OrtModelOnlyTests, ValidateOrtFormatModelDoesNotRunOptimizersInFullBuild) {
|
|||
test_info.output_verifier = [](const std::vector<OrtValue>& fetches) {
|
||||
const auto& output = fetches[0].Get<Tensor>();
|
||||
ASSERT_TRUE(output.Shape().NumDimensions() == 2);
|
||||
// ASSERT_TRUE(output.Data<float>()[0] == 125.f);
|
||||
};
|
||||
|
||||
RunOrtModel(test_info);
|
||||
}
|
||||
|
||||
TEST(OrtModelOnlyTests, SerializeToOrtFormat) {
|
||||
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("ort_github_issue_4031.onnx.ort");
|
||||
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("testdata/ort_github_issue_4031.onnx.test_output.ort");
|
||||
SaveAndCompareModels("testdata/ort_github_issue_4031.onnx", ort_file);
|
||||
|
||||
// DumpOrtModelAsJson(ToMBString(ort_file));
|
||||
|
|
@ -301,7 +303,8 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormat) {
|
|||
}
|
||||
|
||||
TEST(OrtModelOnlyTests, SparseInitializerHandling) {
|
||||
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("sparse_initializer_handling.onnx.ort");
|
||||
const std::basic_string<ORTCHAR_T> ort_file =
|
||||
ORT_TSTR("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx.test_output.ort");
|
||||
SaveAndCompareModels("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx", ort_file);
|
||||
|
||||
SessionOptions so;
|
||||
|
|
@ -322,7 +325,8 @@ TEST(OrtModelOnlyTests, SparseInitializerHandling) {
|
|||
|
||||
#if !defined(DISABLE_ML_OPS)
|
||||
TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) {
|
||||
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("sklearn_bin_voting_classifier_soft_converted.ort");
|
||||
const std::basic_string<ORTCHAR_T> ort_file =
|
||||
ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft_converted.test_output.ort");
|
||||
SaveAndCompareModels("testdata/sklearn_bin_voting_classifier_soft.onnx", ort_file);
|
||||
|
||||
OrtModelTestInfo test_info;
|
||||
|
|
@ -361,6 +365,7 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) {
|
|||
|
||||
RunOrtModel(test_info);
|
||||
}
|
||||
|
||||
#endif // #if !defined(DISABLE_ML_OPS)
|
||||
#endif // #if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@
|
|||
#include "providers.h"
|
||||
#include "test_allocator.h"
|
||||
#include "test_fixture.h"
|
||||
#include "utils.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <Windows.h>
|
||||
|
|
@ -32,11 +33,9 @@
|
|||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
||||
struct Input {
|
||||
const char* name = nullptr;
|
||||
std::vector<int64_t> dims;
|
||||
std::vector<float> values;
|
||||
};
|
||||
// Once we use C++17 this could be replaced with std::size
|
||||
template <typename T, size_t N>
|
||||
constexpr size_t countof(T (&)[N]) { return N; }
|
||||
|
||||
extern std::unique_ptr<Ort::Env> ort_env;
|
||||
|
||||
|
|
@ -51,14 +50,18 @@ void RunSession(OrtAllocator* allocator, Ort::Session& session_object,
|
|||
std::vector<const char*> input_names;
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
input_names.emplace_back(inputs[i].name);
|
||||
ort_inputs.emplace_back(Ort::Value::CreateTensor<float>(allocator->Info(allocator), const_cast<float*>(inputs[i].values.data()), inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size()));
|
||||
ort_inputs.emplace_back(
|
||||
Ort::Value::CreateTensor<float>(allocator->Info(allocator), const_cast<float*>(inputs[i].values.data()),
|
||||
inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size()));
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> ort_outputs;
|
||||
if (output_tensor)
|
||||
session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, output_tensor, 1);
|
||||
session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
|
||||
&output_name, output_tensor, 1);
|
||||
else {
|
||||
ort_outputs = session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, 1);
|
||||
ort_outputs = session_object.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
|
||||
&output_name, 1);
|
||||
ASSERT_EQ(ort_outputs.size(), 1u);
|
||||
output_tensor = &ort_outputs[0];
|
||||
}
|
||||
|
|
@ -74,17 +77,17 @@ void RunSession(OrtAllocator* allocator, Ort::Session& session_object,
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T, typename OutT>
|
||||
void TestInference(Ort::Env& env, T model_uri,
|
||||
const std::vector<Input>& inputs,
|
||||
const char* output_name,
|
||||
const std::vector<int64_t>& expected_dims_y,
|
||||
const std::vector<OutT>& expected_values_y,
|
||||
int provider_type,
|
||||
OrtCustomOpDomain* custom_op_domain_ptr,
|
||||
const char* custom_op_library_filename,
|
||||
void** library_handle = nullptr,
|
||||
bool test_session_creation_only = false) {
|
||||
template <typename OutT>
|
||||
static void TestInference(Ort::Env& env, const std::basic_string<ORTCHAR_T>& model_uri,
|
||||
const std::vector<Input>& inputs,
|
||||
const char* output_name,
|
||||
const std::vector<int64_t>& expected_dims_y,
|
||||
const std::vector<OutT>& expected_values_y,
|
||||
int provider_type,
|
||||
OrtCustomOpDomain* custom_op_domain_ptr,
|
||||
const char* custom_op_library_filename,
|
||||
void** library_handle = nullptr,
|
||||
bool test_session_creation_only = false) {
|
||||
Ort::SessionOptions session_options;
|
||||
|
||||
if (provider_type == 1) {
|
||||
|
|
@ -103,7 +106,8 @@ void TestInference(Ort::Env& env, T model_uri,
|
|||
#endif
|
||||
} else if (provider_type == 3) {
|
||||
#ifdef USE_NUPHAR
|
||||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nuphar(session_options, /*allow_unaligned_buffers*/ 1, ""));
|
||||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nuphar(session_options,
|
||||
/*allow_unaligned_buffers*/ 1, ""));
|
||||
std::cout << "Running simple inference with nuphar provider" << std::endl;
|
||||
#else
|
||||
return;
|
||||
|
|
@ -116,11 +120,12 @@ void TestInference(Ort::Env& env, T model_uri,
|
|||
}
|
||||
|
||||
if (custom_op_library_filename) {
|
||||
Ort::ThrowOnError(Ort::GetApi().RegisterCustomOpsLibrary((OrtSessionOptions*)session_options, custom_op_library_filename, library_handle));
|
||||
Ort::ThrowOnError(Ort::GetApi().RegisterCustomOpsLibrary(session_options,
|
||||
custom_op_library_filename, library_handle));
|
||||
}
|
||||
|
||||
// if session creation passes, model loads fine
|
||||
Ort::Session session(env, model_uri, session_options);
|
||||
Ort::Session session(env, model_uri.c_str(), session_options);
|
||||
|
||||
// caller wants to test running the model (not just loading the model)
|
||||
if (!test_session_creation_only) {
|
||||
|
|
@ -136,7 +141,8 @@ void TestInference(Ort::Env& env, T model_uri,
|
|||
expected_values_y,
|
||||
nullptr);
|
||||
//with preallocated output tensor
|
||||
Ort::Value value_y = Ort::Value::CreateTensor<float>(default_allocator.get(), expected_dims_y.data(), expected_dims_y.size());
|
||||
Ort::Value value_y = Ort::Value::CreateTensor<float>(default_allocator.get(),
|
||||
expected_dims_y.data(), expected_dims_y.size());
|
||||
|
||||
//test it twice
|
||||
for (int i = 0; i != 2; ++i)
|
||||
|
|
@ -181,7 +187,8 @@ TEST_P(CApiTestWithProvider, simple) {
|
|||
std::vector<int64_t> expected_dims_y = {3, 2};
|
||||
std::vector<float> expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f};
|
||||
|
||||
TestInference<PATH_TYPE, float>(*ort_env, MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, GetParam(), nullptr, nullptr);
|
||||
TestInference<float>(*ort_env, MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, GetParam(),
|
||||
nullptr, nullptr);
|
||||
}
|
||||
|
||||
TEST(CApiTest, dim_param) {
|
||||
|
|
@ -216,70 +223,6 @@ INSTANTIATE_TEST_SUITE_P(CApiTestWithProviders,
|
|||
CApiTestWithProvider,
|
||||
::testing::Values(0, 1, 2, 3, 4));
|
||||
|
||||
struct OrtTensorDimensions : std::vector<int64_t> {
|
||||
OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) {
|
||||
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
|
||||
std::vector<int64_t>::operator=(ort.GetTensorShape(info));
|
||||
ort.ReleaseTensorTypeAndShapeInfo(info);
|
||||
}
|
||||
};
|
||||
|
||||
// Once we use C++17 this could be replaced with std::size
|
||||
template <typename T, size_t N>
|
||||
constexpr size_t countof(T (&)[N]) { return N; }
|
||||
void cuda_add(int64_t, float*, const float*, const float*);
|
||||
|
||||
struct MyCustomKernel {
|
||||
MyCustomKernel(Ort::CustomOpApi ort, const OrtKernelInfo* /*info*/) : ort_(ort) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
|
||||
const float* X = ort_.GetTensorData<float>(input_X);
|
||||
const float* Y = ort_.GetTensorData<float>(input_Y);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
float* out = ort_.GetTensorMutableData<float>(output);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
#ifdef USE_CUDA
|
||||
cuda_add(size, out, X, Y);
|
||||
#else
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
out[i] = X[i] + Y[i];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
Ort::CustomOpApi ort_;
|
||||
};
|
||||
|
||||
struct MyCustomOp : Ort::CustomOpBase<MyCustomOp, MyCustomKernel> {
|
||||
explicit MyCustomOp(const char* provider) : provider_(provider) {}
|
||||
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { return new MyCustomKernel(api, info); };
|
||||
const char* GetName() const { return "Foo"; };
|
||||
|
||||
const char* GetExecutionProviderType() const { return provider_; };
|
||||
|
||||
size_t GetInputTypeCount() const { return 2; };
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
|
||||
size_t GetOutputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
|
||||
private:
|
||||
const char* provider_;
|
||||
};
|
||||
|
||||
TEST(CApiTest, custom_op_handler) {
|
||||
std::cout << "Running custom op inference" << std::endl;
|
||||
|
||||
|
|
@ -303,9 +246,11 @@ TEST(CApiTest, custom_op_handler) {
|
|||
custom_op_domain.Add(&custom_op);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
TestInference<PATH_TYPE, float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 1, custom_op_domain, nullptr, nullptr);
|
||||
TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 1,
|
||||
custom_op_domain, nullptr, nullptr);
|
||||
#else
|
||||
TestInference<PATH_TYPE, float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0, custom_op_domain, nullptr);
|
||||
TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0,
|
||||
custom_op_domain, nullptr);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -368,9 +313,11 @@ struct SliceCustomOpKernel {
|
|||
|
||||
struct SliceCustomOp : Ort::CustomOpBase<SliceCustomOp, SliceCustomOpKernel> {
|
||||
explicit SliceCustomOp(const char* provider) : provider_(provider) {}
|
||||
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { return new SliceCustomOpKernel(api, info); };
|
||||
const char* GetName() const { return "Slice"; };
|
||||
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const {
|
||||
return new SliceCustomOpKernel(api, info);
|
||||
};
|
||||
|
||||
const char* GetName() const { return "Slice"; };
|
||||
const char* GetExecutionProviderType() const { return provider_; };
|
||||
|
||||
size_t GetInputTypeCount() const { return 3; };
|
||||
|
|
@ -427,9 +374,11 @@ TEST(CApiTest, varied_input_custom_op_handler) {
|
|||
custom_op_domain.Add(&slice_custom_op);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
TestInference<PATH_TYPE, float>(*ort_env, VARIED_INPUT_CUSTOM_OP_MODEL_URI, inputs, "Z", expected_dims_z, expected_values_z, 1, custom_op_domain, nullptr, nullptr);
|
||||
TestInference<float>(*ort_env, VARIED_INPUT_CUSTOM_OP_MODEL_URI, inputs, "Z",
|
||||
expected_dims_z, expected_values_z, 1, custom_op_domain, nullptr, nullptr);
|
||||
#else
|
||||
TestInference<PATH_TYPE, float>(*ort_env, VARIED_INPUT_CUSTOM_OP_MODEL_URI, inputs, "Z", expected_dims_z, expected_values_z, 0, custom_op_domain, nullptr);
|
||||
TestInference<float>(*ort_env, VARIED_INPUT_CUSTOM_OP_MODEL_URI, inputs, "Z",
|
||||
expected_dims_z, expected_values_z, 0, custom_op_domain, nullptr);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -454,8 +403,8 @@ TEST(CApiTest, RegisterCustomOpForCPUAndCUDA) {
|
|||
custom_op_domain.Add(&custom_op_cpu);
|
||||
custom_op_domain.Add(&custom_op_cuda);
|
||||
|
||||
TestInference<PATH_TYPE, float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y,
|
||||
expected_values_y, 1, custom_op_domain, nullptr, nullptr, true);
|
||||
TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y,
|
||||
expected_values_y, 1, custom_op_domain, nullptr, nullptr, true);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -496,8 +445,8 @@ lib_name = "./libcustom_op_library.so";
|
|||
#endif
|
||||
|
||||
void* library_handle = nullptr;
|
||||
TestInference<PATH_TYPE, int32_t>(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y,
|
||||
expected_values_y, 0, nullptr, lib_name.c_str(), &library_handle);
|
||||
TestInference<int32_t>(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y,
|
||||
expected_values_y, 0, nullptr, lib_name.c_str(), &library_handle);
|
||||
|
||||
#ifdef _WIN32
|
||||
bool success = ::FreeLibrary(reinterpret_cast<HMODULE>(library_handle));
|
||||
|
|
@ -552,7 +501,8 @@ TEST(CApiTest, test_pyop) {
|
|||
input.values = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
std::vector<int64_t> expected_dims_y = {2, 2};
|
||||
std::vector<float> expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f};
|
||||
TestInference<PATH_TYPE, float>(*ort_env, PYOP_FLOAT_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0, nullptr, nullptr);
|
||||
TestInference<float>(*ort_env, PYOP_FLOAT_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0,
|
||||
nullptr, nullptr);
|
||||
}
|
||||
|
||||
TEST(CApiTest, test_pyop_multi) {
|
||||
|
|
@ -564,7 +514,8 @@ TEST(CApiTest, test_pyop_multi) {
|
|||
input.values = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
std::vector<int64_t> expected_dims_y = {2, 2};
|
||||
std::vector<float> expected_values_y = {8.0f, 16.0f, 24.0f, 32.0f};
|
||||
TestInference<PATH_TYPE, float>(*ort_env, PYOP_MULTI_MODEL_URI, inputs, "Z", expected_dims_y, expected_values_y, 0, nullptr, nullptr);
|
||||
TestInference<float>(*ort_env, PYOP_MULTI_MODEL_URI, inputs, "Z", expected_dims_y, expected_values_y, 0,
|
||||
nullptr, nullptr);
|
||||
}
|
||||
|
||||
TEST(CApiTest, test_pyop_kwarg) {
|
||||
|
|
@ -576,7 +527,8 @@ TEST(CApiTest, test_pyop_kwarg) {
|
|||
input.values = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
std::vector<int64_t> expected_dims_y = {2, 2};
|
||||
std::vector<float> expected_values_y = {25.0f, 50.0f, 75.0f, 100.0f};
|
||||
TestInference<PATH_TYPE, float>(*ort_env, PYOP_KWARG_MODEL_URI, inputs, "Z", expected_dims_y, expected_values_y, 0, nullptr, nullptr);
|
||||
TestInference<float>(*ort_env, PYOP_KWARG_MODEL_URI, inputs, "Z", expected_dims_y, expected_values_y, 0,
|
||||
nullptr, nullptr);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -596,7 +548,8 @@ TEST(ReducedOpsBuildTest, test_included_ops) {
|
|||
std::vector<Input> inputs = {{"X", {3}, {-1.0f, 2.0f, -3.0f}}};
|
||||
std::vector<int64_t> expected_dims_y = {1};
|
||||
std::vector<float> expected_values_y = {0.75};
|
||||
TestInference<PATH_TYPE, float>(*ort_env, model_uri, inputs, "Y", expected_dims_y, expected_values_y, 0, nullptr, nullptr);
|
||||
TestInference<float>(*ort_env, model_uri, inputs, "Y", expected_dims_y, expected_values_y, 0,
|
||||
nullptr, nullptr);
|
||||
}
|
||||
|
||||
TEST(ReducedOpsBuildTest, test_excluded_ops) {
|
||||
|
|
@ -609,7 +562,8 @@ TEST(ReducedOpsBuildTest, test_excluded_ops) {
|
|||
bool failed = false;
|
||||
try {
|
||||
//only test model loading, exception expected
|
||||
TestInference<PATH_TYPE, float>(*ort_env, model_uri, inputs, "Y", expected_dims_y, expected_values_y, 0, nullptr, nullptr, nullptr, true);
|
||||
TestInference<float>(*ort_env, model_uri, inputs, "Y", expected_dims_y, expected_values_y, 0,
|
||||
nullptr, nullptr, nullptr, true);
|
||||
} catch (const Ort::Exception& e) {
|
||||
failed = e.GetOrtErrorCode() == ORT_NOT_IMPLEMENTED;
|
||||
}
|
||||
|
|
@ -782,8 +736,8 @@ TEST(CApiTest, io_binding_cuda) {
|
|||
ASSERT_NE(output_data.get(), nullptr);
|
||||
|
||||
// Create an OrtValue tensor backed by data on CUDA memory
|
||||
Ort::Value bound_y = Ort::Value::CreateTensor(info_cuda, reinterpret_cast<float*>(output_data.get()), expected_y.size(),
|
||||
expected_y_shape.data(), expected_y_shape.size());
|
||||
Ort::Value bound_y = Ort::Value::CreateTensor(info_cuda, reinterpret_cast<float*>(output_data.get()),
|
||||
expected_y.size(), expected_y_shape.data(), expected_y_shape.size());
|
||||
|
||||
Ort::IoBinding binding(session);
|
||||
binding.BindInput("X", bound_x);
|
||||
|
|
@ -854,7 +808,8 @@ TEST(CApiTest, create_tensor) {
|
|||
int64_t expected_len = 2;
|
||||
auto default_allocator = onnxruntime::make_unique<MockedOrtAllocator>();
|
||||
|
||||
Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
|
||||
Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1,
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
|
||||
|
||||
Ort::ThrowOnError(Ort::GetApi().FillStringTensor(tensor, s, expected_len));
|
||||
auto shape_info = tensor.GetTensorTypeAndShapeInfo();
|
||||
|
|
@ -874,7 +829,8 @@ TEST(CApiTest, fill_string_tensor) {
|
|||
int64_t expected_len = 2;
|
||||
auto default_allocator = onnxruntime::make_unique<MockedOrtAllocator>();
|
||||
|
||||
Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
|
||||
Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1,
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
|
||||
|
||||
for (int64_t i = 0; i < expected_len; i++) {
|
||||
tensor.FillStringTensorElement(s[i], i);
|
||||
|
|
@ -892,7 +848,8 @@ TEST(CApiTest, get_string_tensor_element) {
|
|||
int64_t element_index = 0;
|
||||
auto default_allocator = onnxruntime::make_unique<MockedOrtAllocator>();
|
||||
|
||||
Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
|
||||
Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1,
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
|
||||
|
||||
tensor.FillStringTensor(s, expected_len);
|
||||
|
||||
|
|
@ -1007,7 +964,8 @@ TEST(CApiTest, override_initializer) {
|
|||
|
||||
std::string f2_data{"f2_string"};
|
||||
// Place a string into Tensor OrtValue and assign to the
|
||||
Ort::Value f2_input_tensor = Ort::Value::CreateTensor(allocator.get(), dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
|
||||
Ort::Value f2_input_tensor = Ort::Value::CreateTensor(allocator.get(), dims.data(), dims.size(),
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
|
||||
const char* const input_char_string[] = {f2_data.c_str()};
|
||||
f2_input_tensor.FillStringTensor(input_char_string, 1U);
|
||||
|
||||
|
|
@ -1138,7 +1096,8 @@ TEST(CApiTest, model_metadata) {
|
|||
ASSERT_TRUE(version == 1);
|
||||
|
||||
int64_t num_keys_in_custom_metadata_map;
|
||||
char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(), num_keys_in_custom_metadata_map);
|
||||
char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(),
|
||||
num_keys_in_custom_metadata_map);
|
||||
ASSERT_TRUE(num_keys_in_custom_metadata_map == 2);
|
||||
|
||||
allocator.get()->Free(custom_metadata_map_keys[0]);
|
||||
|
|
@ -1147,7 +1106,8 @@ TEST(CApiTest, model_metadata) {
|
|||
|
||||
char* lookup_value_1 = model_metadata.LookupCustomMetadataMap("ort_config", allocator.get());
|
||||
ASSERT_TRUE(strcmp(lookup_value_1,
|
||||
"{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}") == 0);
|
||||
"{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, "
|
||||
"\"graph_optimization_level\": 99, \"enable_profiling\": 1}}") == 0);
|
||||
allocator.get()->Free(lookup_value_1);
|
||||
|
||||
char* lookup_value_2 = model_metadata.LookupCustomMetadataMap("dummy_key", allocator.get());
|
||||
|
|
@ -1180,7 +1140,8 @@ TEST(CApiTest, model_metadata) {
|
|||
|
||||
// Model does not contain custom metadata map
|
||||
int64_t num_keys_in_custom_metadata_map;
|
||||
char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(), num_keys_in_custom_metadata_map);
|
||||
char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(),
|
||||
num_keys_in_custom_metadata_map);
|
||||
ASSERT_TRUE(num_keys_in_custom_metadata_map == 0);
|
||||
ASSERT_TRUE(custom_metadata_map_keys == nullptr);
|
||||
}
|
||||
|
|
@ -1235,9 +1196,9 @@ TEST(CApiTest, TestSharedAllocatorUsingCreateAndRegisterAllocator) {
|
|||
ASSERT_TRUE(api.CreateAndRegisterAllocator(env_ptr, mem_info, arena_cfg) == nullptr);
|
||||
|
||||
// test for duplicates
|
||||
std::unique_ptr<OrtStatus, decltype(api.ReleaseStatus)> status_releaser(api.CreateAndRegisterAllocator(env_ptr, mem_info,
|
||||
arena_cfg),
|
||||
api.ReleaseStatus);
|
||||
std::unique_ptr<OrtStatus, decltype(api.ReleaseStatus)> status_releaser(
|
||||
api.CreateAndRegisterAllocator(env_ptr, mem_info, arena_cfg),
|
||||
api.ReleaseStatus);
|
||||
ASSERT_FALSE(status_releaser.get() == nullptr);
|
||||
|
||||
Ort::SessionOptions session_options;
|
||||
|
|
|
|||
134
onnxruntime/test/shared_lib/test_ort_format_models.cc
Normal file
134
onnxruntime/test/shared_lib/test_ort_format_models.cc
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// if we can't load an ORT format model we can't really test anything
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
#include "core/common/make_unique.h"
|
||||
#include "core/graph/constants.h"
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
|
||||
#include "test_allocator.h"
|
||||
#include "utils.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
extern std::unique_ptr<Ort::Env> ort_env;
|
||||
|
||||
static void TestInference(Ort::Env& env, const std::basic_string<ORTCHAR_T>& model_uri,
|
||||
const std::vector<Input>& inputs, const char* output_name,
|
||||
const std::vector<int64_t>& expected_dims_y, const std::vector<float>& expected_values_y,
|
||||
Ort::CustomOpDomain& custom_op_domain) {
|
||||
Ort::SessionOptions session_options;
|
||||
session_options.Add(custom_op_domain);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
OrtCUDAProviderOptions cuda_options{};
|
||||
session_options.AppendExecutionProvider_CUDA(cuda_options);
|
||||
#endif
|
||||
|
||||
Ort::Session session(env, model_uri.c_str(), session_options);
|
||||
|
||||
MockedOrtAllocator allocator;
|
||||
std::vector<Ort::Value> ort_inputs;
|
||||
std::vector<const char*> input_names;
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
// we put the data in a Tensor, and Tensor has a method to get a mutable pointer to the data.
|
||||
// we never call that for an input, but need to do the const_cast to make this potential explicit
|
||||
float* input_data = const_cast<float*>(inputs[i].values.data());
|
||||
|
||||
auto input_tensor = Ort::Value::CreateTensor<float>(allocator.Info(), input_data, inputs[i].values.size(),
|
||||
inputs[i].dims.data(), inputs[i].dims.size());
|
||||
|
||||
input_names.push_back(inputs[i].name);
|
||||
ort_inputs.push_back(std::move(input_tensor));
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> ort_outputs;
|
||||
ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
|
||||
&output_name, 1);
|
||||
|
||||
ASSERT_EQ(ort_outputs.size(), size_t(1));
|
||||
const auto& output_tensor = &ort_outputs[0];
|
||||
|
||||
auto type_info = output_tensor->GetTensorTypeAndShapeInfo();
|
||||
ASSERT_EQ(type_info.GetShape(), expected_dims_y);
|
||||
size_t total_len = type_info.GetElementCount();
|
||||
ASSERT_EQ(expected_values_y.size(), total_len);
|
||||
|
||||
const auto* f = output_tensor->GetTensorMutableData<float>();
|
||||
for (size_t i = 0; i < total_len; ++i) {
|
||||
ASSERT_EQ(expected_values_y[i], f[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
TEST(OrtFormatCustomOpTests, ConvertOnnxModelToOrt) {
|
||||
const std::basic_string<ORTCHAR_T> onnx_file = ORT_TSTR("testdata/foo_1.onnx");
|
||||
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("testdata/foo_1.onnx.test_output.ort");
|
||||
|
||||
#ifdef USE_CUDA
|
||||
MyCustomOp custom_op{onnxruntime::kCudaExecutionProvider};
|
||||
#else
|
||||
MyCustomOp custom_op{onnxruntime::kCpuExecutionProvider};
|
||||
#endif
|
||||
Ort::CustomOpDomain custom_op_domain("");
|
||||
custom_op_domain.Add(&custom_op);
|
||||
|
||||
// convert to ort by loading the onnx model
|
||||
{
|
||||
Ort::SessionOptions so;
|
||||
so.Add(custom_op_domain);
|
||||
so.SetLogId("CustomOp");
|
||||
so.SetOptimizedModelFilePath(ort_file.c_str());
|
||||
|
||||
#ifdef USE_CUDA
|
||||
OrtCUDAProviderOptions cuda_options{};
|
||||
so.AppendExecutionProvider_CUDA(cuda_options);
|
||||
#endif
|
||||
|
||||
Ort::Session session(*ort_env, onnx_file.c_str(), so);
|
||||
}
|
||||
|
||||
// now load the ORT format model and execute it
|
||||
std::vector<Input> inputs(1);
|
||||
Input& input = inputs[0];
|
||||
input.name = "X";
|
||||
input.dims = {3, 2};
|
||||
input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
|
||||
// model adds 1, 2, 3, 4, 5, 6 to the input values
|
||||
std::vector<int64_t> expected_dims_y = {3, 2};
|
||||
std::vector<float> expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f};
|
||||
|
||||
TestInference(*ort_env, ort_file, inputs, "Y", expected_dims_y, expected_values_y, custom_op_domain);
|
||||
}
|
||||
#endif // if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
// the saved ORT format model has the CPU EP assigned to the custom op node, so we only test if we're not using the
|
||||
// CUDA EP for the test.
|
||||
#ifndef USE_CUDA
|
||||
TEST(OrtFormatCustomOpTests, LoadOrtModel) {
|
||||
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("testdata/foo_1.onnx.ort");
|
||||
|
||||
MyCustomOp custom_op{onnxruntime::kCpuExecutionProvider};
|
||||
Ort::CustomOpDomain custom_op_domain("");
|
||||
custom_op_domain.Add(&custom_op);
|
||||
|
||||
// load the ORT format model and execute it
|
||||
std::vector<Input> inputs(1);
|
||||
Input& input = inputs[0];
|
||||
input.name = "X";
|
||||
input.dims = {3, 2};
|
||||
input.values = {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f};
|
||||
|
||||
// model adds 1, 2, 3, 4, 5, 6 to the input values
|
||||
std::vector<int64_t> expected_dims_y = {3, 2};
|
||||
std::vector<float> expected_values_y = {7.0f, 7.0f, 7.0f, 7.0f, 7.0f, 7.0f};
|
||||
|
||||
TestInference(*ort_env, ort_file, inputs, "Y", expected_dims_y, expected_values_y, custom_op_domain);
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // #if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
35
onnxruntime/test/shared_lib/utils.cc
Normal file
35
onnxruntime/test/shared_lib/utils.cc
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
void cuda_add(int64_t, float*, const float*, const float*);
|
||||
#endif
|
||||
|
||||
void MyCustomKernel::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
|
||||
const float* X = ort_.GetTensorData<float>(input_X);
|
||||
const float* Y = ort_.GetTensorData<float>(input_Y);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
float* out = ort_.GetTensorMutableData<float>(output);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
#ifdef USE_CUDA
|
||||
cuda_add(size, out, X, Y);
|
||||
#else
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
out[i] = X[i] + Y[i];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
45
onnxruntime/test/shared_lib/utils.h
Normal file
45
onnxruntime/test/shared_lib/utils.h
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
|
||||
struct Input {
|
||||
const char* name = nullptr;
|
||||
std::vector<int64_t> dims;
|
||||
std::vector<float> values;
|
||||
};
|
||||
|
||||
struct OrtTensorDimensions : std::vector<int64_t> {
|
||||
OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) {
|
||||
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
|
||||
std::vector<int64_t>::operator=(ort.GetTensorShape(info));
|
||||
ort.ReleaseTensorTypeAndShapeInfo(info);
|
||||
}
|
||||
};
|
||||
|
||||
struct MyCustomKernel {
|
||||
MyCustomKernel(Ort::CustomOpApi ort, const OrtKernelInfo* /*info*/) : ort_(ort) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
Ort::CustomOpApi ort_;
|
||||
};
|
||||
|
||||
struct MyCustomOp : Ort::CustomOpBase<MyCustomOp, MyCustomKernel> {
|
||||
explicit MyCustomOp(const char* provider) : provider_(provider) {}
|
||||
|
||||
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { return new MyCustomKernel(api, info); };
|
||||
const char* GetName() const { return "Foo"; };
|
||||
const char* GetExecutionProviderType() const { return provider_; };
|
||||
|
||||
size_t GetInputTypeCount() const { return 2; };
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
|
||||
size_t GetOutputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
|
||||
private:
|
||||
const char* provider_;
|
||||
};
|
||||
BIN
onnxruntime/test/testdata/foo_1.onnx.ort
vendored
Normal file
BIN
onnxruntime/test/testdata/foo_1.onnx.ort
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue