diff --git a/onnxruntime/test/contrib_ops/group_norm_op_test.cc b/onnxruntime/test/contrib_ops/group_norm_op_test.cc index b02c135702..4983cb5abf 100644 --- a/onnxruntime/test/contrib_ops/group_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_norm_op_test.cc @@ -737,6 +737,22 @@ TEST(GroupNormTest, GroupNorm_128) { for (const int channels_last : channels_last_values) { if (enable_cuda || enable_rocm || enable_dml) { + std::vector> execution_providers; + if (enable_cuda && channels_last != 0) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + if (enable_rocm && channels_last != 0) { + execution_providers.push_back(DefaultRocmExecutionProvider()); + } + if (enable_dml) { + execution_providers.push_back(DefaultDmlExecutionProvider()); + } + + // Don't run the test if no providers are supported + if (execution_providers.empty()) { + continue; + } + OpTester test("GroupNorm", 1, onnxruntime::kMSDomain); test.AddAttribute("epsilon", 1e-05f); test.AddAttribute("groups", 32); @@ -763,7 +779,12 @@ TEST(GroupNormTest, GroupNorm_128) { test.AddInput("gamma", {C}, gamma_data); test.AddInput("beta", {C}, beta_data); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + // Test float32, with activation + enable_cuda = HasCudaEnvironment(0); + if (enable_cuda || enable_rocm || enable_dml) { std::vector> execution_providers; if (enable_cuda && channels_last != 0) { execution_providers.push_back(DefaultCudaExecutionProvider()); @@ -775,14 +796,11 @@ TEST(GroupNormTest, GroupNorm_128) { execution_providers.push_back(DefaultDmlExecutionProvider()); } - if (!execution_providers.empty()) { - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + // Don't run the test if no providers are supported + if (execution_providers.empty()) { + continue; } - } - // Test float32, with activation - enable_cuda = HasCudaEnvironment(0); - if (enable_cuda || enable_rocm || enable_dml) { OpTester test("GroupNorm", 1, onnxruntime::kMSDomain); test.AddAttribute("epsilon", 1e-05f); test.AddAttribute("groups", 32); @@ -809,21 +827,7 @@ TEST(GroupNormTest, GroupNorm_128) { test.AddInput("gamma", {C}, gamma_data); test.AddInput("beta", {C}, beta_data); - - std::vector> execution_providers; - if (enable_cuda && channels_last != 0) { - execution_providers.push_back(DefaultCudaExecutionProvider()); - } - if (enable_rocm && channels_last != 0) { - execution_providers.push_back(DefaultRocmExecutionProvider()); - } - if (enable_dml) { - execution_providers.push_back(DefaultDmlExecutionProvider()); - } - - if (!execution_providers.empty()) { - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } }