diff --git a/.github/labeler.yml b/.github/labeler.yml index 4b9838b5b5f..e98ba022b55 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -107,3 +107,8 @@ - torch/csrc/dynamo/compiled_autograd.h - torch/_dynamo/compiled_autograd.py - torch/inductor/test_compiled_autograd.py + +"ciflow/xpu": +- torch/csrc/inductor/aoti_include/xpu.h +- torch/csrc/inductor/cpp_wrapper/device_internal/xpu.h +- torch/csrc/inductor/cpp_wrapper/xpu.h diff --git a/.lintrunner.toml b/.lintrunner.toml index 601bdf04d19..2a0a5d4cb5c 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -557,7 +557,7 @@ exclude_patterns = [ command = [ 'python3', 'tools/linter/adapters/grep_linter.py', - '--pattern=#include ', '--linter-name=PYBIND11_INCLUDE', '--match-first-only', diff --git a/setup.py b/setup.py index 0583f696955..3886865acf3 100644 --- a/setup.py +++ b/setup.py @@ -1290,6 +1290,7 @@ def main(): "include/torch/csrc/distributed/autograd/rpc_messages/*.h", "include/torch/csrc/dynamo/*.h", "include/torch/csrc/inductor/*.h", + "include/torch/csrc/inductor/aoti_include/*.h", "include/torch/csrc/inductor/aoti_package/*.h", "include/torch/csrc/inductor/aoti_runner/*.h", "include/torch/csrc/inductor/aoti_runtime/*.h", @@ -1297,6 +1298,8 @@ def main(): "include/torch/csrc/inductor/aoti_torch/c/*.h", "include/torch/csrc/inductor/aoti_torch/generated/*.h", "include/torch/csrc/inductor/aoti_torch/generated/extend/*.h", + "include/torch/csrc/inductor/cpp_wrapper/*.h", + "include/torch/csrc/inductor/cpp_wrapper/device_internal/*.h", "include/torch/csrc/jit/*.h", "include/torch/csrc/jit/backends/*.h", "include/torch/csrc/jit/generated/*.h", diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 1996a73aee6..11997cbaf23 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -693,7 +693,6 @@ def torch_key() -> bytes: # a hash representing the state of the source code. extra_files = ( "codegen/aoti_runtime/interface.cpp", - "codegen/aoti_runtime/implementation.cpp", "codegen/cpp_prefix.h", "script.ld", ) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index e91c0ab897a..2341ccac351 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -253,9 +253,6 @@ class DeviceOpOverrides: def kernel_driver(self): raise NotImplementedError - def abi_compatible_header(self): - raise NotImplementedError - def cpp_stream_type(self): raise NotImplementedError diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 0a85491d339..62f85911718 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -19,7 +19,7 @@ from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.symbol import symbol_is_type, SymT from .. import config, ir -from ..utils import _align, ALIGN_BYTES, cache_on_self, normalize_name +from ..utils import _align, cache_on_self, normalize_name from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper from .common import get_device_op_overrides, IndentedBuffer, Kernel @@ -129,23 +129,17 @@ class CppWrapperCpu(PythonWrapperCodegen): # include a hash so our code cache gives different constants different files self.header.writeline(f"// {name} {hashed}") + def get_device_include(self): + if V.graph.aot_mode: + return f"#include " + return f"#include " + def write_header(self): if V.graph.is_const_graph: # We do not write header for constant graph, it will be written by main module. return - if V.graph.aot_mode: - self.header.splice( - """ - #include - #include - """ - ) - with open( - os.path.join(os.path.dirname(__file__), "aoti_runtime", "interface.cpp") - ) as f: - self.header.splice(f.read()) - else: + if not V.graph.aot_mode: self.header.splice( """ import torch @@ -153,61 +147,17 @@ class CppWrapperCpu(PythonWrapperCodegen): cpp_wrapper_src = ( ''' - #include - #include - - #define PYBIND11_SIMPLE_GIL_MANAGEMENT - #include - namespace py = pybind11; - - class RAIIPyObject { - public: - RAIIPyObject() : obj_(nullptr) {} - RAIIPyObject(PyObject* obj) : obj_(obj) {} - ~RAIIPyObject() { - Py_XDECREF(obj_); - } - RAIIPyObject& operator=(const RAIIPyObject& other) { - if (this != &other) { - Py_XDECREF(obj_); - obj_ = other.obj_; - Py_XINCREF(obj_); - } - return *this; - } - operator PyObject*() { - return obj_; - } - PyObject* get() { - return obj_; - } - private: - PyObject* obj_; - }; - - #include - #include - using namespace torch::aot_inductor; """ ) - self.header.splice( - f""" - #include - #include - #include - #include + self.header.splice(self.get_device_include()) - #include - typedef at::Half half; - typedef at::BFloat16 bfloat16; + if V.graph.aot_mode: + with open( + os.path.join(os.path.dirname(__file__), "aoti_runtime", "interface.cpp") + ) as f: + self.header.splice(f.read()) - // Round up to the nearest multiple of {ALIGN_BYTES} - [[maybe_unused]] static int64_t align(int64_t nbytes) {{ - return (nbytes + {ALIGN_BYTES} - 1) & -{ALIGN_BYTES}; - }} - """ - ) extend_aoti_c_shim_include = ( f"torch/csrc/inductor/aoti_torch/generated/extend/c_shim_{self.device}.h" ) @@ -1566,8 +1516,10 @@ class CppWrapperCpu(PythonWrapperCodegen): return final_tmp_name def codegen_device_copy(self, src, dst, non_blocking: bool): + """This function is overridden by cpp_wrapper_cpu_array_ref, so we don't need to + handle cases where dst is not an AtenTensorHandle.""" self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_(expensive_copy_to_tensor_if_needed({dst}), {src}, {non_blocking}));" + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_({dst}, {src}, {non_blocking}));" ) def codegen_multi_output(self, name, value): diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py index 44b76299efb..7b74d27b7ec 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -import os from itertools import count from typing import Callable, Optional @@ -84,18 +83,11 @@ class CppWrapperCpuArrayRef(CppWrapperCpu): return DTYPE_TO_CPP[dtype] return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>" - def write_header(self): - if V.graph.is_const_graph: - # We do not write header for constant graph, it will be written by main module. - return - - super().write_header() - with open( - os.path.join( - os.path.dirname(__file__), "aoti_runtime", "implementation.cpp" - ) - ) as f: - self.header.splice(f.read()) + def get_device_include(self): + assert self.device == "cpu", "ArrayRef only supported on CPU!" + if V.graph.aot_mode: + return "#include " + return "#include " def codegen_input_numel_asserts(self): for name, buf in V.graph.graph_inputs.items(): diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 64b5b2812b0..b5fb183bef5 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -205,9 +205,6 @@ class CppWrapperGpu(CppWrapperCpu): return super().write_header() - - self.header.splice("#include ") - self.header.splice(self.device_codegen.abi_compatible_header()) self.header.splice( maybe_hipify_code_wrapper(self.device_codegen.kernel_driver()) ) diff --git a/torch/_inductor/codegen/cuda/device_op_overrides.py b/torch/_inductor/codegen/cuda/device_op_overrides.py index a774be7844a..585ccceed8c 100644 --- a/torch/_inductor/codegen/cuda/device_op_overrides.py +++ b/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -225,9 +225,6 @@ class CUDADeviceOpOverrides(DeviceOpOverrides): #endif """ - def abi_compatible_header(self): - return "#include " - def cpp_stream_type(self): return "cudaStream_t" diff --git a/torch/_inductor/codegen/debug_utils.py b/torch/_inductor/codegen/debug_utils.py index 923c9d4d2ae..b303209c050 100644 --- a/torch/_inductor/codegen/debug_utils.py +++ b/torch/_inductor/codegen/debug_utils.py @@ -53,6 +53,7 @@ class DebugPrinterManager: def __init__( self, debug_printer_level, + use_array_ref: bool, args_to_print_or_save: Optional[list[str]] = None, kernel_name: str = "", kernel=None, @@ -60,6 +61,7 @@ class DebugPrinterManager: kernel_type=None, ): self.debug_printer_level = IntermediateValueDebuggingLevel(debug_printer_level) + self.use_array_ref = use_array_ref if args_to_print_or_save is None: args_to_print_or_save = [] self.args_to_print_or_save = args_to_print_or_save @@ -155,12 +157,15 @@ class DebugPrinterManager: ] self.args_to_print_or_save = args_to_print_or_save_extern elif kernel_type == "cpp": - args_to_print_or_save_cpp = [ - f"copy_arrayref_tensor_to_tensor({arg})" + self.args_to_print_or_save = [ + ( + f"copy_arrayref_tensor_to_tensor({arg})" + if self.use_array_ref + else arg + ) for arg in args_to_print_or_save if arg.startswith(("buf", "arg")) ] - self.args_to_print_or_save = args_to_print_or_save_cpp else: self.args_to_print_or_save = args_to_print_or_save self.kernel_name = kernel_name diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 8ba1a08664c..0ceb2b041b0 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -720,7 +720,8 @@ class PythonWrapperCodegen(CodeGen): # intermediate tensor value printing utility self.debug_printer = DebugPrinterManager( - debug_printer_level=config.aot_inductor.debug_intermediate_value_printer + debug_printer_level=config.aot_inductor.debug_intermediate_value_printer, + use_array_ref=config.aot_inductor.allow_stack_allocation, ) # Additional files that are dependent to the wrapper (ex. cubin files) diff --git a/torch/_inductor/codegen/xpu/device_op_overrides.py b/torch/_inductor/codegen/xpu/device_op_overrides.py index 5945dd6f679..7b83d19ffa0 100644 --- a/torch/_inductor/codegen/xpu/device_op_overrides.py +++ b/torch/_inductor/codegen/xpu/device_op_overrides.py @@ -57,12 +57,6 @@ class XPUDeviceOpOverrides(DeviceOpOverrides): """ return source_codes - def abi_compatible_header(self): - return """ - #include - #include - """ - def cpp_stream_type(self): return "sycl::queue*" diff --git a/torch/csrc/inductor/aoti_include/array_ref.h b/torch/csrc/inductor/aoti_include/array_ref.h new file mode 100644 index 00000000000..35b7e168e69 --- /dev/null +++ b/torch/csrc/inductor/aoti_include/array_ref.h @@ -0,0 +1,7 @@ +#pragma once + +#include +#include +#include +#include +#include diff --git a/torch/csrc/inductor/aoti_include/common.h b/torch/csrc/inductor/aoti_include/common.h new file mode 100644 index 00000000000..48a1aeb0303 --- /dev/null +++ b/torch/csrc/inductor/aoti_include/common.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +using half = at::Half; +using bfloat16 = at::BFloat16; + +// Round up to the nearest multiple of 64 +[[maybe_unused]] inline int64_t align(int64_t nbytes) { + return (nbytes + 64 - 1) & -64; +} diff --git a/torch/csrc/inductor/aoti_include/cpu.h b/torch/csrc/inductor/aoti_include/cpu.h new file mode 100644 index 00000000000..508a15b4563 --- /dev/null +++ b/torch/csrc/inductor/aoti_include/cpu.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/torch/csrc/inductor/aoti_include/cuda.h b/torch/csrc/inductor/aoti_include/cuda.h new file mode 100644 index 00000000000..59948abf171 --- /dev/null +++ b/torch/csrc/inductor/aoti_include/cuda.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/torch/csrc/inductor/aoti_include/xpu.h b/torch/csrc/inductor/aoti_include/xpu.h new file mode 100644 index 00000000000..d0e15b13f11 --- /dev/null +++ b/torch/csrc/inductor/aoti_include/xpu.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/torch/_inductor/codegen/aoti_runtime/implementation.cpp b/torch/csrc/inductor/array_ref_impl.h similarity index 89% rename from torch/_inductor/codegen/aoti_runtime/implementation.cpp rename to torch/csrc/inductor/array_ref_impl.h index 017e7a104d5..9e3ec836f5f 100644 --- a/torch/_inductor/codegen/aoti_runtime/implementation.cpp +++ b/torch/csrc/inductor/array_ref_impl.h @@ -1,14 +1,11 @@ -// NOTE: Like interface.cpp, this file will be copied into AOTInductor -// generated output. This file is intended to keep implementation -// details separate from the implementation of the AOTI public -// interface. +#pragma once + #include #include #include #include -namespace torch { -namespace aot_inductor { +namespace torch::aot_inductor { template void convert_output_to_handle( const ArrayRefTensor& output, @@ -82,9 +79,9 @@ template void assert_numel(const ArrayRefTensor& tensor, uint64_t numel) { if (tensor.numel() != numel) { std::stringstream err; - err << "incorrect numel for input tensor. expected " << numel << ", got " << tensor.numel(); + err << "incorrect numel for input tensor. expected " << numel << ", got " + << tensor.numel(); throw std::runtime_error(err.str()); } } -} // namespace aot_inductor -} // namespace torch +} // namespace torch::aot_inductor diff --git a/torch/csrc/inductor/cpp_wrapper/array_ref.h b/torch/csrc/inductor/cpp_wrapper/array_ref.h new file mode 100644 index 00000000000..de9a53d7df5 --- /dev/null +++ b/torch/csrc/inductor/cpp_wrapper/array_ref.h @@ -0,0 +1,7 @@ +#pragma once + +#include +#include +#include +#include +#include diff --git a/torch/csrc/inductor/cpp_wrapper/common.h b/torch/csrc/inductor/cpp_wrapper/common.h new file mode 100644 index 00000000000..77530ddc211 --- /dev/null +++ b/torch/csrc/inductor/cpp_wrapper/common.h @@ -0,0 +1,49 @@ +#pragma once + +#include +#include +#include + +#define PYBIND11_SIMPLE_GIL_MANAGEMENT +#include +namespace py = pybind11; + +class RAIIPyObject { + public: + RAIIPyObject() : obj_(nullptr) {} + RAIIPyObject(PyObject* obj) : obj_(obj) {} + ~RAIIPyObject() { + Py_XDECREF(obj_); + } + RAIIPyObject& operator=(const RAIIPyObject& other) { + if (this != &other) { + Py_XDECREF(obj_); + obj_ = other.obj_; + Py_XINCREF(obj_); + } + return *this; + } + operator PyObject*() { + return obj_; + } + PyObject* get() { + return obj_; + } + + private: + PyObject* obj_; +}; + +#include +#include +using namespace torch::aot_inductor; + +#include +#include +using half = at::Half; +using bfloat16 = at::BFloat16; + +// Round up to the nearest multiple of 64 +[[maybe_unused]] inline int64_t align(int64_t nbytes) { + return (nbytes + 64 - 1) & -64; +} diff --git a/torch/csrc/inductor/cpp_wrapper/cpu.h b/torch/csrc/inductor/cpp_wrapper/cpu.h new file mode 100644 index 00000000000..76c2afd9160 --- /dev/null +++ b/torch/csrc/inductor/cpp_wrapper/cpu.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/torch/csrc/inductor/cpp_wrapper/cuda.h b/torch/csrc/inductor/cpp_wrapper/cuda.h new file mode 100644 index 00000000000..782a2b67727 --- /dev/null +++ b/torch/csrc/inductor/cpp_wrapper/cuda.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/torch/csrc/inductor/cpp_wrapper/device_internal/cpu.h b/torch/csrc/inductor/cpp_wrapper/device_internal/cpu.h new file mode 100644 index 00000000000..c203906bb3f --- /dev/null +++ b/torch/csrc/inductor/cpp_wrapper/device_internal/cpu.h @@ -0,0 +1,3 @@ +#pragma once + +#include diff --git a/torch/csrc/inductor/cpp_wrapper/device_internal/cuda.h b/torch/csrc/inductor/cpp_wrapper/device_internal/cuda.h new file mode 100644 index 00000000000..29eaadda4f1 --- /dev/null +++ b/torch/csrc/inductor/cpp_wrapper/device_internal/cuda.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/torch/csrc/inductor/cpp_wrapper/device_internal/xpu.h b/torch/csrc/inductor/cpp_wrapper/device_internal/xpu.h new file mode 100644 index 00000000000..32bce0f4e74 --- /dev/null +++ b/torch/csrc/inductor/cpp_wrapper/device_internal/xpu.h @@ -0,0 +1,5 @@ +#pragma once + +#include +#include +#include diff --git a/torch/csrc/inductor/cpp_wrapper/xpu.h b/torch/csrc/inductor/cpp_wrapper/xpu.h new file mode 100644 index 00000000000..e26dea0f3b6 --- /dev/null +++ b/torch/csrc/inductor/cpp_wrapper/xpu.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include