Add support for custom ops to minimal build. (#6228)

* Add support for custom ops to minimal build.
Cost is only ~8KB so including in base minimal build.
This commit is contained in:
Scott McKay 2021-01-25 10:41:00 +10:00 committed by GitHub
parent 6507b4f818
commit e1dc268e45
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 470 additions and 216 deletions

View file

@ -10,8 +10,6 @@ file(GLOB_RECURSE onnxruntime_framework_srcs CONFIGURE_DEPENDS
if (onnxruntime_MINIMAL_BUILD)
set(onnxruntime_framework_src_exclude
"${ONNXRUNTIME_ROOT}/core/framework/provider_bridge_ort.cc"
"${ONNXRUNTIME_INCLUDE_DIR}/core/framework/customregistry.h"
"${ONNXRUNTIME_ROOT}/core/framework/customregistry.cc"
"${ONNXRUNTIME_ROOT}/core/framework/fallback_cpu_capability.h"
"${ONNXRUNTIME_ROOT}/core/framework/fallback_cpu_capability.cc"
)

View file

@ -322,7 +322,10 @@ set (onnxruntime_shared_lib_test_SRC
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_run_options.cc
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_allocator.cc
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_nontensor_types.cc
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_loading.cc)
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_loading.cc
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_ort_format_models.cc
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/utils.h
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/utils.cc)
if (NOT onnxruntime_MINIMAL_BUILD)
list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_inference.cc)

View file

@ -5,11 +5,14 @@
#include "core/common/status.h"
#include "core/common/logging/logging.h"
#include "core/graph/schema_registry.h"
#include "core/framework/op_kernel.h"
#include "core/framework/kernel_def_builder.h"
#include "core/framework/kernel_registry.h"
#if !defined(ORT_MINIMAL_BUILD)
#include "core/graph/schema_registry.h"
#endif
namespace onnxruntime {
/**
@ -17,9 +20,14 @@ namespace onnxruntime {
*/
class CustomRegistry final {
public:
CustomRegistry() :
kernel_registry_(std::make_shared<KernelRegistry>()),
opschema_registry_(std::make_shared<onnxruntime::OnnxRuntimeOpSchemaRegistry>()) {}
CustomRegistry()
: kernel_registry_(std::make_shared<KernelRegistry>())
#if !defined(ORT_MINIMAL_BUILD)
,
opschema_registry_(std::make_shared<onnxruntime::OnnxRuntimeOpSchemaRegistry>())
#endif
{
}
/**
* Register a kernel definition together with kernel factory method to this session.
@ -32,18 +40,21 @@ class CustomRegistry final {
common::Status RegisterCustomKernel(KernelCreateInfo&);
common::Status RegisterOpSet(std::vector<ONNX_NAMESPACE::OpSchema>& schemas, const std::string& domain,
int baseline_opset_version, int opset_version);
const std::shared_ptr<KernelRegistry>& GetKernelRegistry();
#if !defined(ORT_MINIMAL_BUILD)
common::Status RegisterOpSet(std::vector<ONNX_NAMESPACE::OpSchema>& schemas, const std::string& domain,
int baseline_opset_version, int opset_version);
const std::shared_ptr<onnxruntime::OnnxRuntimeOpSchemaRegistry>& GetOpschemaRegistry();
#endif
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomRegistry);
std::shared_ptr<KernelRegistry> kernel_registry_;
std::shared_ptr<onnxruntime::OnnxRuntimeOpSchemaRegistry> opschema_registry_;
#if !defined(ORT_MINIMAL_BUILD)
std::shared_ptr<onnxruntime::OnnxRuntimeOpSchemaRegistry> opschema_registry_;
#endif
};
} // namespace onnxruntime

View file

@ -654,7 +654,8 @@ class Graph {
/** Return true if "node_arg" is a input or an initializer. Otherwise, returns false. */
bool IsInputsIncludingInitializers(const NodeArg* node_arg) const noexcept {
return std::find(graph_inputs_including_initializers_.begin(), graph_inputs_including_initializers_.end(), node_arg) != graph_inputs_including_initializers_.end();
return std::find(graph_inputs_including_initializers_.begin(),
graph_inputs_including_initializers_.end(), node_arg) != graph_inputs_including_initializers_.end();
}
/** Gets the Graph inputs that are initializers
@ -1085,6 +1086,9 @@ class Graph {
static common::Status LoadFromOrtFormat(
const onnxruntime::experimental::fbs::Graph& fbs_graph, const Model& owning_model,
const std::unordered_map<std::string, int>& domain_to_version,
#if !defined(ORT_MINIMAL_BUILD)
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
#endif
const logging::Logger& logger, std::unique_ptr<Graph>& graph);
// deserialize a subgraph
@ -1104,6 +1108,9 @@ class Graph {
// Create empty Graph instance to re-create from ORT format serialized data.
Graph(const Model& owning_model,
const std::unordered_map<std::string, int>& domain_to_version,
#if !defined(ORT_MINIMAL_BUILD)
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
#endif
Graph* parent_graph, const Node* parent_node,
const logging::Logger& logger);

View file

@ -12,21 +12,22 @@ common::Status CustomRegistry::RegisterCustomKernel(KernelCreateInfo& create_inf
return kernel_registry_->Register(std::move(create_info));
}
const std::shared_ptr<KernelRegistry>& CustomRegistry::GetKernelRegistry() {
return kernel_registry_;
}
#if !defined(ORT_MINIMAL_BUILD)
common::Status CustomRegistry::RegisterOpSet(
std::vector<ONNX_NAMESPACE::OpSchema>& schemas,
const std::string& domain,
int baseline_opset_version,
int opset_version) {
return opschema_registry_->RegisterOpSet(schemas, domain, baseline_opset_version, opset_version);
}
const std::shared_ptr<KernelRegistry>& CustomRegistry::GetKernelRegistry() {
return kernel_registry_;
}
const std::shared_ptr<onnxruntime::OnnxRuntimeOpSchemaRegistry>& CustomRegistry::GetOpschemaRegistry() {
return opschema_registry_;
}
#endif
} // namespace onnxruntime

View file

@ -77,19 +77,22 @@ void IExecutionProvider::InsertAllocator(AllocatorPtr allocator) {
#if !defined(ORT_MINIMAL_BUILD)
common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& /*fused_node*/,
std::vector<NodeComputeInfo>& /*node_compute_funcs*/) {
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED,
"IExecutionProvider::Compile with fused Node is not implemented by " + type_);
}
common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& /*fused_node*/,
std::string& /*dll_path*/) {
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED,
"IExecutionProvider::Compile with fused Node and dll path is not implemented by " + type_);
}
#endif
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
common::Status IExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& /*fused_nodes_and_graphs*/,
std::vector<NodeComputeInfo>& /*node_compute_funcs*/) {
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED,
"IExecutionProvider::Compile with FusedNodeAndGraph is not implemented by " + type_);
}
#endif

