mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Fix torch cpp ext build when CPU wheel is installed but GPU card is present (#11608)
* Fix torch cpp ext build when CPU wheel is installed but GPU card is present Also there is a minor improvement for ATen operator that allows both "::op" and "aten::op" name for operators * Fix flake8 false positive
This commit is contained in:
parent
147a1737f9
commit
427230431a
3 changed files with 25 additions and 13 deletions
2
.flake8
2
.flake8
|
|
@ -24,4 +24,6 @@ exclude =
|
|||
./orttraining,
|
||||
# ignore server code for now
|
||||
./server,
|
||||
# ignore issues from different git branches
|
||||
./.git,
|
||||
ignore = W503, E203
|
||||
|
|
|
|||
|
|
@ -119,9 +119,16 @@ class ATenOperatorCache {
|
|||
}
|
||||
|
||||
const ATenOperator& GetOperator(const std::string& op_name, const std::string& overload_name) {
|
||||
auto key = std::make_pair(op_name, overload_name);
|
||||
// PyTorch ONNX converter creates ATen operators with name without domain
|
||||
std::string final_op_name = op_name;
|
||||
auto pos = op_name.find("::");
|
||||
if (pos == std::string::npos) {
|
||||
final_op_name = std::string("aten::" + op_name);
|
||||
}
|
||||
|
||||
auto key = std::make_pair(final_op_name, overload_name);
|
||||
if (ops_.find(key) == ops_.end()) {
|
||||
c10::OperatorName full_name(op_name, overload_name);
|
||||
c10::OperatorName full_name(final_op_name, overload_name);
|
||||
auto op = torch::jit::findOperatorFor(full_name);
|
||||
TORCH_INTERNAL_ASSERT(op);
|
||||
ATenOperator aten_op;
|
||||
|
|
|
|||
|
|
@ -3,13 +3,15 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from onnxruntime.training import ortmodule
|
||||
|
||||
from glob import glob
|
||||
from shutil import copyfile
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from glob import glob
|
||||
from shutil import copyfile
|
||||
|
||||
import torch
|
||||
|
||||
from onnxruntime.training import ortmodule
|
||||
|
||||
|
||||
def _list_extensions(path):
|
||||
|
|
@ -30,25 +32,26 @@ def _list_cuda_extensions():
|
|||
|
||||
|
||||
def _install_extension(ext_name, ext_path, cwd):
|
||||
ret_code = subprocess.call(f"{sys.executable} {ext_path} build", cwd=cwd, shell=True)
|
||||
ret_code = subprocess.call((sys.executable, ext_path, "build"), cwd=cwd)
|
||||
if ret_code != 0:
|
||||
print(f'There was an error compiling "{ext_name}" PyTorch CPP extension')
|
||||
print(f"There was an error compiling '{ext_name}' PyTorch CPP extension")
|
||||
sys.exit(ret_code)
|
||||
|
||||
|
||||
def build_torch_cpp_extensions():
|
||||
"""Builds PyTorch CPP extensions and returns metadata"""
|
||||
|
||||
"""Builds PyTorch CPP extensions and returns metadata."""
|
||||
# Run this from within onnxruntime package folder
|
||||
is_gpu_available = ortmodule.ONNXRUNTIME_CUDA_VERSION is not None or ortmodule.ONNXRUNTIME_ROCM_VERSION is not None
|
||||
is_gpu_available = torch.cuda.is_available() and (
|
||||
ortmodule.ONNXRUNTIME_CUDA_VERSION is not None or ortmodule.ONNXRUNTIME_ROCM_VERSION is not None
|
||||
)
|
||||
os.chdir(ortmodule.ORTMODULE_TORCH_CPP_DIR)
|
||||
|
||||
# Extensions might leverage CUDA/ROCM versions internally
|
||||
os.environ["ONNXRUNTIME_CUDA_VERSION"] = (
|
||||
ortmodule.ONNXRUNTIME_CUDA_VERSION if not ortmodule.ONNXRUNTIME_CUDA_VERSION is None else ""
|
||||
ortmodule.ONNXRUNTIME_CUDA_VERSION if ortmodule.ONNXRUNTIME_CUDA_VERSION is not None else ""
|
||||
)
|
||||
os.environ["ONNXRUNTIME_ROCM_VERSION"] = (
|
||||
ortmodule.ONNXRUNTIME_ROCM_VERSION if not ortmodule.ONNXRUNTIME_ROCM_VERSION is None else ""
|
||||
ortmodule.ONNXRUNTIME_ROCM_VERSION if ortmodule.ONNXRUNTIME_ROCM_VERSION is not None else ""
|
||||
)
|
||||
|
||||
############################################################################
|
||||
|
|
|
|||
Loading…
Reference in a new issue