From 6f5fddd814645b9353400849b40f6915f6e44972 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Sun, 9 Feb 2025 19:20:50 -0800 Subject: [PATCH] [Inductor UT] Set input tensors to corresponding device for test case in test_aot_indutor.py Fix #145247 ghstack-source-id: 3d452ab351e9c64a8d98230e702941176e44ea54 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145248 --- test/inductor/test_aot_inductor_utils.py | 13 +++++++++++++ test/inductor/test_torchinductor.py | 21 ++++----------------- torch/testing/_internal/inductor_utils.py | 13 +++++++++++++ 3 files changed, 30 insertions(+), 17 deletions(-) diff --git a/test/inductor/test_aot_inductor_utils.py b/test/inductor/test_aot_inductor_utils.py index 46f8885541b..37e10166bda 100644 --- a/test/inductor/test_aot_inductor_utils.py +++ b/test/inductor/test_aot_inductor_utils.py @@ -16,6 +16,7 @@ from torch._inductor import config from torch._inductor.test_case import TestCase from torch.testing import FileCheck from torch.testing._internal.common_utils import IS_FBCODE +from torch.testing._internal.inductor_utils import clone_preserve_strides_offset from torch.utils import _pytree as pytree @@ -177,6 +178,18 @@ def check_model( torch.manual_seed(0) if not isinstance(model, types.FunctionType): model = model.to(self.device) + + # For non mixed device inputs with default "cpu",set the device manully. + if all( + t.device.type == "cpu" + for t in example_inputs + if isinstance(t, torch.Tensor) + ): + example_inputs = tuple( + clone_preserve_strides_offset(x, device=self.device) + for x in example_inputs + ) + ref_model = copy.deepcopy(model) ref_inputs = copy.deepcopy(example_inputs) expected = ref_model(*ref_inputs) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 72bd6055368..b9ce55fefef 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -121,6 +121,7 @@ from torch._inductor.compile_fx import ( from torch._inductor.utils import has_torchvision_roi_align from torch.testing._internal.common_utils import slowTest from torch.testing._internal.inductor_utils import ( + clone_preserve_strides_offset, GPU_TYPE, HAS_CPU, HAS_GPU, @@ -372,20 +373,6 @@ def compute_grads(args, kwrags, results, grads): ) -def clone_preserve_strides(x, device=None): - if not isinstance(x, torch.Tensor): - return x - buffer = torch.as_strided( - x, (x.untyped_storage().size() // x.element_size(),), (1,), 0 - ) - if not device: - buffer = buffer.clone() - else: - buffer = buffer.to(device, copy=True) - out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset()) - return out - - def check_model( self: TestCase, model, @@ -409,7 +396,7 @@ def check_model( kwargs = kwargs or {} torch._dynamo.reset() - ref_inputs = [clone_preserve_strides(x) for x in example_inputs] + ref_inputs = [clone_preserve_strides_offset(x) for x in example_inputs] ref_kwargs = kwargs has_lowp_args = False @@ -422,7 +409,7 @@ def check_model( # Eager model may fail if the dtype is not supported eager_result = None - ref_inputs = [clone_preserve_strides(x) for x in example_inputs] + ref_inputs = [clone_preserve_strides_offset(x) for x in example_inputs] expect_dtypes = [ x.dtype if isinstance(x, torch.Tensor) else None for x in pytree.tree_leaves(eager_result) @@ -625,7 +612,7 @@ def check_model_gpu( if copy_to_gpu: example_inputs = tuple( - clone_preserve_strides(x, device=GPU_TYPE) for x in example_inputs + clone_preserve_strides_offset(x, device=GPU_TYPE) for x in example_inputs ) check_model( diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 13de003b330..c5e2b19c419 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -195,3 +195,16 @@ def maybe_skip_size_asserts(op): return torch._inductor.config.patch(size_asserts=False) else: return contextlib.nullcontext() + +def clone_preserve_strides_offset(x, device=None): + if not isinstance(x, torch.Tensor): + return x + buffer = torch.as_strided( + x, (x.untyped_storage().size() // x.element_size(),), (1,), 0 + ) + if not device: + buffer = buffer.clone() + else: + buffer = buffer.to(device, copy=True) + out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset()) + return out