mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
cpp_wrapper: fix set_.source_Tensor lowering (#145654)
Adds a C-shim fallback for `set_.source_Tensor`, which is effectively required by `ir.SetSourceTensorKernel`. As a necessary prerequisite to use that IR node, updates `CppWrapperCpu` to handle in-place returns in C-shim ops (the arguments for those returns are silently dropped by `torchgen`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/145654 Approved by: https://github.com/desertfire ghstack dependencies: #145095
This commit is contained in:
parent
7c0fe7a045
commit
9873319a42
6 changed files with 29 additions and 12 deletions
|
|
@ -415,6 +415,10 @@ test_inductor_cpp_wrapper_shard() {
|
||||||
# Run certain inductor unit tests with cpp wrapper. In the end state, we
|
# Run certain inductor unit tests with cpp wrapper. In the end state, we
|
||||||
# should be able to run all the inductor unit tests with cpp_wrapper.
|
# should be able to run all the inductor unit tests with cpp_wrapper.
|
||||||
python test/run_test.py --include inductor/test_torchinductor --verbose
|
python test/run_test.py --include inductor/test_torchinductor --verbose
|
||||||
|
python test/run_test.py \
|
||||||
|
--include inductor/test_cpu_repro \
|
||||||
|
-k 'test_set_source_Tensor' \
|
||||||
|
--verbose
|
||||||
|
|
||||||
# Run inductor benchmark tests with cpp wrapper.
|
# Run inductor benchmark tests with cpp wrapper.
|
||||||
# Skip benchmark tests if it's in rerun-disabled-mode.
|
# Skip benchmark tests if it's in rerun-disabled-mode.
|
||||||
|
|
|
||||||
|
|
@ -714,12 +714,15 @@ class CPUReproTests(TestCase):
|
||||||
model = M(H=32, W=32, num_channels=4, num_colors=2)
|
model = M(H=32, W=32, num_channels=4, num_colors=2)
|
||||||
fn_opt = torch.compile(model, backend="inductor")
|
fn_opt = torch.compile(model, backend="inductor")
|
||||||
v = (torch.rand(10, 32, 32, 4) > 0.5).to(torch.float32)
|
v = (torch.rand(10, 32, 32, 4) > 0.5).to(torch.float32)
|
||||||
inps = [
|
inp = v.clone()
|
||||||
v.clone(),
|
result, code = run_and_get_cpp_code(fn_opt, inp)
|
||||||
]
|
self.assertIn(
|
||||||
result, code = run_and_get_cpp_code(fn_opt, *inps)
|
"aoti_torch_cpu_set__source_Tensor"
|
||||||
self.assertTrue("aten.set_.source_Tensor" in code)
|
if config.cpp_wrapper
|
||||||
self.assertEqual(model(*inps), result)
|
else "aten.set_.source_Tensor",
|
||||||
|
code,
|
||||||
|
)
|
||||||
|
self.assertEqual(model(inp), result)
|
||||||
|
|
||||||
@torch._dynamo.config.patch(dynamic_shapes=True)
|
@torch._dynamo.config.patch(dynamic_shapes=True)
|
||||||
@torch._dynamo.config.patch(assume_static_by_default=False)
|
@torch._dynamo.config.patch(assume_static_by_default=False)
|
||||||
|
|
|
||||||
|
|
@ -1042,16 +1042,23 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));"
|
f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));"
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args):
|
def generate_c_shim_extern_kernel_alloc(
|
||||||
|
self, extern_kernel: ir.ExternKernelAlloc, args: list[str]
|
||||||
|
) -> None:
|
||||||
# registered output buffer name
|
# registered output buffer name
|
||||||
name = extern_kernel.name
|
name = extern_kernel.name
|
||||||
output_handle_name = f"{name}_handle"
|
output_handle_name = f"{name}_handle"
|
||||||
self.writeline(f"AtenTensorHandle {output_handle_name};")
|
is_inplace = (
|
||||||
output_arg = f"&{output_handle_name}"
|
isinstance(extern_kernel.op_overload, torch._ops.OpOverload)
|
||||||
self.generate_c_shim_extern_kernel_call(
|
and torch.Tag.inplace_view in extern_kernel.op_overload.tags
|
||||||
extern_kernel.get_kernel_name(), args + [output_arg]
|
|
||||||
)
|
)
|
||||||
self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});")
|
|
||||||
|
if not is_inplace:
|
||||||
|
self.writeline(f"AtenTensorHandle {output_handle_name};")
|
||||||
|
args = [*args, f"&{output_handle_name}"]
|
||||||
|
self.generate_c_shim_extern_kernel_call(extern_kernel.get_kernel_name(), args)
|
||||||
|
if not is_inplace:
|
||||||
|
self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});")
|
||||||
|
|
||||||
def generate_extern_kernel_alloc(self, extern_kernel, args):
|
def generate_extern_kernel_alloc(self, extern_kernel, args):
|
||||||
if getattr(extern_kernel, "outputs", None):
|
if getattr(extern_kernel, "outputs", None):
|
||||||
|
|
|
||||||
|
|
@ -127,6 +127,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_reduce_two_out(AtenTenso
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_searchsorted_Scalar(AtenTensorHandle sorted_sequence, double self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_searchsorted_Scalar(AtenTensorHandle sorted_sequence, double self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_searchsorted_Tensor(AtenTensorHandle sorted_sequence, AtenTensorHandle self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_searchsorted_Tensor(AtenTensorHandle sorted_sequence, AtenTensorHandle self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0);
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_set__source_Tensor(AtenTensorHandle self, AtenTensorHandle source);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||||
|
|
|
||||||
|
|
@ -134,6 +134,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_reduce_two_out(AtenTens
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_searchsorted_Scalar(AtenTensorHandle sorted_sequence, double self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_searchsorted_Scalar(AtenTensorHandle sorted_sequence, double self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_searchsorted_Tensor(AtenTensorHandle sorted_sequence, AtenTensorHandle self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_searchsorted_Tensor(AtenTensorHandle sorted_sequence, AtenTensorHandle self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0);
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_set__source_Tensor(AtenTensorHandle self, AtenTensorHandle source);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||||
|
|
|
||||||
|
|
@ -134,6 +134,7 @@ inductor_fallback_ops = {
|
||||||
"aten.searchsorted.Tensor",
|
"aten.searchsorted.Tensor",
|
||||||
"aten._segment_reduce_backward.default",
|
"aten._segment_reduce_backward.default",
|
||||||
"aten.segment_reduce.default",
|
"aten.segment_reduce.default",
|
||||||
|
"aten.set_.source_Tensor",
|
||||||
"aten.slice.Tensor",
|
"aten.slice.Tensor",
|
||||||
"aten.soft_margin_loss_backward.default",
|
"aten.soft_margin_loss_backward.default",
|
||||||
"aten.sort.default",
|
"aten.sort.default",
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue