mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
generate device context managers in inductor code (#90934)
Fixes https://github.com/pytorch/torchdynamo/issues/1717, https://github.com/pytorch/torchdynamo/issues/1990 <s>TODO: add test with multiple devices, figure out extra context initialization</s> Problems: <s>It still initializes context on 0-th device that it shouldn't, I'll take a look where that happens and fix before landing</s> It adds a python device context manages, that is absurdly slow and takes ~2.5 us (should be nanoseconds). That's not a problem for real models, because it'll be called just once, but it is a bit of an inconvenience for microbenchmarking, we should make that context manager more performant (won't fix in this PR) It still can have bugs for graphs that run on multiple devices and can have buffers incorrectly shared between multiple device by memory reuse, if that happens that'll need to be solved separately. Generated code: ``` def call(args): arg0_1, arg1_1 = args args.clear() with torch.cuda.device(1): buf0 = empty_strided((4, ), (1, ), device='cuda', dtype=torch.float32) stream1 = get_cuda_stream(1) triton_fused_div_0.run(arg0_1, arg1_1, buf0, 4, grid=grid(4), stream=stream1) del arg0_1 del arg1_1 return (buf0, ) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/90934 Approved by: https://github.com/wconstab
This commit is contained in:
parent
9d8fa78d2c
commit
a10b3ce876
6 changed files with 79 additions and 33 deletions
|
|
@ -70,8 +70,12 @@ except (ImportError, AssertionError) as e:
|
|||
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||
|
||||
HAS_MULTIGPU = HAS_CUDA and torch.cuda.device_count() >= 2
|
||||
aten = torch.ops.aten
|
||||
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
|
||||
requires_multigpu = functools.partial(
|
||||
unittest.skipIf, not HAS_MULTIGPU, "requires multiple cuda devices"
|
||||
)
|
||||
|
||||
torch._inductor.config.triton.autotune = False # too slow
|
||||
|
||||
|
|
@ -2053,6 +2057,15 @@ class CommonTemplate:
|
|||
check_lowp=False, # cpu doesn't understand fp16, and there are explicit .cpu() calls
|
||||
)
|
||||
|
||||
@requires_multigpu()
|
||||
def test_multi_gpu_device(self):
|
||||
def fn(x, y):
|
||||
r = torch.ops.aten.div(x, y)
|
||||
r = r.to("cuda:1")
|
||||
return 2 * r
|
||||
|
||||
self.common(fn, (torch.randn(4), torch.randn(4)), check_lowp=False)
|
||||
|
||||
def test_unbind(self):
|
||||
def fn(a):
|
||||
return torch.unbind(a), torch.unbind(a, -1)
|
||||
|
|
|
|||
|
|
@ -524,7 +524,7 @@ class TritonFuture:
|
|||
|
||||
class AsyncCompile:
|
||||
def __init__(self):
|
||||
self._context_keepalive = None
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(1)
|
||||
|
|
@ -614,9 +614,6 @@ class AsyncCompile:
|
|||
|
||||
def triton(self, source_code):
|
||||
_compile_start()
|
||||
if self._context_keepalive is None:
|
||||
# Workaround `CUDA: Error- context is destroyed`
|
||||
self._context_keepalive = torch.tensor([1], device="cuda")
|
||||
|
||||
if config.compile_threads > 1:
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
|
|
|
|||
|
|
@ -234,29 +234,27 @@ class TritonTemplateKernel(TritonKernel):
|
|||
OUT_H = self.args_dict["OUT_H"]
|
||||
OUT_W = self.args_dict["OUT_W"]
|
||||
KERNEL_N = self.args_dict["KERNEL_N"]
|
||||
with code.indent():
|
||||
code.splice(
|
||||
f"""
|
||||
def grid_{name}(META):
|
||||
return (
|
||||
triton.cdiv({BATCH} * {OUT_H} * {OUT_W}, META["BLOCK_M"]),
|
||||
triton.cdiv({KERNEL_N}, META["BLOCK_N"]),
|
||||
)
|
||||
"""
|
||||
)
|
||||
code.splice(
|
||||
f"""
|
||||
def grid_{name}(META):
|
||||
return (
|
||||
triton.cdiv({BATCH} * {OUT_H} * {OUT_W}, META["BLOCK_M"]),
|
||||
triton.cdiv({KERNEL_N}, META["BLOCK_N"]),
|
||||
)
|
||||
"""
|
||||
)
|
||||
if isinstance(self.node, ir.MatrixMultiply):
|
||||
M = self.args_dict["M"]
|
||||
N = self.args_dict["N"]
|
||||
with code.indent():
|
||||
code.splice(
|
||||
f"""
|
||||
def grid_{name}(META):
|
||||
return (
|
||||
triton.cdiv({M}, META["BLOCK_M"]) * triton.cdiv({N}, META["BLOCK_N"]),
|
||||
META["SPLIT_K"],
|
||||
)
|
||||
"""
|
||||
)
|
||||
code.splice(
|
||||
f"""
|
||||
def grid_{name}(META):
|
||||
return (
|
||||
triton.cdiv({M}, META["BLOCK_M"]) * triton.cdiv({N}, META["BLOCK_N"]),
|
||||
META["SPLIT_K"],
|
||||
)
|
||||
"""
|
||||
)
|
||||
return code.getvalue()
|
||||
|
||||
def call_kernel(self, wrapper, name: str):
|
||||
|
|
@ -272,7 +270,9 @@ class TritonTemplateKernel(TritonKernel):
|
|||
", " + extra_args if extra_args and len(extra_args) > 0 else ""
|
||||
)
|
||||
args_kwargs = args + ", " + self_const_kwargs
|
||||
wrapper.writeline(self.gen_grid(name))
|
||||
lines = self.gen_grid(name).split("\n")
|
||||
for l in lines:
|
||||
wrapper.writeline(l)
|
||||
wrapper.writeline(f"{name}[grid_{name}]({args_kwargs})")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -92,6 +92,19 @@ class MemoryPlanningState:
|
|||
self.reuse_pool[key].append(item)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class EnterCudaDeviceContextManagerLine:
|
||||
device_idx: int
|
||||
|
||||
def codegen(self, code: IndentedBuffer):
|
||||
code.writeline(f"with torch.cuda.device({self.device_idx}):")
|
||||
|
||||
|
||||
class ExitCudaDeviceContextManagerLine:
|
||||
def codegen(self, code: IndentedBuffer):
|
||||
pass
|
||||
|
||||
|
||||
class MemoryPlanningLine:
|
||||
def plan(self, state: MemoryPlanningState) -> "MemoryPlanningLine":
|
||||
"""First pass to find reuse"""
|
||||
|
|
@ -428,6 +441,12 @@ class WrapperCodeGen(CodeGen):
|
|||
self.allocated.add(output_buffer.get_name())
|
||||
self.write_reuse_line(input_buffer, output_buffer)
|
||||
|
||||
def codegen_cuda_device_guard_enter(self, device_idx):
|
||||
self.lines.append(EnterCudaDeviceContextManagerLine(device_idx))
|
||||
|
||||
def codegen_cuda_device_guard_exit(self):
|
||||
self.lines.append(ExitCudaDeviceContextManagerLine())
|
||||
|
||||
def generate_return(self, output_refs):
|
||||
if output_refs:
|
||||
self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )")
|
||||
|
|
@ -477,9 +496,15 @@ class WrapperCodeGen(CodeGen):
|
|||
if isinstance(self.lines[i], MemoryPlanningLine):
|
||||
self.lines[i] = self.lines[i].plan(planning_state)
|
||||
|
||||
device_cm_stack = contextlib.ExitStack()
|
||||
for line in self.lines:
|
||||
if isinstance(line, MemoryPlanningLine):
|
||||
line.codegen(self.wrapper_call)
|
||||
elif isinstance(line, EnterCudaDeviceContextManagerLine):
|
||||
line.codegen(self.wrapper_call)
|
||||
device_cm_stack.enter_context(self.wrapper_call.indent())
|
||||
elif isinstance(line, ExitCudaDeviceContextManagerLine):
|
||||
device_cm_stack.close()
|
||||
else:
|
||||
self.wrapper_call.writeline(line)
|
||||
|
||||
|
|
@ -509,7 +534,7 @@ class WrapperCodeGen(CodeGen):
|
|||
f"{name} = rand_strided("
|
||||
f"{V.graph.sizevars.codegen_benchmark_shape_tuple(shape)}, "
|
||||
f"{V.graph.sizevars.codegen_benchmark_shape_tuple(stride)}, "
|
||||
f"device='{device.type}', dtype={dtype})"
|
||||
f"device='{device}', dtype={dtype})"
|
||||
)
|
||||
|
||||
output.writelines(["", "", 'if __name__ == "__main__":'])
|
||||
|
|
|
|||
|
|
@ -1119,6 +1119,16 @@ class Scheduler:
|
|||
or node.is_template()
|
||||
):
|
||||
self.flush()
|
||||
if device != self.current_device:
|
||||
if device.type == "cuda":
|
||||
if self.current_device and self.current_device.type == "cuda":
|
||||
V.graph.wrapper_code.codegen_cuda_device_guard_exit()
|
||||
assert device.index is not None, "device should have an index"
|
||||
V.graph.wrapper_code.codegen_cuda_device_guard_enter(
|
||||
device.index
|
||||
)
|
||||
elif self.current_device and self.current_device.type == "cuda":
|
||||
V.graph.wrapper_code.codegen_cuda_device_guard_exit()
|
||||
self.current_device = device
|
||||
|
||||
self.buffer_names_to_free.update(node.last_usage)
|
||||
|
|
|
|||
|
|
@ -69,7 +69,6 @@ class CachingAutotuner(KernelInterface):
|
|||
compile_meta["constants"][self.fn.arg_names.index(k)] = v
|
||||
compile_meta["num_warps"] = cfg.num_warps
|
||||
compile_meta["num_stages"] = cfg.num_stages
|
||||
|
||||
if warm_cache_only_with_cc:
|
||||
triton.compile(
|
||||
self.fn,
|
||||
|
|
@ -79,12 +78,14 @@ class CachingAutotuner(KernelInterface):
|
|||
)
|
||||
return
|
||||
|
||||
torch.cuda.set_device(torch.cuda.current_device())
|
||||
|
||||
binary = triton.compile(
|
||||
self.fn,
|
||||
**compile_meta,
|
||||
)
|
||||
# load binary to the correct device
|
||||
with torch.cuda.device(compile_meta["device"]):
|
||||
# need to initialize context
|
||||
torch.cuda.synchronize(torch.cuda.current_device())
|
||||
binary = triton.compile(
|
||||
self.fn,
|
||||
**compile_meta,
|
||||
)
|
||||
|
||||
call_args = [
|
||||
arg
|
||||
|
|
|
|||
Loading…
Reference in a new issue