2022-10-13 23:18:06 +00:00
|
|
|
import collections
|
2022-11-29 00:58:46 +00:00
|
|
|
import contextlib
|
2022-10-13 23:18:06 +00:00
|
|
|
import dataclasses
|
|
|
|
|
import functools
|
|
|
|
|
import hashlib
|
2023-04-03 17:21:01 +00:00
|
|
|
import os
|
2022-10-13 23:18:06 +00:00
|
|
|
from itertools import count
|
2023-03-13 14:20:36 +00:00
|
|
|
from typing import Any, Dict, List, Tuple
|
2022-10-13 23:18:06 +00:00
|
|
|
|
2023-02-27 20:26:18 +00:00
|
|
|
import sympy
|
2023-03-13 14:20:36 +00:00
|
|
|
from sympy import Expr
|
2023-02-27 20:26:18 +00:00
|
|
|
|
2023-01-17 20:25:18 +00:00
|
|
|
from torch._dynamo.utils import dynamo_timed
|
2022-10-13 23:18:06 +00:00
|
|
|
from .. import codecache, config, ir
|
2023-04-06 21:00:39 +00:00
|
|
|
from ..codecache import CudaKernelParamCache
|
2023-03-06 21:30:33 +00:00
|
|
|
from ..utils import (
|
|
|
|
|
cache_on_self,
|
|
|
|
|
get_benchmark_name,
|
|
|
|
|
has_triton,
|
2023-03-15 22:14:10 +00:00
|
|
|
LineContext,
|
2023-03-06 21:30:33 +00:00
|
|
|
sympy_dot,
|
|
|
|
|
sympy_product,
|
2023-03-13 14:20:36 +00:00
|
|
|
sympy_symbol,
|
2023-03-06 21:30:33 +00:00
|
|
|
)
|
2022-10-13 23:18:06 +00:00
|
|
|
from ..virtualized import V
|
2023-02-10 11:21:54 +00:00
|
|
|
from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel, PythonPrinter
|
2022-10-13 23:18:06 +00:00
|
|
|
|
2023-04-05 21:34:58 +00:00
|
|
|
|
2023-02-10 11:21:54 +00:00
|
|
|
pexpr = PythonPrinter().doprint
|
2022-10-13 23:18:06 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def buffer_reuse_key(node: ir.Buffer):
|
|
|
|
|
size = node.get_size()
|
|
|
|
|
stride = node.get_stride()
|
|
|
|
|
last_element = sympy_dot([s - 1 for s in size], stride)
|
|
|
|
|
return (
|
|
|
|
|
node.get_device(),
|
|
|
|
|
node.get_dtype(),
|
|
|
|
|
V.graph.sizevars.simplify(sympy_product(size)),
|
|
|
|
|
# Detect gaps in tensor storage caused by strides
|
|
|
|
|
V.graph.sizevars.size_hint(last_element),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2023-04-03 17:21:01 +00:00
|
|
|
def is_int(s: str):
|
|
|
|
|
try:
|
|
|
|
|
int(s)
|
|
|
|
|
except ValueError:
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_float(s: str):
|
|
|
|
|
try:
|
|
|
|
|
float(s)
|
|
|
|
|
except ValueError:
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
2022-10-13 23:18:06 +00:00
|
|
|
class MemoryPlanningState:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.reuse_pool: Dict[
|
|
|
|
|
Any, List["FreeIfNotReusedLine"]
|
|
|
|
|
] = collections.defaultdict(list)
|
|
|
|
|
|
|
|
|
|
def __contains__(self, key):
|
|
|
|
|
return bool(self.reuse_pool.get(key, None))
|
|
|
|
|
|
|
|
|
|
def pop(self, key) -> "FreeIfNotReusedLine":
|
|
|
|
|
item = self.reuse_pool[key].pop()
|
|
|
|
|
assert not item.is_reused
|
|
|
|
|
return item
|
|
|
|
|
|
|
|
|
|
def push(self, key, item: "FreeIfNotReusedLine"):
|
|
|
|
|
assert not item.is_reused
|
|
|
|
|
self.reuse_pool[key].append(item)
|
|
|
|
|
|
|
|
|
|
|
generate device context managers in inductor code (#90934)
Fixes https://github.com/pytorch/torchdynamo/issues/1717, https://github.com/pytorch/torchdynamo/issues/1990
<s>TODO: add test with multiple devices, figure out extra context initialization</s>
Problems:
<s>It still initializes context on 0-th device that it shouldn't, I'll take a look where that happens and fix before landing</s>
It adds a python device context manages, that is absurdly slow and takes ~2.5 us (should be nanoseconds). That's not a problem for real models, because it'll be called just once, but it is a bit of an inconvenience for microbenchmarking, we should make that context manager more performant (won't fix in this PR)
It still can have bugs for graphs that run on multiple devices and can have buffers incorrectly shared between multiple device by memory reuse, if that happens that'll need to be solved separately.
Generated code:
```
def call(args):
arg0_1, arg1_1 = args
args.clear()
with torch.cuda.device(1):
buf0 = empty_strided((4, ), (1, ), device='cuda', dtype=torch.float32)
stream1 = get_cuda_stream(1)
triton_fused_div_0.run(arg0_1, arg1_1, buf0, 4, grid=grid(4), stream=stream1)
del arg0_1
del arg1_1
return (buf0, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90934
Approved by: https://github.com/wconstab
2022-12-16 18:03:39 +00:00
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class EnterCudaDeviceContextManagerLine:
|
|
|
|
|
device_idx: int
|
|
|
|
|
|
2023-04-03 17:21:01 +00:00
|
|
|
def codegen(self, code: IndentedBuffer, device_cm_stack: contextlib.ExitStack):
|
|
|
|
|
if V.graph.cpp_wrapper:
|
|
|
|
|
code.writeline("\n")
|
|
|
|
|
code.writeline(f"at::cuda::CUDAGuard device_guard({self.device_idx});")
|
|
|
|
|
else:
|
|
|
|
|
# Note _DeviceGuard has less overhead than device, but only accepts
|
|
|
|
|
# integers
|
|
|
|
|
code.writeline(f"with torch.cuda._DeviceGuard({self.device_idx}):")
|
|
|
|
|
device_cm_stack.enter_context(code.indent())
|
|
|
|
|
code.writeline(
|
|
|
|
|
f"torch.cuda.set_device({self.device_idx}) # no-op to ensure context"
|
|
|
|
|
)
|
generate device context managers in inductor code (#90934)
Fixes https://github.com/pytorch/torchdynamo/issues/1717, https://github.com/pytorch/torchdynamo/issues/1990
<s>TODO: add test with multiple devices, figure out extra context initialization</s>
Problems:
<s>It still initializes context on 0-th device that it shouldn't, I'll take a look where that happens and fix before landing</s>
It adds a python device context manages, that is absurdly slow and takes ~2.5 us (should be nanoseconds). That's not a problem for real models, because it'll be called just once, but it is a bit of an inconvenience for microbenchmarking, we should make that context manager more performant (won't fix in this PR)
It still can have bugs for graphs that run on multiple devices and can have buffers incorrectly shared between multiple device by memory reuse, if that happens that'll need to be solved separately.
Generated code:
```
def call(args):
arg0_1, arg1_1 = args
args.clear()
with torch.cuda.device(1):
buf0 = empty_strided((4, ), (1, ), device='cuda', dtype=torch.float32)
stream1 = get_cuda_stream(1)
triton_fused_div_0.run(arg0_1, arg1_1, buf0, 4, grid=grid(4), stream=stream1)
del arg0_1
del arg1_1
return (buf0, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90934
Approved by: https://github.com/wconstab
2022-12-16 18:03:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExitCudaDeviceContextManagerLine:
|
2023-04-03 17:21:01 +00:00
|
|
|
def codegen(self, code: IndentedBuffer, device_cm_stack: contextlib.ExitStack):
|
|
|
|
|
if not V.graph.cpp_wrapper:
|
|
|
|
|
device_cm_stack.close()
|
generate device context managers in inductor code (#90934)
Fixes https://github.com/pytorch/torchdynamo/issues/1717, https://github.com/pytorch/torchdynamo/issues/1990
<s>TODO: add test with multiple devices, figure out extra context initialization</s>
Problems:
<s>It still initializes context on 0-th device that it shouldn't, I'll take a look where that happens and fix before landing</s>
It adds a python device context manages, that is absurdly slow and takes ~2.5 us (should be nanoseconds). That's not a problem for real models, because it'll be called just once, but it is a bit of an inconvenience for microbenchmarking, we should make that context manager more performant (won't fix in this PR)
It still can have bugs for graphs that run on multiple devices and can have buffers incorrectly shared between multiple device by memory reuse, if that happens that'll need to be solved separately.
Generated code:
```
def call(args):
arg0_1, arg1_1 = args
args.clear()
with torch.cuda.device(1):
buf0 = empty_strided((4, ), (1, ), device='cuda', dtype=torch.float32)
stream1 = get_cuda_stream(1)
triton_fused_div_0.run(arg0_1, arg1_1, buf0, 4, grid=grid(4), stream=stream1)
del arg0_1
del arg1_1
return (buf0, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90934
Approved by: https://github.com/wconstab
2022-12-16 18:03:39 +00:00
|
|
|
|
|
|
|
|
|
2023-03-16 13:54:10 +00:00
|
|
|
@dataclasses.dataclass
|
2022-10-13 23:18:06 +00:00
|
|
|
class MemoryPlanningLine:
|
2023-03-16 13:54:10 +00:00
|
|
|
wrapper: "WrapperCodeGen"
|
|
|
|
|
|
2022-10-13 23:18:06 +00:00
|
|
|
def plan(self, state: MemoryPlanningState) -> "MemoryPlanningLine":
|
|
|
|
|
"""First pass to find reuse"""
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def codegen(self, code: IndentedBuffer):
|
|
|
|
|
"""Second pass to output code"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class AllocateLine(MemoryPlanningLine):
|
|
|
|
|
node: ir.Buffer
|
|
|
|
|
|
|
|
|
|
def plan(self, state: MemoryPlanningState):
|
|
|
|
|
if self.node.get_name() in V.graph.removed_buffers:
|
2023-03-16 13:54:10 +00:00
|
|
|
return NullLine(self.wrapper)
|
2022-10-13 23:18:06 +00:00
|
|
|
|
|
|
|
|
# try to reuse a recently freed buffer
|
|
|
|
|
key = buffer_reuse_key(self.node)
|
|
|
|
|
if key in state:
|
|
|
|
|
free_line = state.pop(key)
|
|
|
|
|
free_line.is_reused = True
|
2023-03-16 13:54:10 +00:00
|
|
|
return ReuseLine(self.wrapper, free_line.node, self.node)
|
2022-10-13 23:18:06 +00:00
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def codegen(self, code: IndentedBuffer):
|
|
|
|
|
assert self.node.get_name() not in V.graph.removed_buffers
|
2023-03-16 13:54:10 +00:00
|
|
|
line = self.wrapper.make_buffer_allocation(self.node)
|
|
|
|
|
code.writeline(line)
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
|
|
|
|
|
|
2022-10-13 23:18:06 +00:00
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class FreeIfNotReusedLine(MemoryPlanningLine):
|
|
|
|
|
node: ir.Buffer
|
|
|
|
|
is_reused: bool = False
|
|
|
|
|
|
|
|
|
|
def plan(self, state: MemoryPlanningState):
|
|
|
|
|
assert not self.is_reused
|
|
|
|
|
if self.node.get_name() in V.graph.removed_buffers:
|
2023-03-16 13:54:10 +00:00
|
|
|
return NullLine(self.wrapper)
|
2022-10-13 23:18:06 +00:00
|
|
|
state.push(buffer_reuse_key(self.node), self)
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def codegen(self, code: IndentedBuffer):
|
|
|
|
|
assert self.node.get_name() not in V.graph.removed_buffers
|
|
|
|
|
if not self.is_reused:
|
2023-03-16 13:54:10 +00:00
|
|
|
code.writeline(self.wrapper.make_buffer_free(self.node))
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
|
|
|
|
|
|
2022-10-13 23:18:06 +00:00
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class ReuseLine(MemoryPlanningLine):
|
|
|
|
|
node: ir.Buffer
|
|
|
|
|
reused_as: ir.Buffer
|
|
|
|
|
|
|
|
|
|
def plan(self, state: MemoryPlanningState):
|
|
|
|
|
assert self.node.get_name() not in V.graph.removed_buffers
|
2022-10-15 15:35:32 +00:00
|
|
|
assert self.reused_as.get_name() not in V.graph.removed_buffers
|
2022-10-13 23:18:06 +00:00
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def codegen(self, code: IndentedBuffer):
|
|
|
|
|
assert self.node.get_name() not in V.graph.removed_buffers
|
|
|
|
|
assert self.reused_as.get_name() not in V.graph.removed_buffers
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
code.writeline(
|
2023-03-16 13:54:10 +00:00
|
|
|
self.wrapper.make_buffer_reuse(
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
self.node,
|
|
|
|
|
self.reused_as,
|
|
|
|
|
)
|
|
|
|
|
)
|
2022-10-13 23:18:06 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class NullLine(MemoryPlanningLine):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WrapperCodeGen(CodeGen):
|
|
|
|
|
"""
|
2023-04-05 21:34:58 +00:00
|
|
|
Generate outer wrapper in Python that calls the kernels.
|
2022-10-13 23:18:06 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self._names_iter = count()
|
|
|
|
|
self.header = IndentedBuffer()
|
|
|
|
|
self.prefix = IndentedBuffer()
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
self.wrapper_call = IndentedBuffer()
|
2023-04-06 21:00:39 +00:00
|
|
|
self.src_to_kernel = {}
|
|
|
|
|
self.kernel_to_hash = {}
|
2022-10-13 23:18:06 +00:00
|
|
|
self.lines = []
|
2023-03-13 14:20:36 +00:00
|
|
|
self.need_seed = False
|
|
|
|
|
self.declare = ""
|
|
|
|
|
self.ending = ""
|
2023-04-13 15:41:03 +00:00
|
|
|
self.open_bracket = "["
|
|
|
|
|
self.closed_bracket = "]"
|
2023-03-16 13:54:10 +00:00
|
|
|
self.comment = "#"
|
|
|
|
|
self.namespace = ""
|
2023-04-11 23:55:44 +00:00
|
|
|
self.none_str = "None"
|
2023-04-04 23:57:50 +00:00
|
|
|
self.size = "size()"
|
|
|
|
|
self.stride = "stride()"
|
2023-03-13 14:20:29 +00:00
|
|
|
|
2023-04-05 21:34:58 +00:00
|
|
|
self.write_header()
|
2023-03-13 14:20:29 +00:00
|
|
|
self.write_prefix()
|
|
|
|
|
|
|
|
|
|
for name, value in V.graph.constants.items():
|
|
|
|
|
# include a hash so our code cache gives different constants different files
|
|
|
|
|
hashed = hashlib.sha256(repr(value).encode("utf-8")).hexdigest()
|
|
|
|
|
self.header.writeline(f"{name} = None # {hashed}")
|
|
|
|
|
|
|
|
|
|
self.allocated = set()
|
|
|
|
|
self.freed = set()
|
|
|
|
|
|
|
|
|
|
# maps from reusing buffer to reused buffer
|
|
|
|
|
self.reuses = dict()
|
|
|
|
|
|
|
|
|
|
self.write_get_cuda_stream = functools.lru_cache(None)(
|
|
|
|
|
self.write_get_cuda_stream
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
|
|
|
def add_import_once(line):
|
|
|
|
|
self.header.writeline(line)
|
|
|
|
|
|
|
|
|
|
self.add_import_once = add_import_once
|
|
|
|
|
self._metas = {}
|
|
|
|
|
|
2023-04-05 21:34:58 +00:00
|
|
|
def write_header(self):
|
2022-10-13 23:18:06 +00:00
|
|
|
self.header.splice(
|
|
|
|
|
f"""
|
|
|
|
|
from ctypes import c_void_p, c_long
|
|
|
|
|
import torch
|
2023-02-10 11:21:54 +00:00
|
|
|
import math
|
2022-10-13 23:18:06 +00:00
|
|
|
import random
|
2023-03-27 17:52:38 +00:00
|
|
|
import os
|
|
|
|
|
import tempfile
|
|
|
|
|
from torch._inductor.utils import maybe_profile
|
|
|
|
|
|
2022-10-13 23:18:06 +00:00
|
|
|
from torch import empty_strided, as_strided, device
|
|
|
|
|
from {codecache.__name__} import AsyncCompile
|
2023-01-11 00:08:03 +00:00
|
|
|
from torch._inductor.select_algorithm import extern_kernels
|
2022-10-13 23:18:06 +00:00
|
|
|
|
|
|
|
|
aten = torch.ops.aten
|
2022-10-16 17:16:04 +00:00
|
|
|
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
2022-10-13 23:18:06 +00:00
|
|
|
async_compile = AsyncCompile()
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if has_triton():
|
|
|
|
|
self.header.splice(
|
2023-02-01 15:39:25 +00:00
|
|
|
"""
|
2022-10-13 23:18:06 +00:00
|
|
|
import triton
|
|
|
|
|
import triton.language as tl
|
2023-03-23 00:54:20 +00:00
|
|
|
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
|
2022-10-13 23:18:06 +00:00
|
|
|
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
|
|
|
|
"""
|
|
|
|
|
)
|
|
|
|
|
|
2023-01-11 00:08:03 +00:00
|
|
|
def add_meta_once(self, meta):
|
|
|
|
|
meta = repr(meta)
|
|
|
|
|
if meta not in self._metas:
|
|
|
|
|
var = f"meta{len(self._metas)}"
|
|
|
|
|
self._metas[meta] = var
|
|
|
|
|
self.header.writeline(f"{var} = {meta}")
|
|
|
|
|
return self._metas[meta]
|
|
|
|
|
|
2022-12-13 09:52:54 +00:00
|
|
|
@cache_on_self
|
|
|
|
|
def get_output_refs(self):
|
|
|
|
|
return [x.codegen_reference() for x in V.graph.graph_outputs]
|
|
|
|
|
|
2023-04-04 23:57:50 +00:00
|
|
|
def mark_output_type(self):
|
|
|
|
|
return
|
|
|
|
|
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
def write_prefix(self):
|
2022-10-13 23:18:06 +00:00
|
|
|
self.prefix.splice(
|
2022-10-14 18:05:28 +00:00
|
|
|
"""
|
2022-10-13 23:18:06 +00:00
|
|
|
|
|
|
|
|
async_compile.wait(globals())
|
|
|
|
|
del async_compile
|
|
|
|
|
|
2022-10-14 18:05:28 +00:00
|
|
|
def call(args):
|
2022-10-13 23:18:06 +00:00
|
|
|
"""
|
|
|
|
|
)
|
2023-02-03 09:58:36 +00:00
|
|
|
with self.prefix.indent():
|
2022-12-08 16:22:26 +00:00
|
|
|
if config.triton.debug_sync_graph:
|
2023-02-03 09:58:36 +00:00
|
|
|
self.prefix.writeline("torch.cuda.synchronize()")
|
2022-10-14 18:05:28 +00:00
|
|
|
inp_len = len(V.graph.graph_inputs.keys())
|
|
|
|
|
if inp_len != 0:
|
|
|
|
|
lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}"
|
2023-02-03 09:58:36 +00:00
|
|
|
self.prefix.writeline(f"{lhs} = args")
|
|
|
|
|
self.prefix.writeline("args.clear()")
|
2022-10-13 23:18:06 +00:00
|
|
|
for name in V.graph.randomness_seeds:
|
2023-02-03 09:58:36 +00:00
|
|
|
self.prefix.writeline(
|
2022-10-13 23:18:06 +00:00
|
|
|
f"torch.randint(2**31, size=(), dtype=torch.int64, out={name})"
|
|
|
|
|
)
|
2023-03-13 14:20:36 +00:00
|
|
|
self.codegen_inputs(self.prefix, V.graph.graph_inputs)
|
2023-02-03 09:58:36 +00:00
|
|
|
|
|
|
|
|
def append_precomputed_sizes_to_prefix(self):
|
|
|
|
|
with self.prefix.indent():
|
2023-03-13 14:20:36 +00:00
|
|
|
self.codegen_precomputed_sizes(self.prefix)
|
2022-10-13 23:18:06 +00:00
|
|
|
|
|
|
|
|
def write_get_cuda_stream(self, index):
|
|
|
|
|
name = f"stream{index}"
|
|
|
|
|
self.writeline(f"{name} = get_cuda_stream({index})")
|
|
|
|
|
return name
|
|
|
|
|
|
2022-11-10 21:38:04 +00:00
|
|
|
def next_kernel_suffix(self):
|
|
|
|
|
return f"{next(self._names_iter)}"
|
2022-10-13 23:18:06 +00:00
|
|
|
|
generate device context managers in inductor code (#90934)
Fixes https://github.com/pytorch/torchdynamo/issues/1717, https://github.com/pytorch/torchdynamo/issues/1990
<s>TODO: add test with multiple devices, figure out extra context initialization</s>
Problems:
<s>It still initializes context on 0-th device that it shouldn't, I'll take a look where that happens and fix before landing</s>
It adds a python device context manages, that is absurdly slow and takes ~2.5 us (should be nanoseconds). That's not a problem for real models, because it'll be called just once, but it is a bit of an inconvenience for microbenchmarking, we should make that context manager more performant (won't fix in this PR)
It still can have bugs for graphs that run on multiple devices and can have buffers incorrectly shared between multiple device by memory reuse, if that happens that'll need to be solved separately.
Generated code:
```
def call(args):
arg0_1, arg1_1 = args
args.clear()
with torch.cuda.device(1):
buf0 = empty_strided((4, ), (1, ), device='cuda', dtype=torch.float32)
stream1 = get_cuda_stream(1)
triton_fused_div_0.run(arg0_1, arg1_1, buf0, 4, grid=grid(4), stream=stream1)
del arg0_1
del arg1_1
return (buf0, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90934
Approved by: https://github.com/wconstab
2022-12-16 18:03:39 +00:00
|
|
|
def codegen_cuda_device_guard_enter(self, device_idx):
|
2023-03-16 13:54:10 +00:00
|
|
|
self.writeline(EnterCudaDeviceContextManagerLine(device_idx))
|
generate device context managers in inductor code (#90934)
Fixes https://github.com/pytorch/torchdynamo/issues/1717, https://github.com/pytorch/torchdynamo/issues/1990
<s>TODO: add test with multiple devices, figure out extra context initialization</s>
Problems:
<s>It still initializes context on 0-th device that it shouldn't, I'll take a look where that happens and fix before landing</s>
It adds a python device context manages, that is absurdly slow and takes ~2.5 us (should be nanoseconds). That's not a problem for real models, because it'll be called just once, but it is a bit of an inconvenience for microbenchmarking, we should make that context manager more performant (won't fix in this PR)
It still can have bugs for graphs that run on multiple devices and can have buffers incorrectly shared between multiple device by memory reuse, if that happens that'll need to be solved separately.
Generated code:
```
def call(args):
arg0_1, arg1_1 = args
args.clear()
with torch.cuda.device(1):
buf0 = empty_strided((4, ), (1, ), device='cuda', dtype=torch.float32)
stream1 = get_cuda_stream(1)
triton_fused_div_0.run(arg0_1, arg1_1, buf0, 4, grid=grid(4), stream=stream1)
del arg0_1
del arg1_1
return (buf0, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90934
Approved by: https://github.com/wconstab
2022-12-16 18:03:39 +00:00
|
|
|
|
|
|
|
|
def codegen_cuda_device_guard_exit(self):
|
2023-03-16 13:54:10 +00:00
|
|
|
self.writeline(ExitCudaDeviceContextManagerLine())
|
generate device context managers in inductor code (#90934)
Fixes https://github.com/pytorch/torchdynamo/issues/1717, https://github.com/pytorch/torchdynamo/issues/1990
<s>TODO: add test with multiple devices, figure out extra context initialization</s>
Problems:
<s>It still initializes context on 0-th device that it shouldn't, I'll take a look where that happens and fix before landing</s>
It adds a python device context manages, that is absurdly slow and takes ~2.5 us (should be nanoseconds). That's not a problem for real models, because it'll be called just once, but it is a bit of an inconvenience for microbenchmarking, we should make that context manager more performant (won't fix in this PR)
It still can have bugs for graphs that run on multiple devices and can have buffers incorrectly shared between multiple device by memory reuse, if that happens that'll need to be solved separately.
Generated code:
```
def call(args):
arg0_1, arg1_1 = args
args.clear()
with torch.cuda.device(1):
buf0 = empty_strided((4, ), (1, ), device='cuda', dtype=torch.float32)
stream1 = get_cuda_stream(1)
triton_fused_div_0.run(arg0_1, arg1_1, buf0, 4, grid=grid(4), stream=stream1)
del arg0_1
del arg1_1
return (buf0, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90934
Approved by: https://github.com/wconstab
2022-12-16 18:03:39 +00:00
|
|
|
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
def generate_return(self, output_refs):
|
|
|
|
|
if output_refs:
|
|
|
|
|
self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )")
|
|
|
|
|
else:
|
|
|
|
|
self.wrapper_call.writeline("return ()")
|
|
|
|
|
|
|
|
|
|
def generate_end(self, result):
|
|
|
|
|
return
|
2022-10-13 23:18:06 +00:00
|
|
|
|
2023-04-11 23:55:44 +00:00
|
|
|
def generate_extern_kernel_alloc(self, output_name, kernel, args):
|
2023-04-13 15:41:03 +00:00
|
|
|
self.writeline(
|
|
|
|
|
f"{self.declare}{output_name} = {kernel}({', '.join(args)}){self.ending}"
|
|
|
|
|
)
|
2023-04-11 23:55:44 +00:00
|
|
|
|
2023-04-11 23:55:44 +00:00
|
|
|
def generate_extern_kernel_out(self, output_view, codegen_reference, args, kernel):
|
2022-12-14 15:43:31 +00:00
|
|
|
if output_view:
|
|
|
|
|
args.append(f"out={output_view.codegen_reference()}")
|
|
|
|
|
else:
|
|
|
|
|
args.append(f"out={codegen_reference}")
|
|
|
|
|
self.writeline(f"{kernel}({', '.join(args)})")
|
|
|
|
|
|
2023-03-31 05:06:39 +00:00
|
|
|
def generate_fusion_ops_code(
|
|
|
|
|
self,
|
|
|
|
|
name,
|
|
|
|
|
kernel,
|
|
|
|
|
codegen_args,
|
|
|
|
|
cpp_op_schema,
|
|
|
|
|
cpp_kernel_key,
|
|
|
|
|
cpp_kernel_overload_name="",
|
|
|
|
|
):
|
|
|
|
|
self.writeline(f"{name} = {kernel}({', '.join(codegen_args)})")
|
|
|
|
|
|
2023-01-17 20:25:18 +00:00
|
|
|
@dynamo_timed
|
2022-10-13 23:18:06 +00:00
|
|
|
def generate(self):
|
|
|
|
|
result = IndentedBuffer()
|
|
|
|
|
result.splice(self.header)
|
|
|
|
|
|
|
|
|
|
out_names = V.graph.get_output_names()
|
2022-11-29 00:58:46 +00:00
|
|
|
with contextlib.ExitStack() as stack:
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
stack.enter_context(self.wrapper_call.indent())
|
2022-11-29 00:58:46 +00:00
|
|
|
if config.profiler_mark_wrapper_call:
|
2023-03-20 01:46:20 +00:00
|
|
|
self.generate_profiler_mark_wrapper_call(stack)
|
2023-02-23 23:51:06 +00:00
|
|
|
if config.profile_bandwidth:
|
|
|
|
|
self.wrapper_call.writeline("start_graph()")
|
|
|
|
|
|
2022-10-13 23:18:06 +00:00
|
|
|
while (
|
|
|
|
|
self.lines
|
|
|
|
|
and isinstance(self.lines[-1], MemoryPlanningLine)
|
|
|
|
|
and self.lines[-1].node.name not in out_names
|
|
|
|
|
):
|
|
|
|
|
# these lines will be pointless
|
|
|
|
|
self.lines.pop()
|
|
|
|
|
|
|
|
|
|
# codegen allocations in two passes
|
|
|
|
|
planning_state = MemoryPlanningState()
|
|
|
|
|
for i in range(len(self.lines)):
|
|
|
|
|
if isinstance(self.lines[i], MemoryPlanningLine):
|
|
|
|
|
self.lines[i] = self.lines[i].plan(planning_state)
|
|
|
|
|
|
generate device context managers in inductor code (#90934)
Fixes https://github.com/pytorch/torchdynamo/issues/1717, https://github.com/pytorch/torchdynamo/issues/1990
<s>TODO: add test with multiple devices, figure out extra context initialization</s>
Problems:
<s>It still initializes context on 0-th device that it shouldn't, I'll take a look where that happens and fix before landing</s>
It adds a python device context manages, that is absurdly slow and takes ~2.5 us (should be nanoseconds). That's not a problem for real models, because it'll be called just once, but it is a bit of an inconvenience for microbenchmarking, we should make that context manager more performant (won't fix in this PR)
It still can have bugs for graphs that run on multiple devices and can have buffers incorrectly shared between multiple device by memory reuse, if that happens that'll need to be solved separately.
Generated code:
```
def call(args):
arg0_1, arg1_1 = args
args.clear()
with torch.cuda.device(1):
buf0 = empty_strided((4, ), (1, ), device='cuda', dtype=torch.float32)
stream1 = get_cuda_stream(1)
triton_fused_div_0.run(arg0_1, arg1_1, buf0, 4, grid=grid(4), stream=stream1)
del arg0_1
del arg1_1
return (buf0, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90934
Approved by: https://github.com/wconstab
2022-12-16 18:03:39 +00:00
|
|
|
device_cm_stack = contextlib.ExitStack()
|
2022-10-13 23:18:06 +00:00
|
|
|
for line in self.lines:
|
|
|
|
|
if isinstance(line, MemoryPlanningLine):
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
line.codegen(self.wrapper_call)
|
2023-04-03 17:21:01 +00:00
|
|
|
elif isinstance(
|
|
|
|
|
line,
|
|
|
|
|
(
|
|
|
|
|
EnterCudaDeviceContextManagerLine,
|
|
|
|
|
ExitCudaDeviceContextManagerLine,
|
|
|
|
|
),
|
|
|
|
|
):
|
|
|
|
|
line.codegen(self.wrapper_call, device_cm_stack)
|
2022-10-13 23:18:06 +00:00
|
|
|
else:
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
self.wrapper_call.writeline(line)
|
2022-10-13 23:18:06 +00:00
|
|
|
|
2022-12-13 09:52:54 +00:00
|
|
|
output_refs = self.get_output_refs()
|
2023-04-04 23:57:50 +00:00
|
|
|
self.mark_output_type()
|
2022-12-08 16:22:26 +00:00
|
|
|
if config.triton.debug_sync_graph:
|
|
|
|
|
self.wrapper_call.writeline("torch.cuda.synchronize()")
|
2023-02-23 23:51:06 +00:00
|
|
|
|
|
|
|
|
if config.profile_bandwidth:
|
|
|
|
|
self.wrapper_call.writeline("end_graph()")
|
|
|
|
|
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
self.generate_return(output_refs)
|
|
|
|
|
|
2023-02-03 09:58:36 +00:00
|
|
|
self.append_precomputed_sizes_to_prefix()
|
|
|
|
|
result.splice(self.prefix)
|
|
|
|
|
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
with result.indent():
|
|
|
|
|
result.splice(self.wrapper_call)
|
|
|
|
|
|
|
|
|
|
self.generate_end(result)
|
2022-10-13 23:18:06 +00:00
|
|
|
|
|
|
|
|
self.add_benchmark_harness(result)
|
|
|
|
|
|
2023-03-15 22:14:10 +00:00
|
|
|
return result.getvaluewithlinemap()
|
2022-10-13 23:18:06 +00:00
|
|
|
|
2023-03-13 14:20:36 +00:00
|
|
|
def codegen_inputs(self, code: IndentedBuffer, graph_inputs: Dict[str, ir.Buffer]):
|
|
|
|
|
"""Assign all symbolic shapes to locals"""
|
|
|
|
|
if self.need_seed:
|
|
|
|
|
code.writeline(
|
|
|
|
|
"seed = torch.randint(2**31, size=(), dtype=torch.int32).item()"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
|
|
|
def sizeof(name):
|
2023-04-04 23:57:50 +00:00
|
|
|
code.writeline(
|
|
|
|
|
f"{self.declare}{name}_size = {name}.{self.size}{self.ending}"
|
|
|
|
|
)
|
2023-03-13 14:20:36 +00:00
|
|
|
return f"{name}_size"
|
|
|
|
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
|
|
|
def strideof(name):
|
|
|
|
|
code.writeline(
|
2023-04-04 23:57:50 +00:00
|
|
|
f"{self.declare}{name}_stride = {name}.{self.stride}{self.ending}"
|
2023-03-13 14:20:36 +00:00
|
|
|
)
|
|
|
|
|
return f"{name}_stride"
|
|
|
|
|
|
|
|
|
|
# Assign all symbolic shapes needed to local variables
|
|
|
|
|
needed = set(V.graph.sizevars.var_to_val.keys()) - set(
|
|
|
|
|
V.graph.sizevars.replacements.keys()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def is_expr(x):
|
|
|
|
|
return isinstance(x[1], sympy.Expr)
|
|
|
|
|
|
|
|
|
|
graph_inputs_expr = list(filter(is_expr, graph_inputs.items()))
|
|
|
|
|
graph_inputs_tensors = list(
|
|
|
|
|
filter(lambda x: not is_expr(x), graph_inputs.items())
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for name, shape in graph_inputs_expr:
|
|
|
|
|
shape = V.graph.sizevars.simplify(shape)
|
|
|
|
|
if shape in needed:
|
|
|
|
|
needed.remove(shape)
|
|
|
|
|
code.writeline(f"{self.declare}{shape} = {name}{self.ending}")
|
|
|
|
|
|
|
|
|
|
for name, value in graph_inputs_tensors:
|
|
|
|
|
shapes = value.get_size()
|
|
|
|
|
for dim, shape in enumerate(shapes):
|
|
|
|
|
shape = V.graph.sizevars.simplify(shape)
|
|
|
|
|
if shape in needed:
|
|
|
|
|
needed.remove(shape)
|
|
|
|
|
code.writeline(
|
|
|
|
|
f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for name, value in graph_inputs_tensors:
|
|
|
|
|
shapes = value.get_stride()
|
|
|
|
|
for dim, shape in enumerate(shapes):
|
|
|
|
|
shape = V.graph.sizevars.simplify(shape)
|
|
|
|
|
if shape in needed:
|
|
|
|
|
needed.remove(shape)
|
|
|
|
|
code.writeline(
|
|
|
|
|
f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def codegen_precomputed_sizes(self, code: IndentedBuffer):
|
|
|
|
|
for sym, expr in V.graph.sizevars.inv_precomputed_replacements.items():
|
|
|
|
|
code.writeline(f"{self.declare}{sym} = {pexpr(expr)}")
|
|
|
|
|
|
[inductor] Fix benchmark_compiled_module codegen with CppWrapperCodeGen (#98608)
The python function `benchmark_compiled_module` ends up using C++ expression printer to print the size for `rand_strided`, so you get a set e.g. `{2, 17}` instead of a
tuple `(2, 17)`. Here is a complete example from master:
```python
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided({2, 17}, {17, 1}, device='cpu', dtype=torch.float32)
arg1_1 = rand_strided({2, 17}, {17, 1}, device='cpu', dtype=torch.uint8)
return print_performance(lambda: call([arg0_1, arg1_1]), times=times, repeat=repeat)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98608
Approved by: https://github.com/ngimel
2023-04-07 17:41:52 +00:00
|
|
|
def codegen_python_sizevar(self, x: Expr) -> str:
|
2023-03-13 14:20:36 +00:00
|
|
|
return pexpr(V.graph.sizevars.simplify(x))
|
|
|
|
|
|
[inductor] Fix benchmark_compiled_module codegen with CppWrapperCodeGen (#98608)
The python function `benchmark_compiled_module` ends up using C++ expression printer to print the size for `rand_strided`, so you get a set e.g. `{2, 17}` instead of a
tuple `(2, 17)`. Here is a complete example from master:
```python
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided({2, 17}, {17, 1}, device='cpu', dtype=torch.float32)
arg1_1 = rand_strided({2, 17}, {17, 1}, device='cpu', dtype=torch.uint8)
return print_performance(lambda: call([arg0_1, arg1_1]), times=times, repeat=repeat)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98608
Approved by: https://github.com/ngimel
2023-04-07 17:41:52 +00:00
|
|
|
def codegen_sizevar(self, x: Expr) -> str:
|
|
|
|
|
return self.codegen_python_sizevar(x)
|
|
|
|
|
|
|
|
|
|
def codegen_python_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
|
|
|
|
|
parts = list(map(self.codegen_python_sizevar, shape))
|
2023-03-13 14:20:36 +00:00
|
|
|
if len(parts) == 0:
|
|
|
|
|
return "()"
|
|
|
|
|
if len(parts) == 1:
|
|
|
|
|
return f"({parts[0]}, )"
|
|
|
|
|
return f"({', '.join(parts)})"
|
|
|
|
|
|
[inductor] Fix benchmark_compiled_module codegen with CppWrapperCodeGen (#98608)
The python function `benchmark_compiled_module` ends up using C++ expression printer to print the size for `rand_strided`, so you get a set e.g. `{2, 17}` instead of a
tuple `(2, 17)`. Here is a complete example from master:
```python
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided({2, 17}, {17, 1}, device='cpu', dtype=torch.float32)
arg1_1 = rand_strided({2, 17}, {17, 1}, device='cpu', dtype=torch.uint8)
return print_performance(lambda: call([arg0_1, arg1_1]), times=times, repeat=repeat)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98608
Approved by: https://github.com/ngimel
2023-04-07 17:41:52 +00:00
|
|
|
def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
|
|
|
|
|
return self.codegen_python_shape_tuple(shape)
|
|
|
|
|
|
2023-03-06 21:30:33 +00:00
|
|
|
def benchmark_compiled_module(self, output):
|
2022-10-13 23:18:06 +00:00
|
|
|
def add_fake_input(name, shape, stride, device, dtype):
|
|
|
|
|
output.writeline(
|
|
|
|
|
f"{name} = rand_strided("
|
[inductor] Fix benchmark_compiled_module codegen with CppWrapperCodeGen (#98608)
The python function `benchmark_compiled_module` ends up using C++ expression printer to print the size for `rand_strided`, so you get a set e.g. `{2, 17}` instead of a
tuple `(2, 17)`. Here is a complete example from master:
```python
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided({2, 17}, {17, 1}, device='cpu', dtype=torch.float32)
arg1_1 = rand_strided({2, 17}, {17, 1}, device='cpu', dtype=torch.uint8)
return print_performance(lambda: call([arg0_1, arg1_1]), times=times, repeat=repeat)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98608
Approved by: https://github.com/ngimel
2023-04-07 17:41:52 +00:00
|
|
|
f"{self.codegen_python_shape_tuple(shape)}, "
|
|
|
|
|
f"{self.codegen_python_shape_tuple(stride)}, "
|
generate device context managers in inductor code (#90934)
Fixes https://github.com/pytorch/torchdynamo/issues/1717, https://github.com/pytorch/torchdynamo/issues/1990
<s>TODO: add test with multiple devices, figure out extra context initialization</s>
Problems:
<s>It still initializes context on 0-th device that it shouldn't, I'll take a look where that happens and fix before landing</s>
It adds a python device context manages, that is absurdly slow and takes ~2.5 us (should be nanoseconds). That's not a problem for real models, because it'll be called just once, but it is a bit of an inconvenience for microbenchmarking, we should make that context manager more performant (won't fix in this PR)
It still can have bugs for graphs that run on multiple devices and can have buffers incorrectly shared between multiple device by memory reuse, if that happens that'll need to be solved separately.
Generated code:
```
def call(args):
arg0_1, arg1_1 = args
args.clear()
with torch.cuda.device(1):
buf0 = empty_strided((4, ), (1, ), device='cuda', dtype=torch.float32)
stream1 = get_cuda_stream(1)
triton_fused_div_0.run(arg0_1, arg1_1, buf0, 4, grid=grid(4), stream=stream1)
del arg0_1
del arg1_1
return (buf0, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90934
Approved by: https://github.com/wconstab
2022-12-16 18:03:39 +00:00
|
|
|
f"device='{device}', dtype={dtype})"
|
2022-10-13 23:18:06 +00:00
|
|
|
)
|
|
|
|
|
|
2023-02-27 20:26:18 +00:00
|
|
|
def add_expr_input(name, val):
|
|
|
|
|
output.writeline(f"{name} = {val}")
|
|
|
|
|
|
2023-04-01 00:33:13 +00:00
|
|
|
output.writelines(
|
|
|
|
|
["", "", "def benchmark_compiled_module(times=10, repeat=10):"]
|
|
|
|
|
)
|
2022-10-13 23:18:06 +00:00
|
|
|
with output.indent():
|
|
|
|
|
output.splice(
|
2023-01-17 20:25:18 +00:00
|
|
|
"""
|
|
|
|
|
from torch._dynamo.testing import rand_strided
|
|
|
|
|
from torch._inductor.utils import print_performance
|
2022-10-13 23:18:06 +00:00
|
|
|
""",
|
|
|
|
|
strip=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for name, value in V.graph.constants.items():
|
2023-03-27 17:52:39 +00:00
|
|
|
# all the constants are global variables, that's why we need
|
|
|
|
|
# these 'global var_name' lines
|
|
|
|
|
output.writeline(f"global {name}")
|
2022-10-13 23:18:06 +00:00
|
|
|
add_fake_input(
|
|
|
|
|
name, value.size(), value.stride(), value.device, value.dtype
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for name, value in V.graph.graph_inputs.items():
|
2023-02-27 20:26:18 +00:00
|
|
|
if isinstance(value, sympy.Expr): # Don't need to add symbolic
|
|
|
|
|
add_expr_input(name, V.graph.sizevars.size_hint(value))
|
|
|
|
|
else:
|
|
|
|
|
shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()]
|
|
|
|
|
stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()]
|
|
|
|
|
add_fake_input(
|
|
|
|
|
name, shape, stride, value.get_device(), value.get_dtype()
|
|
|
|
|
)
|
2022-10-13 23:18:06 +00:00
|
|
|
|
2023-04-03 17:21:01 +00:00
|
|
|
call_str = f"call([{', '.join(V.graph.graph_inputs.keys())}])"
|
2022-10-13 23:18:06 +00:00
|
|
|
output.writeline(
|
2023-04-05 21:34:58 +00:00
|
|
|
f"return print_performance(lambda: {call_str}, times=times, repeat=repeat)"
|
2022-10-13 23:18:06 +00:00
|
|
|
)
|
|
|
|
|
|
2023-03-06 21:30:33 +00:00
|
|
|
def add_benchmark_harness(self, output):
|
|
|
|
|
"""
|
|
|
|
|
Append a benchmark harness to generated code for debugging
|
|
|
|
|
"""
|
|
|
|
|
if not config.benchmark_harness:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
self.benchmark_compiled_module(output)
|
|
|
|
|
|
|
|
|
|
output.writelines(["", "", 'if __name__ == "__main__":'])
|
|
|
|
|
with output.indent():
|
|
|
|
|
output.writelines(
|
|
|
|
|
[
|
2023-04-01 00:33:13 +00:00
|
|
|
"from torch._inductor.utils import compiled_module_main",
|
|
|
|
|
f"compiled_module_main('{get_benchmark_name()}', benchmark_compiled_module)",
|
2023-03-06 21:30:33 +00:00
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
2023-03-13 21:26:52 +00:00
|
|
|
def define_kernel(self, name: str, kernel: str, metadata: str = None):
|
|
|
|
|
metadata_comment = f"{metadata}\n" if metadata else ""
|
|
|
|
|
self.header.splice(f"\n\n{metadata_comment}{name} = {kernel}")
|
2022-10-13 23:18:06 +00:00
|
|
|
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
def wrap_kernel_call(self, name, call_args):
|
2023-04-13 15:41:03 +00:00
|
|
|
return f"{name}({', '.join(call_args)}){self.ending}"
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
|
2023-03-20 01:46:20 +00:00
|
|
|
def generate_profiler_mark_wrapper_call(self, stack):
|
|
|
|
|
self.wrapper_call.writeline("from torch.profiler import record_function")
|
|
|
|
|
self.wrapper_call.writeline("with record_function('inductor_wrapper_call'):")
|
|
|
|
|
stack.enter_context(self.wrapper_call.indent())
|
|
|
|
|
|
2023-04-03 17:21:01 +00:00
|
|
|
def generate_kernel_call(self, name, call_args, device_index=None):
|
|
|
|
|
self.writeline(self.wrap_kernel_call(name, call_args))
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
|
2022-10-13 23:18:06 +00:00
|
|
|
def call_kernel(self, name: str, kernel: Kernel):
|
|
|
|
|
tmp = IndentedBuffer()
|
|
|
|
|
kernel.call_kernel(self, tmp, name)
|
|
|
|
|
for line in tmp.getvalue().split("\n"):
|
|
|
|
|
line = line.strip()
|
|
|
|
|
if line:
|
|
|
|
|
self.writeline(line)
|
|
|
|
|
|
|
|
|
|
def writeline(self, line):
|
|
|
|
|
self.lines.append(line)
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
|
2023-03-15 22:14:10 +00:00
|
|
|
def enter_context(self, ctx):
|
|
|
|
|
self.lines.append(LineContext(ctx))
|
|
|
|
|
|
2023-04-11 23:55:44 +00:00
|
|
|
def val_to_str(self, s):
|
|
|
|
|
return repr(s)
|
|
|
|
|
|
2023-03-16 13:54:10 +00:00
|
|
|
# The following methods are for memory management
|
|
|
|
|
def make_buffer_allocation(self, buffer):
|
|
|
|
|
device = buffer.get_device()
|
|
|
|
|
dtype = buffer.get_dtype()
|
|
|
|
|
shape = tuple(buffer.get_size())
|
|
|
|
|
stride = tuple(buffer.get_stride())
|
|
|
|
|
return (
|
|
|
|
|
f"{buffer.get_name()} = empty_strided("
|
|
|
|
|
f"{self.codegen_shape_tuple(shape)}, "
|
|
|
|
|
f"{self.codegen_shape_tuple(stride)}, "
|
|
|
|
|
f"device='{device.type}', dtype={dtype})"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def make_buffer_free(self, buffer):
|
|
|
|
|
return f"del {buffer.get_name()}"
|
|
|
|
|
|
|
|
|
|
def make_buffer_reuse(self, old, new):
|
|
|
|
|
assert old.get_dtype() == new.get_dtype()
|
|
|
|
|
del_line = ""
|
|
|
|
|
if old.get_name() not in V.graph.get_output_names():
|
|
|
|
|
del_line = f"; {self.make_buffer_free(old)}"
|
|
|
|
|
if old.get_size() == new.get_size() and old.get_stride() == new.get_stride():
|
|
|
|
|
return f"{self.declare}{new.get_name()} = {old.get_name()}{del_line} {self.comment} reuse"
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
f"{self.declare}{new.get_name()} = {self.namespace}as_strided({old.get_name()}, "
|
|
|
|
|
f"{self.codegen_shape_tuple(new.get_size())}, "
|
|
|
|
|
f"{self.codegen_shape_tuple(new.get_stride())}){del_line} {self.comment} reuse"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def codegen_deferred_allocation(self, name, layout):
|
|
|
|
|
self.writeline(
|
|
|
|
|
DeferredLine(
|
|
|
|
|
name,
|
|
|
|
|
f"{self.declare}{name} = {layout.view.codegen_reference()}{self.ending} {self.comment} alias",
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def codegen_allocation(self, buffer):
|
|
|
|
|
name = buffer.get_name()
|
|
|
|
|
if name in V.graph.removed_buffers or name in self.allocated:
|
|
|
|
|
return
|
|
|
|
|
self.allocated.add(name)
|
|
|
|
|
if isinstance(
|
|
|
|
|
buffer,
|
|
|
|
|
(ir.ExternKernelAlloc, ir.MultiOutput),
|
|
|
|
|
):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
layout = buffer.get_layout()
|
|
|
|
|
if isinstance(layout, ir.MutationLayout):
|
|
|
|
|
return
|
|
|
|
|
if isinstance(layout, ir.AliasedLayout):
|
2023-03-28 21:32:05 +00:00
|
|
|
assert isinstance(
|
|
|
|
|
layout.view, ir.ReinterpretView
|
|
|
|
|
), f"unexpected {type(layout.view)}: {layout.view}"
|
2023-03-16 13:54:10 +00:00
|
|
|
if not layout.maybe_guard_aligned():
|
|
|
|
|
V.graph.unaligned_buffers.add(name)
|
|
|
|
|
self.codegen_allocation(layout.view.data)
|
|
|
|
|
self.codegen_deferred_allocation(name, layout)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
self.writeline(AllocateLine(self, buffer))
|
|
|
|
|
|
|
|
|
|
def codegen_free(self, buffer):
|
|
|
|
|
name = buffer.get_name()
|
|
|
|
|
|
|
|
|
|
# can be freed but not reused
|
|
|
|
|
if isinstance(buffer, ir.InputBuffer):
|
|
|
|
|
self.writeline(self.make_buffer_free(buffer))
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if not self.can_reuse(buffer):
|
|
|
|
|
return
|
|
|
|
|
self.freed.add(name)
|
|
|
|
|
|
|
|
|
|
layout = buffer.get_layout()
|
|
|
|
|
if isinstance(layout, (ir.AliasedLayout, ir.MultiOutputLayout)):
|
|
|
|
|
self.writeline(self.make_buffer_free(buffer))
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
self.writeline(FreeIfNotReusedLine(self, buffer))
|
|
|
|
|
|
|
|
|
|
def can_reuse(self, buffer):
|
|
|
|
|
name = buffer.get_name()
|
|
|
|
|
if (
|
|
|
|
|
name in V.graph.removed_buffers
|
|
|
|
|
or name in V.graph.graph_inputs
|
|
|
|
|
or name in V.graph.constants
|
|
|
|
|
or name in self.freed
|
|
|
|
|
):
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def did_reuse(self, buffer, reused_buffer):
|
|
|
|
|
# Check whether a given buffer was reused by a possible reuser in the wrapper codegen
|
|
|
|
|
# Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed
|
|
|
|
|
return (
|
|
|
|
|
buffer.get_name() in self.reuses
|
|
|
|
|
and self.reuses[buffer.get_name()] == reused_buffer.get_name()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def codegen_inplace_reuse(self, input_buffer, output_buffer):
|
|
|
|
|
assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer)
|
|
|
|
|
self.codegen_allocation(input_buffer)
|
|
|
|
|
self.freed.add(input_buffer.get_name())
|
|
|
|
|
self.allocated.add(output_buffer.get_name())
|
|
|
|
|
self.reuses[output_buffer.get_name()] = input_buffer.get_name()
|
|
|
|
|
self.writeline(ReuseLine(self, input_buffer, output_buffer))
|
|
|
|
|
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
|
|
|
|
|
class CppWrapperCodeGen(WrapperCodeGen):
|
|
|
|
|
"""
|
2023-04-05 21:34:58 +00:00
|
|
|
Generates cpp wrapper for running on CPU and calls cpp kernels
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
2023-03-13 14:20:36 +00:00
|
|
|
self.declare = "auto "
|
|
|
|
|
self.ending = ";"
|
2023-04-13 15:41:03 +00:00
|
|
|
self.open_bracket = "{"
|
|
|
|
|
self.closed_bracket = "}"
|
2023-03-16 13:54:10 +00:00
|
|
|
self.comment = "//"
|
|
|
|
|
self.namespace = "at::"
|
2023-04-11 23:55:44 +00:00
|
|
|
self.none_str = "at::Tensor()"
|
2023-03-31 05:06:39 +00:00
|
|
|
self.extern_call_ops = set()
|
2023-04-04 23:57:50 +00:00
|
|
|
self.size = "sizes()"
|
|
|
|
|
self.stride = "strides()"
|
2023-04-24 00:58:21 +00:00
|
|
|
self.call_func_name = "inductor_entry_cpp"
|
2023-04-06 21:00:39 +00:00
|
|
|
self.cuda = False
|
2023-03-13 14:20:36 +00:00
|
|
|
|
|
|
|
|
def seed(self):
|
|
|
|
|
"""
|
|
|
|
|
Seed is a special variable used to hold the rng seed for a graph.
|
|
|
|
|
|
|
|
|
|
Note this is only used by the CPU backend, we put seeds in a
|
|
|
|
|
1-element tensor for the CUDA backend.
|
|
|
|
|
"""
|
|
|
|
|
self.need_seed = True
|
|
|
|
|
return sympy_symbol("seed")
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
|
2023-04-05 21:34:58 +00:00
|
|
|
def write_header(self):
|
|
|
|
|
if V.graph.aot_mode:
|
2023-04-24 00:58:21 +00:00
|
|
|
self.header.splice(
|
|
|
|
|
"""
|
|
|
|
|
/* AOTInductor generated code */
|
|
|
|
|
|
|
|
|
|
#include <ATen/ScalarOps.h>
|
|
|
|
|
"""
|
|
|
|
|
)
|
2023-04-05 21:34:58 +00:00
|
|
|
else:
|
|
|
|
|
self.header.splice(
|
|
|
|
|
"""
|
|
|
|
|
import torch
|
|
|
|
|
from torch.utils.cpp_extension import load_inline
|
|
|
|
|
|
|
|
|
|
cpp_wrapper_src = (
|
|
|
|
|
'''
|
|
|
|
|
"""
|
|
|
|
|
)
|
|
|
|
|
|
2023-04-04 23:57:50 +00:00
|
|
|
def mark_output_type(self):
|
|
|
|
|
# mark output type to unwrap tensor back to python scalar
|
|
|
|
|
from ..ir import ShapeAsConstantBuffer
|
|
|
|
|
|
|
|
|
|
output_is_tensor = dict()
|
|
|
|
|
for idx, x in enumerate(V.graph.graph_outputs):
|
|
|
|
|
if isinstance(x, ShapeAsConstantBuffer):
|
|
|
|
|
output_is_tensor[idx] = False
|
|
|
|
|
else:
|
|
|
|
|
output_is_tensor[idx] = True
|
|
|
|
|
|
|
|
|
|
self.output_is_tensor = output_is_tensor
|
|
|
|
|
|
2023-03-13 14:20:36 +00:00
|
|
|
def write_prefix(self):
|
2023-04-05 21:34:58 +00:00
|
|
|
return
|
2023-03-13 14:20:29 +00:00
|
|
|
|
2023-03-13 14:20:36 +00:00
|
|
|
def write_wrapper_decl(self):
|
2023-03-13 14:20:29 +00:00
|
|
|
inputs_len = len(V.graph.graph_inputs.keys())
|
2023-04-03 17:21:01 +00:00
|
|
|
self.prefix.splice(
|
2023-04-05 21:34:58 +00:00
|
|
|
f"""std::vector<at::Tensor> {self.call_func_name}(const std::vector<at::Tensor>& args) {{"""
|
2023-03-13 14:20:29 +00:00
|
|
|
)
|
|
|
|
|
with self.wrapper_call.indent():
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
if inputs_len != 0:
|
2022-12-14 15:43:32 +00:00
|
|
|
for idx, input_key in enumerate(V.graph.graph_inputs.keys()):
|
2023-04-04 23:57:50 +00:00
|
|
|
# unwrap input tensor back to scalar
|
|
|
|
|
if isinstance(V.graph.graph_inputs[input_key], sympy.Expr):
|
|
|
|
|
from ..graph import may_get_constant_buffer_dtype
|
|
|
|
|
from .cpp import DTYPE_TO_CPP
|
|
|
|
|
|
|
|
|
|
dtype = may_get_constant_buffer_dtype(
|
|
|
|
|
V.graph.graph_inputs[input_key]
|
|
|
|
|
)
|
|
|
|
|
assert (
|
|
|
|
|
dtype is not None
|
|
|
|
|
), "Fails to get the dtype of the sympy.Expr"
|
|
|
|
|
cpp_dtype = DTYPE_TO_CPP[dtype]
|
|
|
|
|
self.wrapper_call.writeline(f"{cpp_dtype} {input_key};")
|
|
|
|
|
self.wrapper_call.writeline(
|
|
|
|
|
f"{input_key} = args[{idx}].item<{cpp_dtype}>();"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
self.wrapper_call.writeline(f"at::Tensor {input_key};")
|
|
|
|
|
self.wrapper_call.writeline(f"{input_key} = args[{idx}];")
|
2022-12-14 15:43:32 +00:00
|
|
|
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
for name in V.graph.randomness_seeds:
|
|
|
|
|
self.wrapper_call.writeline(f"at::Tensor {name};")
|
|
|
|
|
self.wrapper_call.writeline(
|
|
|
|
|
f"{name} = at::randint(std::pow(2, 31), {{}}, at::ScalarType::Long);"
|
|
|
|
|
)
|
2023-03-13 14:20:36 +00:00
|
|
|
self.codegen_inputs(self.wrapper_call, V.graph.graph_inputs)
|
|
|
|
|
|
|
|
|
|
def generate(self):
|
|
|
|
|
self.write_wrapper_decl()
|
|
|
|
|
return super().generate()
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
|
2023-04-05 21:34:58 +00:00
|
|
|
def define_kernel(self, name: str, kernel: str, kernel_path: str = None):
|
|
|
|
|
self.header.splice(f"\n{kernel}\n")
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
|
|
|
|
|
def generate_return(self, output_refs):
|
2023-04-05 21:34:58 +00:00
|
|
|
self.wrapper_call.writeline(f"return {{{', '.join(output_refs)}}};\n}}")
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
|
|
|
|
|
def generate_end(self, result):
|
2023-04-05 21:34:58 +00:00
|
|
|
if V.graph.aot_mode:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
result.writeline("'''\n)")
|
|
|
|
|
# Generate load_inline to jit compile the generated cpp code and to use it in Python
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
shared = codecache.get_shared()
|
|
|
|
|
warning_all_flag = codecache.get_warning_all_flag()
|
|
|
|
|
cpp_flags = codecache.cpp_flags()
|
2023-04-05 21:34:58 +00:00
|
|
|
ipaths, lpaths, libs, macros = codecache.get_include_and_linking_paths(
|
2023-04-06 21:00:39 +00:00
|
|
|
vec_isa=codecache.pick_vec_isa(),
|
|
|
|
|
cuda=self.cuda,
|
2023-04-05 21:34:58 +00:00
|
|
|
)
|
2023-04-06 21:00:39 +00:00
|
|
|
optimization_flags = codecache.optimization_flags(cuda=self.cuda)
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
use_custom_generated_macros = codecache.use_custom_generated_macros()
|
|
|
|
|
|
|
|
|
|
extra_cflags = f"{cpp_flags} {optimization_flags} {warning_all_flag} {macros} {use_custom_generated_macros}"
|
|
|
|
|
extra_ldflags = f"{shared} {lpaths} {libs}"
|
|
|
|
|
extra_include_paths = f"{ipaths}"
|
|
|
|
|
|
|
|
|
|
# get the hash of the wrapper code to name the extension
|
|
|
|
|
wrapper_call_hash = codecache.code_hash(self.wrapper_call.getvalue())
|
|
|
|
|
result.splice(
|
|
|
|
|
f"""
|
|
|
|
|
module = load_inline(
|
|
|
|
|
name='inline_extension_{wrapper_call_hash}',
|
2023-04-05 21:34:58 +00:00
|
|
|
cpp_sources=[cpp_wrapper_src],
|
|
|
|
|
functions=['{self.call_func_name}'],
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
extra_cflags=['{extra_cflags}'],
|
|
|
|
|
extra_ldflags=['{extra_ldflags}'],
|
|
|
|
|
extra_include_paths=['{extra_include_paths}'])
|
|
|
|
|
"""
|
|
|
|
|
)
|
2023-04-04 23:57:50 +00:00
|
|
|
|
|
|
|
|
# unwrap output tensor back to python scalar
|
|
|
|
|
if all(x for x in self.output_is_tensor.values()):
|
|
|
|
|
# If no ShapeAsConstantBuffer in the output, directly return the output as tensors
|
|
|
|
|
return_str = "return f(args_tensor)"
|
|
|
|
|
else:
|
|
|
|
|
outputs = [
|
|
|
|
|
f"outputs[{i}]" if self.output_is_tensor[i] else f"outputs[{i}].item()"
|
|
|
|
|
for i in range(len(V.graph.graph_outputs))
|
|
|
|
|
]
|
|
|
|
|
outputs_str = f"[{', '.join(outputs)}]"
|
|
|
|
|
return_str = f"""
|
|
|
|
|
outputs = f(args_tensor)
|
|
|
|
|
return {outputs_str}
|
|
|
|
|
"""
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
# Wrap the func to support setting result._boxed_call = True
|
|
|
|
|
result.splice(
|
|
|
|
|
f"""
|
|
|
|
|
def _wrap_func(f):
|
|
|
|
|
def g(args):
|
2023-04-04 23:57:50 +00:00
|
|
|
args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]
|
|
|
|
|
{return_str}
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
return g
|
2023-04-05 21:34:58 +00:00
|
|
|
call = _wrap_func(module.{self.call_func_name})
|
Add a cpp wrapper for Inductor (#88167)
## Description
Implements https://github.com/pytorch/torchdynamo/issues/1556.
This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting:
```python
from torch._inductor import config
config.cpp_wrapper = True
```
### Example
The main part of the generated code:
```python
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float);
auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float);
auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW);
assert(kernel0_lib != nullptr);
void (*kernel0)(const float*,const float*,float*,float*);
*(void **) (&kernel0) = dlsym(kernel0_lib, "kernel");
kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr()));
arg0_1.reset();
arg1_1.reset();
return std::make_tuple(buf0, buf1); }''' )
module = load_inline(
name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu',
cpp_sources=[wrapper],
functions=['call_0'],
extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'],
extra_ldflags=['-shared -lgomp'],
extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m'])
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_0)
```
### Next steps
The below items will be addressed in upcoming PRs.
- [x] Support Reduction: #88561
- [x] Support None: #88560
- [ ] Support ExternKernel
- [x] ATen GEMM-related OPs: #88667
- [ ] ATen Conv
- [ ] Conv/GEMM fusion OPs
- [x] Cache the kernel loading part: #89742
- [ ] De-allocate input buffers when possible by leveraging CPython APIs
- [ ] Support Constant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
2022-11-30 10:35:05 +00:00
|
|
|
"""
|
|
|
|
|
)
|
2022-12-14 15:43:31 +00:00
|
|
|
|
2023-04-11 23:55:44 +00:00
|
|
|
def generate_extern_kernel_out(self, output_view, codegen_reference, args, kernel):
|
2022-12-14 15:43:31 +00:00
|
|
|
if output_view:
|
|
|
|
|
output_as_strided = f"{output_view.codegen_reference()}"
|
|
|
|
|
output_name = f"{output_view.get_name()}_as_strided"
|
|
|
|
|
self.writeline(f"auto {output_name} = {output_as_strided};")
|
|
|
|
|
|
|
|
|
|
args.insert(0, output_name)
|
|
|
|
|
else:
|
|
|
|
|
args.insert(0, f"{codegen_reference}")
|
2023-04-13 15:41:03 +00:00
|
|
|
self.writeline(self.wrap_kernel_call(kernel, args))
|
2023-03-13 14:20:29 +00:00
|
|
|
|
2023-04-24 00:58:21 +00:00
|
|
|
def add_benchmark_harness(self, output):
|
|
|
|
|
if V.graph.aot_mode:
|
|
|
|
|
return
|
|
|
|
|
super().add_benchmark_harness(output)
|
|
|
|
|
|
2023-04-04 23:57:50 +00:00
|
|
|
def codegen_sizevar(self, x: Expr) -> str:
|
|
|
|
|
from .cpp import cexpr
|
|
|
|
|
|
|
|
|
|
return cexpr(V.graph.sizevars.simplify(x))
|
|
|
|
|
|
2023-03-13 14:20:36 +00:00
|
|
|
def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
|
|
|
|
|
parts = list(map(self.codegen_sizevar, shape))
|
|
|
|
|
if len(parts) == 0:
|
|
|
|
|
return "{}"
|
|
|
|
|
if len(parts) == 1:
|
|
|
|
|
return f"{{{parts[0]}, }}"
|
|
|
|
|
return f"{{{', '.join(parts)}}}"
|
|
|
|
|
|
2023-03-16 13:54:10 +00:00
|
|
|
def make_buffer_free(self, buffer):
|
|
|
|
|
return f"{buffer.get_name()}.reset();"
|
|
|
|
|
|
2023-03-20 01:46:20 +00:00
|
|
|
def generate_profiler_mark_wrapper_call(self, stack):
|
|
|
|
|
self.wrapper_call.writeline(
|
|
|
|
|
'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef<c10::IValue>({{}}));'
|
|
|
|
|
)
|
|
|
|
|
|
2023-03-16 13:54:10 +00:00
|
|
|
def make_buffer_allocation(self, buffer):
|
2023-04-03 17:21:01 +00:00
|
|
|
from .cpp import DEVICE_TO_ATEN, DTYPE_TO_ATEN
|
2023-03-16 13:54:10 +00:00
|
|
|
|
2023-04-03 17:21:01 +00:00
|
|
|
# TODO: map layout here
|
|
|
|
|
device = buffer.get_device()
|
2023-03-16 13:54:10 +00:00
|
|
|
dtype = buffer.get_dtype()
|
|
|
|
|
shape = tuple(buffer.get_size())
|
|
|
|
|
stride = tuple(buffer.get_stride())
|
|
|
|
|
return (
|
|
|
|
|
f"{self.declare}{buffer.get_name()} = {self.namespace}empty_strided("
|
|
|
|
|
f"{self.codegen_shape_tuple(shape)}, "
|
|
|
|
|
f"{self.codegen_shape_tuple(stride)}, "
|
2023-04-03 17:21:01 +00:00
|
|
|
f"at::device({DEVICE_TO_ATEN[device.type]})"
|
|
|
|
|
f".dtype({DTYPE_TO_ATEN[dtype]})){self.ending}"
|
2023-03-16 13:54:10 +00:00
|
|
|
)
|
|
|
|
|
|
2023-03-31 05:06:39 +00:00
|
|
|
def generate_fusion_ops_code(
|
|
|
|
|
self,
|
|
|
|
|
name,
|
|
|
|
|
kernel,
|
|
|
|
|
codegen_args,
|
|
|
|
|
cpp_op_schema,
|
|
|
|
|
cpp_kernel_key,
|
|
|
|
|
cpp_kernel_overload_name="",
|
|
|
|
|
):
|
|
|
|
|
if cpp_kernel_key not in self.extern_call_ops:
|
|
|
|
|
self.writeline(
|
|
|
|
|
f"""
|
|
|
|
|
static auto op_{cpp_kernel_key} =
|
|
|
|
|
c10::Dispatcher::singleton()
|
|
|
|
|
.findSchemaOrThrow(
|
2023-04-11 23:55:44 +00:00
|
|
|
\"{kernel}\",
|
2023-03-31 05:06:39 +00:00
|
|
|
\"{cpp_kernel_overload_name}\")
|
|
|
|
|
.typed<{cpp_op_schema}>();
|
|
|
|
|
"""
|
|
|
|
|
)
|
|
|
|
|
self.extern_call_ops.add(cpp_kernel_key)
|
|
|
|
|
|
|
|
|
|
self.writeline(
|
|
|
|
|
f"auto {name} = op_{cpp_kernel_key}.call({', '.join(codegen_args)});"
|
|
|
|
|
)
|
|
|
|
|
|
2023-04-11 23:55:44 +00:00
|
|
|
def val_to_str(self, s):
|
|
|
|
|
if s is None:
|
|
|
|
|
return self.none_str
|
|
|
|
|
elif isinstance(s, bool):
|
|
|
|
|
return "true" if s else "false"
|
|
|
|
|
elif isinstance(s, str):
|
|
|
|
|
return f'"{s}"'
|
|
|
|
|
elif isinstance(s, (List, Tuple)):
|
2023-04-13 15:41:03 +00:00
|
|
|
vals = ", ".join(list(map(self.val_to_str, s)))
|
|
|
|
|
return f"{{{vals}}}"
|
2023-04-11 23:55:44 +00:00
|
|
|
else:
|
|
|
|
|
return repr(s)
|
|
|
|
|
|
2023-03-13 14:20:29 +00:00
|
|
|
|
2023-04-05 21:34:58 +00:00
|
|
|
class CudaWrapperCodeGen(CppWrapperCodeGen):
|
2023-03-13 14:20:29 +00:00
|
|
|
"""
|
2023-04-05 21:34:58 +00:00
|
|
|
Generates cpp wrapper for running on GPU and calls CUDA kernels
|
2023-04-03 17:21:01 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
2023-04-06 21:00:39 +00:00
|
|
|
self.kernel_callsite_id = count()
|
|
|
|
|
self.arg_var_id = count()
|
|
|
|
|
self.cuda = True
|
2023-04-03 17:21:01 +00:00
|
|
|
|
|
|
|
|
def write_prefix(self):
|
|
|
|
|
self.prefix.splice(
|
|
|
|
|
"""
|
2023-04-06 21:00:39 +00:00
|
|
|
#include <c10/util/Exception.h>
|
2023-04-03 17:21:01 +00:00
|
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
|
|
|
|
2023-04-13 15:41:03 +00:00
|
|
|
#define AT_CUDA_DRIVER_CHECK_OVERRIDE(EXPR) \\
|
|
|
|
|
do { \\
|
|
|
|
|
CUresult __err = EXPR; \\
|
|
|
|
|
if (__err != CUDA_SUCCESS) { \\
|
|
|
|
|
AT_ERROR("CUDA driver error: ", static_cast<int>(__err)); \\
|
|
|
|
|
} \\
|
2023-04-06 21:00:39 +00:00
|
|
|
} while (0)
|
|
|
|
|
|
|
|
|
|
static inline CUfunction loadKernel(const std::string &filePath,
|
|
|
|
|
const std::string &funcName) {
|
2023-04-03 17:21:01 +00:00
|
|
|
CUmodule mod;
|
|
|
|
|
CUfunction func;
|
2023-04-06 21:00:39 +00:00
|
|
|
AT_CUDA_DRIVER_CHECK_OVERRIDE(cuModuleLoad(&mod, filePath.c_str()));
|
|
|
|
|
AT_CUDA_DRIVER_CHECK_OVERRIDE(cuModuleGetFunction(&func, mod, funcName.c_str()));
|
2023-04-03 17:21:01 +00:00
|
|
|
return func;
|
|
|
|
|
}
|
|
|
|
|
|
2023-04-06 21:00:39 +00:00
|
|
|
static inline void launchKernel(
|
2023-04-03 17:21:01 +00:00
|
|
|
CUfunction func,
|
|
|
|
|
int gridX,
|
|
|
|
|
int gridY,
|
|
|
|
|
int gridZ,
|
|
|
|
|
int numWraps,
|
|
|
|
|
int sharedMemBytes,
|
|
|
|
|
void* args[],
|
|
|
|
|
int device_index) {
|
2023-04-06 21:00:39 +00:00
|
|
|
AT_CUDA_DRIVER_CHECK_OVERRIDE(cuLaunchKernel(
|
2023-04-03 17:21:01 +00:00
|
|
|
func, gridX, gridY, gridZ, 32*numWraps, 1, 1, sharedMemBytes,
|
|
|
|
|
at::cuda::getCurrentCUDAStream(device_index), args, nullptr));
|
|
|
|
|
}
|
|
|
|
|
"""
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def define_kernel(self, name: str, kernel: str, kernel_path: str = None):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def generate(self):
|
|
|
|
|
self.prefix.writeline("\n")
|
2023-04-06 21:00:39 +00:00
|
|
|
for kernel in self.src_to_kernel.values():
|
2023-04-03 17:21:01 +00:00
|
|
|
self.prefix.writeline(f"static CUfunction {kernel} = nullptr;")
|
|
|
|
|
self.prefix.writeline("\n")
|
|
|
|
|
return super().generate()
|
|
|
|
|
|
2023-04-06 21:00:39 +00:00
|
|
|
def generate_load_kernel(self, name, params):
|
2023-04-03 17:21:01 +00:00
|
|
|
mangled_name = params.get("mangled_name", None)
|
|
|
|
|
assert mangled_name is not None, "missing mangled_name"
|
2023-04-06 21:00:39 +00:00
|
|
|
cubin_path = params.get("cubin_path", None)
|
2023-04-03 17:21:01 +00:00
|
|
|
assert os.path.exists(
|
|
|
|
|
cubin_path
|
|
|
|
|
), "cubin file should already exist at this moment"
|
|
|
|
|
|
|
|
|
|
self.writeline(f"if ({name} == nullptr) {{")
|
|
|
|
|
self.writeline(
|
|
|
|
|
f""" {name} = loadKernel("{cubin_path}", "{mangled_name}");"""
|
|
|
|
|
)
|
|
|
|
|
self.writeline("}")
|
|
|
|
|
|
|
|
|
|
def generate_args_decl(self, call_args):
|
|
|
|
|
# TODO: only works for constant now, need type info
|
|
|
|
|
new_args = []
|
|
|
|
|
for arg in call_args:
|
2023-04-06 21:00:39 +00:00
|
|
|
var_name = f"var_{next(self.arg_var_id)}"
|
2023-04-03 17:21:01 +00:00
|
|
|
if is_int(arg):
|
|
|
|
|
self.writeline(f"int {var_name} = {arg};")
|
|
|
|
|
elif is_float(arg):
|
|
|
|
|
self.writeline(f"float {var_name} = {arg};")
|
|
|
|
|
else:
|
|
|
|
|
self.writeline(
|
|
|
|
|
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
|
|
|
|
|
)
|
|
|
|
|
new_args.append(f"&{var_name}")
|
|
|
|
|
|
|
|
|
|
return ", ".join(new_args)
|
|
|
|
|
|
|
|
|
|
def generate_kernel_call(self, name, call_args, device_index):
|
2023-04-06 21:00:39 +00:00
|
|
|
params = CudaKernelParamCache.get(self.kernel_to_hash.get(name, None))
|
2023-04-03 17:21:01 +00:00
|
|
|
assert (
|
|
|
|
|
params is not None
|
|
|
|
|
), "cuda kernel parameters should already exist at this moment"
|
|
|
|
|
|
2023-04-06 21:00:39 +00:00
|
|
|
self.generate_load_kernel(name, params)
|
2023-04-03 17:21:01 +00:00
|
|
|
|
|
|
|
|
call_args = self.generate_args_decl(call_args)
|
2023-04-06 21:00:39 +00:00
|
|
|
kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
|
|
|
|
|
self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};")
|
2023-04-03 17:21:01 +00:00
|
|
|
self.writeline(
|
2023-04-06 21:00:39 +00:00
|
|
|
"launchKernel({}, {}, {}, {}, {}, {}, {}, {});".format(
|
|
|
|
|
name,
|
|
|
|
|
params["grid_x"],
|
|
|
|
|
params["grid_y"],
|
|
|
|
|
params["grid_z"],
|
|
|
|
|
params["num_warps"],
|
|
|
|
|
params["shared_mem"],
|
|
|
|
|
kernel_args_var,
|
|
|
|
|
device_index,
|
|
|
|
|
)
|
2023-04-03 17:21:01 +00:00
|
|
|
)
|