mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
[CoreML EP] Enable coreml for onnx_test_runner and onnxruntime_perf_test (macOS only) (#6642)
This commit is contained in:
parent
78e408dbe9
commit
f11b5d3072
6 changed files with 33 additions and 11 deletions
|
|
@ -81,7 +81,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons
|
|||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::vector<size_t>> GetSupportedNodes(const GraphViewer& graph_viewer, const logging::Logger& logger) {
|
||||
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer, const logging::Logger& logger) {
|
||||
std::vector<std::vector<size_t>> supported_node_vecs;
|
||||
if (!util::HasRequiredBaseOS()) {
|
||||
LOGS(logger, WARNING) << "All ops will fallback to CPU EP, because we do not have supported OS";
|
||||
|
|
@ -97,15 +97,16 @@ std::vector<std::vector<size_t>> GetSupportedNodes(const GraphViewer& graph_view
|
|||
std::vector<size_t> supported_node_vec;
|
||||
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
|
||||
for (size_t i = 0; i < node_indices.size(); i++) {
|
||||
const auto* node(graph_viewer.GetNode(node_indices[i]));
|
||||
auto node_idx = node_indices[i];
|
||||
const auto* node(graph_viewer.GetNode(node_idx));
|
||||
bool supported = IsNodeSupported(*node, graph_viewer, logger);
|
||||
LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType()
|
||||
<< "] index: [" << i
|
||||
<< "] index: [" << node_idx
|
||||
<< "] name: [" << node->Name()
|
||||
<< "] supported: [" << supported
|
||||
<< "]";
|
||||
if (supported) {
|
||||
supported_node_vec.push_back(i);
|
||||
supported_node_vec.push_back(node_idx);
|
||||
} else {
|
||||
if (!supported_node_vec.empty()) {
|
||||
supported_node_vecs.push_back(supported_node_vec);
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <core/common/status.h>
|
||||
#include <core/graph/basic_types.h>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -24,8 +25,8 @@ bool GetType(const NodeArg& node_arg, int32_t& type, const logging::Logger& logg
|
|||
bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
|
||||
|
||||
// Get a list of groups of supported nodes, each group represents a subgraph supported by CoreML EP
|
||||
std::vector<std::vector<size_t>> GetSupportedNodes(const GraphViewer& graph_viewer,
|
||||
const logging::Logger& logger);
|
||||
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
|
||||
const logging::Logger& logger);
|
||||
|
||||
} // namespace coreml
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ void usage() {
|
|||
"\t-v: verbose\n"
|
||||
"\t-n [test_case_name]: Specifies a single test case to run.\n"
|
||||
"\t-e [EXECUTION_PROVIDER]: EXECUTION_PROVIDER could be 'cpu', 'cuda', 'dnnl', 'tensorrt', "
|
||||
"'openvino', 'nuphar', 'migraphx', 'acl' or 'armnn'. "
|
||||
"'openvino', 'nuphar', 'migraphx', 'acl', 'armnn', 'nnapi' or 'coreml'. "
|
||||
"Default: 'cpu'.\n"
|
||||
"\t-p: Pause after launch, can attach debugger and continue\n"
|
||||
"\t-x: Use parallel executor, default (without -x): sequential executor.\n"
|
||||
|
|
@ -98,6 +98,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
bool enable_tensorrt = false;
|
||||
bool enable_mem_pattern = true;
|
||||
bool enable_nnapi = false;
|
||||
bool enable_coreml = false;
|
||||
bool enable_dml = false;
|
||||
bool enable_acl = false;
|
||||
bool enable_armnn = false;
|
||||
|
|
@ -165,6 +166,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
enable_tensorrt = true;
|
||||
} else if (!CompareCString(optarg, ORT_TSTR("nnapi"))) {
|
||||
enable_nnapi = true;
|
||||
} else if (!CompareCString(optarg, ORT_TSTR("coreml"))) {
|
||||
enable_coreml = true;
|
||||
} else if (!CompareCString(optarg, ORT_TSTR("dml"))) {
|
||||
enable_dml = true;
|
||||
} else if (!CompareCString(optarg, ORT_TSTR("acl"))) {
|
||||
|
|
@ -285,8 +288,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
double per_sample_tolerance = 1e-3;
|
||||
// when cuda is enabled, set it to a larger value for resolving random MNIST test failure
|
||||
// when openvino is enabled, set it to a larger value for resolving MNIST accuracy mismatch
|
||||
double relative_per_sample_tolerance = enable_cuda ? 0.017 : enable_openvino ? 0.009
|
||||
: 1e-3;
|
||||
double relative_per_sample_tolerance = enable_cuda ? 0.017 : enable_openvino ? 0.009 : 1e-3;
|
||||
|
||||
Ort::SessionOptions sf;
|
||||
|
||||
|
|
@ -373,6 +375,14 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
#else
|
||||
fprintf(stderr, "NNAPI is not supported in this build");
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
if (enable_coreml) {
|
||||
#ifdef USE_COREML
|
||||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sf, 0));
|
||||
#else
|
||||
fprintf(stderr, "CoreML is not supported in this build");
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
if (enable_dml) {
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ namespace perftest {
|
|||
"\t-I: Generate tensor input binding (Free dimensions are treated as 1.)\n"
|
||||
"\t-c [parallel runs]: Specifies the (max) number of runs to invoke simultaneously. Default:1.\n"
|
||||
"\t-e [cpu|cuda|dnnl|tensorrt|openvino|nuphar|dml|acl]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', "
|
||||
"'openvino', 'nuphar', 'dml' or 'acl'. "
|
||||
"'openvino', 'nuphar', 'dml', 'acl', 'nnapi' or 'coreml'. "
|
||||
"Default:'cpu'.\n"
|
||||
"\t-b [tf|ort]: backend to use. Default:ort\n"
|
||||
"\t-r [repeated_times]: Specifies the repeated times if running in 'times' test mode.Default:1000.\n"
|
||||
|
|
@ -93,6 +93,8 @@ namespace perftest {
|
|||
test_config.machine_config.provider_type_name = onnxruntime::kTensorrtExecutionProvider;
|
||||
} else if (!CompareCString(optarg, ORT_TSTR("nnapi"))) {
|
||||
test_config.machine_config.provider_type_name = onnxruntime::kNnapiExecutionProvider;
|
||||
} else if (!CompareCString(optarg, ORT_TSTR("coreml"))) {
|
||||
test_config.machine_config.provider_type_name = onnxruntime::kCoreMLExecutionProvider;
|
||||
} else if (!CompareCString(optarg, ORT_TSTR("nuphar"))) {
|
||||
test_config.machine_config.provider_type_name = onnxruntime::kNupharExecutionProvider;
|
||||
} else if (!CompareCString(optarg, ORT_TSTR("dml"))) {
|
||||
|
|
|
|||
|
|
@ -77,6 +77,12 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
|
|||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_options, 0));
|
||||
#else
|
||||
ORT_THROW("NNAPI is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kCoreMLExecutionProvider) {
|
||||
#ifdef USE_COREML
|
||||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, 0));
|
||||
#else
|
||||
ORT_THROW("COREML is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kDmlExecutionProvider) {
|
||||
#ifdef USE_DML
|
||||
|
|
|
|||
|
|
@ -22,6 +22,9 @@
|
|||
#ifdef USE_NNAPI
|
||||
#include "core/providers/nnapi/nnapi_provider_factory.h"
|
||||
#endif
|
||||
#ifdef USE_COREML
|
||||
#include "core/providers/coreml/coreml_provider_factory.h"
|
||||
#endif
|
||||
#ifdef USE_DML
|
||||
#include "core/providers/dml/dml_provider_factory.h"
|
||||
#endif
|
||||
|
|
@ -34,4 +37,3 @@
|
|||
#ifdef USE_MIGRAPHX
|
||||
#include "core/providers/migraphx/migraphx_provider_factory.h"
|
||||
#endif
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue