diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 4b61e3e4e87..302bdd05a80 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -70,8 +70,12 @@ except (ImportError, AssertionError) as e: from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +HAS_MULTIGPU = HAS_CUDA and torch.cuda.device_count() >= 2 aten = torch.ops.aten requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda") +requires_multigpu = functools.partial( + unittest.skipIf, not HAS_MULTIGPU, "requires multiple cuda devices" +) torch._inductor.config.triton.autotune = False # too slow @@ -2053,6 +2057,15 @@ class CommonTemplate: check_lowp=False, # cpu doesn't understand fp16, and there are explicit .cpu() calls ) + @requires_multigpu() + def test_multi_gpu_device(self): + def fn(x, y): + r = torch.ops.aten.div(x, y) + r = r.to("cuda:1") + return 2 * r + + self.common(fn, (torch.randn(4), torch.randn(4)), check_lowp=False) + def test_unbind(self): def fn(a): return torch.unbind(a), torch.unbind(a, -1) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 747a7850b56..c34fefd9b60 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -524,7 +524,7 @@ class TritonFuture: class AsyncCompile: def __init__(self): - self._context_keepalive = None + pass @staticmethod @functools.lru_cache(1) @@ -614,9 +614,6 @@ class AsyncCompile: def triton(self, source_code): _compile_start() - if self._context_keepalive is None: - # Workaround `CUDA: Error- context is destroyed` - self._context_keepalive = torch.tensor([1], device="cuda") if config.compile_threads > 1: major, minor = torch.cuda.get_device_capability() diff --git a/torch/_inductor/codegen/triton_template.py b/torch/_inductor/codegen/triton_template.py index cd1c2bed6bb..de5e6f7c5d2 100644 --- a/torch/_inductor/codegen/triton_template.py +++ b/torch/_inductor/codegen/triton_template.py @@ -234,29 +234,27 @@ class TritonTemplateKernel(TritonKernel): OUT_H = self.args_dict["OUT_H"] OUT_W = self.args_dict["OUT_W"] KERNEL_N = self.args_dict["KERNEL_N"] - with code.indent(): - code.splice( - f""" - def grid_{name}(META): - return ( - triton.cdiv({BATCH} * {OUT_H} * {OUT_W}, META["BLOCK_M"]), - triton.cdiv({KERNEL_N}, META["BLOCK_N"]), - ) - """ - ) + code.splice( + f""" + def grid_{name}(META): + return ( + triton.cdiv({BATCH} * {OUT_H} * {OUT_W}, META["BLOCK_M"]), + triton.cdiv({KERNEL_N}, META["BLOCK_N"]), + ) + """ + ) if isinstance(self.node, ir.MatrixMultiply): M = self.args_dict["M"] N = self.args_dict["N"] - with code.indent(): - code.splice( - f""" - def grid_{name}(META): - return ( - triton.cdiv({M}, META["BLOCK_M"]) * triton.cdiv({N}, META["BLOCK_N"]), - META["SPLIT_K"], - ) - """ - ) + code.splice( + f""" + def grid_{name}(META): + return ( + triton.cdiv({M}, META["BLOCK_M"]) * triton.cdiv({N}, META["BLOCK_N"]), + META["SPLIT_K"], + ) + """ + ) return code.getvalue() def call_kernel(self, wrapper, name: str): @@ -272,7 +270,9 @@ class TritonTemplateKernel(TritonKernel): ", " + extra_args if extra_args and len(extra_args) > 0 else "" ) args_kwargs = args + ", " + self_const_kwargs - wrapper.writeline(self.gen_grid(name)) + lines = self.gen_grid(name).split("\n") + for l in lines: + wrapper.writeline(l) wrapper.writeline(f"{name}[grid_{name}]({args_kwargs})") diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index bfb6e89c2dd..19c1142e814 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -92,6 +92,19 @@ class MemoryPlanningState: self.reuse_pool[key].append(item) +@dataclasses.dataclass +class EnterCudaDeviceContextManagerLine: + device_idx: int + + def codegen(self, code: IndentedBuffer): + code.writeline(f"with torch.cuda.device({self.device_idx}):") + + +class ExitCudaDeviceContextManagerLine: + def codegen(self, code: IndentedBuffer): + pass + + class MemoryPlanningLine: def plan(self, state: MemoryPlanningState) -> "MemoryPlanningLine": """First pass to find reuse""" @@ -428,6 +441,12 @@ class WrapperCodeGen(CodeGen): self.allocated.add(output_buffer.get_name()) self.write_reuse_line(input_buffer, output_buffer) + def codegen_cuda_device_guard_enter(self, device_idx): + self.lines.append(EnterCudaDeviceContextManagerLine(device_idx)) + + def codegen_cuda_device_guard_exit(self): + self.lines.append(ExitCudaDeviceContextManagerLine()) + def generate_return(self, output_refs): if output_refs: self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )") @@ -477,9 +496,15 @@ class WrapperCodeGen(CodeGen): if isinstance(self.lines[i], MemoryPlanningLine): self.lines[i] = self.lines[i].plan(planning_state) + device_cm_stack = contextlib.ExitStack() for line in self.lines: if isinstance(line, MemoryPlanningLine): line.codegen(self.wrapper_call) + elif isinstance(line, EnterCudaDeviceContextManagerLine): + line.codegen(self.wrapper_call) + device_cm_stack.enter_context(self.wrapper_call.indent()) + elif isinstance(line, ExitCudaDeviceContextManagerLine): + device_cm_stack.close() else: self.wrapper_call.writeline(line) @@ -509,7 +534,7 @@ class WrapperCodeGen(CodeGen): f"{name} = rand_strided(" f"{V.graph.sizevars.codegen_benchmark_shape_tuple(shape)}, " f"{V.graph.sizevars.codegen_benchmark_shape_tuple(stride)}, " - f"device='{device.type}', dtype={dtype})" + f"device='{device}', dtype={dtype})" ) output.writelines(["", "", 'if __name__ == "__main__":']) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 6d7ac0c815e..3425baf1b3b 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -1119,6 +1119,16 @@ class Scheduler: or node.is_template() ): self.flush() + if device != self.current_device: + if device.type == "cuda": + if self.current_device and self.current_device.type == "cuda": + V.graph.wrapper_code.codegen_cuda_device_guard_exit() + assert device.index is not None, "device should have an index" + V.graph.wrapper_code.codegen_cuda_device_guard_enter( + device.index + ) + elif self.current_device and self.current_device.type == "cuda": + V.graph.wrapper_code.codegen_cuda_device_guard_exit() self.current_device = device self.buffer_names_to_free.update(node.last_usage) diff --git a/torch/_inductor/triton_ops/autotune.py b/torch/_inductor/triton_ops/autotune.py index a61927eb018..dde5de62bd9 100644 --- a/torch/_inductor/triton_ops/autotune.py +++ b/torch/_inductor/triton_ops/autotune.py @@ -69,7 +69,6 @@ class CachingAutotuner(KernelInterface): compile_meta["constants"][self.fn.arg_names.index(k)] = v compile_meta["num_warps"] = cfg.num_warps compile_meta["num_stages"] = cfg.num_stages - if warm_cache_only_with_cc: triton.compile( self.fn, @@ -79,12 +78,14 @@ class CachingAutotuner(KernelInterface): ) return - torch.cuda.set_device(torch.cuda.current_device()) - - binary = triton.compile( - self.fn, - **compile_meta, - ) + # load binary to the correct device + with torch.cuda.device(compile_meta["device"]): + # need to initialize context + torch.cuda.synchronize(torch.cuda.current_device()) + binary = triton.compile( + self.fn, + **compile_meta, + ) call_args = [ arg