Fix GroupNorm tests failing when no providers are supported (#17054)

This commit is contained in:
Patrice Vignola 2023-08-09 13:14:13 -07:00 committed by GitHub
parent a7542f48d6
commit 4bc2287a85
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -737,6 +737,22 @@ TEST(GroupNormTest, GroupNorm_128) {
for (const int channels_last : channels_last_values) {
if (enable_cuda || enable_rocm || enable_dml) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
if (enable_cuda && channels_last != 0) {
execution_providers.push_back(DefaultCudaExecutionProvider());
}
if (enable_rocm && channels_last != 0) {
execution_providers.push_back(DefaultRocmExecutionProvider());
}
if (enable_dml) {
execution_providers.push_back(DefaultDmlExecutionProvider());
}
// Don't run the test if no providers are supported
if (execution_providers.empty()) {
continue;
}
OpTester test("GroupNorm", 1, onnxruntime::kMSDomain);
test.AddAttribute<float>("epsilon", 1e-05f);
test.AddAttribute<int64_t>("groups", 32);
@ -763,7 +779,12 @@ TEST(GroupNormTest, GroupNorm_128) {
test.AddInput<float>("gamma", {C}, gamma_data);
test.AddInput<float>("beta", {C}, beta_data);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
// Test float32, with activation
enable_cuda = HasCudaEnvironment(0);
if (enable_cuda || enable_rocm || enable_dml) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
if (enable_cuda && channels_last != 0) {
execution_providers.push_back(DefaultCudaExecutionProvider());
@ -775,14 +796,11 @@ TEST(GroupNormTest, GroupNorm_128) {
execution_providers.push_back(DefaultDmlExecutionProvider());
}
if (!execution_providers.empty()) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
// Don't run the test if no providers are supported
if (execution_providers.empty()) {
continue;
}
}
// Test float32, with activation
enable_cuda = HasCudaEnvironment(0);
if (enable_cuda || enable_rocm || enable_dml) {
OpTester test("GroupNorm", 1, onnxruntime::kMSDomain);
test.AddAttribute<float>("epsilon", 1e-05f);
test.AddAttribute<int64_t>("groups", 32);
@ -809,21 +827,7 @@ TEST(GroupNormTest, GroupNorm_128) {
test.AddInput<float>("gamma", {C}, gamma_data);
test.AddInput<float>("beta", {C}, beta_data);
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
if (enable_cuda && channels_last != 0) {
execution_providers.push_back(DefaultCudaExecutionProvider());
}
if (enable_rocm && channels_last != 0) {
execution_providers.push_back(DefaultRocmExecutionProvider());
}
if (enable_dml) {
execution_providers.push_back(DefaultDmlExecutionProvider());
}
if (!execution_providers.empty()) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}
}