From 427230431a88a016b3f35893bece52878941a588 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Wed, 25 May 2022 09:44:26 -0400 Subject: [PATCH] Fix torch cpp ext build when CPU wheel is installed but GPU card is present (#11608) * Fix torch cpp ext build when CPU wheel is installed but GPU card is present Also there is a minor improvement for ATen operator that allows both "::op" and "aten::op" name for operators * Fix flake8 false positive --- .flake8 | 2 ++ .../cpu/aten_op_executor/aten_op_executor.cc | 11 ++++++-- .../ortmodule/torch_cpp_extensions/install.py | 25 +++++++++++-------- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/.flake8 b/.flake8 index 4684de71eb..c5c1eac4d9 100644 --- a/.flake8 +++ b/.flake8 @@ -24,4 +24,6 @@ exclude = ./orttraining, # ignore server code for now ./server, + # ignore issues from different git branches + ./.git, ignore = W503, E203 diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/aten_op_executor.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/aten_op_executor.cc index 6e44430bb8..e1dbc56814 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/aten_op_executor.cc +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/aten_op_executor.cc @@ -119,9 +119,16 @@ class ATenOperatorCache { } const ATenOperator& GetOperator(const std::string& op_name, const std::string& overload_name) { - auto key = std::make_pair(op_name, overload_name); + // PyTorch ONNX converter creates ATen operators with name without domain + std::string final_op_name = op_name; + auto pos = op_name.find("::"); + if (pos == std::string::npos) { + final_op_name = std::string("aten::" + op_name); + } + + auto key = std::make_pair(final_op_name, overload_name); if (ops_.find(key) == ops_.end()) { - c10::OperatorName full_name(op_name, overload_name); + c10::OperatorName full_name(final_op_name, overload_name); auto op = torch::jit::findOperatorFor(full_name); TORCH_INTERNAL_ASSERT(op); ATenOperator aten_op; diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py index fbd11797d2..c44b124d56 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py @@ -3,13 +3,15 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from onnxruntime.training import ortmodule - -from glob import glob -from shutil import copyfile import os import subprocess import sys +from glob import glob +from shutil import copyfile + +import torch + +from onnxruntime.training import ortmodule def _list_extensions(path): @@ -30,25 +32,26 @@ def _list_cuda_extensions(): def _install_extension(ext_name, ext_path, cwd): - ret_code = subprocess.call(f"{sys.executable} {ext_path} build", cwd=cwd, shell=True) + ret_code = subprocess.call((sys.executable, ext_path, "build"), cwd=cwd) if ret_code != 0: - print(f'There was an error compiling "{ext_name}" PyTorch CPP extension') + print(f"There was an error compiling '{ext_name}' PyTorch CPP extension") sys.exit(ret_code) def build_torch_cpp_extensions(): - """Builds PyTorch CPP extensions and returns metadata""" - + """Builds PyTorch CPP extensions and returns metadata.""" # Run this from within onnxruntime package folder - is_gpu_available = ortmodule.ONNXRUNTIME_CUDA_VERSION is not None or ortmodule.ONNXRUNTIME_ROCM_VERSION is not None + is_gpu_available = torch.cuda.is_available() and ( + ortmodule.ONNXRUNTIME_CUDA_VERSION is not None or ortmodule.ONNXRUNTIME_ROCM_VERSION is not None + ) os.chdir(ortmodule.ORTMODULE_TORCH_CPP_DIR) # Extensions might leverage CUDA/ROCM versions internally os.environ["ONNXRUNTIME_CUDA_VERSION"] = ( - ortmodule.ONNXRUNTIME_CUDA_VERSION if not ortmodule.ONNXRUNTIME_CUDA_VERSION is None else "" + ortmodule.ONNXRUNTIME_CUDA_VERSION if ortmodule.ONNXRUNTIME_CUDA_VERSION is not None else "" ) os.environ["ONNXRUNTIME_ROCM_VERSION"] = ( - ortmodule.ONNXRUNTIME_ROCM_VERSION if not ortmodule.ONNXRUNTIME_ROCM_VERSION is None else "" + ortmodule.ONNXRUNTIME_ROCM_VERSION if ortmodule.ONNXRUNTIME_ROCM_VERSION is not None else "" ) ############################################################################