mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
…inductor codegen. Fixes #136260 **Note**: this is my first code contribution to torch so please let me know if there's anything I need to fix/some other convention I should follow. Regarding the bug, re-running the issue's reproduction code: ``` import torch def fn(x): return x.to(device="cuda", non_blocking=True) inp = torch.randn(3, 4) torch.compile(fn)(inp) ``` We now have the non_blocking being passed on to codegen properly: ``` V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] TRACED GRAPH V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] ===== pre insert_deferred_runtime_asserts __compiled_fn_1 ===== V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] <eval_with_key>.0 class GraphModule(torch.nn.Module): V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] def forward(self, L_x_: "f32[3, 4]"): V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] l_x_ = L_x_ V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True) V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] to: "f32[3, 4]" = l_x_.to(device = 'cuda', non_blocking = True); l_x_ = None V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] return (to,) V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] TRACED GRAPH V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] ===== __compiled_fn_1 ===== V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] /home/niklasz/Desktop/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] def forward(self, L_x_: "f32[3, 4][4, 1]cpu"): V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] l_x_ = L_x_ V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True) V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] to: "f32[3, 4][4, 1]cuda:0" = l_x_.to(device = 'cuda', non_blocking = True); l_x_ = None V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] return (to,) V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] V0922 20:33:25.404000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:114] [0/0] [__aot_graphs] aot_config id: 0, fw_metadata=ViewAndMutationMeta(input_info=[InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=False, keep_input_mutations=True)], output_info=[OutputAliasInfo(output_type=<OutputType.non_alias: 1>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=None, dynamic_dims=set(), requires_grad=False, functional_tensor=None)], num_intermediate_bases=0, keep_input_mutations=True, traced_tangents=[], subclass_inp_meta=[0], subclass_fw_graph_out_meta=[0], subclass_tangent_meta=[], is_train=False, traced_tangent_metas=None, num_symints_saved_for_bw=None, grad_enabled_mutation=None, deterministic=None, static_input_indices=[], tokens={}, indices_of_inputs_that_requires_grad_with_mutations_in_bw=[], bw_donated_idxs=None, num_backward_tokens=0),subclass_metadata=None I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] TRACED GRAPH I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] ===== Forward graph 0 ===== I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] /home/niklasz/Desktop/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] def forward(self, arg0_1: "f32[3, 4][4, 1]cpu"): I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True) I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] device_put: "f32[3, 4][4, 1]cuda:0" = torch.ops.prims.device_put.default(arg0_1, device(type='cuda', index=0), True); arg0_1 = None I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] convert_element_type: "f32[3, 4][4, 1]cuda:0" = torch.ops.prims.convert_element_type.default(device_put, torch.float32); device_put = None I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] return (convert_element_type,) I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1134] [0/0] [__output_code] Output code written to: /tmp/torchinductor_niklasz/ha/chaai264g6ribfw3q2qhl6ayjtaqaavku5wivxtzw4nabgd6htsv.py V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] Output code: V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] # AOT ID: ['0_inference'] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from ctypes import c_void_p, c_long, c_int V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import torch V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import math V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import random V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import os V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import tempfile V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from math import inf, nan V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.hooks import run_intermediate_hooks V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.utils import maybe_profile V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.codegen.memory_planning import _align as align V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch import device, empty_strided V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.async_compile import AsyncCompile V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.select_algorithm import extern_kernels V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.codegen.multi_kernel import MultiKernelCall V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] aten = torch.ops.aten V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] inductor_ops = torch.ops.inductor V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] _quantized = torch.ops._quantized V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] async_compile = AsyncCompile() V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] async_compile.wait(globals()) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] del async_compile V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] def call(args): V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] arg0_1, = args V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] args.clear() V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] assert_size_stride(arg0_1, (3, 4), (4, 1)) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] with torch.cuda._DeviceGuard(0): V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] torch.cuda.set_device(0) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] buf0 = empty_strided_cuda((3, 4), (4, 1), torch.float32) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] buf0.copy_(arg0_1, True) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] del arg0_1 V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] return (buf0, ) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] def benchmark_compiled_module(times=10, repeat=10): V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._dynamo.testing import rand_strided V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.utils import print_performance V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] arg0_1 = rand_strided((3, 4), (4, 1), device='cpu', dtype=torch.float32) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] fn = lambda: call([arg0_1]) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] return print_performance(fn, times=times, repeat=repeat) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] if __name__ == "__main__": V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.wrapper_benchmark import compiled_module_main V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] compiled_module_main('None', benchmark_compiled_module) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] ``` See above line `buf0.copy_(arg0_1, True)`. Specific log setting used: `export TORCH_LOGS="graph_code,aot_graphs,output_code"` Pull Request resolved: https://github.com/pytorch/pytorch/pull/136513 Approved by: https://github.com/eellison |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| context.py | ||
| debug_prims.py | ||
| executor.py | ||
| rng_prims.py | ||