mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Fix GroupNorm tests failing when no providers are supported (#17054)
This commit is contained in:
parent
a7542f48d6
commit
4bc2287a85
1 changed files with 25 additions and 21 deletions
|
|
@ -737,6 +737,22 @@ TEST(GroupNormTest, GroupNorm_128) {
|
|||
|
||||
for (const int channels_last : channels_last_values) {
|
||||
if (enable_cuda || enable_rocm || enable_dml) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
if (enable_cuda && channels_last != 0) {
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
}
|
||||
if (enable_rocm && channels_last != 0) {
|
||||
execution_providers.push_back(DefaultRocmExecutionProvider());
|
||||
}
|
||||
if (enable_dml) {
|
||||
execution_providers.push_back(DefaultDmlExecutionProvider());
|
||||
}
|
||||
|
||||
// Don't run the test if no providers are supported
|
||||
if (execution_providers.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
OpTester test("GroupNorm", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute<float>("epsilon", 1e-05f);
|
||||
test.AddAttribute<int64_t>("groups", 32);
|
||||
|
|
@ -763,7 +779,12 @@ TEST(GroupNormTest, GroupNorm_128) {
|
|||
|
||||
test.AddInput<float>("gamma", {C}, gamma_data);
|
||||
test.AddInput<float>("beta", {C}, beta_data);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
|
||||
// Test float32, with activation
|
||||
enable_cuda = HasCudaEnvironment(0);
|
||||
if (enable_cuda || enable_rocm || enable_dml) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
if (enable_cuda && channels_last != 0) {
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
|
|
@ -775,14 +796,11 @@ TEST(GroupNormTest, GroupNorm_128) {
|
|||
execution_providers.push_back(DefaultDmlExecutionProvider());
|
||||
}
|
||||
|
||||
if (!execution_providers.empty()) {
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
// Don't run the test if no providers are supported
|
||||
if (execution_providers.empty()) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Test float32, with activation
|
||||
enable_cuda = HasCudaEnvironment(0);
|
||||
if (enable_cuda || enable_rocm || enable_dml) {
|
||||
OpTester test("GroupNorm", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute<float>("epsilon", 1e-05f);
|
||||
test.AddAttribute<int64_t>("groups", 32);
|
||||
|
|
@ -809,21 +827,7 @@ TEST(GroupNormTest, GroupNorm_128) {
|
|||
|
||||
test.AddInput<float>("gamma", {C}, gamma_data);
|
||||
test.AddInput<float>("beta", {C}, beta_data);
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
if (enable_cuda && channels_last != 0) {
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
}
|
||||
if (enable_rocm && channels_last != 0) {
|
||||
execution_providers.push_back(DefaultRocmExecutionProvider());
|
||||
}
|
||||
if (enable_dml) {
|
||||
execution_providers.push_back(DefaultDmlExecutionProvider());
|
||||
}
|
||||
|
||||
if (!execution_providers.empty()) {
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue