Fix torch cpp ext build when CPU wheel is installed but GPU card is present (#11608)

* Fix torch cpp ext build when CPU wheel is installed but GPU card is present

Also there is a minor improvement for ATen operator that allows both
"::op" and "aten::op" name for operators

* Fix flake8 false positive
This commit is contained in:
Thiago Crepaldi 2022-05-25 09:44:26 -04:00 committed by GitHub
parent 147a1737f9
commit 427230431a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 13 deletions

View file

@ -24,4 +24,6 @@ exclude =
./orttraining,
# ignore server code for now
./server,
# ignore issues from different git branches
./.git,
ignore = W503, E203

View file

@ -119,9 +119,16 @@ class ATenOperatorCache {
}
const ATenOperator& GetOperator(const std::string& op_name, const std::string& overload_name) {
auto key = std::make_pair(op_name, overload_name);
// PyTorch ONNX converter creates ATen operators with name without domain
std::string final_op_name = op_name;
auto pos = op_name.find("::");
if (pos == std::string::npos) {
final_op_name = std::string("aten::" + op_name);
}
auto key = std::make_pair(final_op_name, overload_name);
if (ops_.find(key) == ops_.end()) {
c10::OperatorName full_name(op_name, overload_name);
c10::OperatorName full_name(final_op_name, overload_name);
auto op = torch::jit::findOperatorFor(full_name);
TORCH_INTERNAL_ASSERT(op);
ATenOperator aten_op;

View file

@ -3,13 +3,15 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from onnxruntime.training import ortmodule
from glob import glob
from shutil import copyfile
import os
import subprocess
import sys
from glob import glob
from shutil import copyfile
import torch
from onnxruntime.training import ortmodule
def _list_extensions(path):
@ -30,25 +32,26 @@ def _list_cuda_extensions():
def _install_extension(ext_name, ext_path, cwd):
ret_code = subprocess.call(f"{sys.executable} {ext_path} build", cwd=cwd, shell=True)
ret_code = subprocess.call((sys.executable, ext_path, "build"), cwd=cwd)
if ret_code != 0:
print(f'There was an error compiling "{ext_name}" PyTorch CPP extension')
print(f"There was an error compiling '{ext_name}' PyTorch CPP extension")
sys.exit(ret_code)
def build_torch_cpp_extensions():
"""Builds PyTorch CPP extensions and returns metadata"""
"""Builds PyTorch CPP extensions and returns metadata."""
# Run this from within onnxruntime package folder
is_gpu_available = ortmodule.ONNXRUNTIME_CUDA_VERSION is not None or ortmodule.ONNXRUNTIME_ROCM_VERSION is not None
is_gpu_available = torch.cuda.is_available() and (
ortmodule.ONNXRUNTIME_CUDA_VERSION is not None or ortmodule.ONNXRUNTIME_ROCM_VERSION is not None
)
os.chdir(ortmodule.ORTMODULE_TORCH_CPP_DIR)
# Extensions might leverage CUDA/ROCM versions internally
os.environ["ONNXRUNTIME_CUDA_VERSION"] = (
ortmodule.ONNXRUNTIME_CUDA_VERSION if not ortmodule.ONNXRUNTIME_CUDA_VERSION is None else ""
ortmodule.ONNXRUNTIME_CUDA_VERSION if ortmodule.ONNXRUNTIME_CUDA_VERSION is not None else ""
)
os.environ["ONNXRUNTIME_ROCM_VERSION"] = (
ortmodule.ONNXRUNTIME_ROCM_VERSION if not ortmodule.ONNXRUNTIME_ROCM_VERSION is None else ""
ortmodule.ONNXRUNTIME_ROCM_VERSION if ortmodule.ONNXRUNTIME_ROCM_VERSION is not None else ""
)
############################################################################