View file

@ -417,6 +417,10 @@ static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr,
std::vector<std::unique_ptr<ComputeCapability>> capabilities =
current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(type));
if (capabilities.empty()) {
return Status::OK();
}
// storage for the GraphViewer for each IndexedSubGraph
std::vector<std::unique_ptr<GraphViewer>> viewers;
viewers.reserve(capabilities.size());

View file

@ -45,14 +45,12 @@ Status KernelRegistryManager::RegisterKernels(const ExecutionProviders& executio
return Status::OK();
}
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
void KernelRegistryManager::RegisterKernelRegistry(std::shared_ptr<KernelRegistry> kernel_registry) {
if (nullptr == kernel_registry) {
return;
}
custom_kernel_registries_.push_front(kernel_registry);
}
#endif
#if !defined(ORT_MINIMAL_BUILD)
bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type) {
@ -86,14 +84,12 @@ Status KernelRegistryManager::SearchKernelRegistry(const onnxruntime::Node& node
return Status(ONNXRUNTIME, FAIL, create_error_message("The node is not placed on any Execution Provider. "));
}
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
for (auto& registry : custom_kernel_registries_) {
status = registry->TryFindKernel(node, std::string(), kernel_def_hash, kernel_create_info);
if (status.IsOK()) {
return status;
}
}
#endif
KernelRegistry* p = nullptr;
auto iter = provider_type_to_registry_.find(ptype);

View file

@ -32,7 +32,6 @@ class KernelRegistryManager {
// Register kernels from providers
Status RegisterKernels(const ExecutionProviders& execution_providers) ORT_MUST_USE_RESULT;
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
// The registry passed in this function has highest priority than anything already in this KernelRegistryManager,
// and anything registered from RegisterKernels
// For example, if you do:
@ -42,6 +41,7 @@ class KernelRegistryManager {
// Then B > A > providers
void RegisterKernelRegistry(std::shared_ptr<KernelRegistry> kernel_registry);
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
/**
* Search kernel registry by provider type.
* @param type provider type string

View file

@ -3619,13 +3619,19 @@ std::ostream& operator<<(std::ostream& out, const Graph& graph) {
#endif // !defined(ORT_MINIMAL_BUILD)
#if defined(ENABLE_ORT_FORMAT_LOAD)
Status Graph::LoadFromOrtFormat(
const onnxruntime::experimental::fbs::Graph& fbs_graph,
const Model& owning_model,
const std::unordered_map<std::string, int>& domain_to_version,
const logging::Logger& logger, std::unique_ptr<Graph>& graph) {
Status Graph::LoadFromOrtFormat(const onnxruntime::experimental::fbs::Graph& fbs_graph,
const Model& owning_model,
const std::unordered_map<std::string, int>& domain_to_version,
#if !defined(ORT_MINIMAL_BUILD)
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
#endif
const logging::Logger& logger, std::unique_ptr<Graph>& graph) {
// can't use make_unique as we're calling a private ctor
graph.reset(new Graph(owning_model, domain_to_version, nullptr, nullptr, logger));
graph.reset(new Graph(owning_model, domain_to_version,
#if !defined(ORT_MINIMAL_BUILD)
schema_registry,
#endif
nullptr, nullptr, logger));
ORT_RETURN_IF_ERROR(graph->LoadFromOrtFormat(fbs_graph));
@ -3636,8 +3642,6 @@ Status Graph::LoadFromOrtFormat(
// and in InferenceSession::Initialize skip partitioning and running optimizers.
graph->SetGraphResolveNeeded();
ORT_RETURN_IF_ERROR(graph->Resolve());
#else
// probably nothing required here. validate with model that has nested subgraphs.
#endif
return Status::OK();
@ -3648,7 +3652,11 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::experimental::fbs::Graph& fbs
const logging::Logger& logger, std::unique_ptr<Graph>& graph) {
// can't use make_unique as we're calling a private ctor
graph.reset(new Graph(parent_graph.owning_model_,
parent_graph.domain_to_version_, &parent_graph, &parent_node,
parent_graph.domain_to_version_,
#if !defined(ORT_MINIMAL_BUILD)
parent_graph.schema_registry_,
#endif
&parent_graph, &parent_node,
logger));
return graph->LoadFromOrtFormat(fbs_graph);
@ -3656,12 +3664,15 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::experimental::fbs::Graph& fbs
Graph::Graph(const Model& owning_model,
const std::unordered_map<std::string, int>& domain_to_version,
#if !defined(ORT_MINIMAL_BUILD)
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
#endif
Graph* parent_graph, const Node* parent_node,
const logging::Logger& logger)
: owning_model_(owning_model),
graph_proto_(&deserialized_proto_data_),
#if !defined(ORT_MINIMAL_BUILD)
schema_registry_(std::make_shared<SchemaRegistryManager>()),
schema_registry_(schema_registry),
#endif
domain_to_version_(domain_to_version),
ir_version_(owning_model.IrVersion()),

View file

@ -93,7 +93,8 @@ Model::Model(const ModelProto& model_proto, const PathString& model_path,
: Model(ModelProto(model_proto), model_path, local_registries, logger) {
}
Model::Model(ModelProto&& model_proto, const PathString& model_path, const IOnnxRuntimeOpSchemaRegistryList* local_registries,
Model::Model(ModelProto&& model_proto, const PathString& model_path,
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger)
: model_path_(Path::Parse(model_path)) {
if (!utils::HasGraph(model_proto)) {
@ -618,6 +619,9 @@ Model::Model() : model_path_{} {
#if defined(ENABLE_ORT_FORMAT_LOAD)
common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,
#if !defined(ORT_MINIMAL_BUILD)
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
#endif
const logging::Logger& logger,
std::unique_ptr<Model>& model) {
model.reset(new Model());
@ -632,6 +636,13 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,
}
model->model_proto_.set_model_version(fbs_model.model_version());
model->model_proto_.set_ir_version(fbs_model.ir_version());
auto schema_registry = std::make_shared<SchemaRegistryManager>();
if (local_registries != nullptr) {
for (const auto& schema_collection : *local_registries) {
schema_registry->RegisterRegistry(schema_collection);
}
}
#else
experimental::utils::LoadStringFromOrtFormat(model->producer_name_, fbs_model.producer_name());
experimental::utils::LoadStringFromOrtFormat(model->producer_version_, fbs_model.producer_version());
@ -648,10 +659,14 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,
auto fbs_graph = fbs_model.graph();
ORT_RETURN_IF(nullptr == fbs_graph, "Graph is null. Invalid ORT format model.");
#if !defined(ORT_MINIMAL_BUILD)
ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version, schema_registry, logger,
model->graph_));
#else
ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version, logger, model->graph_));
#endif
return Status::OK();
}
#endif
#endif // defined(ENABLE_ORT_FORMAT_LOAD)
} // namespace onnxruntime

View file

@ -239,6 +239,9 @@ class Model {
#if defined(ENABLE_ORT_FORMAT_LOAD)
static common::Status LoadFromOrtFormat(const onnxruntime::experimental::fbs::Model& fbs_model,
#if !defined(ORT_MINIMAL_BUILD)
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
#endif
const logging::Logger& logger,
std::unique_ptr<Model>& model);
#endif

View file

@ -4,16 +4,18 @@
#ifdef _WIN32
#pragma warning(disable : 4267)
#endif
#include "core/graph/onnx_protobuf.h"
#include "core/session/inference_session.h"
#include "core/session/ort_apis.h"
#include "core/framework/customregistry.h"
#include "core/framework/data_types.h"
#include "core/framework/op_kernel_info.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/framework/error_code_helper.h"
#include "core/framework/tensor_type_and_shape.h"
#include "core/graph/onnx_protobuf.h"
#include "core/session/inference_session.h"
#include "core/session/ort_apis.h"
ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(const onnxruntime::DataTypeImpl* cpp_type);
// ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(const onnxruntime::DataTypeImpl* cpp_type);
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) {
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<float>(name, out);
@ -67,23 +69,22 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_string, _In_ const OrtKernel
return onnxruntime::ToOrtStatus(status);
}
#if !defined(ORT_MINIMAL_BUILD)
#include "core/framework/customregistry.h"
namespace onnxruntime {
struct CustomOpKernel : OpKernel {
CustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) {
if (op_.version > ORT_API_VERSION)
if (op_.version > ORT_API_VERSION) {
ORT_THROW("Unsupported version '" + std::to_string(op_.version) + "' in custom op '" + op.GetName(&op));
op_kernel_ = op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version), reinterpret_cast<const OrtKernelInfo*>(&info));
}
op_kernel_ = op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version),
reinterpret_cast<const OrtKernelInfo*>(&info));
}
~CustomOpKernel() override { op_.KernelDestroy(op_kernel_); }
Status Compute(OpKernelContext* ctx) const override {
auto* ictx = static_cast<OpKernelContextInternal*>(ctx);
op_.KernelCompute(op_kernel_, reinterpret_cast<OrtKernelContext*>(ictx));
op_.KernelCompute(op_kernel_, reinterpret_cast<OrtKernelContext*>(ctx));
return Status::OK();
}
@ -94,12 +95,17 @@ struct CustomOpKernel : OpKernel {
void* op_kernel_;
};
common::Status CreateCustomRegistry(const std::vector<OrtCustomOpDomain*>& op_domains, std::shared_ptr<CustomRegistry>& output) {
common::Status CreateCustomRegistry(const std::vector<OrtCustomOpDomain*>& op_domains,
std::shared_ptr<CustomRegistry>& output) {
output = std::make_shared<CustomRegistry>();
for (auto& domain : op_domains) {
for (const auto& domain : op_domains) {
// Create an OpSchema for each op and register them
#if !defined(ORT_MINIMAL_BUILD)
// Domain is not empty - add it to the DomainToVersion ONNX map
// If domain is empty, it is assumed to be part of the ONNX domain
if (domain->domain_[0]) {
if (!domain->domain_.empty()) {
// Add it to the DomainToVersion ONNX map if it doesn't already exist
// For example, two sessions using the same session_options should not add the same custom op domain to the version map twice
auto& domain_to_version_range_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance();
@ -111,24 +117,25 @@ common::Status CreateCustomRegistry(const std::vector<OrtCustomOpDomain*>& op_do
}
std::vector<ONNX_NAMESPACE::OpSchema> schemas_list;
for (auto& op : domain->custom_ops_) {
ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "unknown", 0);
for (const auto* op : domain->custom_ops_) {
ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "custom op registered at runtime", 0);
auto input_count = op->GetInputTypeCount(op);
for (size_t i = 0; i < input_count; i++) {
auto type = op->GetInputType(op, i);
schema.Input(i, "A", "Description",
ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type ? "T" :
DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type)));
schema.Input(i, "Input" + std::to_string(i), "",
ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type
? "T"
: DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type)));
}
auto output_count = op->GetOutputTypeCount(op);
for (size_t i = 0; i < output_count; i++) {
auto type = op->GetOutputType(op, i);
schema.Output(i, "A", "Description",
ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type ? "T":
DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type)));
schema.Output(i, "Output" + std::to_string(i), "",
ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type
? "T"
: DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type)));
}
schema.TypeConstraint("T", DataTypeImpl::ToString(DataTypeImpl::AllTensorTypes()), "all types");
@ -136,30 +143,38 @@ common::Status CreateCustomRegistry(const std::vector<OrtCustomOpDomain*>& op_do
schema.SinceVersion(1);
schema.AllowUncheckedAttributes();
schemas_list.push_back(schema);
KernelDefBuilder def_builder;
def_builder.SetName(op->GetName(op))
.SetDomain(domain->domain_)
.SinceVersion(1)
.TypeConstraint("T", DataTypeImpl::AllTensorTypes());
if (const char* provider_type = op->GetExecutionProviderType(op))
def_builder.Provider(provider_type);
else
def_builder.Provider(onnxruntime::kCpuExecutionProvider);
KernelCreateFn kernel_create_fn = [&op](const OpKernelInfo& info) -> OpKernel* { return new CustomOpKernel(info, *op); };
KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn);
output->RegisterCustomKernel(create_info);
}
ORT_RETURN_IF_ERROR(output->RegisterOpSet(schemas_list,
domain->domain_,
1 /* baseline opset version */,
1000 /* opset version */));
#endif
// create the KernelDef for each op and register it
for (const auto* op : domain->custom_ops_) {
KernelDefBuilder def_builder;
def_builder.SetName(op->GetName(op))
.SetDomain(domain->domain_)
.SinceVersion(1)
.TypeConstraint("T", DataTypeImpl::AllTensorTypes());
if (const char* provider_type = op->GetExecutionProviderType(op)) {
def_builder.Provider(provider_type);
} else {
def_builder.Provider(onnxruntime::kCpuExecutionProvider);
}
KernelCreateFn kernel_create_fn = [op](const OpKernelInfo& info) -> OpKernel* {
return new CustomOpKernel(info, *op);
};
KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn);
ORT_RETURN_IF_ERROR(output->RegisterCustomKernel(create_info));
}
}
return Status::OK();
}
} // namespace onnxruntime
#endif // !defined(ORT_MINIMAL_BUILD)

