[CoreML EP] Enable coreml for onnx_test_runner and onnxruntime_perf_test (macOS only) (#6642)

This commit is contained in:
Guoyu Wang 2021-02-12 10:41:36 -08:00 committed by GitHub
parent 78e408dbe9
commit f11b5d3072
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 33 additions and 11 deletions

View file

@ -81,7 +81,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons
return true;
}
std::vector<std::vector<size_t>> GetSupportedNodes(const GraphViewer& graph_viewer, const logging::Logger& logger) {
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer, const logging::Logger& logger) {
std::vector<std::vector<size_t>> supported_node_vecs;
if (!util::HasRequiredBaseOS()) {
LOGS(logger, WARNING) << "All ops will fallback to CPU EP, because we do not have supported OS";
@ -97,15 +97,16 @@ std::vector<std::vector<size_t>> GetSupportedNodes(const GraphViewer& graph_view
std::vector<size_t> supported_node_vec;
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (size_t i = 0; i < node_indices.size(); i++) {
const auto* node(graph_viewer.GetNode(node_indices[i]));
auto node_idx = node_indices[i];
const auto* node(graph_viewer.GetNode(node_idx));
bool supported = IsNodeSupported(*node, graph_viewer, logger);
LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType()
<< "] index: [" << i
<< "] index: [" << node_idx
<< "] name: [" << node->Name()
<< "] supported: [" << supported
<< "]";
if (supported) {
supported_node_vec.push_back(i);
supported_node_vec.push_back(node_idx);
} else {
if (!supported_node_vec.empty()) {
supported_node_vecs.push_back(supported_node_vec);

View file

@ -4,6 +4,7 @@
#pragma once
#include <core/common/status.h>
#include <core/graph/basic_types.h>
namespace onnxruntime {
@ -24,8 +25,8 @@ bool GetType(const NodeArg& node_arg, int32_t& type, const logging::Logger& logg
bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
// Get a list of groups of supported nodes, each group represents a subgraph supported by CoreML EP
std::vector<std::vector<size_t>> GetSupportedNodes(const GraphViewer& graph_viewer,
const logging::Logger& logger);
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
const logging::Logger& logger);
} // namespace coreml
} // namespace onnxruntime

View file

@ -36,7 +36,7 @@ void usage() {
"\t-v: verbose\n"
"\t-n [test_case_name]: Specifies a single test case to run.\n"
"\t-e [EXECUTION_PROVIDER]: EXECUTION_PROVIDER could be 'cpu', 'cuda', 'dnnl', 'tensorrt', "
"'openvino', 'nuphar', 'migraphx', 'acl' or 'armnn'. "
"'openvino', 'nuphar', 'migraphx', 'acl', 'armnn', 'nnapi' or 'coreml'. "
"Default: 'cpu'.\n"
"\t-p: Pause after launch, can attach debugger and continue\n"
"\t-x: Use parallel executor, default (without -x): sequential executor.\n"
@ -98,6 +98,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
bool enable_tensorrt = false;
bool enable_mem_pattern = true;
bool enable_nnapi = false;
bool enable_coreml = false;
bool enable_dml = false;
bool enable_acl = false;
bool enable_armnn = false;
@ -165,6 +166,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
enable_tensorrt = true;
} else if (!CompareCString(optarg, ORT_TSTR("nnapi"))) {
enable_nnapi = true;
} else if (!CompareCString(optarg, ORT_TSTR("coreml"))) {
enable_coreml = true;
} else if (!CompareCString(optarg, ORT_TSTR("dml"))) {
enable_dml = true;
} else if (!CompareCString(optarg, ORT_TSTR("acl"))) {
@ -285,8 +288,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
double per_sample_tolerance = 1e-3;
// when cuda is enabled, set it to a larger value for resolving random MNIST test failure
// when openvino is enabled, set it to a larger value for resolving MNIST accuracy mismatch
double relative_per_sample_tolerance = enable_cuda ? 0.017 : enable_openvino ? 0.009
: 1e-3;
double relative_per_sample_tolerance = enable_cuda ? 0.017 : enable_openvino ? 0.009 : 1e-3;
Ort::SessionOptions sf;
@ -373,6 +375,14 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
#else
fprintf(stderr, "NNAPI is not supported in this build");
return -1;
#endif
}
if (enable_coreml) {
#ifdef USE_COREML
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sf, 0));
#else
fprintf(stderr, "CoreML is not supported in this build");
return -1;
#endif
}
if (enable_dml) {

View file

@ -34,7 +34,7 @@ namespace perftest {
"\t-I: Generate tensor input binding (Free dimensions are treated as 1.)\n"
"\t-c [parallel runs]: Specifies the (max) number of runs to invoke simultaneously. Default:1.\n"
"\t-e [cpu|cuda|dnnl|tensorrt|openvino|nuphar|dml|acl]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', "
"'openvino', 'nuphar', 'dml' or 'acl'. "
"'openvino', 'nuphar', 'dml', 'acl', 'nnapi' or 'coreml'. "
"Default:'cpu'.\n"
"\t-b [tf|ort]: backend to use. Default:ort\n"
"\t-r [repeated_times]: Specifies the repeated times if running in 'times' test mode.Default:1000.\n"
@ -93,6 +93,8 @@ namespace perftest {
test_config.machine_config.provider_type_name = onnxruntime::kTensorrtExecutionProvider;
} else if (!CompareCString(optarg, ORT_TSTR("nnapi"))) {
test_config.machine_config.provider_type_name = onnxruntime::kNnapiExecutionProvider;
} else if (!CompareCString(optarg, ORT_TSTR("coreml"))) {
test_config.machine_config.provider_type_name = onnxruntime::kCoreMLExecutionProvider;
} else if (!CompareCString(optarg, ORT_TSTR("nuphar"))) {
test_config.machine_config.provider_type_name = onnxruntime::kNupharExecutionProvider;
} else if (!CompareCString(optarg, ORT_TSTR("dml"))) {

View file

@ -77,6 +77,12 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_options, 0));
#else
ORT_THROW("NNAPI is not supported in this build\n");
#endif
} else if (provider_name == onnxruntime::kCoreMLExecutionProvider) {
#ifdef USE_COREML
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, 0));
#else
ORT_THROW("COREML is not supported in this build\n");
#endif
} else if (provider_name == onnxruntime::kDmlExecutionProvider) {
#ifdef USE_DML

View file

@ -22,6 +22,9 @@
#ifdef USE_NNAPI
#include "core/providers/nnapi/nnapi_provider_factory.h"
#endif
#ifdef USE_COREML
#include "core/providers/coreml/coreml_provider_factory.h"
#endif
#ifdef USE_DML
#include "core/providers/dml/dml_provider_factory.h"
#endif
@ -34,4 +37,3 @@
#ifdef USE_MIGRAPHX
#include "core/providers/migraphx/migraphx_provider_factory.h"
#endif