From a2dc3e9eacdfd67e3e4538d95a26bd52685e5a2e Mon Sep 17 00:00:00 2001 From: pengwa Date: Tue, 9 Aug 2022 11:52:26 +0800 Subject: [PATCH] Improve the compilation speed when compiling for multiple architectures. (#12490) * improve the compilation speed when compiling for multiple architectures. * formatting * fix * use 0 by default * fix comments --- .../cpu/torch_interop_utils/setup.py | 10 ++++++++-- .../torch_cpp_extensions/cuda/fused_ops/setup.py | 2 +- .../cuda/torch_gpu_allocator/setup.py | 11 +++++++++-- .../ortmodule/torch_cpp_extensions/install.py | 14 ++++++++++++++ 4 files changed, 32 insertions(+), 5 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py index 42fc1c747a..0ab8a0c189 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py @@ -4,12 +4,18 @@ # -------------------------------------------------------------------------- import os -from setuptools import setup, Extension + +from setuptools import Extension, setup from torch.utils import cpp_extension filename = os.path.join(os.path.dirname(__file__), "torch_interop_utils.cc") +extra_compile_args = {"cxx": ["-O3"]} setup( name="torch_interop_utils", - ext_modules=[cpp_extension.CppExtension(name="torch_interop_utils", sources=[filename])], + ext_modules=[ + cpp_extension.CppExtension( + name="torch_interop_utils", sources=[filename], extra_compile_args=extra_compile_args + ) + ], cmdclass={"build_ext": cpp_extension.BuildExtension}, ) diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/setup.py index 86a9369cd5..71d44292d8 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/setup.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/setup.py @@ -21,7 +21,7 @@ filenames = [ use_rocm = True if os.environ["ONNXRUNTIME_ROCM_VERSION"] else False extra_compile_args = {"cxx": ["-O3"]} if not use_rocm: - extra_compile_args.update({"nvcc": ["-lineinfo", "-O3", "--use_fast_math"]}) + extra_compile_args.update({"nvcc": os.environ["ONNXRUNTIME_CUDA_NVCC_EXTRA_ARGS"].split(",")}) setup( name="fused_ops", diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator/setup.py index 7a71c95a3b..169c500b57 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator/setup.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator/setup.py @@ -10,7 +10,6 @@ import sys from setuptools import setup from torch.utils import cpp_extension - # TODO: Implement a cleaner way to auto-generate torch_gpu_allocator.cc use_rocm = True if os.environ["ONNXRUNTIME_ROCM_VERSION"] else False gpu_identifier = "hip" if use_rocm else "cuda" @@ -24,8 +23,16 @@ with fileinput.FileInput(filename, inplace=True) as file: line = line.replace("___gpu_allocator_header___", gpu_allocator_header) sys.stdout.write(line) +extra_compile_args = {"cxx": ["-O3"]} +if not use_rocm: + extra_compile_args.update({"nvcc": os.environ["ONNXRUNTIME_CUDA_NVCC_EXTRA_ARGS"].split(",")}) + setup( name="torch_gpu_allocator", - ext_modules=[cpp_extension.CUDAExtension(name="torch_gpu_allocator", sources=[filename])], + ext_modules=[ + cpp_extension.CUDAExtension( + name="torch_gpu_allocator", sources=[filename], extra_compile_args=extra_compile_args + ) + ], cmdclass={"build_ext": cpp_extension.BuildExtension}, ) diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py index 68a0332c22..6c1f805310 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py @@ -10,6 +10,7 @@ from glob import glob from shutil import copyfile import torch +from packaging import version from onnxruntime.training import ortmodule @@ -38,6 +39,16 @@ def _install_extension(ext_name, ext_path, cwd): sys.exit(ret_code) +def _get_cuda_extra_build_params(): + nvcc_extra_args = ["-lineinfo", "-O3", "--use_fast_math"] + cuda_version = torch.version.cuda + if cuda_version is not None and version.parse(cuda_version) > version.parse("11.2"): + # If number is 0, the number of threads used is the number of CPUs on the machine. + nvcc_extra_args += ["--threads", "0"] + + os.environ["ONNXRUNTIME_CUDA_NVCC_EXTRA_ARGS"] = ",".join(nvcc_extra_args) + + def build_torch_cpp_extensions(): """Builds PyTorch CPP extensions and returns metadata.""" # Run this from within onnxruntime package folder @@ -54,6 +65,9 @@ def build_torch_cpp_extensions(): ortmodule.ONNXRUNTIME_ROCM_VERSION if ortmodule.ONNXRUNTIME_ROCM_VERSION is not None else "" ) + if torch.version.cuda is not None and ortmodule.ONNXRUNTIME_CUDA_VERSION is not None: + _get_cuda_extra_build_params() + ############################################################################ # Pytorch CPP Extensions that DO require CUDA/ROCM ############################################################################