diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 53219556dac..b8564caeb3d 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -190,6 +190,12 @@ aoti_torch_new_uninitialized_tensor(AtenTensorHandle* ret); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_tensor_copy_(AtenTensorHandle src, AtenTensorHandle dst); +// Make the tensor referred to by dst an alias for the tensor referred +// to by src. The two tensors must still be deleted with +// aoti_torch_delete_tensor separately (or not) as before the call. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_assign_tensors(AtenTensorHandle src, AtenTensorHandle dst); + AOTI_TORCH_EXPORT AOTITorchError aoti_torch_addmm_out( AtenTensorHandle out, AtenTensorHandle self, diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 96dd85666f6..1b1a9639183 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -297,6 +297,16 @@ AOTITorchError aoti_torch_tensor_copy_( }); } +AOTITorchError aoti_torch_assign_tensors( + AtenTensorHandle src, + AtenTensorHandle dst) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + at::Tensor* src_tensor = tensor_handle_to_tensor_pointer(src); + at::Tensor* dst_tensor = tensor_handle_to_tensor_pointer(dst); + *dst_tensor = *src_tensor; + }); +} + // TODO: implement a more efficient version instead of calling into aten AOTITorchError aoti_torch_addmm_out( AtenTensorHandle out,