mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
cpp_wrapper: fix set_.source_Tensor lowering (#145654)
Adds a C-shim fallback for `set_.source_Tensor`, which is effectively required by `ir.SetSourceTensorKernel`. As a necessary prerequisite to use that IR node, updates `CppWrapperCpu` to handle in-place returns in C-shim ops (the arguments for those returns are silently dropped by `torchgen`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/145654 Approved by: https://github.com/desertfire ghstack dependencies: #145095
This commit is contained in:
parent
7c0fe7a045
commit
9873319a42
6 changed files with 29 additions and 12 deletions
|
|
@ -415,6 +415,10 @@ test_inductor_cpp_wrapper_shard() {
|
|||
# Run certain inductor unit tests with cpp wrapper. In the end state, we
|
||||
# should be able to run all the inductor unit tests with cpp_wrapper.
|
||||
python test/run_test.py --include inductor/test_torchinductor --verbose
|
||||
python test/run_test.py \
|
||||
--include inductor/test_cpu_repro \
|
||||
-k 'test_set_source_Tensor' \
|
||||
--verbose
|
||||
|
||||
# Run inductor benchmark tests with cpp wrapper.
|
||||
# Skip benchmark tests if it's in rerun-disabled-mode.
|
||||
|
|
|
|||
|
|
@ -714,12 +714,15 @@ class CPUReproTests(TestCase):
|
|||
model = M(H=32, W=32, num_channels=4, num_colors=2)
|
||||
fn_opt = torch.compile(model, backend="inductor")
|
||||
v = (torch.rand(10, 32, 32, 4) > 0.5).to(torch.float32)
|
||||
inps = [
|
||||
v.clone(),
|
||||
]
|
||||
result, code = run_and_get_cpp_code(fn_opt, *inps)
|
||||
self.assertTrue("aten.set_.source_Tensor" in code)
|
||||
self.assertEqual(model(*inps), result)
|
||||
inp = v.clone()
|
||||
result, code = run_and_get_cpp_code(fn_opt, inp)
|
||||
self.assertIn(
|
||||
"aoti_torch_cpu_set__source_Tensor"
|
||||
if config.cpp_wrapper
|
||||
else "aten.set_.source_Tensor",
|
||||
code,
|
||||
)
|
||||
self.assertEqual(model(inp), result)
|
||||
|
||||
@torch._dynamo.config.patch(dynamic_shapes=True)
|
||||
@torch._dynamo.config.patch(assume_static_by_default=False)
|
||||
|
|
|
|||
|
|
@ -1042,16 +1042,23 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));"
|
||||
)
|
||||
|
||||
def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args):
|
||||
def generate_c_shim_extern_kernel_alloc(
|
||||
self, extern_kernel: ir.ExternKernelAlloc, args: list[str]
|
||||
) -> None:
|
||||
# registered output buffer name
|
||||
name = extern_kernel.name
|
||||
output_handle_name = f"{name}_handle"
|
||||
self.writeline(f"AtenTensorHandle {output_handle_name};")
|
||||
output_arg = f"&{output_handle_name}"
|
||||
self.generate_c_shim_extern_kernel_call(
|
||||
extern_kernel.get_kernel_name(), args + [output_arg]
|
||||
is_inplace = (
|
||||
isinstance(extern_kernel.op_overload, torch._ops.OpOverload)
|
||||
and torch.Tag.inplace_view in extern_kernel.op_overload.tags
|
||||
)
|
||||
self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});")
|
||||
|
||||
if not is_inplace:
|
||||
self.writeline(f"AtenTensorHandle {output_handle_name};")
|
||||
args = [*args, f"&{output_handle_name}"]
|
||||
self.generate_c_shim_extern_kernel_call(extern_kernel.get_kernel_name(), args)
|
||||
if not is_inplace:
|
||||
self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});")
|
||||
|
||||
def generate_extern_kernel_alloc(self, extern_kernel, args):
|
||||
if getattr(extern_kernel, "outputs", None):
|
||||
|
|
|
|||
|
|
@ -127,6 +127,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_reduce_two_out(AtenTenso
|
|||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_searchsorted_Scalar(AtenTensorHandle sorted_sequence, double self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_searchsorted_Tensor(AtenTensorHandle sorted_sequence, AtenTensorHandle self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_set__source_Tensor(AtenTensorHandle self, AtenTensorHandle source);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
|
|
|
|||
|
|
@ -134,6 +134,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_reduce_two_out(AtenTens
|
|||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_searchsorted_Scalar(AtenTensorHandle sorted_sequence, double self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_searchsorted_Tensor(AtenTensorHandle sorted_sequence, AtenTensorHandle self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_set__source_Tensor(AtenTensorHandle self, AtenTensorHandle source);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
|
|
|
|||
|
|
@ -134,6 +134,7 @@ inductor_fallback_ops = {
|
|||
"aten.searchsorted.Tensor",
|
||||
"aten._segment_reduce_backward.default",
|
||||
"aten.segment_reduce.default",
|
||||
"aten.set_.source_Tensor",
|
||||
"aten.slice.Tensor",
|
||||
"aten.soft_margin_loss_backward.default",
|
||||
"aten.sort.default",
|
||||
|
|
|
|||
Loading…
Reference in a new issue