View file

@ -14,6 +14,7 @@
#include "core/common/denormal.h"
#include "core/common/logging/logging.h"
#include "core/framework/allocatormgr.h"
#include "core/framework/customregistry.h"
#include "core/framework/error_code_helper.h"
#include "core/framework/execution_frame.h"
#include "core/framework/feeds_fetches_manager.h"
@ -43,6 +44,7 @@
#ifdef USE_DML // TODO: This is necessary for the workaround in TransformGraph
#include "core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h"
#endif
#include "core/session/custom_ops.h"
#include "core/session/environment.h"
#include "core/session/IOBinding.h"
#include "core/session/inference_session_utils.h"
@ -50,10 +52,6 @@
#include "core/util/protobuf_parsing_utils.h"
#include "core/util/thread_utils.h"
#if !defined(ORT_MINIMAL_BUILD)
#include "core/framework/customregistry.h"
#include "core/session/custom_ops.h"
#endif
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::experimental;
@ -468,6 +466,7 @@ common::Status InferenceSession::AddCustomTransformerList(const std::vector<std:
return Status::OK();
}
#endif
common::Status InferenceSession::AddCustomOpDomains(const std::vector<OrtCustomOpDomain*>& op_domains) {
std::shared_ptr<CustomRegistry> custom_registry;
@ -486,10 +485,13 @@ common::Status InferenceSession::RegisterCustomRegistry(std::shared_ptr<CustomRe
// Insert session-level customized kernel registry.
kernel_registry_manager_.RegisterKernelRegistry(custom_registry->GetKernelRegistry());
#if !defined(ORT_MINIMAL_BUILD)
custom_schema_registries_.push_back(custom_registry->GetOpschemaRegistry());
#endif
return Status::OK();
}
#if !defined(ORT_MINIMAL_BUILD)
common::Status InferenceSession::SaveToOrtFormat(const std::basic_string<ORTCHAR_T>& filepath) const {
ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ort format only supports little-edian machines");
@ -1016,7 +1018,15 @@ Status InferenceSession::LoadOrtModel(std::function<Status()> load_ort_format_mo
// need to go from unique_ptr to shared_ptr when moving into model_
std::unique_ptr<Model> tmp_model;
#if !defined(ORT_MINIMAL_BUILD)
ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr,
*session_logger_, tmp_model));
#else
ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model, *session_logger_, tmp_model));
#endif
ORT_RETURN_IF_ERROR(SaveModelMetadata(*tmp_model));
model_ = std::move(tmp_model);

