mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Inductor UT] Set input tensors to corresponding device for test case in
test_aot_indutor.py
Fix #145247
ghstack-source-id: 3d452ab351
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145248
This commit is contained in:
parent
07dbc2d692
commit
6f5fddd814
3 changed files with 30 additions and 17 deletions
|
|
@ -16,6 +16,7 @@ from torch._inductor import config
|
|||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import IS_FBCODE
|
||||
from torch.testing._internal.inductor_utils import clone_preserve_strides_offset
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
|
|
@ -177,6 +178,18 @@ def check_model(
|
|||
torch.manual_seed(0)
|
||||
if not isinstance(model, types.FunctionType):
|
||||
model = model.to(self.device)
|
||||
|
||||
# For non mixed device inputs with default "cpu",set the device manully.
|
||||
if all(
|
||||
t.device.type == "cpu"
|
||||
for t in example_inputs
|
||||
if isinstance(t, torch.Tensor)
|
||||
):
|
||||
example_inputs = tuple(
|
||||
clone_preserve_strides_offset(x, device=self.device)
|
||||
for x in example_inputs
|
||||
)
|
||||
|
||||
ref_model = copy.deepcopy(model)
|
||||
ref_inputs = copy.deepcopy(example_inputs)
|
||||
expected = ref_model(*ref_inputs)
|
||||
|
|
|
|||
|
|
@ -121,6 +121,7 @@ from torch._inductor.compile_fx import (
|
|||
from torch._inductor.utils import has_torchvision_roi_align
|
||||
from torch.testing._internal.common_utils import slowTest
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
clone_preserve_strides_offset,
|
||||
GPU_TYPE,
|
||||
HAS_CPU,
|
||||
HAS_GPU,
|
||||
|
|
@ -372,20 +373,6 @@ def compute_grads(args, kwrags, results, grads):
|
|||
)
|
||||
|
||||
|
||||
def clone_preserve_strides(x, device=None):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return x
|
||||
buffer = torch.as_strided(
|
||||
x, (x.untyped_storage().size() // x.element_size(),), (1,), 0
|
||||
)
|
||||
if not device:
|
||||
buffer = buffer.clone()
|
||||
else:
|
||||
buffer = buffer.to(device, copy=True)
|
||||
out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
|
||||
return out
|
||||
|
||||
|
||||
def check_model(
|
||||
self: TestCase,
|
||||
model,
|
||||
|
|
@ -409,7 +396,7 @@ def check_model(
|
|||
kwargs = kwargs or {}
|
||||
torch._dynamo.reset()
|
||||
|
||||
ref_inputs = [clone_preserve_strides(x) for x in example_inputs]
|
||||
ref_inputs = [clone_preserve_strides_offset(x) for x in example_inputs]
|
||||
ref_kwargs = kwargs
|
||||
has_lowp_args = False
|
||||
|
||||
|
|
@ -422,7 +409,7 @@ def check_model(
|
|||
# Eager model may fail if the dtype is not supported
|
||||
eager_result = None
|
||||
|
||||
ref_inputs = [clone_preserve_strides(x) for x in example_inputs]
|
||||
ref_inputs = [clone_preserve_strides_offset(x) for x in example_inputs]
|
||||
expect_dtypes = [
|
||||
x.dtype if isinstance(x, torch.Tensor) else None
|
||||
for x in pytree.tree_leaves(eager_result)
|
||||
|
|
@ -625,7 +612,7 @@ def check_model_gpu(
|
|||
|
||||
if copy_to_gpu:
|
||||
example_inputs = tuple(
|
||||
clone_preserve_strides(x, device=GPU_TYPE) for x in example_inputs
|
||||
clone_preserve_strides_offset(x, device=GPU_TYPE) for x in example_inputs
|
||||
)
|
||||
|
||||
check_model(
|
||||
|
|
|
|||
|
|
@ -195,3 +195,16 @@ def maybe_skip_size_asserts(op):
|
|||
return torch._inductor.config.patch(size_asserts=False)
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def clone_preserve_strides_offset(x, device=None):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return x
|
||||
buffer = torch.as_strided(
|
||||
x, (x.untyped_storage().size() // x.element_size(),), (1,), 0
|
||||
)
|
||||
if not device:
|
||||
buffer = buffer.clone()
|
||||
else:
|
||||
buffer = buffer.to(device, copy=True)
|
||||
out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
|
||||
return out
|
||||
|
|
|
|||
Loading…
Reference in a new issue