pytorch/test/cpp_extensions/open_registration_extension/setup.py
Zhenbin Lin cbb1ed2966 [1/N] OpenReg: Replace open_registration_extension.cpp with openreg (#141815)
As described in OpenReg [next-steps](https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/README.md#next-steps), here we replace the current `open_registration_extension.cpp` test in PyTorch CI with openreg.

The current `open_registration_extension.cpp` contains two parts:
1. Implentations to support `PrivateUse1` backend.
2. Helper functions used for UTs in `test_cpp_extensions_open_device_registration.py` and `test_transformers.py`.

For the first part, we'll replace it with openreg. For the second part, we'll migrate them to ut files step by step.

@albanD

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141815
Approved by: https://github.com/albanD
2025-01-14 15:59:00 +00:00

74 lines
2 KiB
Python

import distutils.command.clean
import os
import shutil
import sys
from pathlib import Path
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
PACKAGE_NAME = "pytorch_openreg"
version = 1.0
ROOT_DIR = Path(__file__).absolute().parent
CSRS_DIR = ROOT_DIR / "pytorch_openreg/csrc"
class clean(distutils.command.clean.clean):
def run(self):
# Run default behavior first
distutils.command.clean.clean.run(self)
# Remove pytorch_openreg extension
for path in (ROOT_DIR / "pytorch_openreg").glob("**/*.so"):
path.unlink()
# Remove build directory
build_dirs = [
ROOT_DIR / "build",
]
for path in build_dirs:
if path.exists():
shutil.rmtree(str(path), ignore_errors=True)
if __name__ == "__main__":
if sys.platform == "win32":
vc_version = os.getenv("VCToolsVersion", "")
if vc_version.startswith("14.16."):
CXX_FLAGS = ["/sdl"]
else:
CXX_FLAGS = ["/sdl", "/permissive-"]
else:
CXX_FLAGS = {"cxx": ["-g", "-Wall", "-Werror"]}
sources = list(CSRS_DIR.glob("*.cpp"))
# Note that we always compile with debug info
ext_modules = [
CppExtension(
name="pytorch_openreg._C",
sources=sorted(str(s) for s in sources),
include_dirs=[CSRS_DIR],
extra_compile_args=CXX_FLAGS,
)
]
setup(
name=PACKAGE_NAME,
version=version,
author="PyTorch Core Team",
description="Example for PyTorch out of tree registration",
packages=find_packages(exclude=("test",)),
package_data={PACKAGE_NAME: ["*.dll", "*.dylib", "*.so"]},
install_requires=[
"torch",
],
ext_modules=ext_modules,
python_requires=">=3.8",
cmdclass={
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
"clean": clean,
},
)