From 79ba045d7426670fac634cb19fedc4e11eacffad Mon Sep 17 00:00:00 2001 From: raviskolli <48601275+raviskolli@users.noreply.github.com> Date: Mon, 22 Mar 2021 09:02:10 -0700 Subject: [PATCH] Enabled rocm support for graph transformations (#7057) --- .../core/optimizer/graph_transformer_utils.cc | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 19050458ea..016fff30e4 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -144,27 +144,27 @@ std::vector> GenerateTransformers(TransformerL transformers.emplace_back(onnxruntime::make_unique(cpu_execution_providers)); std::unordered_set cpu_acl_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kAclExecutionProvider}; - std::unordered_set cpu_cuda_acl_armnn_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider, onnxruntime::kAclExecutionProvider, onnxruntime::kArmNNExecutionProvider}; + std::unordered_set cpu_cuda_rocm_acl_armnn_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kAclExecutionProvider, onnxruntime::kArmNNExecutionProvider}; - transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_acl_armnn_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_rocm_acl_armnn_execution_providers)); - const std::unordered_set cuda_execution_providers = {onnxruntime::kCudaExecutionProvider}; - const std::unordered_set cpu_cuda_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider}; - transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); - transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); - transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); - transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); - transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); + const std::unordered_set cuda_rocm_execution_providers = {onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider}; + const std::unordered_set cpu_cuda_rocm_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider}; + transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_rocm_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_rocm_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_rocm_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_rocm_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_rocm_execution_providers)); - transformers.emplace_back(onnxruntime::make_unique(cuda_execution_providers)); - transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); - transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); - transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); - transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(cuda_rocm_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_rocm_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_rocm_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_rocm_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_rocm_execution_providers)); - transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_rocm_execution_providers)); - transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_rocm_execution_providers)); #endif } break; @@ -197,8 +197,8 @@ std::vector> GenerateTransformers(TransformerL // These transformers could only be enabled by custom transformer list. #ifndef DISABLE_CONTRIB_OPS if (level == TransformerLevel::Level2) { - std::unordered_set cuda_execution_providers = {onnxruntime::kCudaExecutionProvider}; - transformers.emplace_back(onnxruntime::make_unique(cuda_execution_providers)); + std::unordered_set cuda_rocm_execution_providers = {onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider}; + transformers.emplace_back(onnxruntime::make_unique(cuda_rocm_execution_providers)); } #endif