From e1dc268e45f1f37a7b46d37e47d65416832b1db7 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Mon, 25 Jan 2021 10:41:00 +1000 Subject: [PATCH] Add support for custom ops to minimal build. (#6228) * Add support for custom ops to minimal build. Cost is only ~8KB so including in base minimal build. --- cmake/onnxruntime_framework.cmake | 2 - cmake/onnxruntime_unittests.cmake | 5 +- .../core/framework/customregistry.h | 29 ++- include/onnxruntime/core/graph/graph.h | 9 +- onnxruntime/core/framework/customregistry.cc | 11 +- .../core/framework/execution_provider.cc | 9 +- .../core/framework/graph_partitioner.cc | 4 + .../core/framework/kernel_registry_manager.cc | 4 - .../core/framework/kernel_registry_manager.h | 2 +- onnxruntime/core/graph/graph.cc | 31 ++- onnxruntime/core/graph/model.cc | 21 +- onnxruntime/core/graph/model.h | 3 + onnxruntime/core/session/custom_ops.cc | 93 +++++---- onnxruntime/core/session/inference_session.cc | 18 +- onnxruntime/core/session/inference_session.h | 5 +- onnxruntime/core/session/onnxruntime_c_api.cc | 2 - .../test/framework/ort_model_only_test.cc | 33 +-- onnxruntime/test/shared_lib/test_inference.cc | 191 +++++++----------- .../test/shared_lib/test_ort_format_models.cc | 134 ++++++++++++ onnxruntime/test/shared_lib/utils.cc | 35 ++++ onnxruntime/test/shared_lib/utils.h | 45 +++++ onnxruntime/test/testdata/foo_1.onnx.ort | Bin 0 -> 1312 bytes 22 files changed, 470 insertions(+), 216 deletions(-) create mode 100644 onnxruntime/test/shared_lib/test_ort_format_models.cc create mode 100644 onnxruntime/test/shared_lib/utils.cc create mode 100644 onnxruntime/test/shared_lib/utils.h create mode 100644 onnxruntime/test/testdata/foo_1.onnx.ort diff --git a/cmake/onnxruntime_framework.cmake b/cmake/onnxruntime_framework.cmake index bca3d0f384..06542f7437 100644 --- a/cmake/onnxruntime_framework.cmake +++ b/cmake/onnxruntime_framework.cmake @@ -10,8 +10,6 @@ file(GLOB_RECURSE onnxruntime_framework_srcs CONFIGURE_DEPENDS if (onnxruntime_MINIMAL_BUILD) set(onnxruntime_framework_src_exclude "${ONNXRUNTIME_ROOT}/core/framework/provider_bridge_ort.cc" - "${ONNXRUNTIME_INCLUDE_DIR}/core/framework/customregistry.h" - "${ONNXRUNTIME_ROOT}/core/framework/customregistry.cc" "${ONNXRUNTIME_ROOT}/core/framework/fallback_cpu_capability.h" "${ONNXRUNTIME_ROOT}/core/framework/fallback_cpu_capability.cc" ) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index b06047991c..7d04c84bef 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -322,7 +322,10 @@ set (onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_run_options.cc ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_allocator.cc ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_nontensor_types.cc - ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_loading.cc) + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_loading.cc + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_ort_format_models.cc + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/utils.h + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/utils.cc) if (NOT onnxruntime_MINIMAL_BUILD) list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_inference.cc) diff --git a/include/onnxruntime/core/framework/customregistry.h b/include/onnxruntime/core/framework/customregistry.h index aafe5a5df8..52f6169e2e 100644 --- a/include/onnxruntime/core/framework/customregistry.h +++ b/include/onnxruntime/core/framework/customregistry.h @@ -5,11 +5,14 @@ #include "core/common/status.h" #include "core/common/logging/logging.h" -#include "core/graph/schema_registry.h" #include "core/framework/op_kernel.h" #include "core/framework/kernel_def_builder.h" #include "core/framework/kernel_registry.h" +#if !defined(ORT_MINIMAL_BUILD) +#include "core/graph/schema_registry.h" +#endif + namespace onnxruntime { /** @@ -17,9 +20,14 @@ namespace onnxruntime { */ class CustomRegistry final { public: - CustomRegistry() : - kernel_registry_(std::make_shared()), - opschema_registry_(std::make_shared()) {} + CustomRegistry() + : kernel_registry_(std::make_shared()) +#if !defined(ORT_MINIMAL_BUILD) + , + opschema_registry_(std::make_shared()) +#endif + { + } /** * Register a kernel definition together with kernel factory method to this session. @@ -32,18 +40,21 @@ class CustomRegistry final { common::Status RegisterCustomKernel(KernelCreateInfo&); - common::Status RegisterOpSet(std::vector& schemas, const std::string& domain, - int baseline_opset_version, int opset_version); - const std::shared_ptr& GetKernelRegistry(); +#if !defined(ORT_MINIMAL_BUILD) + common::Status RegisterOpSet(std::vector& schemas, const std::string& domain, + int baseline_opset_version, int opset_version); + const std::shared_ptr& GetOpschemaRegistry(); +#endif private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomRegistry); std::shared_ptr kernel_registry_; - std::shared_ptr opschema_registry_; - +#if !defined(ORT_MINIMAL_BUILD) + std::shared_ptr opschema_registry_; +#endif }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index d1841ff43b..8ce8e4ff2f 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -654,7 +654,8 @@ class Graph { /** Return true if "node_arg" is a input or an initializer. Otherwise, returns false. */ bool IsInputsIncludingInitializers(const NodeArg* node_arg) const noexcept { - return std::find(graph_inputs_including_initializers_.begin(), graph_inputs_including_initializers_.end(), node_arg) != graph_inputs_including_initializers_.end(); + return std::find(graph_inputs_including_initializers_.begin(), + graph_inputs_including_initializers_.end(), node_arg) != graph_inputs_including_initializers_.end(); } /** Gets the Graph inputs that are initializers @@ -1085,6 +1086,9 @@ class Graph { static common::Status LoadFromOrtFormat( const onnxruntime::experimental::fbs::Graph& fbs_graph, const Model& owning_model, const std::unordered_map& domain_to_version, +#if !defined(ORT_MINIMAL_BUILD) + IOnnxRuntimeOpSchemaCollectionPtr schema_registry, +#endif const logging::Logger& logger, std::unique_ptr& graph); // deserialize a subgraph @@ -1104,6 +1108,9 @@ class Graph { // Create empty Graph instance to re-create from ORT format serialized data. Graph(const Model& owning_model, const std::unordered_map& domain_to_version, +#if !defined(ORT_MINIMAL_BUILD) + IOnnxRuntimeOpSchemaCollectionPtr schema_registry, +#endif Graph* parent_graph, const Node* parent_node, const logging::Logger& logger); diff --git a/onnxruntime/core/framework/customregistry.cc b/onnxruntime/core/framework/customregistry.cc index a70f8bbfd2..601ac6dc74 100644 --- a/onnxruntime/core/framework/customregistry.cc +++ b/onnxruntime/core/framework/customregistry.cc @@ -12,21 +12,22 @@ common::Status CustomRegistry::RegisterCustomKernel(KernelCreateInfo& create_inf return kernel_registry_->Register(std::move(create_info)); } +const std::shared_ptr& CustomRegistry::GetKernelRegistry() { + return kernel_registry_; +} + +#if !defined(ORT_MINIMAL_BUILD) common::Status CustomRegistry::RegisterOpSet( std::vector& schemas, const std::string& domain, int baseline_opset_version, int opset_version) { - return opschema_registry_->RegisterOpSet(schemas, domain, baseline_opset_version, opset_version); } -const std::shared_ptr& CustomRegistry::GetKernelRegistry() { - return kernel_registry_; -} - const std::shared_ptr& CustomRegistry::GetOpschemaRegistry() { return opschema_registry_; } +#endif } // namespace onnxruntime diff --git a/onnxruntime/core/framework/execution_provider.cc b/onnxruntime/core/framework/execution_provider.cc index 27ab81bdae..094f6756de 100644 --- a/onnxruntime/core/framework/execution_provider.cc +++ b/onnxruntime/core/framework/execution_provider.cc @@ -77,19 +77,22 @@ void IExecutionProvider::InsertAllocator(AllocatorPtr allocator) { #if !defined(ORT_MINIMAL_BUILD) common::Status IExecutionProvider::Compile(const std::vector& /*fused_node*/, std::vector& /*node_compute_funcs*/) { - return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED); + return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, + "IExecutionProvider::Compile with fused Node is not implemented by " + type_); } common::Status IExecutionProvider::Compile(const std::vector& /*fused_node*/, std::string& /*dll_path*/) { - return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED); + return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, + "IExecutionProvider::Compile with fused Node and dll path is not implemented by " + type_); } #endif #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) common::Status IExecutionProvider::Compile(const std::vector& /*fused_nodes_and_graphs*/, std::vector& /*node_compute_funcs*/) { - return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED); + return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, + "IExecutionProvider::Compile with FusedNodeAndGraph is not implemented by " + type_); } #endif diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 04f5979aa4..9abb3cb6f3 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -417,6 +417,10 @@ static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr, std::vector> capabilities = current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(type)); + if (capabilities.empty()) { + return Status::OK(); + } + // storage for the GraphViewer for each IndexedSubGraph std::vector> viewers; viewers.reserve(capabilities.size()); diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index d4c27112e3..3f028323ef 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -45,14 +45,12 @@ Status KernelRegistryManager::RegisterKernels(const ExecutionProviders& executio return Status::OK(); } -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) void KernelRegistryManager::RegisterKernelRegistry(std::shared_ptr kernel_registry) { if (nullptr == kernel_registry) { return; } custom_kernel_registries_.push_front(kernel_registry); } -#endif #if !defined(ORT_MINIMAL_BUILD) bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type) { @@ -86,14 +84,12 @@ Status KernelRegistryManager::SearchKernelRegistry(const onnxruntime::Node& node return Status(ONNXRUNTIME, FAIL, create_error_message("The node is not placed on any Execution Provider. ")); } -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) for (auto& registry : custom_kernel_registries_) { status = registry->TryFindKernel(node, std::string(), kernel_def_hash, kernel_create_info); if (status.IsOK()) { return status; } } -#endif KernelRegistry* p = nullptr; auto iter = provider_type_to_registry_.find(ptype); diff --git a/onnxruntime/core/framework/kernel_registry_manager.h b/onnxruntime/core/framework/kernel_registry_manager.h index 62d91e4ab7..b7f1b1e7f2 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.h +++ b/onnxruntime/core/framework/kernel_registry_manager.h @@ -32,7 +32,6 @@ class KernelRegistryManager { // Register kernels from providers Status RegisterKernels(const ExecutionProviders& execution_providers) ORT_MUST_USE_RESULT; -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // The registry passed in this function has highest priority than anything already in this KernelRegistryManager, // and anything registered from RegisterKernels // For example, if you do: @@ -42,6 +41,7 @@ class KernelRegistryManager { // Then B > A > providers void RegisterKernelRegistry(std::shared_ptr kernel_registry); +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) /** * Search kernel registry by provider type. * @param type provider type string diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 1f3de24a89..038aaa41e7 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3619,13 +3619,19 @@ std::ostream& operator<<(std::ostream& out, const Graph& graph) { #endif // !defined(ORT_MINIMAL_BUILD) #if defined(ENABLE_ORT_FORMAT_LOAD) -Status Graph::LoadFromOrtFormat( - const onnxruntime::experimental::fbs::Graph& fbs_graph, - const Model& owning_model, - const std::unordered_map& domain_to_version, - const logging::Logger& logger, std::unique_ptr& graph) { +Status Graph::LoadFromOrtFormat(const onnxruntime::experimental::fbs::Graph& fbs_graph, + const Model& owning_model, + const std::unordered_map& domain_to_version, +#if !defined(ORT_MINIMAL_BUILD) + IOnnxRuntimeOpSchemaCollectionPtr schema_registry, +#endif + const logging::Logger& logger, std::unique_ptr& graph) { // can't use make_unique as we're calling a private ctor - graph.reset(new Graph(owning_model, domain_to_version, nullptr, nullptr, logger)); + graph.reset(new Graph(owning_model, domain_to_version, +#if !defined(ORT_MINIMAL_BUILD) + schema_registry, +#endif + nullptr, nullptr, logger)); ORT_RETURN_IF_ERROR(graph->LoadFromOrtFormat(fbs_graph)); @@ -3636,8 +3642,6 @@ Status Graph::LoadFromOrtFormat( // and in InferenceSession::Initialize skip partitioning and running optimizers. graph->SetGraphResolveNeeded(); ORT_RETURN_IF_ERROR(graph->Resolve()); -#else - // probably nothing required here. validate with model that has nested subgraphs. #endif return Status::OK(); @@ -3648,7 +3652,11 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::experimental::fbs::Graph& fbs const logging::Logger& logger, std::unique_ptr& graph) { // can't use make_unique as we're calling a private ctor graph.reset(new Graph(parent_graph.owning_model_, - parent_graph.domain_to_version_, &parent_graph, &parent_node, + parent_graph.domain_to_version_, +#if !defined(ORT_MINIMAL_BUILD) + parent_graph.schema_registry_, +#endif + &parent_graph, &parent_node, logger)); return graph->LoadFromOrtFormat(fbs_graph); @@ -3656,12 +3664,15 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::experimental::fbs::Graph& fbs Graph::Graph(const Model& owning_model, const std::unordered_map& domain_to_version, +#if !defined(ORT_MINIMAL_BUILD) + IOnnxRuntimeOpSchemaCollectionPtr schema_registry, +#endif Graph* parent_graph, const Node* parent_node, const logging::Logger& logger) : owning_model_(owning_model), graph_proto_(&deserialized_proto_data_), #if !defined(ORT_MINIMAL_BUILD) - schema_registry_(std::make_shared()), + schema_registry_(schema_registry), #endif domain_to_version_(domain_to_version), ir_version_(owning_model.IrVersion()), diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index a5ebe136f0..7666d8390c 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -93,7 +93,8 @@ Model::Model(const ModelProto& model_proto, const PathString& model_path, : Model(ModelProto(model_proto), model_path, local_registries, logger) { } -Model::Model(ModelProto&& model_proto, const PathString& model_path, const IOnnxRuntimeOpSchemaRegistryList* local_registries, +Model::Model(ModelProto&& model_proto, const PathString& model_path, + const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger) : model_path_(Path::Parse(model_path)) { if (!utils::HasGraph(model_proto)) { @@ -618,6 +619,9 @@ Model::Model() : model_path_{} { #if defined(ENABLE_ORT_FORMAT_LOAD) common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model, +#if !defined(ORT_MINIMAL_BUILD) + const IOnnxRuntimeOpSchemaRegistryList* local_registries, +#endif const logging::Logger& logger, std::unique_ptr& model) { model.reset(new Model()); @@ -632,6 +636,13 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model, } model->model_proto_.set_model_version(fbs_model.model_version()); model->model_proto_.set_ir_version(fbs_model.ir_version()); + + auto schema_registry = std::make_shared(); + if (local_registries != nullptr) { + for (const auto& schema_collection : *local_registries) { + schema_registry->RegisterRegistry(schema_collection); + } + } #else experimental::utils::LoadStringFromOrtFormat(model->producer_name_, fbs_model.producer_name()); experimental::utils::LoadStringFromOrtFormat(model->producer_version_, fbs_model.producer_version()); @@ -648,10 +659,14 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model, auto fbs_graph = fbs_model.graph(); ORT_RETURN_IF(nullptr == fbs_graph, "Graph is null. Invalid ORT format model."); +#if !defined(ORT_MINIMAL_BUILD) + ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version, schema_registry, logger, + model->graph_)); +#else ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version, logger, model->graph_)); - +#endif return Status::OK(); } -#endif +#endif // defined(ENABLE_ORT_FORMAT_LOAD) } // namespace onnxruntime diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 7f4dfd0d2f..8083ff9f14 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -239,6 +239,9 @@ class Model { #if defined(ENABLE_ORT_FORMAT_LOAD) static common::Status LoadFromOrtFormat(const onnxruntime::experimental::fbs::Model& fbs_model, +#if !defined(ORT_MINIMAL_BUILD) + const IOnnxRuntimeOpSchemaRegistryList* local_registries, +#endif const logging::Logger& logger, std::unique_ptr& model); #endif diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 9ed0299e03..125fecc373 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -4,16 +4,18 @@ #ifdef _WIN32 #pragma warning(disable : 4267) #endif -#include "core/graph/onnx_protobuf.h" -#include "core/session/inference_session.h" -#include "core/session/ort_apis.h" + +#include "core/framework/customregistry.h" #include "core/framework/data_types.h" #include "core/framework/op_kernel_info.h" #include "core/framework/op_kernel_context_internal.h" #include "core/framework/error_code_helper.h" #include "core/framework/tensor_type_and_shape.h" +#include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" +#include "core/session/ort_apis.h" -ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(const onnxruntime::DataTypeImpl* cpp_type); +// ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(const onnxruntime::DataTypeImpl* cpp_type); ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) { auto status = reinterpret_cast(info)->GetAttr(name, out); @@ -67,23 +69,22 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_string, _In_ const OrtKernel return onnxruntime::ToOrtStatus(status); } -#if !defined(ORT_MINIMAL_BUILD) -#include "core/framework/customregistry.h" - namespace onnxruntime { struct CustomOpKernel : OpKernel { CustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { - if (op_.version > ORT_API_VERSION) + if (op_.version > ORT_API_VERSION) { ORT_THROW("Unsupported version '" + std::to_string(op_.version) + "' in custom op '" + op.GetName(&op)); - op_kernel_ = op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version), reinterpret_cast(&info)); + } + + op_kernel_ = op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version), + reinterpret_cast(&info)); } ~CustomOpKernel() override { op_.KernelDestroy(op_kernel_); } Status Compute(OpKernelContext* ctx) const override { - auto* ictx = static_cast(ctx); - op_.KernelCompute(op_kernel_, reinterpret_cast(ictx)); + op_.KernelCompute(op_kernel_, reinterpret_cast(ctx)); return Status::OK(); } @@ -94,12 +95,17 @@ struct CustomOpKernel : OpKernel { void* op_kernel_; }; -common::Status CreateCustomRegistry(const std::vector& op_domains, std::shared_ptr& output) { +common::Status CreateCustomRegistry(const std::vector& op_domains, + std::shared_ptr& output) { output = std::make_shared(); - for (auto& domain : op_domains) { + + for (const auto& domain : op_domains) { + // Create an OpSchema for each op and register them + +#if !defined(ORT_MINIMAL_BUILD) // Domain is not empty - add it to the DomainToVersion ONNX map // If domain is empty, it is assumed to be part of the ONNX domain - if (domain->domain_[0]) { + if (!domain->domain_.empty()) { // Add it to the DomainToVersion ONNX map if it doesn't already exist // For example, two sessions using the same session_options should not add the same custom op domain to the version map twice auto& domain_to_version_range_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); @@ -111,24 +117,25 @@ common::Status CreateCustomRegistry(const std::vector& op_do } std::vector schemas_list; - - for (auto& op : domain->custom_ops_) { - ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "unknown", 0); + for (const auto* op : domain->custom_ops_) { + ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "custom op registered at runtime", 0); auto input_count = op->GetInputTypeCount(op); for (size_t i = 0; i < input_count; i++) { auto type = op->GetInputType(op, i); - schema.Input(i, "A", "Description", - ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type ? "T" : - DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type))); + schema.Input(i, "Input" + std::to_string(i), "", + ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type + ? "T" + : DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type))); } auto output_count = op->GetOutputTypeCount(op); for (size_t i = 0; i < output_count; i++) { auto type = op->GetOutputType(op, i); - schema.Output(i, "A", "Description", - ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type ? "T": - DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type))); + schema.Output(i, "Output" + std::to_string(i), "", + ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type + ? "T" + : DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type))); } schema.TypeConstraint("T", DataTypeImpl::ToString(DataTypeImpl::AllTensorTypes()), "all types"); @@ -136,30 +143,38 @@ common::Status CreateCustomRegistry(const std::vector& op_do schema.SinceVersion(1); schema.AllowUncheckedAttributes(); schemas_list.push_back(schema); - - KernelDefBuilder def_builder; - def_builder.SetName(op->GetName(op)) - .SetDomain(domain->domain_) - .SinceVersion(1) - .TypeConstraint("T", DataTypeImpl::AllTensorTypes()); - - if (const char* provider_type = op->GetExecutionProviderType(op)) - def_builder.Provider(provider_type); - else - def_builder.Provider(onnxruntime::kCpuExecutionProvider); - KernelCreateFn kernel_create_fn = [&op](const OpKernelInfo& info) -> OpKernel* { return new CustomOpKernel(info, *op); }; - KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn); - output->RegisterCustomKernel(create_info); } ORT_RETURN_IF_ERROR(output->RegisterOpSet(schemas_list, domain->domain_, 1 /* baseline opset version */, 1000 /* opset version */)); + +#endif + // create the KernelDef for each op and register it + for (const auto* op : domain->custom_ops_) { + KernelDefBuilder def_builder; + def_builder.SetName(op->GetName(op)) + .SetDomain(domain->domain_) + .SinceVersion(1) + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()); + + if (const char* provider_type = op->GetExecutionProviderType(op)) { + def_builder.Provider(provider_type); + } else { + def_builder.Provider(onnxruntime::kCpuExecutionProvider); + } + + KernelCreateFn kernel_create_fn = [op](const OpKernelInfo& info) -> OpKernel* { + return new CustomOpKernel(info, *op); + }; + + KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn); + ORT_RETURN_IF_ERROR(output->RegisterCustomKernel(create_info)); + } } + return Status::OK(); } } // namespace onnxruntime - -#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c4c60d590f..8cd255676d 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -14,6 +14,7 @@ #include "core/common/denormal.h" #include "core/common/logging/logging.h" #include "core/framework/allocatormgr.h" +#include "core/framework/customregistry.h" #include "core/framework/error_code_helper.h" #include "core/framework/execution_frame.h" #include "core/framework/feeds_fetches_manager.h" @@ -43,6 +44,7 @@ #ifdef USE_DML // TODO: This is necessary for the workaround in TransformGraph #include "core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h" #endif +#include "core/session/custom_ops.h" #include "core/session/environment.h" #include "core/session/IOBinding.h" #include "core/session/inference_session_utils.h" @@ -50,10 +52,6 @@ #include "core/util/protobuf_parsing_utils.h" #include "core/util/thread_utils.h" -#if !defined(ORT_MINIMAL_BUILD) -#include "core/framework/customregistry.h" -#include "core/session/custom_ops.h" -#endif using namespace ONNX_NAMESPACE; using namespace onnxruntime::experimental; @@ -468,6 +466,7 @@ common::Status InferenceSession::AddCustomTransformerList(const std::vector& op_domains) { std::shared_ptr custom_registry; @@ -486,10 +485,13 @@ common::Status InferenceSession::RegisterCustomRegistry(std::shared_ptrGetKernelRegistry()); +#if !defined(ORT_MINIMAL_BUILD) custom_schema_registries_.push_back(custom_registry->GetOpschemaRegistry()); +#endif return Status::OK(); } +#if !defined(ORT_MINIMAL_BUILD) common::Status InferenceSession::SaveToOrtFormat(const std::basic_string& filepath) const { ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ort format only supports little-edian machines"); @@ -1016,7 +1018,15 @@ Status InferenceSession::LoadOrtModel(std::function load_ort_format_mo // need to go from unique_ptr to shared_ptr when moving into model_ std::unique_ptr tmp_model; +#if !defined(ORT_MINIMAL_BUILD) + ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model, + HasLocalSchema() ? &custom_schema_registries_ : nullptr, + *session_logger_, tmp_model)); + +#else ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model, *session_logger_, tmp_model)); +#endif + ORT_RETURN_IF_ERROR(SaveModelMetadata(*tmp_model)); model_ = std::move(tmp_model); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 2b98143621..50727ca7e1 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -184,6 +184,7 @@ class InferenceSession { */ common::Status AddCustomTransformerList(const std::vector& transformers_to_enable) ORT_MUST_USE_RESULT; +#endif // !defined(ORT_MINIMAL_BUILD) /** * Add custom ops. This API is not thread safe. */ @@ -200,8 +201,6 @@ class InferenceSession { */ common::Status RegisterCustomRegistry(std::shared_ptr custom_registry) ORT_MUST_USE_RESULT; -#endif // !defined(ORT_MINIMAL_BUILD) - /** * Load an ONNX or ORT format model. * @@ -582,11 +581,11 @@ class InferenceSession { #if !defined(ORT_MINIMAL_BUILD) std::list> custom_schema_registries_; +#endif //CustomRegistry objects own the corresponding KernelRegistry and OnnxRuntimeOpSchemaRegistry objects. //So its lifetime should be same as its constituents. This vector is to extend the lifetime of the owner. std::vector> custom_registries_; -#endif ModelMetadata model_metadata_; std::unordered_set required_inputs_; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 8d36c67ad9..0c161336eb 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -439,13 +439,11 @@ static ORT_STATUS_PTR CreateSessionAndLoadModel(_In_ const OrtSessionOptions* op env->GetEnvironment()); } -#if !defined(ORT_MINIMAL_BUILD) // Add custom domains Status status; if (options && !options->custom_op_domains_.empty()) { ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(options->custom_op_domains_)); } -#endif // Finish load if (load_config_from_model) { diff --git a/onnxruntime/test/framework/ort_model_only_test.cc b/onnxruntime/test/framework/ort_model_only_test.cc index 3253981b51..db1ffe475e 100644 --- a/onnxruntime/test/framework/ort_model_only_test.cc +++ b/onnxruntime/test/framework/ort_model_only_test.cc @@ -4,6 +4,7 @@ // if we can't load an ORT format model we can't really test anything #if defined(ENABLE_ORT_FORMAT_LOAD) +#include "core/common/make_unique.h" #include "core/framework/data_types.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" @@ -18,6 +19,8 @@ #include "flatbuffers/idl.h" #include "flatbuffers/util.h" +#include "core/session/onnxruntime_cxx_api.h" + #include "gtest/gtest.h" using namespace std; @@ -237,16 +240,16 @@ static void DumpOrtModelAsJson(const std::string& model_uri) { std::ofstream(model_uri + ".json") << json; } */ -/* The full build was causing the following error because the graph node array has some empyt (blank) node at some indices for certain ORT designs - onnx runtime exception : Satisfied, but should not be : node == nullptr - session_state.cc : 814 onnxruntime::SessionState::LoadFromOrtFormatCan't find node with index 4. Invalid ORT format model. - The bug was due to loading an ORT format model in a full build, allowing optimizers to run, but trying to use the saved kernel information. - As the optimizer removed some leaving gaps in the graph node vector. - The build has been fixed in InferenceSession code. The following test case to catch this error. -*/ +/* +Validate we don't run optimizers on an ORT format model in a full build. The optimizers will remove nodes, +which will create a mismatch with the saved kernel information and result in a runtime error. +We could take steps to handle this scenario in a full build, but for consistency we choose to not run optimizers +on any ORT format model. +*/ TEST(OrtModelOnlyTests, ValidateOrtFormatModelDoesNotRunOptimizersInFullBuild) { - const std::basic_string ort_file = ORT_TSTR("mnist.onnx.ort"); + // we have tests that use a pre-created minst.onnx.ort, so make the naming for the unit test generated file clearer + const std::basic_string ort_file = ORT_TSTR("testdata/mnist.onnx.test_output.ort"); SaveAndCompareModels("testdata/mnist.onnx", ort_file); // DumpOrtModelAsJson(ToMBString(ort_file)); @@ -257,8 +260,8 @@ TEST(OrtModelOnlyTests, ValidateOrtFormatModelDoesNotRunOptimizersInFullBuild) { test_info.configs.push_back(std::make_pair(kOrtSessionOptionsConfigLoadModelFormat, "ORT")); OrtValue ml_value; - vector data(28*28, 0.0); - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1,1,28,28}, data, + vector data(28 * 28, 0.0); + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1, 1, 28, 28}, data, &ml_value); test_info.inputs.insert(std::make_pair("Input3", ml_value)); @@ -267,14 +270,13 @@ TEST(OrtModelOnlyTests, ValidateOrtFormatModelDoesNotRunOptimizersInFullBuild) { test_info.output_verifier = [](const std::vector& fetches) { const auto& output = fetches[0].Get(); ASSERT_TRUE(output.Shape().NumDimensions() == 2); - // ASSERT_TRUE(output.Data()[0] == 125.f); }; RunOrtModel(test_info); } TEST(OrtModelOnlyTests, SerializeToOrtFormat) { - const std::basic_string ort_file = ORT_TSTR("ort_github_issue_4031.onnx.ort"); + const std::basic_string ort_file = ORT_TSTR("testdata/ort_github_issue_4031.onnx.test_output.ort"); SaveAndCompareModels("testdata/ort_github_issue_4031.onnx", ort_file); // DumpOrtModelAsJson(ToMBString(ort_file)); @@ -301,7 +303,8 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormat) { } TEST(OrtModelOnlyTests, SparseInitializerHandling) { - const std::basic_string ort_file = ORT_TSTR("sparse_initializer_handling.onnx.ort"); + const std::basic_string ort_file = + ORT_TSTR("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx.test_output.ort"); SaveAndCompareModels("testdata/ort_minimal_test_models/sparse_initializer_handling.onnx", ort_file); SessionOptions so; @@ -322,7 +325,8 @@ TEST(OrtModelOnlyTests, SparseInitializerHandling) { #if !defined(DISABLE_ML_OPS) TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) { - const std::basic_string ort_file = ORT_TSTR("sklearn_bin_voting_classifier_soft_converted.ort"); + const std::basic_string ort_file = + ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft_converted.test_output.ort"); SaveAndCompareModels("testdata/sklearn_bin_voting_classifier_soft.onnx", ort_file); OrtModelTestInfo test_info; @@ -361,6 +365,7 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) { RunOrtModel(test_info); } + #endif // #if !defined(DISABLE_ML_OPS) #endif // #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 4efdd116e6..bcffcda1e5 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -21,6 +21,7 @@ #include "providers.h" #include "test_allocator.h" #include "test_fixture.h" +#include "utils.h" #ifdef _WIN32 #include @@ -32,11 +33,9 @@ #include #endif -struct Input { - const char* name = nullptr; - std::vector dims; - std::vector values; -}; +// Once we use C++17 this could be replaced with std::size +template +constexpr size_t countof(T (&)[N]) { return N; } extern std::unique_ptr ort_env; @@ -51,14 +50,18 @@ void RunSession(OrtAllocator* allocator, Ort::Session& session_object, std::vector input_names; for (size_t i = 0; i < inputs.size(); i++) { input_names.emplace_back(inputs[i].name); - ort_inputs.emplace_back(Ort::Value::CreateTensor(allocator->Info(allocator), const_cast(inputs[i].values.data()), inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(allocator->Info(allocator), const_cast(inputs[i].values.data()), + inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); } std::vector ort_outputs; if (output_tensor) - session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, output_tensor, 1); + session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), + &output_name, output_tensor, 1); else { - ort_outputs = session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, 1); + ort_outputs = session_object.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), + &output_name, 1); ASSERT_EQ(ort_outputs.size(), 1u); output_tensor = &ort_outputs[0]; } @@ -74,17 +77,17 @@ void RunSession(OrtAllocator* allocator, Ort::Session& session_object, } } -template -void TestInference(Ort::Env& env, T model_uri, - const std::vector& inputs, - const char* output_name, - const std::vector& expected_dims_y, - const std::vector& expected_values_y, - int provider_type, - OrtCustomOpDomain* custom_op_domain_ptr, - const char* custom_op_library_filename, - void** library_handle = nullptr, - bool test_session_creation_only = false) { +template +static void TestInference(Ort::Env& env, const std::basic_string& model_uri, + const std::vector& inputs, + const char* output_name, + const std::vector& expected_dims_y, + const std::vector& expected_values_y, + int provider_type, + OrtCustomOpDomain* custom_op_domain_ptr, + const char* custom_op_library_filename, + void** library_handle = nullptr, + bool test_session_creation_only = false) { Ort::SessionOptions session_options; if (provider_type == 1) { @@ -103,7 +106,8 @@ void TestInference(Ort::Env& env, T model_uri, #endif } else if (provider_type == 3) { #ifdef USE_NUPHAR - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nuphar(session_options, /*allow_unaligned_buffers*/ 1, "")); + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nuphar(session_options, + /*allow_unaligned_buffers*/ 1, "")); std::cout << "Running simple inference with nuphar provider" << std::endl; #else return; @@ -116,11 +120,12 @@ void TestInference(Ort::Env& env, T model_uri, } if (custom_op_library_filename) { - Ort::ThrowOnError(Ort::GetApi().RegisterCustomOpsLibrary((OrtSessionOptions*)session_options, custom_op_library_filename, library_handle)); + Ort::ThrowOnError(Ort::GetApi().RegisterCustomOpsLibrary(session_options, + custom_op_library_filename, library_handle)); } // if session creation passes, model loads fine - Ort::Session session(env, model_uri, session_options); + Ort::Session session(env, model_uri.c_str(), session_options); // caller wants to test running the model (not just loading the model) if (!test_session_creation_only) { @@ -136,7 +141,8 @@ void TestInference(Ort::Env& env, T model_uri, expected_values_y, nullptr); //with preallocated output tensor - Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), expected_dims_y.data(), expected_dims_y.size()); + Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), + expected_dims_y.data(), expected_dims_y.size()); //test it twice for (int i = 0; i != 2; ++i) @@ -181,7 +187,8 @@ TEST_P(CApiTestWithProvider, simple) { std::vector expected_dims_y = {3, 2}; std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; - TestInference(*ort_env, MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, GetParam(), nullptr, nullptr); + TestInference(*ort_env, MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, GetParam(), + nullptr, nullptr); } TEST(CApiTest, dim_param) { @@ -216,70 +223,6 @@ INSTANTIATE_TEST_SUITE_P(CApiTestWithProviders, CApiTestWithProvider, ::testing::Values(0, 1, 2, 3, 4)); -struct OrtTensorDimensions : std::vector { - OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) { - OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value); - std::vector::operator=(ort.GetTensorShape(info)); - ort.ReleaseTensorTypeAndShapeInfo(info); - } -}; - -// Once we use C++17 this could be replaced with std::size -template -constexpr size_t countof(T (&)[N]) { return N; } -void cuda_add(int64_t, float*, const float*, const float*); - -struct MyCustomKernel { - MyCustomKernel(Ort::CustomOpApi ort, const OrtKernelInfo* /*info*/) : ort_(ort) { - } - - void Compute(OrtKernelContext* context) { - // Setup inputs - const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0); - const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1); - const float* X = ort_.GetTensorData(input_X); - const float* Y = ort_.GetTensorData(input_Y); - - // Setup output - OrtTensorDimensions dimensions(ort_, input_X); - OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size()); - float* out = ort_.GetTensorMutableData(output); - - OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output); - int64_t size = ort_.GetTensorShapeElementCount(output_info); - ort_.ReleaseTensorTypeAndShapeInfo(output_info); - - // Do computation -#ifdef USE_CUDA - cuda_add(size, out, X, Y); -#else - for (int64_t i = 0; i < size; i++) { - out[i] = X[i] + Y[i]; - } -#endif - } - - private: - Ort::CustomOpApi ort_; -}; - -struct MyCustomOp : Ort::CustomOpBase { - explicit MyCustomOp(const char* provider) : provider_(provider) {} - void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { return new MyCustomKernel(api, info); }; - const char* GetName() const { return "Foo"; }; - - const char* GetExecutionProviderType() const { return provider_; }; - - size_t GetInputTypeCount() const { return 2; }; - ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; - - size_t GetOutputTypeCount() const { return 1; }; - ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; - - private: - const char* provider_; -}; - TEST(CApiTest, custom_op_handler) { std::cout << "Running custom op inference" << std::endl; @@ -303,9 +246,11 @@ TEST(CApiTest, custom_op_handler) { custom_op_domain.Add(&custom_op); #ifdef USE_CUDA - TestInference(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 1, custom_op_domain, nullptr, nullptr); + TestInference(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 1, + custom_op_domain, nullptr, nullptr); #else - TestInference(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0, custom_op_domain, nullptr); + TestInference(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0, + custom_op_domain, nullptr); #endif } @@ -368,9 +313,11 @@ struct SliceCustomOpKernel { struct SliceCustomOp : Ort::CustomOpBase { explicit SliceCustomOp(const char* provider) : provider_(provider) {} - void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { return new SliceCustomOpKernel(api, info); }; - const char* GetName() const { return "Slice"; }; + void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { + return new SliceCustomOpKernel(api, info); + }; + const char* GetName() const { return "Slice"; }; const char* GetExecutionProviderType() const { return provider_; }; size_t GetInputTypeCount() const { return 3; }; @@ -427,9 +374,11 @@ TEST(CApiTest, varied_input_custom_op_handler) { custom_op_domain.Add(&slice_custom_op); #ifdef USE_CUDA - TestInference(*ort_env, VARIED_INPUT_CUSTOM_OP_MODEL_URI, inputs, "Z", expected_dims_z, expected_values_z, 1, custom_op_domain, nullptr, nullptr); + TestInference(*ort_env, VARIED_INPUT_CUSTOM_OP_MODEL_URI, inputs, "Z", + expected_dims_z, expected_values_z, 1, custom_op_domain, nullptr, nullptr); #else - TestInference(*ort_env, VARIED_INPUT_CUSTOM_OP_MODEL_URI, inputs, "Z", expected_dims_z, expected_values_z, 0, custom_op_domain, nullptr); + TestInference(*ort_env, VARIED_INPUT_CUSTOM_OP_MODEL_URI, inputs, "Z", + expected_dims_z, expected_values_z, 0, custom_op_domain, nullptr); #endif } @@ -454,8 +403,8 @@ TEST(CApiTest, RegisterCustomOpForCPUAndCUDA) { custom_op_domain.Add(&custom_op_cpu); custom_op_domain.Add(&custom_op_cuda); - TestInference(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, - expected_values_y, 1, custom_op_domain, nullptr, nullptr, true); + TestInference(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, + expected_values_y, 1, custom_op_domain, nullptr, nullptr, true); } #endif @@ -496,8 +445,8 @@ lib_name = "./libcustom_op_library.so"; #endif void* library_handle = nullptr; - TestInference(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y, - expected_values_y, 0, nullptr, lib_name.c_str(), &library_handle); + TestInference(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y, + expected_values_y, 0, nullptr, lib_name.c_str(), &library_handle); #ifdef _WIN32 bool success = ::FreeLibrary(reinterpret_cast(library_handle)); @@ -552,7 +501,8 @@ TEST(CApiTest, test_pyop) { input.values = {1.0f, 2.0f, 3.0f, 4.0f}; std::vector expected_dims_y = {2, 2}; std::vector expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f}; - TestInference(*ort_env, PYOP_FLOAT_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0, nullptr, nullptr); + TestInference(*ort_env, PYOP_FLOAT_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0, + nullptr, nullptr); } TEST(CApiTest, test_pyop_multi) { @@ -564,7 +514,8 @@ TEST(CApiTest, test_pyop_multi) { input.values = {1.0f, 2.0f, 3.0f, 4.0f}; std::vector expected_dims_y = {2, 2}; std::vector expected_values_y = {8.0f, 16.0f, 24.0f, 32.0f}; - TestInference(*ort_env, PYOP_MULTI_MODEL_URI, inputs, "Z", expected_dims_y, expected_values_y, 0, nullptr, nullptr); + TestInference(*ort_env, PYOP_MULTI_MODEL_URI, inputs, "Z", expected_dims_y, expected_values_y, 0, + nullptr, nullptr); } TEST(CApiTest, test_pyop_kwarg) { @@ -576,7 +527,8 @@ TEST(CApiTest, test_pyop_kwarg) { input.values = {1.0f, 2.0f, 3.0f, 4.0f}; std::vector expected_dims_y = {2, 2}; std::vector expected_values_y = {25.0f, 50.0f, 75.0f, 100.0f}; - TestInference(*ort_env, PYOP_KWARG_MODEL_URI, inputs, "Z", expected_dims_y, expected_values_y, 0, nullptr, nullptr); + TestInference(*ort_env, PYOP_KWARG_MODEL_URI, inputs, "Z", expected_dims_y, expected_values_y, 0, + nullptr, nullptr); } #endif @@ -596,7 +548,8 @@ TEST(ReducedOpsBuildTest, test_included_ops) { std::vector inputs = {{"X", {3}, {-1.0f, 2.0f, -3.0f}}}; std::vector expected_dims_y = {1}; std::vector expected_values_y = {0.75}; - TestInference(*ort_env, model_uri, inputs, "Y", expected_dims_y, expected_values_y, 0, nullptr, nullptr); + TestInference(*ort_env, model_uri, inputs, "Y", expected_dims_y, expected_values_y, 0, + nullptr, nullptr); } TEST(ReducedOpsBuildTest, test_excluded_ops) { @@ -609,7 +562,8 @@ TEST(ReducedOpsBuildTest, test_excluded_ops) { bool failed = false; try { //only test model loading, exception expected - TestInference(*ort_env, model_uri, inputs, "Y", expected_dims_y, expected_values_y, 0, nullptr, nullptr, nullptr, true); + TestInference(*ort_env, model_uri, inputs, "Y", expected_dims_y, expected_values_y, 0, + nullptr, nullptr, nullptr, true); } catch (const Ort::Exception& e) { failed = e.GetOrtErrorCode() == ORT_NOT_IMPLEMENTED; } @@ -782,8 +736,8 @@ TEST(CApiTest, io_binding_cuda) { ASSERT_NE(output_data.get(), nullptr); // Create an OrtValue tensor backed by data on CUDA memory - Ort::Value bound_y = Ort::Value::CreateTensor(info_cuda, reinterpret_cast(output_data.get()), expected_y.size(), - expected_y_shape.data(), expected_y_shape.size()); + Ort::Value bound_y = Ort::Value::CreateTensor(info_cuda, reinterpret_cast(output_data.get()), + expected_y.size(), expected_y_shape.data(), expected_y_shape.size()); Ort::IoBinding binding(session); binding.BindInput("X", bound_x); @@ -854,7 +808,8 @@ TEST(CApiTest, create_tensor) { int64_t expected_len = 2; auto default_allocator = onnxruntime::make_unique(); - Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); + Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1, + ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); Ort::ThrowOnError(Ort::GetApi().FillStringTensor(tensor, s, expected_len)); auto shape_info = tensor.GetTensorTypeAndShapeInfo(); @@ -874,7 +829,8 @@ TEST(CApiTest, fill_string_tensor) { int64_t expected_len = 2; auto default_allocator = onnxruntime::make_unique(); - Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); + Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1, + ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); for (int64_t i = 0; i < expected_len; i++) { tensor.FillStringTensorElement(s[i], i); @@ -892,7 +848,8 @@ TEST(CApiTest, get_string_tensor_element) { int64_t element_index = 0; auto default_allocator = onnxruntime::make_unique(); - Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); + Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1, + ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); tensor.FillStringTensor(s, expected_len); @@ -1007,7 +964,8 @@ TEST(CApiTest, override_initializer) { std::string f2_data{"f2_string"}; // Place a string into Tensor OrtValue and assign to the - Ort::Value f2_input_tensor = Ort::Value::CreateTensor(allocator.get(), dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); + Ort::Value f2_input_tensor = Ort::Value::CreateTensor(allocator.get(), dims.data(), dims.size(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); const char* const input_char_string[] = {f2_data.c_str()}; f2_input_tensor.FillStringTensor(input_char_string, 1U); @@ -1138,7 +1096,8 @@ TEST(CApiTest, model_metadata) { ASSERT_TRUE(version == 1); int64_t num_keys_in_custom_metadata_map; - char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(), num_keys_in_custom_metadata_map); + char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(), + num_keys_in_custom_metadata_map); ASSERT_TRUE(num_keys_in_custom_metadata_map == 2); allocator.get()->Free(custom_metadata_map_keys[0]); @@ -1147,7 +1106,8 @@ TEST(CApiTest, model_metadata) { char* lookup_value_1 = model_metadata.LookupCustomMetadataMap("ort_config", allocator.get()); ASSERT_TRUE(strcmp(lookup_value_1, - "{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}") == 0); + "{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, " + "\"graph_optimization_level\": 99, \"enable_profiling\": 1}}") == 0); allocator.get()->Free(lookup_value_1); char* lookup_value_2 = model_metadata.LookupCustomMetadataMap("dummy_key", allocator.get()); @@ -1180,7 +1140,8 @@ TEST(CApiTest, model_metadata) { // Model does not contain custom metadata map int64_t num_keys_in_custom_metadata_map; - char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(), num_keys_in_custom_metadata_map); + char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(), + num_keys_in_custom_metadata_map); ASSERT_TRUE(num_keys_in_custom_metadata_map == 0); ASSERT_TRUE(custom_metadata_map_keys == nullptr); } @@ -1235,9 +1196,9 @@ TEST(CApiTest, TestSharedAllocatorUsingCreateAndRegisterAllocator) { ASSERT_TRUE(api.CreateAndRegisterAllocator(env_ptr, mem_info, arena_cfg) == nullptr); // test for duplicates - std::unique_ptr status_releaser(api.CreateAndRegisterAllocator(env_ptr, mem_info, - arena_cfg), - api.ReleaseStatus); + std::unique_ptr status_releaser( + api.CreateAndRegisterAllocator(env_ptr, mem_info, arena_cfg), + api.ReleaseStatus); ASSERT_FALSE(status_releaser.get() == nullptr); Ort::SessionOptions session_options; diff --git a/onnxruntime/test/shared_lib/test_ort_format_models.cc b/onnxruntime/test/shared_lib/test_ort_format_models.cc new file mode 100644 index 0000000000..d288325715 --- /dev/null +++ b/onnxruntime/test/shared_lib/test_ort_format_models.cc @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// if we can't load an ORT format model we can't really test anything +#if defined(ENABLE_ORT_FORMAT_LOAD) + +#include "core/common/make_unique.h" +#include "core/graph/constants.h" +#include "core/session/onnxruntime_cxx_api.h" + +#include "test_allocator.h" +#include "utils.h" + +#include "gtest/gtest.h" + +extern std::unique_ptr ort_env; + +static void TestInference(Ort::Env& env, const std::basic_string& model_uri, + const std::vector& inputs, const char* output_name, + const std::vector& expected_dims_y, const std::vector& expected_values_y, + Ort::CustomOpDomain& custom_op_domain) { + Ort::SessionOptions session_options; + session_options.Add(custom_op_domain); + +#ifdef USE_CUDA + OrtCUDAProviderOptions cuda_options{}; + session_options.AppendExecutionProvider_CUDA(cuda_options); +#endif + + Ort::Session session(env, model_uri.c_str(), session_options); + + MockedOrtAllocator allocator; + std::vector ort_inputs; + std::vector input_names; + + for (size_t i = 0; i < inputs.size(); i++) { + // we put the data in a Tensor, and Tensor has a method to get a mutable pointer to the data. + // we never call that for an input, but need to do the const_cast to make this potential explicit + float* input_data = const_cast(inputs[i].values.data()); + + auto input_tensor = Ort::Value::CreateTensor(allocator.Info(), input_data, inputs[i].values.size(), + inputs[i].dims.data(), inputs[i].dims.size()); + + input_names.push_back(inputs[i].name); + ort_inputs.push_back(std::move(input_tensor)); + } + + std::vector ort_outputs; + ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), + &output_name, 1); + + ASSERT_EQ(ort_outputs.size(), size_t(1)); + const auto& output_tensor = &ort_outputs[0]; + + auto type_info = output_tensor->GetTensorTypeAndShapeInfo(); + ASSERT_EQ(type_info.GetShape(), expected_dims_y); + size_t total_len = type_info.GetElementCount(); + ASSERT_EQ(expected_values_y.size(), total_len); + + const auto* f = output_tensor->GetTensorMutableData(); + for (size_t i = 0; i < total_len; ++i) { + ASSERT_EQ(expected_values_y[i], f[i]); + } +} + +#if !defined(ORT_MINIMAL_BUILD) +TEST(OrtFormatCustomOpTests, ConvertOnnxModelToOrt) { + const std::basic_string onnx_file = ORT_TSTR("testdata/foo_1.onnx"); + const std::basic_string ort_file = ORT_TSTR("testdata/foo_1.onnx.test_output.ort"); + +#ifdef USE_CUDA + MyCustomOp custom_op{onnxruntime::kCudaExecutionProvider}; +#else + MyCustomOp custom_op{onnxruntime::kCpuExecutionProvider}; +#endif + Ort::CustomOpDomain custom_op_domain(""); + custom_op_domain.Add(&custom_op); + + // convert to ort by loading the onnx model + { + Ort::SessionOptions so; + so.Add(custom_op_domain); + so.SetLogId("CustomOp"); + so.SetOptimizedModelFilePath(ort_file.c_str()); + +#ifdef USE_CUDA + OrtCUDAProviderOptions cuda_options{}; + so.AppendExecutionProvider_CUDA(cuda_options); +#endif + + Ort::Session session(*ort_env, onnx_file.c_str(), so); + } + + // now load the ORT format model and execute it + std::vector inputs(1); + Input& input = inputs[0]; + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // model adds 1, 2, 3, 4, 5, 6 to the input values + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f}; + + TestInference(*ort_env, ort_file, inputs, "Y", expected_dims_y, expected_values_y, custom_op_domain); +} +#endif // if !defined(ORT_MINIMAL_BUILD) + +// the saved ORT format model has the CPU EP assigned to the custom op node, so we only test if we're not using the +// CUDA EP for the test. +#ifndef USE_CUDA +TEST(OrtFormatCustomOpTests, LoadOrtModel) { + const std::basic_string ort_file = ORT_TSTR("testdata/foo_1.onnx.ort"); + + MyCustomOp custom_op{onnxruntime::kCpuExecutionProvider}; + Ort::CustomOpDomain custom_op_domain(""); + custom_op_domain.Add(&custom_op); + + // load the ORT format model and execute it + std::vector inputs(1); + Input& input = inputs[0]; + input.name = "X"; + input.dims = {3, 2}; + input.values = {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}; + + // model adds 1, 2, 3, 4, 5, 6 to the input values + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {7.0f, 7.0f, 7.0f, 7.0f, 7.0f, 7.0f}; + + TestInference(*ort_env, ort_file, inputs, "Y", expected_dims_y, expected_values_y, custom_op_domain); +} +#endif + +#endif // #if defined(ENABLE_ORT_FORMAT_LOAD) diff --git a/onnxruntime/test/shared_lib/utils.cc b/onnxruntime/test/shared_lib/utils.cc new file mode 100644 index 0000000000..b27c9e8228 --- /dev/null +++ b/onnxruntime/test/shared_lib/utils.cc @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "utils.h" + +#ifdef USE_CUDA +#include +void cuda_add(int64_t, float*, const float*, const float*); +#endif + +void MyCustomKernel::Compute(OrtKernelContext* context) { + // Setup inputs + const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0); + const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1); + const float* X = ort_.GetTensorData(input_X); + const float* Y = ort_.GetTensorData(input_Y); + + // Setup output + OrtTensorDimensions dimensions(ort_, input_X); + OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size()); + float* out = ort_.GetTensorMutableData(output); + + OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output); + int64_t size = ort_.GetTensorShapeElementCount(output_info); + ort_.ReleaseTensorTypeAndShapeInfo(output_info); + + // Do computation +#ifdef USE_CUDA + cuda_add(size, out, X, Y); +#else + for (int64_t i = 0; i < size; i++) { + out[i] = X[i] + Y[i]; + } +#endif +} diff --git a/onnxruntime/test/shared_lib/utils.h b/onnxruntime/test/shared_lib/utils.h new file mode 100644 index 0000000000..560ead8708 --- /dev/null +++ b/onnxruntime/test/shared_lib/utils.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/onnxruntime_cxx_api.h" + +struct Input { + const char* name = nullptr; + std::vector dims; + std::vector values; +}; + +struct OrtTensorDimensions : std::vector { + OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) { + OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value); + std::vector::operator=(ort.GetTensorShape(info)); + ort.ReleaseTensorTypeAndShapeInfo(info); + } +}; + +struct MyCustomKernel { + MyCustomKernel(Ort::CustomOpApi ort, const OrtKernelInfo* /*info*/) : ort_(ort) { + } + + void Compute(OrtKernelContext* context); + + private: + Ort::CustomOpApi ort_; +}; + +struct MyCustomOp : Ort::CustomOpBase { + explicit MyCustomOp(const char* provider) : provider_(provider) {} + + void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { return new MyCustomKernel(api, info); }; + const char* GetName() const { return "Foo"; }; + const char* GetExecutionProviderType() const { return provider_; }; + + size_t GetInputTypeCount() const { return 2; }; + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; + + size_t GetOutputTypeCount() const { return 1; }; + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; + + private: + const char* provider_; +}; diff --git a/onnxruntime/test/testdata/foo_1.onnx.ort b/onnxruntime/test/testdata/foo_1.onnx.ort new file mode 100644 index 0000000000000000000000000000000000000000..84b8f93da44a6cfa5f2f28af5e8c17f6f23de0a0 GIT binary patch literal 1312 zcmZ{kziU%b6vt0$v>|FKHAp}y9Wppbh|-~(!C#YNkgCENh7Ki_xnvuVUbzVH0F=brPu=iatzW_Rw+-$o&2o3v9_ zvLcpYncV_Eie`(y(o{ZXCTsW=?-$T-_P&08{%G(;wxnX`fvOqxBt?9)i(rQ$FG1^= z8NW;_RI*!=r|t1#_SY7J&Nf!)4z4s7#=et!Z6o5m0Q2&jJwVYpr`Xwu8<_FN~vB z5^sh}K|(m2!Ob|9naOd!%W?ep5;~g?1qawV$BQ&gHO~WLWpCpa3|oDb^X_+=N~`}2 zkp02O_R+Wntx3M)yE{5-lw{*F1UsMQ2(pgqy6vuSPQNE zynp25J3fBcw`>zbJvG0o4(RXqq~|iJ5cNthEPDR=9Nzod>ZLmsr%%M(d8dw?H|Jw^ zf+|J6?){f|!~dI$`iPgoR4+(7w08nap4Mw%7U+$d^{?=0@&kSF3hXt6 zS*fo?IXXi-tjAHbU0+Lr%`n)i_mXxPh0zn`ERN)hXP!TtHy|%Rr|Rl|O`g|<@avH} z=U6|&t=?T-4%)qqBzzhq>-KRZ|I9!6QD