View file

@ -184,6 +184,7 @@ class InferenceSession {
*/
common::Status AddCustomTransformerList(const std::vector<std::string>& transformers_to_enable) ORT_MUST_USE_RESULT;
#endif // !defined(ORT_MINIMAL_BUILD)
/**
* Add custom ops. This API is not thread safe.
*/
@ -200,8 +201,6 @@ class InferenceSession {
*/
common::Status RegisterCustomRegistry(std::shared_ptr<CustomRegistry> custom_registry) ORT_MUST_USE_RESULT;
#endif // !defined(ORT_MINIMAL_BUILD)
/**
* Load an ONNX or ORT format model.
*
@ -582,11 +581,11 @@ class InferenceSession {
#if !defined(ORT_MINIMAL_BUILD)
std::list<std::shared_ptr<onnxruntime::IOnnxRuntimeOpSchemaCollection>> custom_schema_registries_;
#endif
//CustomRegistry objects own the corresponding KernelRegistry and OnnxRuntimeOpSchemaRegistry objects.
//So its lifetime should be same as its constituents. This vector is to extend the lifetime of the owner.
std::vector<std::shared_ptr<CustomRegistry>> custom_registries_;
#endif
ModelMetadata model_metadata_;
std::unordered_set<std::string> required_inputs_;

View file

@ -439,13 +439,11 @@ static ORT_STATUS_PTR CreateSessionAndLoadModel(_In_ const OrtSessionOptions* op
env->GetEnvironment());
}
#if !defined(ORT_MINIMAL_BUILD)
// Add custom domains
Status status;
if (options && !options->custom_op_domains_.empty()) {
ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(options->custom_op_domains_));
}
#endif
// Finish load
if (load_config_from_model) {

View file

@ -4,6 +4,7 @@
// if we can't load an ORT format model we can't really test anything
#if defined(ENABLE_ORT_FORMAT_LOAD)
#include "core/common/make_unique.h"
#include "core/framework/data_types.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/onnx_protobuf.h"
@ -18,6 +19,8 @@
#include "flatbuffers/idl.h"
#include "flatbuffers/util.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "gtest/gtest.h"
using namespace std;
@ -237,16 +240,16 @@ static void DumpOrtModelAsJson(const std::string& model_uri) {
std::ofstream(model_uri + ".json") << json;
}
*/
/* The full build was causing the following error because the graph node array has some empyt (blank) node at some indices for certain ORT designs
onnx runtime exception : Satisfied, but should not be : node == nullptr
session_state.cc : 814 onnxruntime::SessionState::LoadFromOrtFormatCan't find node with index 4. Invalid ORT format model.
The bug was due to loading an ORT format model in a full build, allowing optimizers to run, but trying to use the saved kernel information.
As the optimizer removed some leaving gaps in the graph node vector.
The build has been fixed in InferenceSession code. The following test case to catch this error.
*/
/*
Validate we don't run optimizers on an ORT format model in a full build. The optimizers will remove nodes,
which will create a mismatch with the saved kernel information and result in a runtime error.
We could take steps to handle this scenario in a full build, but for consistency we choose to not run optimizers
on any ORT format model.
*/
TEST(OrtModelOnlyTests, ValidateOrtFormatModelDoesNotRunOptimizersInFullBuild) {
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("mnist.onnx.ort");
// we have tests that use a pre-created minst.onnx.ort, so make the naming for the unit test generated file clearer
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("testdata/mnist.onnx.test_output.ort");
SaveAndCompareModels("testdata/mnist.onnx", ort_file);
// DumpOrtModelAsJson(ToMBString(ort_file));
@ -257,8 +260,8 @@ TEST(OrtModelOnlyTests, ValidateOrtFormatModelDoesNotRunOptimizersInFullBuild) {
test_info.configs.push_back(std::make_pair(kOrtSessionOptionsConfigLoadModelFormat, "ORT"));
OrtValue ml_value;
vector<float> data(28*28, 0.0);
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1,1,28,28}, data,
vector<float> data(28 * 28, 0.0);
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1, 1, 28, 28}, data,
&ml_value);
test_info.inputs.insert(std::make_pair("Input3", ml_value));
@ -267,14 +270,13 @@ TEST(OrtModelOnlyTests, ValidateOrtFormatModelDoesNotRunOptimizersInFullBuild) {
test_info.output_verifier = [](const std::vector<OrtValue>& fetches) {
const auto& output = fetches[0].Get<Tensor>();
ASSERT_TRUE(output.Shape().NumDimensions() == 2);
// ASSERT_TRUE(output.Data<float>()[0] == 125.f);
};
RunOrtModel(test_info);
}
TEST(OrtModelOnlyTests, SerializeToOrtFormat) {
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("ort_github_issue_4031.onnx.ort");
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("testdata/ort_github_issue_4031.onnx.test_output.ort");
SaveAndCompareModels("testdata/ort_github_issue_4031.onnx", ort_file);
// DumpOrtModelAsJson(ToMBString(ort_file));
@ -301,7 +303,8 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormat) {
}
TEST(OrtModelOnlyTests, SparseInitializerHandling) {
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("sparse_initializer_handling.onnx.ort");
const std::basic_string<ORTCHAR_T> ort_file =
ORT_TSTR("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx.test_output.ort");
SaveAndCompareModels("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx", ort_file);
SessionOptions so;
@ -322,7 +325,8 @@ TEST(OrtModelOnlyTests, SparseInitializerHandling) {
#if !defined(DISABLE_ML_OPS)
TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) {
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("sklearn_bin_voting_classifier_soft_converted.ort");
const std::basic_string<ORTCHAR_T> ort_file =
ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft_converted.test_output.ort");
SaveAndCompareModels("testdata/sklearn_bin_voting_classifier_soft.onnx", ort_file);
OrtModelTestInfo test_info;
@ -361,6 +365,7 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) {
RunOrtModel(test_info);
}
#endif // #if !defined(DISABLE_ML_OPS)
#endif // #if !defined(ORT_MINIMAL_BUILD)

View file

@ -21,6 +21,7 @@
#include "providers.h"
#include "test_allocator.h"
#include "test_fixture.h"
#include "utils.h"
#ifdef _WIN32
#include <Windows.h>
@ -32,11 +33,9 @@
#include <cuda_runtime.h>
#endif
struct Input {
const char* name = nullptr;
std::vector<int64_t> dims;
std::vector<float> values;
};
// Once we use C++17 this could be replaced with std::size
template <typename T, size_t N>
constexpr size_t countof(T (&)[N]) { return N; }
extern std::unique_ptr<Ort::Env> ort_env;
@ -51,14 +50,18 @@ void RunSession(OrtAllocator* allocator, Ort::Session& session_object,
std::vector<const char*> input_names;
for (size_t i = 0; i < inputs.size(); i++) {
input_names.emplace_back(inputs[i].name);
ort_inputs.emplace_back(Ort::Value::CreateTensor<float>(allocator->Info(allocator), const_cast<float*>(inputs[i].values.data()), inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size()));
ort_inputs.emplace_back(
Ort::Value::CreateTensor<float>(allocator->Info(allocator), const_cast<float*>(inputs[i].values.data()),
inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size()));
}
std::vector<Ort::Value> ort_outputs;
if (output_tensor)
session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, output_tensor, 1);
session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
&output_name, output_tensor, 1);
else {
ort_outputs = session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, 1);
ort_outputs = session_object.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
&output_name, 1);
ASSERT_EQ(ort_outputs.size(), 1u);
output_tensor = &ort_outputs[0];
}
@ -74,17 +77,17 @@ void RunSession(OrtAllocator* allocator, Ort::Session& session_object,
}
}
template <typename T, typename OutT>
void TestInference(Ort::Env& env, T model_uri,
const std::vector<Input>& inputs,
const char* output_name,
const std::vector<int64_t>& expected_dims_y,
const std::vector<OutT>& expected_values_y,
int provider_type,
OrtCustomOpDomain* custom_op_domain_ptr,
const char* custom_op_library_filename,
void** library_handle = nullptr,
bool test_session_creation_only = false) {
template <typename OutT>
static void TestInference(Ort::Env& env, const std::basic_string<ORTCHAR_T>& model_uri,
const std::vector<Input>& inputs,
const char* output_name,
const std::vector<int64_t>& expected_dims_y,
const std::vector<OutT>& expected_values_y,
int provider_type,
OrtCustomOpDomain* custom_op_domain_ptr,
const char* custom_op_library_filename,
void** library_handle = nullptr,
bool test_session_creation_only = false) {
Ort::SessionOptions session_options;
if (provider_type == 1) {
@ -103,7 +106,8 @@ void TestInference(Ort::Env& env, T model_uri,
#endif
} else if (provider_type == 3) {
#ifdef USE_NUPHAR
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nuphar(session_options, /*allow_unaligned_buffers*/ 1, ""));
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nuphar(session_options,
/*allow_unaligned_buffers*/ 1, ""));
std::cout << "Running simple inference with nuphar provider" << std::endl;
#else
return;
@ -116,11 +120,12 @@ void TestInference(Ort::Env& env, T model_uri,
}
if (custom_op_library_filename) {
Ort::ThrowOnError(Ort::GetApi().RegisterCustomOpsLibrary((OrtSessionOptions*)session_options, custom_op_library_filename, library_handle));
Ort::ThrowOnError(Ort::GetApi().RegisterCustomOpsLibrary(session_options,
custom_op_library_filename, library_handle));
}
// if session creation passes, model loads fine
Ort::Session session(env, model_uri, session_options);
Ort::Session session(env, model_uri.c_str(), session_options);
// caller wants to test running the model (not just loading the model)
if (!test_session_creation_only) {
@ -136,7 +141,8 @@ void TestInference(Ort::Env& env, T model_uri,
expected_values_y,
nullptr);
//with preallocated output tensor
Ort::Value value_y = Ort::Value::CreateTensor<float>(default_allocator.get(), expected_dims_y.data(), expected_dims_y.size());
Ort::Value value_y = Ort::Value::CreateTensor<float>(default_allocator.get(),
expected_dims_y.data(), expected_dims_y.size());
//test it twice
for (int i = 0; i != 2; ++i)
@ -181,7 +187,8 @@ TEST_P(CApiTestWithProvider, simple) {
std::vector<int64_t> expected_dims_y = {3, 2};
std::vector<float> expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f};
TestInference<PATH_TYPE, float>(*ort_env, MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, GetParam(), nullptr, nullptr);
TestInference<float>(*ort_env, MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, GetParam(),
nullptr, nullptr);
}
TEST(CApiTest, dim_param) {
@ -216,70 +223,6 @@ INSTANTIATE_TEST_SUITE_P(CApiTestWithProviders,
CApiTestWithProvider,
::testing::Values(0, 1, 2, 3, 4));
struct OrtTensorDimensions : std::vector<int64_t> {
OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) {
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
std::vector<int64_t>::operator=(ort.GetTensorShape(info));
ort.ReleaseTensorTypeAndShapeInfo(info);
}
};
// Once we use C++17 this could be replaced with std::size
template <typename T, size_t N>
constexpr size_t countof(T (&)[N]) { return N; }
void cuda_add(int64_t, float*, const float*, const float*);
struct MyCustomKernel {
MyCustomKernel(Ort::CustomOpApi ort, const OrtKernelInfo* /*info*/) : ort_(ort) {
}
void Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
const float* X = ort_.GetTensorData<float>(input_X);
const float* Y = ort_.GetTensorData<float>(input_Y);
// Setup output
OrtTensorDimensions dimensions(ort_, input_X);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
float* out = ort_.GetTensorMutableData<float>(output);
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
// Do computation
#ifdef USE_CUDA
cuda_add(size, out, X, Y);
#else
for (int64_t i = 0; i < size; i++) {
out[i] = X[i] + Y[i];
}
#endif
}
private:
Ort::CustomOpApi ort_;
};
struct MyCustomOp : Ort::CustomOpBase<MyCustomOp, MyCustomKernel> {
explicit MyCustomOp(const char* provider) : provider_(provider) {}
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { return new MyCustomKernel(api, info); };
const char* GetName() const { return "Foo"; };
const char* GetExecutionProviderType() const { return provider_; };
size_t GetInputTypeCount() const { return 2; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
private:
const char* provider_;
};
TEST(CApiTest, custom_op_handler) {
std::cout << "Running custom op inference" << std::endl;
@ -303,9 +246,11 @@ TEST(CApiTest, custom_op_handler) {
custom_op_domain.Add(&custom_op);
#ifdef USE_CUDA
TestInference<PATH_TYPE, float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 1, custom_op_domain, nullptr, nullptr);
TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 1,
custom_op_domain, nullptr, nullptr);
#else
TestInference<PATH_TYPE, float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0, custom_op_domain, nullptr);
TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0,
custom_op_domain, nullptr);
#endif
}
@ -368,9 +313,11 @@ struct SliceCustomOpKernel {
struct SliceCustomOp : Ort::CustomOpBase<SliceCustomOp, SliceCustomOpKernel> {
explicit SliceCustomOp(const char* provider) : provider_(provider) {}
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { return new SliceCustomOpKernel(api, info); };
const char* GetName() const { return "Slice"; };
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const {
return new SliceCustomOpKernel(api, info);
};
const char* GetName() const { return "Slice"; };
const char* GetExecutionProviderType() const { return provider_; };
size_t GetInputTypeCount() const { return 3; };
@ -427,9 +374,11 @@ TEST(CApiTest, varied_input_custom_op_handler) {
custom_op_domain.Add(&slice_custom_op);
#ifdef USE_CUDA
TestInference<PATH_TYPE, float>(*ort_env, VARIED_INPUT_CUSTOM_OP_MODEL_URI, inputs, "Z", expected_dims_z, expected_values_z, 1, custom_op_domain, nullptr, nullptr);
TestInference<float>(*ort_env, VARIED_INPUT_CUSTOM_OP_MODEL_URI, inputs, "Z",
expected_dims_z, expected_values_z, 1, custom_op_domain, nullptr, nullptr);
#else
TestInference<PATH_TYPE, float>(*ort_env, VARIED_INPUT_CUSTOM_OP_MODEL_URI, inputs, "Z", expected_dims_z, expected_values_z, 0, custom_op_domain, nullptr);
TestInference<float>(*ort_env, VARIED_INPUT_CUSTOM_OP_MODEL_URI, inputs, "Z",
expected_dims_z, expected_values_z, 0, custom_op_domain, nullptr);
#endif
}
@ -454,8 +403,8 @@ TEST(CApiTest, RegisterCustomOpForCPUAndCUDA) {
custom_op_domain.Add(&custom_op_cpu);
custom_op_domain.Add(&custom_op_cuda);
TestInference<PATH_TYPE, float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y,
expected_values_y, 1, custom_op_domain, nullptr, nullptr, true);
TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y,
expected_values_y, 1, custom_op_domain, nullptr, nullptr, true);
}
#endif
@ -496,8 +445,8 @@ lib_name = "./libcustom_op_library.so";
#endif
void* library_handle = nullptr;
TestInference<PATH_TYPE, int32_t>(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y,
expected_values_y, 0, nullptr, lib_name.c_str(), &library_handle);
TestInference<int32_t>(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y,
expected_values_y, 0, nullptr, lib_name.c_str(), &library_handle);
#ifdef _WIN32
bool success = ::FreeLibrary(reinterpret_cast<HMODULE>(library_handle));
@ -552,7 +501,8 @@ TEST(CApiTest, test_pyop) {
input.values = {1.0f, 2.0f, 3.0f, 4.0f};
std::vector<int64_t> expected_dims_y = {2, 2};
std::vector<float> expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f};
TestInference<PATH_TYPE, float>(*ort_env, PYOP_FLOAT_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0, nullptr, nullptr);
TestInference<float>(*ort_env, PYOP_FLOAT_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0,
nullptr, nullptr);
}
TEST(CApiTest, test_pyop_multi) {
@ -564,7 +514,8 @@ TEST(CApiTest, test_pyop_multi) {
input.values = {1.0f, 2.0f, 3.0f, 4.0f};
std::vector<int64_t> expected_dims_y = {2, 2};
std::vector<float> expected_values_y = {8.0f, 16.0f, 24.0f, 32.0f};
TestInference<PATH_TYPE, float>(*ort_env, PYOP_MULTI_MODEL_URI, inputs, "Z", expected_dims_y, expected_values_y, 0, nullptr, nullptr);
TestInference<float>(*ort_env, PYOP_MULTI_MODEL_URI, inputs, "Z", expected_dims_y, expected_values_y, 0,
nullptr, nullptr);
}
TEST(CApiTest, test_pyop_kwarg) {
@ -576,7 +527,8 @@ TEST(CApiTest, test_pyop_kwarg) {
input.values = {1.0f, 2.0f, 3.0f, 4.0f};
std::vector<int64_t> expected_dims_y = {2, 2};
std::vector<float> expected_values_y = {25.0f, 50.0f, 75.0f, 100.0f};
TestInference<PATH_TYPE, float>(*ort_env, PYOP_KWARG_MODEL_URI, inputs, "Z", expected_dims_y, expected_values_y, 0, nullptr, nullptr);
TestInference<float>(*ort_env, PYOP_KWARG_MODEL_URI, inputs, "Z", expected_dims_y, expected_values_y, 0,
nullptr, nullptr);
}
#endif
@ -596,7 +548,8 @@ TEST(ReducedOpsBuildTest, test_included_ops) {
std::vector<Input> inputs = {{"X", {3}, {-1.0f, 2.0f, -3.0f}}};
std::vector<int64_t> expected_dims_y = {1};
std::vector<float> expected_values_y = {0.75};
TestInference<PATH_TYPE, float>(*ort_env, model_uri, inputs, "Y", expected_dims_y, expected_values_y, 0, nullptr, nullptr);
TestInference<float>(*ort_env, model_uri, inputs, "Y", expected_dims_y, expected_values_y, 0,
nullptr, nullptr);
}
TEST(ReducedOpsBuildTest, test_excluded_ops) {
@ -609,7 +562,8 @@ TEST(ReducedOpsBuildTest, test_excluded_ops) {
bool failed = false;
try {
//only test model loading, exception expected
TestInference<PATH_TYPE, float>(*ort_env, model_uri, inputs, "Y", expected_dims_y, expected_values_y, 0, nullptr, nullptr, nullptr, true);
TestInference<float>(*ort_env, model_uri, inputs, "Y", expected_dims_y, expected_values_y, 0,
nullptr, nullptr, nullptr, true);
} catch (const Ort::Exception& e) {
failed = e.GetOrtErrorCode() == ORT_NOT_IMPLEMENTED;
}
@ -782,8 +736,8 @@ TEST(CApiTest, io_binding_cuda) {
ASSERT_NE(output_data.get(), nullptr);
// Create an OrtValue tensor backed by data on CUDA memory
Ort::Value bound_y = Ort::Value::CreateTensor(info_cuda, reinterpret_cast<float*>(output_data.get()), expected_y.size(),
expected_y_shape.data(), expected_y_shape.size());
Ort::Value bound_y = Ort::Value::CreateTensor(info_cuda, reinterpret_cast<float*>(output_data.get()),
expected_y.size(), expected_y_shape.data(), expected_y_shape.size());
Ort::IoBinding binding(session);
binding.BindInput("X", bound_x);
@ -854,7 +808,8 @@ TEST(CApiTest, create_tensor) {
int64_t expected_len = 2;
auto default_allocator = onnxruntime::make_unique<MockedOrtAllocator>();
Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1,
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
Ort::ThrowOnError(Ort::GetApi().FillStringTensor(tensor, s, expected_len));
auto shape_info = tensor.GetTensorTypeAndShapeInfo();
@ -874,7 +829,8 @@ TEST(CApiTest, fill_string_tensor) {
int64_t expected_len = 2;
auto default_allocator = onnxruntime::make_unique<MockedOrtAllocator>();
Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1,
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
for (int64_t i = 0; i < expected_len; i++) {
tensor.FillStringTensorElement(s[i], i);
@ -892,7 +848,8 @@ TEST(CApiTest, get_string_tensor_element) {
int64_t element_index = 0;
auto default_allocator = onnxruntime::make_unique<MockedOrtAllocator>();
Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1,
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
tensor.FillStringTensor(s, expected_len);
@ -1007,7 +964,8 @@ TEST(CApiTest, override_initializer) {
std::string f2_data{"f2_string"};
// Place a string into Tensor OrtValue and assign to the
Ort::Value f2_input_tensor = Ort::Value::CreateTensor(allocator.get(), dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
Ort::Value f2_input_tensor = Ort::Value::CreateTensor(allocator.get(), dims.data(), dims.size(),
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
const char* const input_char_string[] = {f2_data.c_str()};
f2_input_tensor.FillStringTensor(input_char_string, 1U);
@ -1138,7 +1096,8 @@ TEST(CApiTest, model_metadata) {
ASSERT_TRUE(version == 1);
int64_t num_keys_in_custom_metadata_map;
char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(), num_keys_in_custom_metadata_map);
char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(),
num_keys_in_custom_metadata_map);
ASSERT_TRUE(num_keys_in_custom_metadata_map == 2);
allocator.get()->Free(custom_metadata_map_keys[0]);
@ -1147,7 +1106,8 @@ TEST(CApiTest, model_metadata) {
char* lookup_value_1 = model_metadata.LookupCustomMetadataMap("ort_config", allocator.get());
ASSERT_TRUE(strcmp(lookup_value_1,
"{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}") == 0);
"{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, "
"\"graph_optimization_level\": 99, \"enable_profiling\": 1}}") == 0);
allocator.get()->Free(lookup_value_1);
char* lookup_value_2 = model_metadata.LookupCustomMetadataMap("dummy_key", allocator.get());
@ -1180,7 +1140,8 @@ TEST(CApiTest, model_metadata) {
// Model does not contain custom metadata map
int64_t num_keys_in_custom_metadata_map;
char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(), num_keys_in_custom_metadata_map);
char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(),
num_keys_in_custom_metadata_map);
ASSERT_TRUE(num_keys_in_custom_metadata_map == 0);
ASSERT_TRUE(custom_metadata_map_keys == nullptr);
}
@ -1235,9 +1196,9 @@ TEST(CApiTest, TestSharedAllocatorUsingCreateAndRegisterAllocator) {
ASSERT_TRUE(api.CreateAndRegisterAllocator(env_ptr, mem_info, arena_cfg) == nullptr);
// test for duplicates
std::unique_ptr<OrtStatus, decltype(api.ReleaseStatus)> status_releaser(api.CreateAndRegisterAllocator(env_ptr, mem_info,
arena_cfg),
api.ReleaseStatus);
std::unique_ptr<OrtStatus, decltype(api.ReleaseStatus)> status_releaser(
api.CreateAndRegisterAllocator(env_ptr, mem_info, arena_cfg),
api.ReleaseStatus);
ASSERT_FALSE(status_releaser.get() == nullptr);
Ort::SessionOptions session_options;

View file

@ -0,0 +1,134 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// if we can't load an ORT format model we can't really test anything
#if defined(ENABLE_ORT_FORMAT_LOAD)
#include "core/common/make_unique.h"
#include "core/graph/constants.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "test_allocator.h"
#include "utils.h"
#include "gtest/gtest.h"
extern std::unique_ptr<Ort::Env> ort_env;
static void TestInference(Ort::Env& env, const std::basic_string<ORTCHAR_T>& model_uri,
const std::vector<Input>& inputs, const char* output_name,
const std::vector<int64_t>& expected_dims_y, const std::vector<float>& expected_values_y,
Ort::CustomOpDomain& custom_op_domain) {
Ort::SessionOptions session_options;
session_options.Add(custom_op_domain);
#ifdef USE_CUDA
OrtCUDAProviderOptions cuda_options{};
session_options.AppendExecutionProvider_CUDA(cuda_options);
#endif
Ort::Session session(env, model_uri.c_str(), session_options);
MockedOrtAllocator allocator;
std::vector<Ort::Value> ort_inputs;
std::vector<const char*> input_names;
for (size_t i = 0; i < inputs.size(); i++) {
// we put the data in a Tensor, and Tensor has a method to get a mutable pointer to the data.
// we never call that for an input, but need to do the const_cast to make this potential explicit
float* input_data = const_cast<float*>(inputs[i].values.data());
auto input_tensor = Ort::Value::CreateTensor<float>(allocator.Info(), input_data, inputs[i].values.size(),
inputs[i].dims.data(), inputs[i].dims.size());
input_names.push_back(inputs[i].name);
ort_inputs.push_back(std::move(input_tensor));
}
std::vector<Ort::Value> ort_outputs;
ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
&output_name, 1);
ASSERT_EQ(ort_outputs.size(), size_t(1));
const auto& output_tensor = &ort_outputs[0];
auto type_info = output_tensor->GetTensorTypeAndShapeInfo();
ASSERT_EQ(type_info.GetShape(), expected_dims_y);
size_t total_len = type_info.GetElementCount();
ASSERT_EQ(expected_values_y.size(), total_len);
const auto* f = output_tensor->GetTensorMutableData<float>();
for (size_t i = 0; i < total_len; ++i) {
ASSERT_EQ(expected_values_y[i], f[i]);
}
}
#if !defined(ORT_MINIMAL_BUILD)
TEST(OrtFormatCustomOpTests, ConvertOnnxModelToOrt) {
const std::basic_string<ORTCHAR_T> onnx_file = ORT_TSTR("testdata/foo_1.onnx");
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("testdata/foo_1.onnx.test_output.ort");
#ifdef USE_CUDA
MyCustomOp custom_op{onnxruntime::kCudaExecutionProvider};
#else
MyCustomOp custom_op{onnxruntime::kCpuExecutionProvider};
#endif
Ort::CustomOpDomain custom_op_domain("");
custom_op_domain.Add(&custom_op);
// convert to ort by loading the onnx model
{
Ort::SessionOptions so;
so.Add(custom_op_domain);
so.SetLogId("CustomOp");
so.SetOptimizedModelFilePath(ort_file.c_str());
#ifdef USE_CUDA
OrtCUDAProviderOptions cuda_options{};
so.AppendExecutionProvider_CUDA(cuda_options);
#endif
Ort::Session session(*ort_env, onnx_file.c_str(), so);
}
// now load the ORT format model and execute it
std::vector<Input> inputs(1);
Input& input = inputs[0];
input.name = "X";
input.dims = {3, 2};
input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
// model adds 1, 2, 3, 4, 5, 6 to the input values
std::vector<int64_t> expected_dims_y = {3, 2};
std::vector<float> expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f};
TestInference(*ort_env, ort_file, inputs, "Y", expected_dims_y, expected_values_y, custom_op_domain);
}
#endif // if !defined(ORT_MINIMAL_BUILD)
// the saved ORT format model has the CPU EP assigned to the custom op node, so we only test if we're not using the
// CUDA EP for the test.
#ifndef USE_CUDA
TEST(OrtFormatCustomOpTests, LoadOrtModel) {
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("testdata/foo_1.onnx.ort");
MyCustomOp custom_op{onnxruntime::kCpuExecutionProvider};
Ort::CustomOpDomain custom_op_domain("");
custom_op_domain.Add(&custom_op);
// load the ORT format model and execute it
std::vector<Input> inputs(1);
Input& input = inputs[0];
input.name = "X";
input.dims = {3, 2};
input.values = {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f};
// model adds 1, 2, 3, 4, 5, 6 to the input values
std::vector<int64_t> expected_dims_y = {3, 2};
std::vector<float> expected_values_y = {7.0f, 7.0f, 7.0f, 7.0f, 7.0f, 7.0f};
TestInference(*ort_env, ort_file, inputs, "Y", expected_dims_y, expected_values_y, custom_op_domain);
}
#endif
#endif // #if defined(ENABLE_ORT_FORMAT_LOAD)

View file

@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "utils.h"
#ifdef USE_CUDA
#include <cuda_runtime.h>
void cuda_add(int64_t, float*, const float*, const float*);
#endif
void MyCustomKernel::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
const float* X = ort_.GetTensorData<float>(input_X);
const float* Y = ort_.GetTensorData<float>(input_Y);
// Setup output
OrtTensorDimensions dimensions(ort_, input_X);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
float* out = ort_.GetTensorMutableData<float>(output);
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
// Do computation
#ifdef USE_CUDA
cuda_add(size, out, X, Y);
#else
for (int64_t i = 0; i < size; i++) {
out[i] = X[i] + Y[i];
}
#endif
}

View file

@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/session/onnxruntime_cxx_api.h"
struct Input {
const char* name = nullptr;
std::vector<int64_t> dims;
std::vector<float> values;
};
struct OrtTensorDimensions : std::vector<int64_t> {
OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) {
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
std::vector<int64_t>::operator=(ort.GetTensorShape(info));
ort.ReleaseTensorTypeAndShapeInfo(info);
}
};
struct MyCustomKernel {
MyCustomKernel(Ort::CustomOpApi ort, const OrtKernelInfo* /*info*/) : ort_(ort) {
}
void Compute(OrtKernelContext* context);
private:
Ort::CustomOpApi ort_;
};
struct MyCustomOp : Ort::CustomOpBase<MyCustomOp, MyCustomKernel> {
explicit MyCustomOp(const char* provider) : provider_(provider) {}
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { return new MyCustomKernel(api, info); };
const char* GetName() const { return "Foo"; };
const char* GetExecutionProviderType() const { return provider_; };
size_t GetInputTypeCount() const { return 2; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
private:
const char* provider_;
};

BIN
onnxruntime/test/testdata/foo_1.onnx.ort vendored Normal file

Binary file not shown.