pytorch/test/dynamo/test_debug_utils.py
Elias Ellison 930b60f5aa Add Debug Utility To Generate Inputs for AOT Graphs (#119409)
```
    Takes in a function which has been printed with print_readable() and constructs kwargs to run it.
    Currently only handles Tensor inputs and a graph module which might have tensor constants.
    Example:
        Consider a function `forward` defined as follows:
        >>> def forward(self, primals_1: "f32[1001, 6]"):
        ...     _tensor_constant0: "i64[4190]" = self._tensor_constant0
        ...     # Further implementation
        >>> kwargs = aot_graph_input_parser(forward)
        >>> forward(**kwargs)
    """
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119409
Approved by: https://github.com/shunting314
2024-02-09 03:55:19 +00:00

156 lines
5.8 KiB
Python

# Owner(s): ["module: dynamo"]
import unittest
import torch
from functorch import make_fx
from torch._dynamo import debug_utils
from torch._dynamo.debug_utils import aot_graph_input_parser
from torch._dynamo.test_case import TestCase
from torch.testing._internal.inductor_utils import HAS_CUDA
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
class TestDebugUtils(TestCase):
def test_cast_model_to_fp64_dtype_args(self):
# Test that dtype arguments are converted to fp64
def fn(x):
return (
torch.ops.prims.convert_element_type(x, torch.float16),
x.to(torch.float16),
torch.full(x.shape, 2, dtype=torch.float32, device=x.device),
x.new_empty(x.shape),
)
x = torch.randn(32, device="cpu")
decomps = torch._decomp.core_aten_decompositions()
fx = make_fx(fn, decomposition_table=decomps)(x)
self.assertExpectedInline(
fx.code.lstrip(),
"""\
def forward(self, x_1):
convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float16)
_to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float16); x_1 = None
full = torch.ops.aten.full.default([32], 2, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
return (convert_element_type, _to_copy, full, empty)
""", # NOQA: B950
)
fp64_model, fp64_examples = debug_utils.cast_to_fp64(fx, (x,))
self.assertEqual(fp64_examples, (x.to(torch.float64),))
self.assertExpectedInline(
fx.code.lstrip(),
"""\
def forward(self, x_1):
convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float64)
_to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float64); x_1 = None
full = torch.ops.aten.full.default([32], 2, dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
return (convert_element_type, _to_copy, full, empty)
""", # NOQA: B950
)
@requires_cuda
def test_aot_graph_parser(self):
from torch import device
f32 = torch.float32
i64 = torch.int64
i32 = torch.int32
def forward(
self,
primals_1: "f32[1001, 6]",
primals_2: "f32[1001]",
primals_3: "f32[1001, 64]",
primals_4: "f32[4190]",
primals_5: "f32[4190]",
primals_6: "f32[1739, 4190]",
primals_48: "f32[6144, 4191]",
):
_tensor_constant0: "i64[4190]" = self._tensor_constant0
lift_fresh_copy: "i64[4190]" = torch.ops.aten.lift_fresh_copy.default(
_tensor_constant0
)
_tensor_constant0 = None
index: "f32[6144, 4190]" = torch.ops.aten.index.Tensor(
primals_48, [None, lift_fresh_copy]
)
lift_fresh_copy = None
_tensor_constant1: "i64[6]" = self._tensor_constant1
lift_fresh_copy_1: "i64[6]" = torch.ops.aten.lift_fresh_copy.default(
_tensor_constant1
)
_tensor_constant1 = None
index_1: "f32[6144, 6]" = torch.ops.aten.index.Tensor(
primals_48, [None, lift_fresh_copy_1]
)
primals_48 = lift_fresh_copy_1 = None
permute: "f32[6, 1001]" = torch.ops.aten.permute.default(primals_1, [1, 0])
primals_1 = None
addmm: "f32[6144, 1001]" = torch.ops.aten.addmm.default(
primals_2, index_1, permute
)
primals_2 = permute = None
amax: "f32[6144, 1]" = torch.ops.aten.amax.default(addmm, [-1], True)
sub: "f32[6144, 1001]" = torch.ops.aten.sub.Tensor(addmm, amax)
exp: "f32[6144, 1001]" = torch.ops.aten.exp.default(sub)
sub = None
sum_1: "f32[6144, 1]" = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
div: "f32[6144, 1001]" = torch.ops.aten.div.Tensor(exp, sum_1)
exp = None
full_default: "i32[6144, 1001]" = torch.ops.aten.full.default(
[6144, 1001],
1,
dtype=torch.int32,
layout=torch.strided,
device=device(type="cuda", index=0),
pin_memory=False,
)
iota: "i32[1001]" = torch.ops.prims.iota.default(
1001,
start=0,
step=1,
dtype=torch.int32,
device=device(type="cuda"),
requires_grad=False,
)
mul: "i32[6144, 1001]" = torch.ops.aten.mul.Tensor(full_default, iota)
full_default = iota = None
iota_1: "i32[6144]" = torch.ops.prims.iota.default(
6144,
start=0,
step=1001,
dtype=torch.int32,
device=device(type="cuda", index=0),
requires_grad=False,
)
view: "i32[6150144]" = torch.ops.aten.reshape.default(mul, [-1])
mul = None
view_1: "f32[6150144]" = torch.ops.aten.reshape.default(div, [-1])
div = None
_embedding_bag = torch.ops.aten._embedding_bag.default(
primals_3, view, iota_1, False, 0, False, view_1
)
return _embedding_bag
kwargs = aot_graph_input_parser(forward, device="cuda")
# runs successfully
forward(**kwargs)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()