pytorch/torch/_inductor/codegen/wrapper.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

1086 lines
37 KiB
Python
Raw Normal View History

import collections
import contextlib
import dataclasses
import functools
import hashlib
import os
from itertools import count
from typing import Any, Dict, List, Tuple
import sympy
from sympy import Expr
from torch._dynamo.utils import dynamo_timed
from .. import codecache, config, ir
from ..codecache import CudaKernelParamCache
[inductor] run all kernel benchmarks individually in a compiled module (#95845) This is a follow up for PR #95506 to run all the triton kernels in a compiled module individually as suggested by Horace. Here are the steps: 1. Run the model as usual with a benchmark script and with TORCHINDUCTOR_BENCHMARK_KERNEL enabled. e.g. ``` TORCHINDUCTOR_BENCHMARK_KERNEL=1 python benchmarks/dynamo/torchbench.py --backend inductor --amp --performance --dashboard --only resnet18 --disable-cudagraphs --training ``` 2. From the output we will see 3 lines like ``` Compiled module path: /tmp/torchinductor_shunting/rs/crsuc6zrt3y6lktz33jjqgpkuahya56xj6sentyiz7iv4pjud43j.py ``` That's because we have one graph module for fwd/bwd/optitimizer respectively. Each graph module will have one such output corresponding to the compiled module. 3. We can run the compiled module directly. Without any extra arguments, we just maintain the previous behavior to run the call function -- which just does what the original graph module does but in a more efficient way. But if we add the '-k' argument, we will run benchmark for each individual kernels in the file. ``` python /tmp/torchinductor_shunting/rs/crsuc6zrt3y6lktz33jjqgpkuahya56xj6sentyiz7iv4pjud43j.py -k ``` Example output: <img width="430" alt="Screenshot 2023-03-01 at 4 51 06 PM" src="https://user-images.githubusercontent.com/52589240/222302996-814a85be-472b-463c-9e85-39d2c9d20e1a.png"> Note: I use the first 10 characters of the hash to identify each kernel since 1. hash is easier to get in the code :) 2. name like `triton__3` only makes sense within a compiled module, but a hash can make sense even without specifying the compiled module (assuming we have enough bytes for the hash) If we found a triton kernel with hash like c226iuf2wi having poor performance, we can look it up in the original compiled module file. It works since we comment each compiled triton kernel with the full hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95845 Approved by: https://github.com/Chillee
2023-03-06 21:30:33 +00:00
from ..utils import (
cache_on_self,
get_benchmark_name,
has_triton,
LineContext,
[inductor] run all kernel benchmarks individually in a compiled module (#95845) This is a follow up for PR #95506 to run all the triton kernels in a compiled module individually as suggested by Horace. Here are the steps: 1. Run the model as usual with a benchmark script and with TORCHINDUCTOR_BENCHMARK_KERNEL enabled. e.g. ``` TORCHINDUCTOR_BENCHMARK_KERNEL=1 python benchmarks/dynamo/torchbench.py --backend inductor --amp --performance --dashboard --only resnet18 --disable-cudagraphs --training ``` 2. From the output we will see 3 lines like ``` Compiled module path: /tmp/torchinductor_shunting/rs/crsuc6zrt3y6lktz33jjqgpkuahya56xj6sentyiz7iv4pjud43j.py ``` That's because we have one graph module for fwd/bwd/optitimizer respectively. Each graph module will have one such output corresponding to the compiled module. 3. We can run the compiled module directly. Without any extra arguments, we just maintain the previous behavior to run the call function -- which just does what the original graph module does but in a more efficient way. But if we add the '-k' argument, we will run benchmark for each individual kernels in the file. ``` python /tmp/torchinductor_shunting/rs/crsuc6zrt3y6lktz33jjqgpkuahya56xj6sentyiz7iv4pjud43j.py -k ``` Example output: <img width="430" alt="Screenshot 2023-03-01 at 4 51 06 PM" src="https://user-images.githubusercontent.com/52589240/222302996-814a85be-472b-463c-9e85-39d2c9d20e1a.png"> Note: I use the first 10 characters of the hash to identify each kernel since 1. hash is easier to get in the code :) 2. name like `triton__3` only makes sense within a compiled module, but a hash can make sense even without specifying the compiled module (assuming we have enough bytes for the hash) If we found a triton kernel with hash like c226iuf2wi having poor performance, we can look it up in the original compiled module file. It works since we comment each compiled triton kernel with the full hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95845 Approved by: https://github.com/Chillee
2023-03-06 21:30:33 +00:00
sympy_dot,
sympy_product,
sympy_symbol,
[inductor] run all kernel benchmarks individually in a compiled module (#95845) This is a follow up for PR #95506 to run all the triton kernels in a compiled module individually as suggested by Horace. Here are the steps: 1. Run the model as usual with a benchmark script and with TORCHINDUCTOR_BENCHMARK_KERNEL enabled. e.g. ``` TORCHINDUCTOR_BENCHMARK_KERNEL=1 python benchmarks/dynamo/torchbench.py --backend inductor --amp --performance --dashboard --only resnet18 --disable-cudagraphs --training ``` 2. From the output we will see 3 lines like ``` Compiled module path: /tmp/torchinductor_shunting/rs/crsuc6zrt3y6lktz33jjqgpkuahya56xj6sentyiz7iv4pjud43j.py ``` That's because we have one graph module for fwd/bwd/optitimizer respectively. Each graph module will have one such output corresponding to the compiled module. 3. We can run the compiled module directly. Without any extra arguments, we just maintain the previous behavior to run the call function -- which just does what the original graph module does but in a more efficient way. But if we add the '-k' argument, we will run benchmark for each individual kernels in the file. ``` python /tmp/torchinductor_shunting/rs/crsuc6zrt3y6lktz33jjqgpkuahya56xj6sentyiz7iv4pjud43j.py -k ``` Example output: <img width="430" alt="Screenshot 2023-03-01 at 4 51 06 PM" src="https://user-images.githubusercontent.com/52589240/222302996-814a85be-472b-463c-9e85-39d2c9d20e1a.png"> Note: I use the first 10 characters of the hash to identify each kernel since 1. hash is easier to get in the code :) 2. name like `triton__3` only makes sense within a compiled module, but a hash can make sense even without specifying the compiled module (assuming we have enough bytes for the hash) If we found a triton kernel with hash like c226iuf2wi having poor performance, we can look it up in the original compiled module file. It works since we comment each compiled triton kernel with the full hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95845 Approved by: https://github.com/Chillee
2023-03-06 21:30:33 +00:00
)
from ..virtualized import V
from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel, PythonPrinter
pexpr = PythonPrinter().doprint
def buffer_reuse_key(node: ir.Buffer):
size = node.get_size()
stride = node.get_stride()
last_element = sympy_dot([s - 1 for s in size], stride)
return (
node.get_device(),
node.get_dtype(),
V.graph.sizevars.simplify(sympy_product(size)),
# Detect gaps in tensor storage caused by strides
V.graph.sizevars.size_hint(last_element),
)
def is_int(s: str):
try:
int(s)
except ValueError:
return False
return True
def is_float(s: str):
try:
float(s)
except ValueError:
return False
return True
class MemoryPlanningState:
def __init__(self):
super().__init__()
self.reuse_pool: Dict[
Any, List["FreeIfNotReusedLine"]
] = collections.defaultdict(list)
def __contains__(self, key):
return bool(self.reuse_pool.get(key, None))
def pop(self, key) -> "FreeIfNotReusedLine":
item = self.reuse_pool[key].pop()
assert not item.is_reused
return item
def push(self, key, item: "FreeIfNotReusedLine"):
assert not item.is_reused
self.reuse_pool[key].append(item)
@dataclasses.dataclass
class EnterCudaDeviceContextManagerLine:
device_idx: int
def codegen(self, code: IndentedBuffer, device_cm_stack: contextlib.ExitStack):
if V.graph.cpp_wrapper:
code.writeline("\n")
code.writeline(f"at::cuda::CUDAGuard device_guard({self.device_idx});")
else:
# Note _DeviceGuard has less overhead than device, but only accepts
# integers
code.writeline(f"with torch.cuda._DeviceGuard({self.device_idx}):")
device_cm_stack.enter_context(code.indent())
code.writeline(
f"torch.cuda.set_device({self.device_idx}) # no-op to ensure context"
)
class ExitCudaDeviceContextManagerLine:
def codegen(self, code: IndentedBuffer, device_cm_stack: contextlib.ExitStack):
if not V.graph.cpp_wrapper:
device_cm_stack.close()
@dataclasses.dataclass
class MemoryPlanningLine:
wrapper: "WrapperCodeGen"
def plan(self, state: MemoryPlanningState) -> "MemoryPlanningLine":
"""First pass to find reuse"""
return self
def codegen(self, code: IndentedBuffer):
"""Second pass to output code"""
pass
@dataclasses.dataclass
class AllocateLine(MemoryPlanningLine):
node: ir.Buffer
def plan(self, state: MemoryPlanningState):
if self.node.get_name() in V.graph.removed_buffers:
return NullLine(self.wrapper)
# try to reuse a recently freed buffer
key = buffer_reuse_key(self.node)
if key in state:
free_line = state.pop(key)
free_line.is_reused = True
return ReuseLine(self.wrapper, free_line.node, self.node)
return self
def codegen(self, code: IndentedBuffer):
assert self.node.get_name() not in V.graph.removed_buffers
line = self.wrapper.make_buffer_allocation(self.node)
code.writeline(line)
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
@dataclasses.dataclass
class FreeIfNotReusedLine(MemoryPlanningLine):
node: ir.Buffer
is_reused: bool = False
def plan(self, state: MemoryPlanningState):
assert not self.is_reused
if self.node.get_name() in V.graph.removed_buffers:
return NullLine(self.wrapper)
state.push(buffer_reuse_key(self.node), self)
return self
def codegen(self, code: IndentedBuffer):
assert self.node.get_name() not in V.graph.removed_buffers
if not self.is_reused:
code.writeline(self.wrapper.make_buffer_free(self.node))
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
@dataclasses.dataclass
class ReuseLine(MemoryPlanningLine):
node: ir.Buffer
reused_as: ir.Buffer
def plan(self, state: MemoryPlanningState):
assert self.node.get_name() not in V.graph.removed_buffers
assert self.reused_as.get_name() not in V.graph.removed_buffers
return self
def codegen(self, code: IndentedBuffer):
assert self.node.get_name() not in V.graph.removed_buffers
assert self.reused_as.get_name() not in V.graph.removed_buffers
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
code.writeline(
self.wrapper.make_buffer_reuse(
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
self.node,
self.reused_as,
)
)
class NullLine(MemoryPlanningLine):
pass
class WrapperCodeGen(CodeGen):
"""
Generate outer wrapper in Python that calls the kernels.
"""
def __init__(self):
super().__init__()
self._names_iter = count()
self.header = IndentedBuffer()
self.prefix = IndentedBuffer()
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
self.wrapper_call = IndentedBuffer()
self.src_to_kernel = {}
self.kernel_to_hash = {}
self.lines = []
self.need_seed = False
self.declare = ""
self.ending = ""
self.open_bracket = "["
self.closed_bracket = "]"
self.comment = "#"
self.namespace = ""
self.none_str = "None"
self.size = "size()"
self.stride = "stride()"
self.write_header()
self.write_prefix()
for name, value in V.graph.constants.items():
# include a hash so our code cache gives different constants different files
hashed = hashlib.sha256(repr(value).encode("utf-8")).hexdigest()
self.header.writeline(f"{name} = None # {hashed}")
self.allocated = set()
self.freed = set()
# maps from reusing buffer to reused buffer
self.reuses = dict()
self.write_get_cuda_stream = functools.lru_cache(None)(
self.write_get_cuda_stream
)
@functools.lru_cache(None)
def add_import_once(line):
self.header.writeline(line)
self.add_import_once = add_import_once
self._metas = {}
def write_header(self):
self.header.splice(
f"""
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from torch._inductor.utils import maybe_profile
from torch import empty_strided, as_strided, device
from {codecache.__name__} import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
"""
)
if has_triton():
self.header.splice(
"""
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
"""
)
def add_meta_once(self, meta):
meta = repr(meta)
if meta not in self._metas:
var = f"meta{len(self._metas)}"
self._metas[meta] = var
self.header.writeline(f"{var} = {meta}")
return self._metas[meta]
@cache_on_self
def get_output_refs(self):
return [x.codegen_reference() for x in V.graph.graph_outputs]
def mark_output_type(self):
return
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
def write_prefix(self):
self.prefix.splice(
"""
async_compile.wait(globals())
del async_compile
def call(args):
"""
)
with self.prefix.indent():
if config.triton.debug_sync_graph:
self.prefix.writeline("torch.cuda.synchronize()")
inp_len = len(V.graph.graph_inputs.keys())
if inp_len != 0:
lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}"
self.prefix.writeline(f"{lhs} = args")
self.prefix.writeline("args.clear()")
for name in V.graph.randomness_seeds:
self.prefix.writeline(
f"torch.randint(2**31, size=(), dtype=torch.int64, out={name})"
)
self.codegen_inputs(self.prefix, V.graph.graph_inputs)
def append_precomputed_sizes_to_prefix(self):
with self.prefix.indent():
self.codegen_precomputed_sizes(self.prefix)
def write_get_cuda_stream(self, index):
name = f"stream{index}"
self.writeline(f"{name} = get_cuda_stream({index})")
return name
def next_kernel_suffix(self):
return f"{next(self._names_iter)}"
def codegen_cuda_device_guard_enter(self, device_idx):
self.writeline(EnterCudaDeviceContextManagerLine(device_idx))
def codegen_cuda_device_guard_exit(self):
self.writeline(ExitCudaDeviceContextManagerLine())
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
def generate_return(self, output_refs):
if output_refs:
self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )")
else:
self.wrapper_call.writeline("return ()")
def generate_end(self, result):
return
def generate_extern_kernel_alloc(self, output_name, kernel, args):
self.writeline(
f"{self.declare}{output_name} = {kernel}({', '.join(args)}){self.ending}"
)
def generate_extern_kernel_out(self, output_view, codegen_reference, args, kernel):
if output_view:
args.append(f"out={output_view.codegen_reference()}")
else:
args.append(f"out={codegen_reference}")
self.writeline(f"{kernel}({', '.join(args)})")
def generate_fusion_ops_code(
self,
name,
kernel,
codegen_args,
cpp_op_schema,
cpp_kernel_key,
cpp_kernel_overload_name="",
):
self.writeline(f"{name} = {kernel}({', '.join(codegen_args)})")
@dynamo_timed
def generate(self):
result = IndentedBuffer()
result.splice(self.header)
out_names = V.graph.get_output_names()
with contextlib.ExitStack() as stack:
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
stack.enter_context(self.wrapper_call.indent())
if config.profiler_mark_wrapper_call:
self.generate_profiler_mark_wrapper_call(stack)
if config.profile_bandwidth:
self.wrapper_call.writeline("start_graph()")
while (
self.lines
and isinstance(self.lines[-1], MemoryPlanningLine)
and self.lines[-1].node.name not in out_names
):
# these lines will be pointless
self.lines.pop()
# codegen allocations in two passes
planning_state = MemoryPlanningState()
for i in range(len(self.lines)):
if isinstance(self.lines[i], MemoryPlanningLine):
self.lines[i] = self.lines[i].plan(planning_state)
device_cm_stack = contextlib.ExitStack()
for line in self.lines:
if isinstance(line, MemoryPlanningLine):
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
line.codegen(self.wrapper_call)
elif isinstance(
line,
(
EnterCudaDeviceContextManagerLine,
ExitCudaDeviceContextManagerLine,
),
):
line.codegen(self.wrapper_call, device_cm_stack)
else:
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
self.wrapper_call.writeline(line)
output_refs = self.get_output_refs()
self.mark_output_type()
if config.triton.debug_sync_graph:
self.wrapper_call.writeline("torch.cuda.synchronize()")
if config.profile_bandwidth:
self.wrapper_call.writeline("end_graph()")
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
self.generate_return(output_refs)
self.append_precomputed_sizes_to_prefix()
result.splice(self.prefix)
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
with result.indent():
result.splice(self.wrapper_call)
self.generate_end(result)
self.add_benchmark_harness(result)
return result.getvaluewithlinemap()
def codegen_inputs(self, code: IndentedBuffer, graph_inputs: Dict[str, ir.Buffer]):
"""Assign all symbolic shapes to locals"""
if self.need_seed:
code.writeline(
"seed = torch.randint(2**31, size=(), dtype=torch.int32).item()"
)
@functools.lru_cache(None)
def sizeof(name):
code.writeline(
f"{self.declare}{name}_size = {name}.{self.size}{self.ending}"
)
return f"{name}_size"
@functools.lru_cache(None)
def strideof(name):
code.writeline(
f"{self.declare}{name}_stride = {name}.{self.stride}{self.ending}"
)
return f"{name}_stride"
# Assign all symbolic shapes needed to local variables
needed = set(V.graph.sizevars.var_to_val.keys()) - set(
V.graph.sizevars.replacements.keys()
)
def is_expr(x):
return isinstance(x[1], sympy.Expr)
graph_inputs_expr = list(filter(is_expr, graph_inputs.items()))
graph_inputs_tensors = list(
filter(lambda x: not is_expr(x), graph_inputs.items())
)
for name, shape in graph_inputs_expr:
shape = V.graph.sizevars.simplify(shape)
if shape in needed:
needed.remove(shape)
code.writeline(f"{self.declare}{shape} = {name}{self.ending}")
for name, value in graph_inputs_tensors:
shapes = value.get_size()
for dim, shape in enumerate(shapes):
shape = V.graph.sizevars.simplify(shape)
if shape in needed:
needed.remove(shape)
code.writeline(
f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}"
)
for name, value in graph_inputs_tensors:
shapes = value.get_stride()
for dim, shape in enumerate(shapes):
shape = V.graph.sizevars.simplify(shape)
if shape in needed:
needed.remove(shape)
code.writeline(
f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}"
)
def codegen_precomputed_sizes(self, code: IndentedBuffer):
for sym, expr in V.graph.sizevars.inv_precomputed_replacements.items():
code.writeline(f"{self.declare}{sym} = {pexpr(expr)}")
def codegen_python_sizevar(self, x: Expr) -> str:
return pexpr(V.graph.sizevars.simplify(x))
def codegen_sizevar(self, x: Expr) -> str:
return self.codegen_python_sizevar(x)
def codegen_python_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
parts = list(map(self.codegen_python_sizevar, shape))
if len(parts) == 0:
return "()"
if len(parts) == 1:
return f"({parts[0]}, )"
return f"({', '.join(parts)})"
def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
return self.codegen_python_shape_tuple(shape)
[inductor] run all kernel benchmarks individually in a compiled module (#95845) This is a follow up for PR #95506 to run all the triton kernels in a compiled module individually as suggested by Horace. Here are the steps: 1. Run the model as usual with a benchmark script and with TORCHINDUCTOR_BENCHMARK_KERNEL enabled. e.g. ``` TORCHINDUCTOR_BENCHMARK_KERNEL=1 python benchmarks/dynamo/torchbench.py --backend inductor --amp --performance --dashboard --only resnet18 --disable-cudagraphs --training ``` 2. From the output we will see 3 lines like ``` Compiled module path: /tmp/torchinductor_shunting/rs/crsuc6zrt3y6lktz33jjqgpkuahya56xj6sentyiz7iv4pjud43j.py ``` That's because we have one graph module for fwd/bwd/optitimizer respectively. Each graph module will have one such output corresponding to the compiled module. 3. We can run the compiled module directly. Without any extra arguments, we just maintain the previous behavior to run the call function -- which just does what the original graph module does but in a more efficient way. But if we add the '-k' argument, we will run benchmark for each individual kernels in the file. ``` python /tmp/torchinductor_shunting/rs/crsuc6zrt3y6lktz33jjqgpkuahya56xj6sentyiz7iv4pjud43j.py -k ``` Example output: <img width="430" alt="Screenshot 2023-03-01 at 4 51 06 PM" src="https://user-images.githubusercontent.com/52589240/222302996-814a85be-472b-463c-9e85-39d2c9d20e1a.png"> Note: I use the first 10 characters of the hash to identify each kernel since 1. hash is easier to get in the code :) 2. name like `triton__3` only makes sense within a compiled module, but a hash can make sense even without specifying the compiled module (assuming we have enough bytes for the hash) If we found a triton kernel with hash like c226iuf2wi having poor performance, we can look it up in the original compiled module file. It works since we comment each compiled triton kernel with the full hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95845 Approved by: https://github.com/Chillee
2023-03-06 21:30:33 +00:00
def benchmark_compiled_module(self, output):
def add_fake_input(name, shape, stride, device, dtype):
output.writeline(
f"{name} = rand_strided("
f"{self.codegen_python_shape_tuple(shape)}, "
f"{self.codegen_python_shape_tuple(stride)}, "
f"device='{device}', dtype={dtype})"
)
def add_expr_input(name, val):
output.writeline(f"{name} = {val}")
[inductor] more cuda metrics in wrapper (#97723) Following metrics should be helpful: - percent of time GPU is busy - percent of time various category of kernels (e.g. pointwise/reduction triton kernel) takes - percent of time each individual kernel takes compared to total wall time of the benchmark This PR add those. Example result from hf_Bert infernece graph: ``` == triton_pointwise category kernels == Kernel Self CUDA TIME (ms) Count Percent ------------------------------ --------------------- ------- --------- triton_poi_fused_gelu_6_0d1d 0.48154 12.0 5.52% triton_poi_fused_clone_1_0d1d2 0.29011 24.0 3.33% triton_poi_fused_clone_2_0d1d2 0.17417 12.0 2.00% triton_poi_fused_clone_4_0d1d2 0.10797 12.0 1.24% Total 1.05379 12.08% == triton_persistent_reduction category kernels == Kernel Self CUDA TIME (ms) Count Percent ------------------------------ --------------------- ------- --------- triton_per_fused__softmax__to_ 0.97188 12.0 11.14% triton_per_fused_add_native_la 0.37401 24.0 4.29% triton_per_fused_gelu_native_l 0.02 1.0 0.23% triton_per_fused_add_embedding 0.01718 1.0 0.20% Total 1.38307 15.86% == unknown category kernels == Kernel Self CUDA TIME (ms) Count Percent ------------------------------ --------------------- ------- --------- ampere_fp16_s16816gemm_fp16_12 2.24514 24.0 25.74% ampere_fp16_s16816gemm_fp16_25 1.39796 49.0 16.03% void cutlass::Kernel<cutlass_8 1.36093 1.0 15.61% ampere_fp16_s16816gemm_fp16_64 0.74591 12.0 8.55% ampere_fp16_s16816gemm_fp16_12 0.61989 12.0 7.11% Memset (Device) 0.024 12.0 0.28% void at::native::(anonymous na 0.01543 2.03 0.18% void at::native::vectorized_el 0.00011 0.03 0.00% Total 6.40937 73.49% Percent of time when GPU is busy: 101.44% ``` Note: the output shows total time GPU is busy is larger than total wall time. We measure total wall time disabling profiling while measure GPU time enabling profiling, that may distort the measurement a bit? But I assume the effect is not too large assuming the profiler mostly increase CPU time (rather than GPU). ## interesting usages 1. I pick a model that cudagraphs improve perf significantly like densenet121 and run the tool on it's forward graph. It's no surprise that quite a lot of time GPU is idle: ``` (Forward graph) Percent of time when GPU is busy: 32.69% Total wall time 17.307 ms ``` Its backward graph has less percent of GPU idle time, but it's still high: ``` (Backward graph) Percent of time when GPU is busy: 46.70% Total wall time 17.422 ms ``` 2. I profile a subset of torchbench models and plot a table to show the percent of execution time for pointwise/reduction/persistent_reduction/unknown_category . Since I plan to explore using coordinate descent tuner to improve reduction, those models with high percent of time spending on reduction should be good caididates (e.g. resnet50, mobilenet_v2 ). NOTE: a same model appears twice. The first rows is for the fwd graph and the second for the bwd graph. We profile different graphs for a model separately. ``` benchmark_name pointwise_percent reduction_percent persistent_reduction_percent unknown_category_percent GPU_busy_percent wall_time_ms ----------------------- ------------------- ------------------- ------------------------------ -------------------------- ------------------ -------------- resnet18 19.73% 7.86% 4.81% 41.25% 73.65% 2.549ms resnet18 18.59% 7.13% 3.35% 67.35% 96.41% 3.467ms resnet50 29.57% 22.13% 2.07% 51.68% 105.46% 6.834ms resnet50 26.42% 15.27% 0.94% 59.68% 102.31% 13.346ms vgg16 26.23% 0.00% 0.00% 74.20% 100.43% 18.212ms vgg16 15.63% 5.61% 0.10% 79.42% 100.75% 33.485ms BERT_pytorch 28.62% 4.82% 14.88% 33.32% 81.64% 7.162ms BERT_pytorch 14.43% 13.41% 18.19% 49.24% 95.27% 10.395ms densenet121 11.89% 2.14% 3.86% 16.36% 34.25% 16.531ms densenet121 10.37% 2.06% 4.09% 31.46% 47.98% 16.934ms hf_Bert 23.94% 0.00% 29.88% 46.09% 99.90% 7.766ms hf_Bert 11.65% 10.54% 20.26% 61.66% 104.11% 11.892ms nvidia_deeprecommender 42.92% 0.00% 0.00% 56.75% 99.67% 3.476ms nvidia_deeprecommender 31.36% 3.44% 0.46% 65.20% 100.45% 3.872ms alexnet 30.99% 0.00% 0.00% 69.16% 100.14% 3.169ms alexnet 24.41% 4.83% 0.17% 71.09% 100.50% 4.709ms mobilenet_v2 29.21% 27.79% 2.49% 44.00% 103.49% 10.160ms mobilenet_v2 17.50% 15.05% 1.06% 69.68% 103.29% 20.715ms resnext50_32x4d 18.96% 9.28% 2.31% 28.79% 59.33% 5.899ms resnext50_32x4d 18.48% 11.01% 1.86% 53.80% 85.14% 7.167ms mnasnet1_0 19.07% 14.52% 3.01% 35.43% 72.03% 6.028ms mnasnet1_0 14.17% 12.00% 1.87% 67.56% 95.60% 9.225ms squeezenet1_1 38.56% 0.00% 1.77% 56.21% 96.53% 2.221ms squeezenet1_1 21.26% 7.57% 1.05% 67.30% 97.18% 4.942ms timm_vision_transformer 17.05% 0.00% 18.80% 65.79% 101.64% 9.608ms timm_vision_transformer 9.31% 9.07% 10.32% 73.25% 101.96% 16.814ms ``` ## how to use `python {compiled_module_wrapper.py} -p` Pull Request resolved: https://github.com/pytorch/pytorch/pull/97723 Approved by: https://github.com/jansel
2023-04-01 00:33:13 +00:00
output.writelines(
["", "", "def benchmark_compiled_module(times=10, repeat=10):"]
)
with output.indent():
output.splice(
"""
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
""",
strip=True,
)
for name, value in V.graph.constants.items():
# all the constants are global variables, that's why we need
# these 'global var_name' lines
output.writeline(f"global {name}")
add_fake_input(
name, value.size(), value.stride(), value.device, value.dtype
)
for name, value in V.graph.graph_inputs.items():
if isinstance(value, sympy.Expr): # Don't need to add symbolic
add_expr_input(name, V.graph.sizevars.size_hint(value))
else:
shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()]
stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()]
add_fake_input(
name, shape, stride, value.get_device(), value.get_dtype()
)
call_str = f"call([{', '.join(V.graph.graph_inputs.keys())}])"
output.writeline(
f"return print_performance(lambda: {call_str}, times=times, repeat=repeat)"
)
[inductor] run all kernel benchmarks individually in a compiled module (#95845) This is a follow up for PR #95506 to run all the triton kernels in a compiled module individually as suggested by Horace. Here are the steps: 1. Run the model as usual with a benchmark script and with TORCHINDUCTOR_BENCHMARK_KERNEL enabled. e.g. ``` TORCHINDUCTOR_BENCHMARK_KERNEL=1 python benchmarks/dynamo/torchbench.py --backend inductor --amp --performance --dashboard --only resnet18 --disable-cudagraphs --training ``` 2. From the output we will see 3 lines like ``` Compiled module path: /tmp/torchinductor_shunting/rs/crsuc6zrt3y6lktz33jjqgpkuahya56xj6sentyiz7iv4pjud43j.py ``` That's because we have one graph module for fwd/bwd/optitimizer respectively. Each graph module will have one such output corresponding to the compiled module. 3. We can run the compiled module directly. Without any extra arguments, we just maintain the previous behavior to run the call function -- which just does what the original graph module does but in a more efficient way. But if we add the '-k' argument, we will run benchmark for each individual kernels in the file. ``` python /tmp/torchinductor_shunting/rs/crsuc6zrt3y6lktz33jjqgpkuahya56xj6sentyiz7iv4pjud43j.py -k ``` Example output: <img width="430" alt="Screenshot 2023-03-01 at 4 51 06 PM" src="https://user-images.githubusercontent.com/52589240/222302996-814a85be-472b-463c-9e85-39d2c9d20e1a.png"> Note: I use the first 10 characters of the hash to identify each kernel since 1. hash is easier to get in the code :) 2. name like `triton__3` only makes sense within a compiled module, but a hash can make sense even without specifying the compiled module (assuming we have enough bytes for the hash) If we found a triton kernel with hash like c226iuf2wi having poor performance, we can look it up in the original compiled module file. It works since we comment each compiled triton kernel with the full hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95845 Approved by: https://github.com/Chillee
2023-03-06 21:30:33 +00:00
def add_benchmark_harness(self, output):
"""
Append a benchmark harness to generated code for debugging
"""
if not config.benchmark_harness:
return
self.benchmark_compiled_module(output)
output.writelines(["", "", 'if __name__ == "__main__":'])
with output.indent():
output.writelines(
[
[inductor] more cuda metrics in wrapper (#97723) Following metrics should be helpful: - percent of time GPU is busy - percent of time various category of kernels (e.g. pointwise/reduction triton kernel) takes - percent of time each individual kernel takes compared to total wall time of the benchmark This PR add those. Example result from hf_Bert infernece graph: ``` == triton_pointwise category kernels == Kernel Self CUDA TIME (ms) Count Percent ------------------------------ --------------------- ------- --------- triton_poi_fused_gelu_6_0d1d 0.48154 12.0 5.52% triton_poi_fused_clone_1_0d1d2 0.29011 24.0 3.33% triton_poi_fused_clone_2_0d1d2 0.17417 12.0 2.00% triton_poi_fused_clone_4_0d1d2 0.10797 12.0 1.24% Total 1.05379 12.08% == triton_persistent_reduction category kernels == Kernel Self CUDA TIME (ms) Count Percent ------------------------------ --------------------- ------- --------- triton_per_fused__softmax__to_ 0.97188 12.0 11.14% triton_per_fused_add_native_la 0.37401 24.0 4.29% triton_per_fused_gelu_native_l 0.02 1.0 0.23% triton_per_fused_add_embedding 0.01718 1.0 0.20% Total 1.38307 15.86% == unknown category kernels == Kernel Self CUDA TIME (ms) Count Percent ------------------------------ --------------------- ------- --------- ampere_fp16_s16816gemm_fp16_12 2.24514 24.0 25.74% ampere_fp16_s16816gemm_fp16_25 1.39796 49.0 16.03% void cutlass::Kernel<cutlass_8 1.36093 1.0 15.61% ampere_fp16_s16816gemm_fp16_64 0.74591 12.0 8.55% ampere_fp16_s16816gemm_fp16_12 0.61989 12.0 7.11% Memset (Device) 0.024 12.0 0.28% void at::native::(anonymous na 0.01543 2.03 0.18% void at::native::vectorized_el 0.00011 0.03 0.00% Total 6.40937 73.49% Percent of time when GPU is busy: 101.44% ``` Note: the output shows total time GPU is busy is larger than total wall time. We measure total wall time disabling profiling while measure GPU time enabling profiling, that may distort the measurement a bit? But I assume the effect is not too large assuming the profiler mostly increase CPU time (rather than GPU). ## interesting usages 1. I pick a model that cudagraphs improve perf significantly like densenet121 and run the tool on it's forward graph. It's no surprise that quite a lot of time GPU is idle: ``` (Forward graph) Percent of time when GPU is busy: 32.69% Total wall time 17.307 ms ``` Its backward graph has less percent of GPU idle time, but it's still high: ``` (Backward graph) Percent of time when GPU is busy: 46.70% Total wall time 17.422 ms ``` 2. I profile a subset of torchbench models and plot a table to show the percent of execution time for pointwise/reduction/persistent_reduction/unknown_category . Since I plan to explore using coordinate descent tuner to improve reduction, those models with high percent of time spending on reduction should be good caididates (e.g. resnet50, mobilenet_v2 ). NOTE: a same model appears twice. The first rows is for the fwd graph and the second for the bwd graph. We profile different graphs for a model separately. ``` benchmark_name pointwise_percent reduction_percent persistent_reduction_percent unknown_category_percent GPU_busy_percent wall_time_ms ----------------------- ------------------- ------------------- ------------------------------ -------------------------- ------------------ -------------- resnet18 19.73% 7.86% 4.81% 41.25% 73.65% 2.549ms resnet18 18.59% 7.13% 3.35% 67.35% 96.41% 3.467ms resnet50 29.57% 22.13% 2.07% 51.68% 105.46% 6.834ms resnet50 26.42% 15.27% 0.94% 59.68% 102.31% 13.346ms vgg16 26.23% 0.00% 0.00% 74.20% 100.43% 18.212ms vgg16 15.63% 5.61% 0.10% 79.42% 100.75% 33.485ms BERT_pytorch 28.62% 4.82% 14.88% 33.32% 81.64% 7.162ms BERT_pytorch 14.43% 13.41% 18.19% 49.24% 95.27% 10.395ms densenet121 11.89% 2.14% 3.86% 16.36% 34.25% 16.531ms densenet121 10.37% 2.06% 4.09% 31.46% 47.98% 16.934ms hf_Bert 23.94% 0.00% 29.88% 46.09% 99.90% 7.766ms hf_Bert 11.65% 10.54% 20.26% 61.66% 104.11% 11.892ms nvidia_deeprecommender 42.92% 0.00% 0.00% 56.75% 99.67% 3.476ms nvidia_deeprecommender 31.36% 3.44% 0.46% 65.20% 100.45% 3.872ms alexnet 30.99% 0.00% 0.00% 69.16% 100.14% 3.169ms alexnet 24.41% 4.83% 0.17% 71.09% 100.50% 4.709ms mobilenet_v2 29.21% 27.79% 2.49% 44.00% 103.49% 10.160ms mobilenet_v2 17.50% 15.05% 1.06% 69.68% 103.29% 20.715ms resnext50_32x4d 18.96% 9.28% 2.31% 28.79% 59.33% 5.899ms resnext50_32x4d 18.48% 11.01% 1.86% 53.80% 85.14% 7.167ms mnasnet1_0 19.07% 14.52% 3.01% 35.43% 72.03% 6.028ms mnasnet1_0 14.17% 12.00% 1.87% 67.56% 95.60% 9.225ms squeezenet1_1 38.56% 0.00% 1.77% 56.21% 96.53% 2.221ms squeezenet1_1 21.26% 7.57% 1.05% 67.30% 97.18% 4.942ms timm_vision_transformer 17.05% 0.00% 18.80% 65.79% 101.64% 9.608ms timm_vision_transformer 9.31% 9.07% 10.32% 73.25% 101.96% 16.814ms ``` ## how to use `python {compiled_module_wrapper.py} -p` Pull Request resolved: https://github.com/pytorch/pytorch/pull/97723 Approved by: https://github.com/jansel
2023-04-01 00:33:13 +00:00
"from torch._inductor.utils import compiled_module_main",
f"compiled_module_main('{get_benchmark_name()}', benchmark_compiled_module)",
[inductor] run all kernel benchmarks individually in a compiled module (#95845) This is a follow up for PR #95506 to run all the triton kernels in a compiled module individually as suggested by Horace. Here are the steps: 1. Run the model as usual with a benchmark script and with TORCHINDUCTOR_BENCHMARK_KERNEL enabled. e.g. ``` TORCHINDUCTOR_BENCHMARK_KERNEL=1 python benchmarks/dynamo/torchbench.py --backend inductor --amp --performance --dashboard --only resnet18 --disable-cudagraphs --training ``` 2. From the output we will see 3 lines like ``` Compiled module path: /tmp/torchinductor_shunting/rs/crsuc6zrt3y6lktz33jjqgpkuahya56xj6sentyiz7iv4pjud43j.py ``` That's because we have one graph module for fwd/bwd/optitimizer respectively. Each graph module will have one such output corresponding to the compiled module. 3. We can run the compiled module directly. Without any extra arguments, we just maintain the previous behavior to run the call function -- which just does what the original graph module does but in a more efficient way. But if we add the '-k' argument, we will run benchmark for each individual kernels in the file. ``` python /tmp/torchinductor_shunting/rs/crsuc6zrt3y6lktz33jjqgpkuahya56xj6sentyiz7iv4pjud43j.py -k ``` Example output: <img width="430" alt="Screenshot 2023-03-01 at 4 51 06 PM" src="https://user-images.githubusercontent.com/52589240/222302996-814a85be-472b-463c-9e85-39d2c9d20e1a.png"> Note: I use the first 10 characters of the hash to identify each kernel since 1. hash is easier to get in the code :) 2. name like `triton__3` only makes sense within a compiled module, but a hash can make sense even without specifying the compiled module (assuming we have enough bytes for the hash) If we found a triton kernel with hash like c226iuf2wi having poor performance, we can look it up in the original compiled module file. It works since we comment each compiled triton kernel with the full hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95845 Approved by: https://github.com/Chillee
2023-03-06 21:30:33 +00:00
]
)
def define_kernel(self, name: str, kernel: str, metadata: str = None):
metadata_comment = f"{metadata}\n" if metadata else ""
self.header.splice(f"\n\n{metadata_comment}{name} = {kernel}")
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
def wrap_kernel_call(self, name, call_args):
return f"{name}({', '.join(call_args)}){self.ending}"
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
def generate_profiler_mark_wrapper_call(self, stack):
self.wrapper_call.writeline("from torch.profiler import record_function")
self.wrapper_call.writeline("with record_function('inductor_wrapper_call'):")
stack.enter_context(self.wrapper_call.indent())
def generate_kernel_call(self, name, call_args, device_index=None):
self.writeline(self.wrap_kernel_call(name, call_args))
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
def call_kernel(self, name: str, kernel: Kernel):
tmp = IndentedBuffer()
kernel.call_kernel(self, tmp, name)
for line in tmp.getvalue().split("\n"):
line = line.strip()
if line:
self.writeline(line)
def writeline(self, line):
self.lines.append(line)
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
def enter_context(self, ctx):
self.lines.append(LineContext(ctx))
def val_to_str(self, s):
return repr(s)
# The following methods are for memory management
def make_buffer_allocation(self, buffer):
device = buffer.get_device()
dtype = buffer.get_dtype()
shape = tuple(buffer.get_size())
stride = tuple(buffer.get_stride())
return (
f"{buffer.get_name()} = empty_strided("
f"{self.codegen_shape_tuple(shape)}, "
f"{self.codegen_shape_tuple(stride)}, "
f"device='{device.type}', dtype={dtype})"
)
def make_buffer_free(self, buffer):
return f"del {buffer.get_name()}"
def make_buffer_reuse(self, old, new):
assert old.get_dtype() == new.get_dtype()
del_line = ""
if old.get_name() not in V.graph.get_output_names():
del_line = f"; {self.make_buffer_free(old)}"
if old.get_size() == new.get_size() and old.get_stride() == new.get_stride():
return f"{self.declare}{new.get_name()} = {old.get_name()}{del_line} {self.comment} reuse"
return (
f"{self.declare}{new.get_name()} = {self.namespace}as_strided({old.get_name()}, "
f"{self.codegen_shape_tuple(new.get_size())}, "
f"{self.codegen_shape_tuple(new.get_stride())}){del_line} {self.comment} reuse"
)
def codegen_deferred_allocation(self, name, layout):
self.writeline(
DeferredLine(
name,
f"{self.declare}{name} = {layout.view.codegen_reference()}{self.ending} {self.comment} alias",
)
)
def codegen_allocation(self, buffer):
name = buffer.get_name()
if name in V.graph.removed_buffers or name in self.allocated:
return
self.allocated.add(name)
if isinstance(
buffer,
(ir.ExternKernelAlloc, ir.MultiOutput),
):
return
layout = buffer.get_layout()
if isinstance(layout, ir.MutationLayout):
return
if isinstance(layout, ir.AliasedLayout):
assert isinstance(
layout.view, ir.ReinterpretView
), f"unexpected {type(layout.view)}: {layout.view}"
if not layout.maybe_guard_aligned():
V.graph.unaligned_buffers.add(name)
self.codegen_allocation(layout.view.data)
self.codegen_deferred_allocation(name, layout)
return
self.writeline(AllocateLine(self, buffer))
def codegen_free(self, buffer):
name = buffer.get_name()
# can be freed but not reused
if isinstance(buffer, ir.InputBuffer):
self.writeline(self.make_buffer_free(buffer))
return
if not self.can_reuse(buffer):
return
self.freed.add(name)
layout = buffer.get_layout()
if isinstance(layout, (ir.AliasedLayout, ir.MultiOutputLayout)):
self.writeline(self.make_buffer_free(buffer))
return
self.writeline(FreeIfNotReusedLine(self, buffer))
def can_reuse(self, buffer):
name = buffer.get_name()
if (
name in V.graph.removed_buffers
or name in V.graph.graph_inputs
or name in V.graph.constants
or name in self.freed
):
return False
return True
def did_reuse(self, buffer, reused_buffer):
# Check whether a given buffer was reused by a possible reuser in the wrapper codegen
# Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed
return (
buffer.get_name() in self.reuses
and self.reuses[buffer.get_name()] == reused_buffer.get_name()
)
def codegen_inplace_reuse(self, input_buffer, output_buffer):
assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer)
self.codegen_allocation(input_buffer)
self.freed.add(input_buffer.get_name())
self.allocated.add(output_buffer.get_name())
self.reuses[output_buffer.get_name()] = input_buffer.get_name()
self.writeline(ReuseLine(self, input_buffer, output_buffer))
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
class CppWrapperCodeGen(WrapperCodeGen):
"""
Generates cpp wrapper for running on CPU and calls cpp kernels
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
"""
def __init__(self):
super().__init__()
self.declare = "auto "
self.ending = ";"
self.open_bracket = "{"
self.closed_bracket = "}"
self.comment = "//"
self.namespace = "at::"
self.none_str = "at::Tensor()"
self.extern_call_ops = set()
self.size = "sizes()"
self.stride = "strides()"
self.call_func_name = "inductor_entry_cpp"
self.cuda = False
def seed(self):
"""
Seed is a special variable used to hold the rng seed for a graph.
Note this is only used by the CPU backend, we put seeds in a
1-element tensor for the CUDA backend.
"""
self.need_seed = True
return sympy_symbol("seed")
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
def write_header(self):
if V.graph.aot_mode:
self.header.splice(
"""
/* AOTInductor generated code */
#include <ATen/ScalarOps.h>
"""
)
else:
self.header.splice(
"""
import torch
from torch.utils.cpp_extension import load_inline
cpp_wrapper_src = (
'''
"""
)
def mark_output_type(self):
# mark output type to unwrap tensor back to python scalar
from ..ir import ShapeAsConstantBuffer
output_is_tensor = dict()
for idx, x in enumerate(V.graph.graph_outputs):
if isinstance(x, ShapeAsConstantBuffer):
output_is_tensor[idx] = False
else:
output_is_tensor[idx] = True
self.output_is_tensor = output_is_tensor
def write_prefix(self):
return
def write_wrapper_decl(self):
inputs_len = len(V.graph.graph_inputs.keys())
self.prefix.splice(
f"""std::vector<at::Tensor> {self.call_func_name}(const std::vector<at::Tensor>& args) {{"""
)
with self.wrapper_call.indent():
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
if inputs_len != 0:
for idx, input_key in enumerate(V.graph.graph_inputs.keys()):
# unwrap input tensor back to scalar
if isinstance(V.graph.graph_inputs[input_key], sympy.Expr):
from ..graph import may_get_constant_buffer_dtype
from .cpp import DTYPE_TO_CPP
dtype = may_get_constant_buffer_dtype(
V.graph.graph_inputs[input_key]
)
assert (
dtype is not None
), "Fails to get the dtype of the sympy.Expr"
cpp_dtype = DTYPE_TO_CPP[dtype]
self.wrapper_call.writeline(f"{cpp_dtype} {input_key};")
self.wrapper_call.writeline(
f"{input_key} = args[{idx}].item<{cpp_dtype}>();"
)
else:
self.wrapper_call.writeline(f"at::Tensor {input_key};")
self.wrapper_call.writeline(f"{input_key} = args[{idx}];")
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
for name in V.graph.randomness_seeds:
self.wrapper_call.writeline(f"at::Tensor {name};")
self.wrapper_call.writeline(
f"{name} = at::randint(std::pow(2, 31), {{}}, at::ScalarType::Long);"
)
self.codegen_inputs(self.wrapper_call, V.graph.graph_inputs)
def generate(self):
self.write_wrapper_decl()
return super().generate()
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
def define_kernel(self, name: str, kernel: str, kernel_path: str = None):
self.header.splice(f"\n{kernel}\n")
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
def generate_return(self, output_refs):
self.wrapper_call.writeline(f"return {{{', '.join(output_refs)}}};\n}}")
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
def generate_end(self, result):
if V.graph.aot_mode:
return
result.writeline("'''\n)")
# Generate load_inline to jit compile the generated cpp code and to use it in Python
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
shared = codecache.get_shared()
warning_all_flag = codecache.get_warning_all_flag()
cpp_flags = codecache.cpp_flags()
ipaths, lpaths, libs, macros = codecache.get_include_and_linking_paths(
vec_isa=codecache.pick_vec_isa(),
cuda=self.cuda,
)
optimization_flags = codecache.optimization_flags(cuda=self.cuda)
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
use_custom_generated_macros = codecache.use_custom_generated_macros()
extra_cflags = f"{cpp_flags} {optimization_flags} {warning_all_flag} {macros} {use_custom_generated_macros}"
extra_ldflags = f"{shared} {lpaths} {libs}"
extra_include_paths = f"{ipaths}"
# get the hash of the wrapper code to name the extension
wrapper_call_hash = codecache.code_hash(self.wrapper_call.getvalue())
result.splice(
f"""
module = load_inline(
name='inline_extension_{wrapper_call_hash}',
cpp_sources=[cpp_wrapper_src],
functions=['{self.call_func_name}'],
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
extra_cflags=['{extra_cflags}'],
extra_ldflags=['{extra_ldflags}'],
extra_include_paths=['{extra_include_paths}'])
"""
)
# unwrap output tensor back to python scalar
if all(x for x in self.output_is_tensor.values()):
# If no ShapeAsConstantBuffer in the output, directly return the output as tensors
return_str = "return f(args_tensor)"
else:
outputs = [
f"outputs[{i}]" if self.output_is_tensor[i] else f"outputs[{i}].item()"
for i in range(len(V.graph.graph_outputs))
]
outputs_str = f"[{', '.join(outputs)}]"
return_str = f"""
outputs = f(args_tensor)
return {outputs_str}
"""
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
# Wrap the func to support setting result._boxed_call = True
result.splice(
f"""
def _wrap_func(f):
def g(args):
args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]
{return_str}
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
return g
call = _wrap_func(module.{self.call_func_name})
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
"""
)
def generate_extern_kernel_out(self, output_view, codegen_reference, args, kernel):
if output_view:
output_as_strided = f"{output_view.codegen_reference()}"
output_name = f"{output_view.get_name()}_as_strided"
self.writeline(f"auto {output_name} = {output_as_strided};")
args.insert(0, output_name)
else:
args.insert(0, f"{codegen_reference}")
self.writeline(self.wrap_kernel_call(kernel, args))
def add_benchmark_harness(self, output):
if V.graph.aot_mode:
return
super().add_benchmark_harness(output)
def codegen_sizevar(self, x: Expr) -> str:
from .cpp import cexpr
return cexpr(V.graph.sizevars.simplify(x))
def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
parts = list(map(self.codegen_sizevar, shape))
if len(parts) == 0:
return "{}"
if len(parts) == 1:
return f"{{{parts[0]}, }}"
return f"{{{', '.join(parts)}}}"
def make_buffer_free(self, buffer):
return f"{buffer.get_name()}.reset();"
def generate_profiler_mark_wrapper_call(self, stack):
self.wrapper_call.writeline(
'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef<c10::IValue>({{}}));'
)
def make_buffer_allocation(self, buffer):
from .cpp import DEVICE_TO_ATEN, DTYPE_TO_ATEN
# TODO: map layout here
device = buffer.get_device()
dtype = buffer.get_dtype()
shape = tuple(buffer.get_size())
stride = tuple(buffer.get_stride())
return (
f"{self.declare}{buffer.get_name()} = {self.namespace}empty_strided("
f"{self.codegen_shape_tuple(shape)}, "
f"{self.codegen_shape_tuple(stride)}, "
f"at::device({DEVICE_TO_ATEN[device.type]})"
f".dtype({DTYPE_TO_ATEN[dtype]})){self.ending}"
)
def generate_fusion_ops_code(
self,
name,
kernel,
codegen_args,
cpp_op_schema,
cpp_kernel_key,
cpp_kernel_overload_name="",
):
if cpp_kernel_key not in self.extern_call_ops:
self.writeline(
f"""
static auto op_{cpp_kernel_key} =
c10::Dispatcher::singleton()
.findSchemaOrThrow(
\"{kernel}\",
\"{cpp_kernel_overload_name}\")
.typed<{cpp_op_schema}>();
"""
)
self.extern_call_ops.add(cpp_kernel_key)
self.writeline(
f"auto {name} = op_{cpp_kernel_key}.call({', '.join(codegen_args)});"
)
def val_to_str(self, s):
if s is None:
return self.none_str
elif isinstance(s, bool):
return "true" if s else "false"
elif isinstance(s, str):
return f'"{s}"'
elif isinstance(s, (List, Tuple)):
vals = ", ".join(list(map(self.val_to_str, s)))
return f"{{{vals}}}"
else:
return repr(s)
class CudaWrapperCodeGen(CppWrapperCodeGen):
"""
Generates cpp wrapper for running on GPU and calls CUDA kernels
"""
def __init__(self):
super().__init__()
self.kernel_callsite_id = count()
self.arg_var_id = count()
self.cuda = True
def write_prefix(self):
self.prefix.splice(
"""
#include <c10/util/Exception.h>
#include <c10/cuda/CUDAGuard.h>
#define AT_CUDA_DRIVER_CHECK_OVERRIDE(EXPR) \\
do { \\
CUresult __err = EXPR; \\
if (__err != CUDA_SUCCESS) { \\
AT_ERROR("CUDA driver error: ", static_cast<int>(__err)); \\
} \\
} while (0)
static inline CUfunction loadKernel(const std::string &filePath,
const std::string &funcName) {
CUmodule mod;
CUfunction func;
AT_CUDA_DRIVER_CHECK_OVERRIDE(cuModuleLoad(&mod, filePath.c_str()));
AT_CUDA_DRIVER_CHECK_OVERRIDE(cuModuleGetFunction(&func, mod, funcName.c_str()));
return func;
}
static inline void launchKernel(
CUfunction func,
int gridX,
int gridY,
int gridZ,
int numWraps,
int sharedMemBytes,
void* args[],
int device_index) {
AT_CUDA_DRIVER_CHECK_OVERRIDE(cuLaunchKernel(
func, gridX, gridY, gridZ, 32*numWraps, 1, 1, sharedMemBytes,
at::cuda::getCurrentCUDAStream(device_index), args, nullptr));
}
"""
)
def define_kernel(self, name: str, kernel: str, kernel_path: str = None):
pass
def generate(self):
self.prefix.writeline("\n")
for kernel in self.src_to_kernel.values():
self.prefix.writeline(f"static CUfunction {kernel} = nullptr;")
self.prefix.writeline("\n")
return super().generate()
def generate_load_kernel(self, name, params):
mangled_name = params.get("mangled_name", None)
assert mangled_name is not None, "missing mangled_name"
cubin_path = params.get("cubin_path", None)
assert os.path.exists(
cubin_path
), "cubin file should already exist at this moment"
self.writeline(f"if ({name} == nullptr) {{")
self.writeline(
f""" {name} = loadKernel("{cubin_path}", "{mangled_name}");"""
)
self.writeline("}")
def generate_args_decl(self, call_args):
# TODO: only works for constant now, need type info
new_args = []
for arg in call_args:
var_name = f"var_{next(self.arg_var_id)}"
if is_int(arg):
self.writeline(f"int {var_name} = {arg};")
elif is_float(arg):
self.writeline(f"float {var_name} = {arg};")
else:
self.writeline(
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
)
new_args.append(f"&{var_name}")
return ", ".join(new_args)
def generate_kernel_call(self, name, call_args, device_index):
params = CudaKernelParamCache.get(self.kernel_to_hash.get(name, None))
assert (
params is not None
), "cuda kernel parameters should already exist at this moment"
self.generate_load_kernel(name, params)
call_args = self.generate_args_decl(call_args)
kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};")
self.writeline(
"launchKernel({}, {}, {}, {}, {}, {}, {}, {});".format(
name,
params["grid_x"],
params["grid_y"],
params["grid_z"],
params["num_warps"],
params["shared_mem"],
kernel_args_var,
device_index,
)
)