cpp_wrapper: fix set_.source_Tensor lowering (#145654)

Adds a C-shim fallback for `set_.source_Tensor`, which is effectively required by `ir.SetSourceTensorKernel`. As a necessary prerequisite to use that IR node, updates `CppWrapperCpu` to handle in-place returns in C-shim ops (the arguments for those returns are silently dropped by `torchgen`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145654
Approved by: https://github.com/desertfire
ghstack dependencies: #145095
This commit is contained in:
Benjamin Glass 2025-02-04 18:04:10 +00:00 committed by PyTorch MergeBot
parent 7c0fe7a045
commit 9873319a42
6 changed files with 29 additions and 12 deletions

View file

@ -415,6 +415,10 @@ test_inductor_cpp_wrapper_shard() {
# Run certain inductor unit tests with cpp wrapper. In the end state, we # Run certain inductor unit tests with cpp wrapper. In the end state, we
# should be able to run all the inductor unit tests with cpp_wrapper. # should be able to run all the inductor unit tests with cpp_wrapper.
python test/run_test.py --include inductor/test_torchinductor --verbose python test/run_test.py --include inductor/test_torchinductor --verbose
python test/run_test.py \
--include inductor/test_cpu_repro \
-k 'test_set_source_Tensor' \
--verbose
# Run inductor benchmark tests with cpp wrapper. # Run inductor benchmark tests with cpp wrapper.
# Skip benchmark tests if it's in rerun-disabled-mode. # Skip benchmark tests if it's in rerun-disabled-mode.

View file

@ -714,12 +714,15 @@ class CPUReproTests(TestCase):
model = M(H=32, W=32, num_channels=4, num_colors=2) model = M(H=32, W=32, num_channels=4, num_colors=2)
fn_opt = torch.compile(model, backend="inductor") fn_opt = torch.compile(model, backend="inductor")
v = (torch.rand(10, 32, 32, 4) > 0.5).to(torch.float32) v = (torch.rand(10, 32, 32, 4) > 0.5).to(torch.float32)
inps = [ inp = v.clone()
v.clone(), result, code = run_and_get_cpp_code(fn_opt, inp)
] self.assertIn(
result, code = run_and_get_cpp_code(fn_opt, *inps) "aoti_torch_cpu_set__source_Tensor"
self.assertTrue("aten.set_.source_Tensor" in code) if config.cpp_wrapper
self.assertEqual(model(*inps), result) else "aten.set_.source_Tensor",
code,
)
self.assertEqual(model(inp), result)
@torch._dynamo.config.patch(dynamic_shapes=True) @torch._dynamo.config.patch(dynamic_shapes=True)
@torch._dynamo.config.patch(assume_static_by_default=False) @torch._dynamo.config.patch(assume_static_by_default=False)

View file

@ -1042,16 +1042,23 @@ class CppWrapperCpu(PythonWrapperCodegen):
f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));" f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));"
) )
def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args): def generate_c_shim_extern_kernel_alloc(
self, extern_kernel: ir.ExternKernelAlloc, args: list[str]
) -> None:
# registered output buffer name # registered output buffer name
name = extern_kernel.name name = extern_kernel.name
output_handle_name = f"{name}_handle" output_handle_name = f"{name}_handle"
self.writeline(f"AtenTensorHandle {output_handle_name};") is_inplace = (
output_arg = f"&{output_handle_name}" isinstance(extern_kernel.op_overload, torch._ops.OpOverload)
self.generate_c_shim_extern_kernel_call( and torch.Tag.inplace_view in extern_kernel.op_overload.tags
extern_kernel.get_kernel_name(), args + [output_arg]
) )
self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});")
if not is_inplace:
self.writeline(f"AtenTensorHandle {output_handle_name};")
args = [*args, f"&{output_handle_name}"]
self.generate_c_shim_extern_kernel_call(extern_kernel.get_kernel_name(), args)
if not is_inplace:
self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});")
def generate_extern_kernel_alloc(self, extern_kernel, args): def generate_extern_kernel_alloc(self, extern_kernel, args):
if getattr(extern_kernel, "outputs", None): if getattr(extern_kernel, "outputs", None):

View file

@ -127,6 +127,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_reduce_two_out(AtenTenso
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_searchsorted_Scalar(AtenTensorHandle sorted_sequence, double self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_searchsorted_Scalar(AtenTensorHandle sorted_sequence, double self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_searchsorted_Tensor(AtenTensorHandle sorted_sequence, AtenTensorHandle self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_searchsorted_Tensor(AtenTensorHandle sorted_sequence, AtenTensorHandle self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_set__source_Tensor(AtenTensorHandle self, AtenTensorHandle source);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);

View file

@ -134,6 +134,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_reduce_two_out(AtenTens
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_searchsorted_Scalar(AtenTensorHandle sorted_sequence, double self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_searchsorted_Scalar(AtenTensorHandle sorted_sequence, double self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_searchsorted_Tensor(AtenTensorHandle sorted_sequence, AtenTensorHandle self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_searchsorted_Tensor(AtenTensorHandle sorted_sequence, AtenTensorHandle self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_set__source_Tensor(AtenTensorHandle self, AtenTensorHandle source);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);

View file

@ -134,6 +134,7 @@ inductor_fallback_ops = {
"aten.searchsorted.Tensor", "aten.searchsorted.Tensor",
"aten._segment_reduce_backward.default", "aten._segment_reduce_backward.default",
"aten.segment_reduce.default", "aten.segment_reduce.default",
"aten.set_.source_Tensor",
"aten.slice.Tensor", "aten.slice.Tensor",
"aten.soft_margin_loss_backward.default", "aten.soft_margin_loss_backward.default",
"aten.sort.default", "aten.sort.default",