mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Improve the compilation speed when compiling for multiple architectures. (#12490)
* improve the compilation speed when compiling for multiple architectures. * formatting * fix * use 0 by default * fix comments
This commit is contained in:
parent
56bd96a3f5
commit
a2dc3e9eac
4 changed files with 32 additions and 5 deletions
|
|
@ -4,12 +4,18 @@
|
|||
# --------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
from setuptools import setup, Extension
|
||||
|
||||
from setuptools import Extension, setup
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
filename = os.path.join(os.path.dirname(__file__), "torch_interop_utils.cc")
|
||||
extra_compile_args = {"cxx": ["-O3"]}
|
||||
setup(
|
||||
name="torch_interop_utils",
|
||||
ext_modules=[cpp_extension.CppExtension(name="torch_interop_utils", sources=[filename])],
|
||||
ext_modules=[
|
||||
cpp_extension.CppExtension(
|
||||
name="torch_interop_utils", sources=[filename], extra_compile_args=extra_compile_args
|
||||
)
|
||||
],
|
||||
cmdclass={"build_ext": cpp_extension.BuildExtension},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ filenames = [
|
|||
use_rocm = True if os.environ["ONNXRUNTIME_ROCM_VERSION"] else False
|
||||
extra_compile_args = {"cxx": ["-O3"]}
|
||||
if not use_rocm:
|
||||
extra_compile_args.update({"nvcc": ["-lineinfo", "-O3", "--use_fast_math"]})
|
||||
extra_compile_args.update({"nvcc": os.environ["ONNXRUNTIME_CUDA_NVCC_EXTRA_ARGS"].split(",")})
|
||||
|
||||
setup(
|
||||
name="fused_ops",
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import sys
|
|||
from setuptools import setup
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
|
||||
# TODO: Implement a cleaner way to auto-generate torch_gpu_allocator.cc
|
||||
use_rocm = True if os.environ["ONNXRUNTIME_ROCM_VERSION"] else False
|
||||
gpu_identifier = "hip" if use_rocm else "cuda"
|
||||
|
|
@ -24,8 +23,16 @@ with fileinput.FileInput(filename, inplace=True) as file:
|
|||
line = line.replace("___gpu_allocator_header___", gpu_allocator_header)
|
||||
sys.stdout.write(line)
|
||||
|
||||
extra_compile_args = {"cxx": ["-O3"]}
|
||||
if not use_rocm:
|
||||
extra_compile_args.update({"nvcc": os.environ["ONNXRUNTIME_CUDA_NVCC_EXTRA_ARGS"].split(",")})
|
||||
|
||||
setup(
|
||||
name="torch_gpu_allocator",
|
||||
ext_modules=[cpp_extension.CUDAExtension(name="torch_gpu_allocator", sources=[filename])],
|
||||
ext_modules=[
|
||||
cpp_extension.CUDAExtension(
|
||||
name="torch_gpu_allocator", sources=[filename], extra_compile_args=extra_compile_args
|
||||
)
|
||||
],
|
||||
cmdclass={"build_ext": cpp_extension.BuildExtension},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from glob import glob
|
|||
from shutil import copyfile
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from onnxruntime.training import ortmodule
|
||||
|
||||
|
|
@ -38,6 +39,16 @@ def _install_extension(ext_name, ext_path, cwd):
|
|||
sys.exit(ret_code)
|
||||
|
||||
|
||||
def _get_cuda_extra_build_params():
|
||||
nvcc_extra_args = ["-lineinfo", "-O3", "--use_fast_math"]
|
||||
cuda_version = torch.version.cuda
|
||||
if cuda_version is not None and version.parse(cuda_version) > version.parse("11.2"):
|
||||
# If number is 0, the number of threads used is the number of CPUs on the machine.
|
||||
nvcc_extra_args += ["--threads", "0"]
|
||||
|
||||
os.environ["ONNXRUNTIME_CUDA_NVCC_EXTRA_ARGS"] = ",".join(nvcc_extra_args)
|
||||
|
||||
|
||||
def build_torch_cpp_extensions():
|
||||
"""Builds PyTorch CPP extensions and returns metadata."""
|
||||
# Run this from within onnxruntime package folder
|
||||
|
|
@ -54,6 +65,9 @@ def build_torch_cpp_extensions():
|
|||
ortmodule.ONNXRUNTIME_ROCM_VERSION if ortmodule.ONNXRUNTIME_ROCM_VERSION is not None else ""
|
||||
)
|
||||
|
||||
if torch.version.cuda is not None and ortmodule.ONNXRUNTIME_CUDA_VERSION is not None:
|
||||
_get_cuda_extra_build_params()
|
||||
|
||||
############################################################################
|
||||
# Pytorch CPP Extensions that DO require CUDA/ROCM
|
||||
############################################################################
|
||||
|
|
|
|||
Loading…
Reference in a new issue