[Inductor UT] Set input tensors to corresponding device for test case in

test_aot_indutor.py

Fix #145247

ghstack-source-id: 3d452ab351
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145248
This commit is contained in:
xinan.lin 2025-02-09 19:20:50 -08:00
parent 07dbc2d692
commit 6f5fddd814
3 changed files with 30 additions and 17 deletions

View file

@ -16,6 +16,7 @@ from torch._inductor import config
from torch._inductor.test_case import TestCase
from torch.testing import FileCheck
from torch.testing._internal.common_utils import IS_FBCODE
from torch.testing._internal.inductor_utils import clone_preserve_strides_offset
from torch.utils import _pytree as pytree
@ -177,6 +178,18 @@ def check_model(
torch.manual_seed(0)
if not isinstance(model, types.FunctionType):
model = model.to(self.device)
# For non mixed device inputs with default "cpu",set the device manully.
if all(
t.device.type == "cpu"
for t in example_inputs
if isinstance(t, torch.Tensor)
):
example_inputs = tuple(
clone_preserve_strides_offset(x, device=self.device)
for x in example_inputs
)
ref_model = copy.deepcopy(model)
ref_inputs = copy.deepcopy(example_inputs)
expected = ref_model(*ref_inputs)

View file

@ -121,6 +121,7 @@ from torch._inductor.compile_fx import (
from torch._inductor.utils import has_torchvision_roi_align
from torch.testing._internal.common_utils import slowTest
from torch.testing._internal.inductor_utils import (
clone_preserve_strides_offset,
GPU_TYPE,
HAS_CPU,
HAS_GPU,
@ -372,20 +373,6 @@ def compute_grads(args, kwrags, results, grads):
)
def clone_preserve_strides(x, device=None):
if not isinstance(x, torch.Tensor):
return x
buffer = torch.as_strided(
x, (x.untyped_storage().size() // x.element_size(),), (1,), 0
)
if not device:
buffer = buffer.clone()
else:
buffer = buffer.to(device, copy=True)
out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
return out
def check_model(
self: TestCase,
model,
@ -409,7 +396,7 @@ def check_model(
kwargs = kwargs or {}
torch._dynamo.reset()
ref_inputs = [clone_preserve_strides(x) for x in example_inputs]
ref_inputs = [clone_preserve_strides_offset(x) for x in example_inputs]
ref_kwargs = kwargs
has_lowp_args = False
@ -422,7 +409,7 @@ def check_model(
# Eager model may fail if the dtype is not supported
eager_result = None
ref_inputs = [clone_preserve_strides(x) for x in example_inputs]
ref_inputs = [clone_preserve_strides_offset(x) for x in example_inputs]
expect_dtypes = [
x.dtype if isinstance(x, torch.Tensor) else None
for x in pytree.tree_leaves(eager_result)
@ -625,7 +612,7 @@ def check_model_gpu(
if copy_to_gpu:
example_inputs = tuple(
clone_preserve_strides(x, device=GPU_TYPE) for x in example_inputs
clone_preserve_strides_offset(x, device=GPU_TYPE) for x in example_inputs
)
check_model(

View file

@ -195,3 +195,16 @@ def maybe_skip_size_asserts(op):
return torch._inductor.config.patch(size_asserts=False)
else:
return contextlib.nullcontext()
def clone_preserve_strides_offset(x, device=None):
if not isinstance(x, torch.Tensor):
return x
buffer = torch.as_strided(
x, (x.untyped_storage().size() // x.element_size(),), (1,), 0
)
if not device:
buffer = buffer.clone()
else:
buffer = buffer.to(device, copy=True)
out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
return out