From 9873319a420059ba2f64af4e31995b59b1a86f12 Mon Sep 17 00:00:00 2001 From: Benjamin Glass Date: Tue, 4 Feb 2025 18:04:10 +0000 Subject: [PATCH] cpp_wrapper: fix set_.source_Tensor lowering (#145654) Adds a C-shim fallback for `set_.source_Tensor`, which is effectively required by `ir.SetSourceTensorKernel`. As a necessary prerequisite to use that IR node, updates `CppWrapperCpu` to handle in-place returns in C-shim ops (the arguments for those returns are silently dropped by `torchgen`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/145654 Approved by: https://github.com/desertfire ghstack dependencies: #145095 --- .ci/pytorch/test.sh | 4 ++++ test/inductor/test_cpu_repro.py | 15 +++++++++------ torch/_inductor/codegen/cpp_wrapper_cpu.py | 19 +++++++++++++------ .../aoti_torch/generated/c_shim_cpu.h | 1 + .../aoti_torch/generated/c_shim_cuda.h | 1 + torchgen/aoti/fallback_ops.py | 1 + 6 files changed, 29 insertions(+), 12 deletions(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 0535bb9066d..e9baab4d1b8 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -415,6 +415,10 @@ test_inductor_cpp_wrapper_shard() { # Run certain inductor unit tests with cpp wrapper. In the end state, we # should be able to run all the inductor unit tests with cpp_wrapper. python test/run_test.py --include inductor/test_torchinductor --verbose + python test/run_test.py \ + --include inductor/test_cpu_repro \ + -k 'test_set_source_Tensor' \ + --verbose # Run inductor benchmark tests with cpp wrapper. # Skip benchmark tests if it's in rerun-disabled-mode. diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 7b8323300d5..b7a37ae05ef 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -714,12 +714,15 @@ class CPUReproTests(TestCase): model = M(H=32, W=32, num_channels=4, num_colors=2) fn_opt = torch.compile(model, backend="inductor") v = (torch.rand(10, 32, 32, 4) > 0.5).to(torch.float32) - inps = [ - v.clone(), - ] - result, code = run_and_get_cpp_code(fn_opt, *inps) - self.assertTrue("aten.set_.source_Tensor" in code) - self.assertEqual(model(*inps), result) + inp = v.clone() + result, code = run_and_get_cpp_code(fn_opt, inp) + self.assertIn( + "aoti_torch_cpu_set__source_Tensor" + if config.cpp_wrapper + else "aten.set_.source_Tensor", + code, + ) + self.assertEqual(model(inp), result) @torch._dynamo.config.patch(dynamic_shapes=True) @torch._dynamo.config.patch(assume_static_by_default=False) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 5328fbf66a7..e86034fc106 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1042,16 +1042,23 @@ class CppWrapperCpu(PythonWrapperCodegen): f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));" ) - def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args): + def generate_c_shim_extern_kernel_alloc( + self, extern_kernel: ir.ExternKernelAlloc, args: list[str] + ) -> None: # registered output buffer name name = extern_kernel.name output_handle_name = f"{name}_handle" - self.writeline(f"AtenTensorHandle {output_handle_name};") - output_arg = f"&{output_handle_name}" - self.generate_c_shim_extern_kernel_call( - extern_kernel.get_kernel_name(), args + [output_arg] + is_inplace = ( + isinstance(extern_kernel.op_overload, torch._ops.OpOverload) + and torch.Tag.inplace_view in extern_kernel.op_overload.tags ) - self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});") + + if not is_inplace: + self.writeline(f"AtenTensorHandle {output_handle_name};") + args = [*args, f"&{output_handle_name}"] + self.generate_c_shim_extern_kernel_call(extern_kernel.get_kernel_name(), args) + if not is_inplace: + self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});") def generate_extern_kernel_alloc(self, extern_kernel, args): if getattr(extern_kernel, "outputs", None): diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index 2a5eb60e9c8..924e77b28c2 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -127,6 +127,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_reduce_two_out(AtenTenso AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_searchsorted_Scalar(AtenTensorHandle sorted_sequence, double self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_searchsorted_Tensor(AtenTensorHandle sorted_sequence, AtenTensorHandle self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_set__source_Tensor(AtenTensorHandle self, AtenTensorHandle source); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index 1ab10c71fe8..8bf3cd03c7d 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -134,6 +134,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_reduce_two_out(AtenTens AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_searchsorted_Scalar(AtenTensorHandle sorted_sequence, double self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_searchsorted_Tensor(AtenTensorHandle sorted_sequence, AtenTensorHandle self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_set__source_Tensor(AtenTensorHandle self, AtenTensorHandle source); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1); diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index dead690831f..a0cdd6c402d 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -134,6 +134,7 @@ inductor_fallback_ops = { "aten.searchsorted.Tensor", "aten._segment_reduce_backward.default", "aten.segment_reduce.default", + "aten.set_.source_Tensor", "aten.slice.Tensor", "aten.soft_margin_loss_backward.default", "aten.sort.default",