Refactor the model tests code in onnxruntime_test_all.exe (#11300)

Update the code to use OrtApis instead of the old onnxruntime::InferenceSession class. Mainly because the old one doesn't support custom op. We are trying to convert some EPs to custom ops. Hopefully they can continue to leverage this test set.
This commit is contained in:
Changming Sun 2022-04-25 11:52:51 -07:00 committed by GitHub
parent 6fb29f5b9a
commit aaa583e776
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -6,9 +6,12 @@
#include "core/session/onnxruntime_c_api.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/common/gsl_suppress.h"
#include "core/session/ort_apis.h"
#include "core/session/inference_session.h"
#include "core/session/ort_env.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"
#include "test_allocator.h"
#include "asserts.h"
#include <core/platform/path_lib.h>
#include "default_providers.h"
@ -16,6 +19,30 @@
#include <codecvt>
#include <locale>
#ifdef USE_DNNL
#include "core/providers/dnnl/dnnl_provider_factory.h"
#endif
#ifdef USE_NUPHAR
#include "core/providers/nuphar/nuphar_provider_factory.h"
#endif
#ifdef USE_NNAPI
#include "core/providers/nnapi/nnapi_provider_factory.h"
#endif
#ifdef USE_RKNPU
#include "core/providers/rknpu/rknpu_provider_factory.h"
#endif
#ifdef USE_ACL
#include "core/providers/acl/acl_provider_factory.h"
#endif
#ifdef USE_ARMNN
#include "core/providers/armnn/armnn_provider_factory.h"
#endif
// test infrastructure
#include "test/onnx/TestCase.h"
#include "test/compare_ortvalue.h"
@ -25,6 +52,12 @@
extern std::unique_ptr<Ort::Env> ort_env;
#define ASSERT_ORT_STATUS_OK(function) \
do { \
OrtStatus* _tmp_status = (function); \
ASSERT_EQ(_tmp_status, nullptr) << OrtApis::GetErrorMessage(_tmp_status); \
} while (false)
using namespace onnxruntime::common;
namespace onnxruntime {
@ -72,11 +105,12 @@ TEST_P(ModelTest, Run) {
}
std::unique_ptr<OnnxModelInfo> model_info = std::make_unique<OnnxModelInfo>(model_path.c_str());
if (model_info->GetONNXOpSetVersion() != 14 && model_info->GetONNXOpSetVersion() != 15 && provider_name == "tensorrt") {
if (model_info->GetONNXOpSetVersion() != 14 && model_info->GetONNXOpSetVersion() != 15 &&
provider_name == "tensorrt") {
// TensorRT can run most of the model tests, but only part of
// them is enabled here to save CI build time.
// Besides saving CI build time, TRT isnt able to support full ONNX ops spec and therefore some testcases will fail.
// That's one of reasons we skip those testcases and only test latest ONNX opsets.
// Besides saving CI build time, TRT isnt able to support full ONNX ops spec and therefore some testcases will
// fail. That's one of reasons we skip those testcases and only test latest ONNX opsets.
SkipTest();
return;
}
@ -95,7 +129,8 @@ TEST_P(ModelTest, Run) {
#endif
// TODO: filter model based on opset
std::set<BrokenTest> broken_tests = {
{"slice_neg_steps", "Type parameter (Tind) bound to different types (tensor(int64) and tensor(int32) in node ()."},
{"slice_neg_steps",
"Type parameter (Tind) bound to different types (tensor(int64) and tensor(int32) in node ()."},
{"cast_BFLOAT16_to_FLOAT", "Unexpected input data type"},
{"loop13_seq", "Creation of empty sequences is currently not supported in the test runner"},
{"sequence_insert_at_front", "shape mismatch, expect {4} got {3}"},
@ -175,7 +210,9 @@ TEST_P(ModelTest, Run) {
{"nesterov_momentum", "not a registered function/op", {}}, // Op not registered.
{"softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight", "type error", {"onnx170"}},
{"softmax_cross_entropy_mean_weight_ignore_index_log_prob", "type error", {"onnx170"}},
{"softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob", "type error", {"onnx170"}},
{"softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob",
"type error",
{"onnx170"}},
{"softmax_cross_entropy_mean_weight_log_prob", "type error", {"onnx170"}},
{"softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_log_prob", "type error", {"onnx170"}},
{"softmax_cross_entropy_mean_weight_ignore_index_3d", "type error", {"onnx170"}},
@ -185,7 +222,9 @@ TEST_P(ModelTest, Run) {
{"softmax_cross_entropy_mean", "type error", {"onnx170"}},
{"softmax_cross_entropy_mean_log_prob", "type error", {"onnx170"}},
{"softmax_cross_entropy_mean_no_weight_ignore_index", "type error", {"onnx170"}},
{"softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_log_prob", "type error", {"onnx170"}},
{"softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_log_prob",
"type error",
{"onnx170"}},
{"softmax_cross_entropy_mean_3d_log_prob", "type error", {"onnx170"}},
{"softmax_cross_entropy_none_log_prob", "type error", {"onnx170"}},
{"softmax_cross_entropy_mean_3d", "type error", {"onnx170"}},
@ -199,7 +238,9 @@ TEST_P(ModelTest, Run) {
{"softmax_cross_entropy_mean_no_weight_ignore_index_3d", "type error", {"onnx170"}},
{"softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index", "type error", {"onnx170"}},
{"softmax_cross_entropy_sum", "type error", {"onnx170"}},
{"softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_log_prob", "type error", {"onnx170"}},
{"softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_log_prob",
"type error",
{"onnx170"}},
{"softmax_cross_entropy_none_weights", "type error", {"onnx170"}},
{"softmax_cross_entropy_mean_no_weight_ignore_index_4d_log_prob", "type error", {"onnx170"}},
{"softmax_cross_entropy_none", "type error", {"onnx170"}},
@ -212,7 +253,8 @@ TEST_P(ModelTest, Run) {
// Some EPs may fail to pass some specific testcases.
// For example TenosrRT EP may fail on FLOAT16 related testcases if GPU doesn't support float16.
// Instead of list all these testcases, we can use following keyword set to filter out testcases wchich contain specific keyword.
// Instead of list all these testcases, we can use following keyword set to filter out testcases wchich contain
// specific keyword.
std::set<std::string> broken_tests_keyword_set = {};
if (provider_name == "nuphar") {
@ -391,13 +433,17 @@ TEST_P(ModelTest, Run) {
broken_tests.insert({"dynamicquantizelinear_max_adjusted_expanded", "It causes segmentation fault"});
broken_tests.insert({"basic_conv_with_padding",
"Cannot set more than one input unless network has Q/DQ layers. TensorRT EP could not build engine for fused node"});
"Cannot set more than one input unless network has Q/DQ layers. TensorRT EP could not build "
"engine for fused node"});
broken_tests.insert({"basic_conv_without_padding",
"Cannot set more than one input unless network has Q/DQ layers. TensorRT EP could not build engine for fused node"});
"Cannot set more than one input unless network has Q/DQ layers. TensorRT EP could not build "
"engine for fused node"});
broken_tests.insert({"conv_with_strides_no_padding",
"Cannot set more than one input unless network has Q/DQ layers. TensorRT EP could not build engine for fused node"});
"Cannot set more than one input unless network has Q/DQ layers. TensorRT EP could not build "
"engine for fused node"});
broken_tests.insert({"conv_with_autopad_same", "Internal Error (node_of_y: Cannot set more than one input unless network has Q/DQ layers.)"});
broken_tests.insert({"conv_with_autopad_same",
"Internal Error (node_of_y: Cannot set more than one input unless network has Q/DQ layers.)"});
// sce op is not supported
broken_tests_keyword_set.insert({"sce"});
@ -581,91 +627,135 @@ TEST_P(ModelTest, Run) {
for (bool is_single_thread : use_single_thread) {
for (ExecutionMode execution_mode : execution_modes) {
SessionOptions so;
if (!is_single_thread)
so.use_per_session_threads = false;
else
so.intra_op_param.thread_pool_size = 1; // Disable intra op thread pool
so.execution_mode = execution_mode;
so.session_logid = ToUTF8String(test_case_name);
so.session_log_severity_level = (int)logging::Severity::kERROR;
InferenceSession session_object(so, (**ort_env).GetEnvironment());
if (provider_name == "cuda") {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider()));
} else if (provider_name == "rocm") {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider()));
} else if (provider_name == "dnnl") {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultDnnlExecutionProvider()));
} else if (provider_name == "nuphar") {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultNupharExecutionProvider()));
} else if (provider_name == "tensorrt") {
if (test_case_name.find(ORT_TSTR("FLOAT16")) != std::string::npos) {
OrtTensorRTProviderOptionsV2 params{
0,
0,
nullptr,
1000,
1,
1 << 30,
1, // enable fp16
0,
nullptr,
0,
0,
0,
0,
0,
nullptr,
0,
nullptr,
0};
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(TensorrtExecutionProviderWithOptions(&params)));
} else {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultTensorrtExecutionProvider()));
}
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider()));
} else if (provider_name == "migraphx") {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultMIGraphXExecutionProvider()));
} else if (provider_name == "openvino") {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultOpenVINOExecutionProvider()));
} else if (provider_name == "nnapi") {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultNnapiExecutionProvider()));
} else if (provider_name == "rknpu") {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRknpuExecutionProvider()));
} else if (provider_name == "acl") {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultAclExecutionProvider()));
} else if (provider_name == "armnn") {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultArmNNExecutionProvider()));
OrtSessionOptions* ortso;
ASSERT_ORT_STATUS_OK(OrtApis::CreateSessionOptions(&ortso));
std::unique_ptr<OrtSessionOptions, decltype(&OrtApis::ReleaseSessionOptions)> rel_ort_session_option(
ortso, &OrtApis::ReleaseSessionOptions);
if (!is_single_thread) {
ASSERT_ORT_STATUS_OK(OrtApis::DisablePerSessionThreads(ortso));
} else {
ASSERT_ORT_STATUS_OK(OrtApis::SetIntraOpNumThreads(ortso, 1));
}
ASSERT_STATUS_OK(session_object.Load(model_path));
auto st = session_object.Initialize();
if (st.Code() == NOT_IMPLEMENTED)
return;
ASSERT_TRUE(st.IsOK()) << st.ErrorMessage();
ASSERT_ORT_STATUS_OK(OrtApis::SetSessionExecutionMode(ortso, execution_mode));
ASSERT_ORT_STATUS_OK(OrtApis::SetSessionLogId(ortso, ToUTF8String(test_case_name).c_str()));
ASSERT_ORT_STATUS_OK(OrtApis::SetSessionLogSeverityLevel(ortso, ORT_LOGGING_LEVEL_ERROR));
if (provider_name == "cuda") {
OrtCUDAProviderOptionsV2* cuda_options = nullptr;
ASSERT_ORT_STATUS_OK(OrtApis::CreateCUDAProviderOptions(&cuda_options));
std::unique_ptr<OrtCUDAProviderOptionsV2, decltype(&OrtApis::ReleaseCUDAProviderOptions)> rel_cuda_options(
cuda_options, &OrtApis::ReleaseCUDAProviderOptions);
ASSERT_ORT_STATUS_OK(OrtApis::SessionOptionsAppendExecutionProvider_CUDA_V2(ortso, cuda_options));
} else if (provider_name == "rocm") {
OrtROCMProviderOptions ep_options;
ASSERT_ORT_STATUS_OK(OrtApis::SessionOptionsAppendExecutionProvider_ROCM(ortso, &ep_options));
}
#ifdef USE_DNNL
else if (provider_name == "dnnl") {
ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_Dnnl(ortso, false));
}
#endif
#ifdef USE_NUPHAR
else if (provider_name == "nuphar") {
ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_Nuphar(ortso, 1, ""));
}
#endif
else if (provider_name == "tensorrt") {
if (test_case_name.find(ORT_TSTR("FLOAT16")) != std::string::npos) {
OrtTensorRTProviderOptionsV2 params{0, 0, nullptr, 1000, 1, 1 << 30,
1, // enable fp16
0, nullptr, 0, 0, 0, 0, 0, nullptr, 0, nullptr, 0};
ASSERT_ORT_STATUS_OK(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(ortso, &params));
} else {
OrtTensorRTProviderOptionsV2* ep_option;
ASSERT_ORT_STATUS_OK(OrtApis::CreateTensorRTProviderOptions(&ep_option));
std::unique_ptr<OrtTensorRTProviderOptionsV2, decltype(&OrtApis::ReleaseTensorRTProviderOptions)>
rel_cuda_options(ep_option, &OrtApis::ReleaseTensorRTProviderOptions);
ASSERT_ORT_STATUS_OK(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(ortso, ep_option));
}
// Enable CUDA fallback
OrtCUDAProviderOptionsV2* cuda_options = nullptr;
ASSERT_ORT_STATUS_OK(OrtApis::CreateCUDAProviderOptions(&cuda_options));
std::unique_ptr<OrtCUDAProviderOptionsV2, decltype(&OrtApis::ReleaseCUDAProviderOptions)> rel_cuda_options(
cuda_options, &OrtApis::ReleaseCUDAProviderOptions);
ASSERT_ORT_STATUS_OK(OrtApis::SessionOptionsAppendExecutionProvider_CUDA_V2(ortso, cuda_options));
} else if (provider_name == "migraphx") {
OrtMIGraphXProviderOptions ep_options;
ASSERT_ORT_STATUS_OK(OrtApis::SessionOptionsAppendExecutionProvider_MIGraphX(ortso, &ep_options));
} else if (provider_name == "openvino") {
OrtOpenVINOProviderOptions ep_options;
ASSERT_ORT_STATUS_OK(OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO(ortso, &ep_options));
}
#ifdef USE_NNAPI
else if (provider_name == "nnapi") {
ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_Nnapi(ortso, 0));
}
#endif
#ifdef USE_RKNPU
else if (provider_name == "rknpu") {
ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_Rknpu(ortso));
}
#endif
#ifdef USE_ACL
else if (provider_name == "acl") {
ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_ACL(ortso, 0));
}
#endif
#ifdef USE_ARMNN
else if (provider_name == "armnn") {
ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_ArmNN(ortso));
}
#endif
OrtSession* ort_session;
OrtStatus* ort_st = OrtApis::CreateSession(*ort_env, model_path.c_str(), ortso, &ort_session);
if (ort_st != nullptr) {
OrtErrorCode error_code = OrtApis::GetErrorCode(ort_st);
if (error_code == ORT_NOT_IMPLEMENTED) {
OrtApis::ReleaseStatus(ort_st);
continue;
}
FAIL() << OrtApis::GetErrorMessage(ort_st);
}
std::unique_ptr<OrtSession, decltype(&OrtApis::ReleaseSession)> rel_ort_session(ort_session,
&OrtApis::ReleaseSession);
const size_t data_count = l->GetDataCount();
auto default_allocator = std::make_unique<MockedOrtAllocator>();
for (size_t task_id = 0; task_id != data_count; ++task_id) {
onnxruntime::test::HeapBuffer holder;
std::unordered_map<std::string, Ort::Value> feeds;
l->LoadTestData(task_id, holder, feeds, true);
std::pair<common::Status, const OutputDefList*> output_meta_data = session_object.GetModelOutputs();
ASSERT_STATUS_OK(output_meta_data.first);
size_t output_count;
ASSERT_ORT_STATUS_OK(OrtApis::SessionGetOutputCount(ort_session, &output_count));
// Create output feed
size_t output_count = output_meta_data.second->size();
std::vector<std::string> output_names(output_count);
std::vector<char*> output_names(output_count);
for (size_t i = 0; i != output_count; ++i) {
output_names[i] = (*output_meta_data.second)[i]->Name();
ASSERT_ORT_STATUS_OK(
OrtApis::SessionGetOutputName(ort_session, i, default_allocator.get(), &output_names[i]));
}
std::vector<OrtValue> output_values(output_count);
std::vector<const char*> input_names;
std::vector<OrtValue*> input_values;
std::vector<OrtValue*> output_values(output_count);
{
std::unordered_map<std::string, OrtValue> input;
for (auto& p : feeds) {
const OrtValue* v = p.second;
input.emplace(p.first, *v);
input_names.push_back(p.first.c_str());
input_values.push_back(p.second);
}
ort_st = OrtApis::Run(ort_session, nullptr, input_names.data(), input_values.data(), input_values.size(),
output_names.data(), output_names.size(), output_values.data());
if (ort_st != nullptr) {
OrtErrorCode error_code = OrtApis::GetErrorCode(ort_st);
if (error_code == ORT_NOT_IMPLEMENTED) {
OrtApis::ReleaseStatus(ort_st);
for (char* p : output_names) {
default_allocator->Free(p);
}
for (OrtValue* v : output_values) {
OrtApis::ReleaseValue(v);
}
}
FAIL() << OrtApis::GetErrorMessage(ort_st);
}
ASSERT_STATUS_OK(session_object.Run(input, output_names, &output_values));
}
bool post_procesing = false;
@ -683,7 +773,7 @@ TEST_P(ModelTest, Run) {
size_t i = 0;
for (auto& output_name : output_names) {
// p_fetches is filled in the order of output_names.
name_fetch_output_map[output_name] = &output_values[i];
name_fetch_output_map[output_name] = output_values[i];
const ONNX_NAMESPACE::ValueInfoProto* infoProto = l->GetOutputInfoFromModel(i);
if (infoProto != nullptr)
name_output_value_info_proto.insert(std::make_pair(infoProto->name(), infoProto));
@ -714,6 +804,12 @@ TEST_P(ModelTest, Run) {
break;
}
}
for (char* p : output_names) {
default_allocator->Free(p);
}
for (OrtValue* v : output_values) {
OrtApis::ReleaseValue(v);
}
}
}
}
@ -797,69 +893,81 @@ TEST_P(ModelTest, Run) {
ORT_TSTR("operator_pow"),
};
static const ORTCHAR_T* cuda_flaky_tests[] = {
ORT_TSTR("fp16_inception_v1"),
ORT_TSTR("fp16_shufflenet"),
ORT_TSTR("fp16_tiny_yolov2"),
ORT_TSTR("candy"),
static const ORTCHAR_T* cuda_flaky_tests[] = {ORT_TSTR("fp16_inception_v1"),
ORT_TSTR("fp16_shufflenet"),
ORT_TSTR("fp16_tiny_yolov2"),
ORT_TSTR("candy"),
ORT_TSTR("tinyyolov3"),
ORT_TSTR("mlperf_ssd_mobilenet_300"),
ORT_TSTR("mlperf_ssd_resnet34_1200"),
ORT_TSTR("tf_inception_v1"),
ORT_TSTR("faster_rcnn"),
ORT_TSTR("split_zero_size_splits"),
ORT_TSTR("convtranspose_3d"),
ORT_TSTR("fp16_test_tiny_yolov2-Candy"),
ORT_TSTR("fp16_coreml_FNS-Candy"),
ORT_TSTR("fp16_test_tiny_yolov2"),
ORT_TSTR("fp16_test_shufflenet"),
ORT_TSTR("keras2coreml_SimpleRNN_ImageNet")};
static const ORTCHAR_T* openvino_disabled_tests[] = {
ORT_TSTR("tf_mobilenet_v1_1.0_224"),
ORT_TSTR("bertsquad"),
ORT_TSTR("yolov3"),
ORT_TSTR("LSTM_Seq_lens_unpacked"),
ORT_TSTR("tinyyolov3"),
ORT_TSTR("mlperf_ssd_mobilenet_300"),
ORT_TSTR("mlperf_ssd_resnet34_1200"),
ORT_TSTR("tf_inception_v1"),
ORT_TSTR("faster_rcnn"),
ORT_TSTR("split_zero_size_splits"),
ORT_TSTR("convtranspose_3d"),
ORT_TSTR("fp16_test_tiny_yolov2-Candy"),
ORT_TSTR("fp16_coreml_FNS-Candy"),
ORT_TSTR("fp16_test_tiny_yolov2"),
ORT_TSTR("fp16_test_shufflenet"),
ORT_TSTR("keras2coreml_SimpleRNN_ImageNet")};
static const ORTCHAR_T* openvino_disabled_tests[] = {ORT_TSTR("tf_mobilenet_v1_1.0_224"),
ORT_TSTR("bertsquad"),
ORT_TSTR("yolov3"),
ORT_TSTR("LSTM_Seq_lens_unpacked"),
ORT_TSTR("tinyyolov3"),
ORT_TSTR("faster_rcnn"),
ORT_TSTR("mask_rcnn"),
ORT_TSTR("coreml_FNS-Candy_ImageNet"),
ORT_TSTR("tf_mobilenet_v2_1.0_224"),
ORT_TSTR("tf_mobilenet_v2_1.4_224"),
ORT_TSTR("operator_permute2"),
ORT_TSTR("operator_repeat"),
ORT_TSTR("operator_repeat_dim_overflow"),
ORT_TSTR("mlperf_ssd_resnet34_1200"),
ORT_TSTR("candy"),
ORT_TSTR("cntk_simple_seg"),
ORT_TSTR("GPT2_LM_HEAD"),
ORT_TSTR("mlperf_ssd_mobilenet_300"),
ORT_TSTR("negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_mean_weight"),
ORT_TSTR("negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_mean_weight_expanded"),
ORT_TSTR("negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_none_no_weight"),
ORT_TSTR("negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_none_no_weight_expanded"),
ORT_TSTR("softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight"),
ORT_TSTR("softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_expanded"),
ORT_TSTR("softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_log_prob"),
ORT_TSTR("softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_log_prob_expanded"),
ORT_TSTR("softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight"),
ORT_TSTR("softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_expanded"),
ORT_TSTR("softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_prob"),
ORT_TSTR("softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_prob_expanded")};
ORT_TSTR("mask_rcnn"),
ORT_TSTR("coreml_FNS-Candy_ImageNet"),
ORT_TSTR("tf_mobilenet_v2_1.0_224"),
ORT_TSTR("tf_mobilenet_v2_1.4_224"),
ORT_TSTR("operator_permute2"),
ORT_TSTR("operator_repeat"),
ORT_TSTR("operator_repeat_dim_overflow"),
ORT_TSTR("mlperf_ssd_resnet34_1200"),
ORT_TSTR("candy"),
ORT_TSTR("cntk_simple_seg"),
ORT_TSTR("GPT2_LM_HEAD"),
ORT_TSTR("mlperf_ssd_mobilenet_300"),
ORT_TSTR("negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_mean_weight"),
ORT_TSTR("negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_mean_weight_expanded"),
ORT_TSTR("negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_none_no_weight"),
ORT_TSTR("negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_none_no_weight_expanded"),
ORT_TSTR("softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight"),
ORT_TSTR("softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_expanded"),
ORT_TSTR("softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_log_prob"),
ORT_TSTR("softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_log_prob_expanded"),
ORT_TSTR("softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight"),
ORT_TSTR("softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_expanded"),
ORT_TSTR("softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_prob"),
ORT_TSTR("softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_prob_expanded")};
static const ORTCHAR_T* dml_disabled_tests[] = {ORT_TSTR("mlperf_ssd_resnet34_1200"),
ORT_TSTR("mlperf_ssd_mobilenet_300"), ORT_TSTR("mask_rcnn"),
ORT_TSTR("faster_rcnn"), ORT_TSTR("tf_pnasnet_large"),
ORT_TSTR("zfnet512"), ORT_TSTR("keras2coreml_Dense_ImageNet")};
static const ORTCHAR_T* dnnl_disabled_tests[] = {ORT_TSTR("densenet121"), ORT_TSTR("resnet18v2"),
ORT_TSTR("resnet34v2"), ORT_TSTR("resnet50v2"),
ORT_TSTR("mlperf_ssd_mobilenet_300"),
ORT_TSTR("mask_rcnn"),
ORT_TSTR("faster_rcnn"),
ORT_TSTR("tf_pnasnet_large"),
ORT_TSTR("zfnet512"),
ORT_TSTR("keras2coreml_Dense_ImageNet")};
static const ORTCHAR_T* dnnl_disabled_tests[] = {ORT_TSTR("densenet121"),
ORT_TSTR("resnet18v2"),
ORT_TSTR("resnet34v2"),
ORT_TSTR("resnet50v2"),
ORT_TSTR("resnet101v2"),
ORT_TSTR("resnet101v2"), ORT_TSTR("vgg19"),
ORT_TSTR("tf_inception_resnet_v2"), ORT_TSTR("tf_inception_v1"),
ORT_TSTR("tf_inception_v3"), ORT_TSTR("tf_inception_v4"),
ORT_TSTR("resnet101v2"),
ORT_TSTR("vgg19"),
ORT_TSTR("tf_inception_resnet_v2"),
ORT_TSTR("tf_inception_v1"),
ORT_TSTR("tf_inception_v3"),
ORT_TSTR("tf_inception_v4"),
ORT_TSTR("tf_mobilenet_v1_1.0_224"),
ORT_TSTR("tf_mobilenet_v2_1.0_224"),
ORT_TSTR("tf_mobilenet_v2_1.4_224"), ORT_TSTR("tf_nasnet_large"),
ORT_TSTR("tf_pnasnet_large"), ORT_TSTR("tf_resnet_v1_50"),
ORT_TSTR("tf_resnet_v1_101"), ORT_TSTR("tf_resnet_v1_101"),
ORT_TSTR("tf_resnet_v2_101"), ORT_TSTR("tf_resnet_v2_152"),
ORT_TSTR("tf_mobilenet_v2_1.4_224"),
ORT_TSTR("tf_nasnet_large"),
ORT_TSTR("tf_pnasnet_large"),
ORT_TSTR("tf_resnet_v1_50"),
ORT_TSTR("tf_resnet_v1_101"),
ORT_TSTR("tf_resnet_v1_101"),
ORT_TSTR("tf_resnet_v2_101"),
ORT_TSTR("tf_resnet_v2_152"),
ORT_TSTR("batchnorm_example_training_mode"),
ORT_TSTR("batchnorm_epsilon_training_mode"),
ORT_TSTR("mobilenetv2-1.0"),
@ -876,12 +984,17 @@ TEST_P(ModelTest, Run) {
ORT_TSTR("mul_uint8"),
ORT_TSTR("div_uint8")};
static const ORTCHAR_T* tensorrt_disabled_tests[] = {
ORT_TSTR("udnie"), ORT_TSTR("rain_princess"),
ORT_TSTR("pointilism"), ORT_TSTR("mosaic"),
ORT_TSTR("udnie"),
ORT_TSTR("rain_princess"),
ORT_TSTR("pointilism"),
ORT_TSTR("mosaic"),
ORT_TSTR("LSTM_Seq_lens_unpacked"),
ORT_TSTR("cgan"), ORT_TSTR("candy"),
ORT_TSTR("tinyyolov3"), ORT_TSTR("yolov3"),
ORT_TSTR("mlperf_ssd_resnet34_1200"), ORT_TSTR("mlperf_ssd_mobilenet_300"),
ORT_TSTR("cgan"),
ORT_TSTR("candy"),
ORT_TSTR("tinyyolov3"),
ORT_TSTR("yolov3"),
ORT_TSTR("mlperf_ssd_resnet34_1200"),
ORT_TSTR("mlperf_ssd_mobilenet_300"),
ORT_TSTR("mask_rcnn"),
ORT_TSTR("faster_rcnn"),
ORT_TSTR("fp16_shufflenet"),
@ -901,7 +1014,7 @@ TEST_P(ModelTest, Run) {
ORT_TSTR("convtranspose_3d"),
ORT_TSTR("conv_with_strides_and_asymmetric_padding"),
ORT_TSTR("conv_with_strides_padding"),
ORT_TSTR("size") //INVALID_ARGUMENT: Cannot find binding of given name: x
ORT_TSTR("size") // INVALID_ARGUMENT: Cannot find binding of given name: x
};
for (const ORTCHAR_T* provider_name : provider_names) {
std::unordered_set<std::basic_string<ORTCHAR_T>> all_disabled_tests(std::begin(immutable_broken_tests),
@ -1037,8 +1150,9 @@ auto ExpandModelName = [](const ::testing::TestParamInfo<ModelTest::ParamType>&
#endif
};
// The optional last argument is a function or functor that generates custom test name suffixes based on the test parameters.
// Specify the last argument to make test name more meaningful and clear instead of just the sequential number.
// The optional last argument is a function or functor that generates custom test name suffixes based on the test
// parameters. Specify the last argument to make test name more meaningful and clear instead of just the sequential
// number.
INSTANTIATE_TEST_SUITE_P(ModelTests, ModelTest, testing::ValuesIn(GetParameterStrings()), ExpandModelName);
} // namespace test