From 8af31e30d7ebab01e2b2fb8ae8a80dc547ec23cf Mon Sep 17 00:00:00 2001 From: "Yanan Cao (PyTorch)" Date: Wed, 5 Feb 2025 22:56:54 +0000 Subject: [PATCH] [Codemod][AddExplicitStrictExportArg] caffe2/torch (#146439) Differential Revision: D69068432 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146439 Approved by: https://github.com/avikchaudhuri --- test/dynamo/test_export.py | 8 +- test/fx/test_fx_xform_observer.py | 4 +- test/inductor/test_aot_inductor.py | 2 +- test/test_fx.py | 1396 +++++++++++++++++----------- 4 files changed, 846 insertions(+), 564 deletions(-) diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 3a3e69c7052..a2e8f6b922d 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -4591,7 +4591,7 @@ class ExportTestsDevice(torch._dynamo.test_case.TestCase): random_inputs = (torch.rand([32, 3, 32, 32]).to(device),) dim_x = torch.export.Dim("dim_x", min=1, max=32) exp_program = torch.export.export( - model, random_inputs, dynamic_shapes={"x": {0: dim_x}} + model, random_inputs, dynamic_shapes={"x": {0: dim_x}}, strict=True ) output_buffer = io.BytesIO() # Tests if we can restore saved nn.Parameters when we load them again @@ -4621,7 +4621,9 @@ class ExportTestsDevice(torch._dynamo.test_case.TestCase): batchsize = torch.export.Dim("dim0", min=3, max=1024) dynamic_shape_spec = {"a": [batchsize, None, None], "b": [None, None]} - torch.export.export(model, (a, b), dynamic_shapes=dynamic_shape_spec) + torch.export.export( + model, (a, b), dynamic_shapes=dynamic_shape_spec, strict=True + ) def test_export_fast_binary_broadcast_check_unbacked(self, device): class MyModel(torch.nn.Module): @@ -4634,7 +4636,7 @@ class ExportTestsDevice(torch._dynamo.test_case.TestCase): model = MyModel().eval().to(device) numel = torch.tensor(10) scalar = torch.randn(1) - torch.export.export(model, (numel, scalar)) + torch.export.export(model, (numel, scalar), strict=True) common_utils.instantiate_parametrized_tests(ExportTests) diff --git a/test/fx/test_fx_xform_observer.py b/test/fx/test_fx_xform_observer.py index b272af9b17f..7d4370b5dcf 100644 --- a/test/fx/test_fx_xform_observer.py +++ b/test/fx/test_fx_xform_observer.py @@ -144,7 +144,7 @@ class TestGraphTransformObserver(TestCase): return torch.neg(x) model = SimpleLinearModel() - gm = torch.export.export(model, (torch.rand(10),)).module() + gm = torch.export.export(model, (torch.rand(10),), strict=True).module() with GraphTransformObserver(gm, "test"): add_node = gm.graph.call_function(torch.ops.aten.add.default, (1, 1)) @@ -171,7 +171,7 @@ class TestGraphTransformObserver(TestCase): return torch.neg(x) model = SimpleLinearModel() - gm = torch.export.export(model, (torch.rand(10),)).module() + gm = torch.export.export(model, (torch.rand(10),), strict=True).module() with GraphTransformObserver(gm, "test"): gm2 = copy.deepcopy(gm) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 23c7bc15379..f1c13adb64b 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -2213,7 +2213,7 @@ class AOTInductorTestsTemplate: example_inputs = (torch.randn(10, 10, device=self.device),) optimized = torch._inductor.aoti_load_package( torch._inductor.aoti_compile_and_package( - torch.export.export(Model(), example_inputs) + torch.export.export(Model(), example_inputs, strict=True) ) ) try: diff --git a/test/test_fx.py b/test/test_fx.py index 9083081f1ae..c702adbfaf3 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -2,185 +2,235 @@ # ruff: noqa: F841 import builtins -import contextlib import collections +import contextlib import copy import functools import inspect +import io import math import numbers -import io import operator import os import pickle import sys -import torch import traceback -import typing import types -import warnings +import typing import unittest -from math import sqrt -from functorch.experimental import control_flow -from torch.multiprocessing import Process -from torch.testing import FileCheck -from torch.testing._internal.common_methods_invocations import op_db -from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests -import torch.utils._pytree as pytree -import torch.fx._pytree as fx_pytree -from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, CodeGen -from torch.fx.node import Target, Argument, _format_arg -from torch.fx.passes import shape_prop -from torch.fx.immutable_collections import immutable_dict, immutable_list -from torch.fx.experimental.rewriter import RewritingTracer -from torch.fx.operator_schemas import get_signature_for_torch_op -from copy import deepcopy +import warnings from collections import namedtuple +from copy import deepcopy +from math import sqrt +from typing import Any, Callable, List, NamedTuple, Optional, Tuple, Union -from torch.fx.proxy import TraceError -from torch.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMPATIBILITY -from torch.fx._symbolic_trace import PHBase, PHWithMeta -from fx.test_subgraph_rewriter import TestSubgraphRewriter # noqa: F401 -from fx.test_dce_pass import TestDCE # noqa: F401 -from fx.test_fx_const_fold import TestConstFold # noqa: F401 -from fx.test_fx_param_shape_control_flow import TestConstParamShapeInControlFlow # noqa: F401 -from fx.test_pass_infra import TestPassManager # noqa: F401 +import torch +import torch.fx._pytree as fx_pytree +import torch.utils._pytree as pytree +from functorch.experimental import control_flow + +from fx.named_tup import MyNamedTup from fx.test_common_passes import TestCommonPass # noqa: F401 from fx.test_cse_pass import TestCSEPass # noqa: F401 -from fx.test_matcher_utils import TestMatcher # noqa: F401 -from fx.test_source_matcher_utils import TestSourceMatcher # noqa: F401 +from fx.test_dce_pass import TestDCE # noqa: F401 +from fx.test_fx_const_fold import TestConstFold # noqa: F401 +from fx.test_fx_param_shape_control_flow import ( # noqa: F401 + TestConstParamShapeInControlFlow, +) -from fx.test_gradual_type import AnnotationsTest # noqa: F401 -from fx.test_gradual_type import TypeCheckerTest # noqa: F401 -from typing import Any, Callable, NamedTuple, Optional, Union, Tuple, List +from fx.test_gradual_type import ( # noqa: F401 # noqa: F401 + AnnotationsTest, + TypeCheckerTest, +) +from fx.test_matcher_utils import TestMatcher # noqa: F401 +from fx.test_pass_infra import TestPassManager # noqa: F401 +from fx.test_source_matcher_utils import TestSourceMatcher # noqa: F401 +from fx.test_subgraph_rewriter import TestSubgraphRewriter # noqa: F401 +from torch.fx import ( + CodeGen, + Graph, + GraphModule, + Interpreter, + Node, + PH, + Proxy, + symbolic_trace, + Tracer, + Transformer, + wrap, +) +from torch.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMPATIBILITY +from torch.fx._symbolic_trace import PHBase, PHWithMeta +from torch.fx.experimental.rewriter import RewritingTracer +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.node import _format_arg, Argument, Target +from torch.fx.operator_schemas import get_signature_for_torch_op +from torch.fx.passes import shape_prop + +from torch.fx.proxy import TraceError +from torch.multiprocessing import Process +from torch.testing import FileCheck +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + onlyCPU, + ops, +) +from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_utils import ( + find_library_location, IS_FBCODE, IS_MACOS, IS_WINDOWS, - find_library_location, run_tests, skipIfTorchDynamo, ) from torch.testing._internal.jit_utils import JitTestCase -from fx.named_tup import MyNamedTup - try: from torchvision import models as torchvision_models + HAS_TORCHVISION = True except ImportError: HAS_TORCHVISION = False skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") from torch.testing._internal.common_quantization import skipIfNoDynamoSupport + class SimpleTest(torch.nn.Module): def forward(self, x): return torch.relu(x + 3.0) + def a_non_torch_leaf(a, b): return a + b + # Used for test_autowrap_function. Autowrapped functions need to be global def fx_int(x: float) -> int: return int(x) + def fx_int_x2(x: float) -> int: return int(x) * 2 + # used in test_pytree. It's all the way out here because pickling a GraphModule # that uses Point errors out if Point is local to the function -Point = namedtuple('Point', ['x', 'y']) +Point = namedtuple("Point", ["x", "y"]) + # Test wrap() passing both a function name as well as a function # directly def a_lifted_leaf(a, b): return a[0] + a[1] + b -wrap('a_lifted_leaf') + +wrap("a_lifted_leaf") # Test wrapping twice doesn't break anything -wrap('a_lifted_leaf') +wrap("a_lifted_leaf") + def a_lifted_leaf2(a, b): return a[0] + a[1] + b + wrap(a_lifted_leaf2) -wrap('len') +wrap("len") + +wrap("getattr") -wrap('getattr') def wrapped_named_tup(p1, *, p2): return p1.x + p2.y + wrap(wrapped_named_tup) + @wrap def wrapped_via_decorator(a): return a + 1 -wrap('wrapped_with_submodule') + +wrap("wrapped_with_submodule") + def wrapped_with_submodule(x: torch.Tensor, batchnorm1d: torch.nn.BatchNorm1d): return batchnorm1d(x) + def my_decorator(f): @functools.wraps(f) def wrapper_inside_decorator(*args, **kwargs): return f(*args, **kwargs) + return wrapper_inside_decorator + @wrap @my_decorator def wrapped_decorated_fn(x): return x + real_wrapped_via_decorator = wrapped_via_decorator real_a_lifed_leaf = a_lifted_leaf real_a_lifed_leaf2 = a_lifted_leaf2 _sqrt = sqrt -wrap('wrapper_fn') +wrap("wrapper_fn") + def wrapper_fn(x): return torch.foo(x) + class Pair(NamedTuple): - x : torch.Tensor - y : torch.Tensor + x: torch.Tensor + y: torch.Tensor def _custom_fx_repr_fn(self) -> str: return f"Pair(x={_format_arg(self.x)}, y={_format_arg(self.y)})" + # for testing pytrees class Foo: # noqa: B209 def __init__(self, a, b): self.a = a self.b = b + class Add(torch.nn.Module): def forward(self, x): return x + x + @torch.fx.has_side_effect @torch.fx.wrap def side_effect_func(x: torch.Tensor): print(x) + class TestFX(JitTestCase): def setUp(self): super().setUp() # Checking for mutable operations whil tracing is feature flagged # Enable it in testing but not by default - self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations + self.orig_tracer_mutable_flag = ( + torch.fx.proxy.TracerBase.check_mutable_operations + ) torch.fx.proxy.TracerBase.check_mutable_operations = True if not (IS_FBCODE or IS_WINDOWS or IS_MACOS): - lib_file_path = find_library_location('libtorchbind_test.so') + lib_file_path = find_library_location("libtorchbind_test.so") torch.ops.load_library(str(lib_file_path)) def tearDown(self): super().tearDown() - torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag + torch.fx.proxy.TracerBase.check_mutable_operations = ( + self.orig_tracer_mutable_flag + ) def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None): """Check that an nn.Module's results match the GraphModule version @@ -211,7 +261,9 @@ class TestFX(JitTestCase): def forward(self, A, B, c): t = torch.sigmoid(A) + self.lin(c) - return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3)) + return self.sub_mod( + t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3) + ) m = MyModule() gm = symbolic_trace(m) @@ -227,9 +279,8 @@ class TestFX(JitTestCase): gm2 = symbolic_trace(m2) class T(torch.nn.Module): - def forward(self, A, b=4, *args, c=5, **kwargs): - x = A + 1 + args[0] + kwargs['3'] + x = A + 1 + args[0] + kwargs["3"] return x t = T() @@ -258,8 +309,8 @@ class TestFX(JitTestCase): def test_custom_import(self): graph = torch.fx.Graph() - a = graph.placeholder('x') - b = graph.placeholder('y') + a = graph.placeholder("x") + b = graph.placeholder("y") c = graph.call_function(a_non_torch_leaf, (a, b)) d = graph.call_function(torch.sin, (c,)) graph.output(d) @@ -270,11 +321,11 @@ class TestFX(JitTestCase): def test_args_kwargs(self): class T(torch.nn.Module): def forward(self, *args, **kwargs): - x = args[0] + kwargs['foo'] + x = args[0] + kwargs["foo"] return x t = T() - self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)}) + self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {"foo": torch.rand(1)}) def test_varargs_concrete(self): class T(torch.nn.Module): @@ -298,8 +349,12 @@ class TestFX(JitTestCase): return torch.relu(args[1]) t = T() - with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'): - self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)}) + with self.assertRaisesRegex( + RuntimeError, r"cannot be part of \*args expansion" + ): + self.checkGraphModule( + t, (torch.rand(1), torch.rand(1)), {"foo": torch.rand(1)} + ) def test_fx_shifts(self): class MyModule(torch.nn.Module): @@ -324,9 +379,9 @@ class TestFX(JitTestCase): def test_dict(self): class MyDictMod(torch.nn.Module): def forward(self, d): - return d['3'].relu(), {'4' : d['3'].neg()} + return d["3"].relu(), {"4": d["3"].neg()} - input_dict = {'3': torch.rand(3, 4)} + input_dict = {"3": torch.rand(3, 4)} m = MyDictMod() self.checkGraphModule(m, (input_dict,)) @@ -359,18 +414,26 @@ class TestFX(JitTestCase): def f(x, y): x = control_flow.cond(x[0] == 0, true, false, [x, y]) - with self.assertRaisesRegex(RuntimeError, r"Expected pred to be bool or tensor, but got Proxy\(eq\)"): + with self.assertRaisesRegex( + RuntimeError, r"Expected pred to be bool or tensor, but got Proxy\(eq\)" + ): _ = symbolic_trace(f) def test_disallow_override(self): # Custom delegate to disallow in-place tensor operations class NoMutableCallTracer(Tracer): - def create_node(self, kind : str, target : Union[str, Callable], - args : tuple[Argument, ...], kwargs : dict[str, Any], name : Optional[str] = None, - type_expr : Optional[Any] = None) -> Node: + def create_node( + self, + kind: str, + target: Union[str, Callable], + args: tuple[Argument, ...], + kwargs: dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> Node: name = target if isinstance(target, str) else torch.typename(target) - if name[-1] == '_': - raise RuntimeError('In-place operations are not supported') + if name[-1] == "_": + raise RuntimeError("In-place operations are not supported") return super().create_node(kind, target, args, kwargs, name) # Test method @@ -381,7 +444,7 @@ class TestFX(JitTestCase): m = MyInplaceMod() - with self.assertRaisesRegex(RuntimeError, 'In-place operations'): + with self.assertRaisesRegex(RuntimeError, "In-place operations"): NoMutableCallTracer().trace(m) # Test free function @@ -389,8 +452,9 @@ class TestFX(JitTestCase): def forward(self, x): torch.log_(x) return x + m2 = MyInplaceMod2() - with self.assertRaisesRegex(RuntimeError, 'In-place operations'): + with self.assertRaisesRegex(RuntimeError, "In-place operations"): NoMutableCallTracer().trace(m2) # Test symbolic node as an arg @@ -399,8 +463,9 @@ class TestFX(JitTestCase): y = torch.ones(3, 4) y.add_(x) return x + m3 = MyInplaceMod3() - with self.assertRaisesRegex(RuntimeError, 'In-place operations'): + with self.assertRaisesRegex(RuntimeError, "In-place operations"): NoMutableCallTracer().trace(m3) def test_leaf_module(self): @@ -421,17 +486,21 @@ class TestFX(JitTestCase): mrm = MyReluMod() sym = NoLeafModulesTracer().trace(mrm) for node in sym.nodes: - self.assertNotEqual(node.op, 'call_module') + self.assertNotEqual(node.op, "call_module") sym.lint() def test_wrap(self): self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5)) def to_trace(y): - return a_lifted_leaf((4, y), 3) + a_lifted_leaf((3, 4), 5) + a_lifted_leaf((y, y), y) + return ( + a_lifted_leaf((4, y), 3) + + a_lifted_leaf((3, 4), 5) + + a_lifted_leaf((y, y), y) + ) m = symbolic_trace(to_trace) - self.assertIn('a_lifted_leaf', m.code) + self.assertIn("a_lifted_leaf", m.code) self.assertEqual(27, m(2)) self.assertIs(a_lifted_leaf, real_a_lifed_leaf) @@ -439,10 +508,14 @@ class TestFX(JitTestCase): self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5)) def to_trace(y): - return a_lifted_leaf2((4, y), 3) + a_lifted_leaf2((3, 4), 5) + a_lifted_leaf2((y, y), y) + return ( + a_lifted_leaf2((4, y), 3) + + a_lifted_leaf2((3, 4), 5) + + a_lifted_leaf2((y, y), y) + ) m = symbolic_trace(to_trace) - self.assertIn('a_lifted_leaf2', m.code) + self.assertIn("a_lifted_leaf2", m.code) self.assertEqual(27, m(2)) self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2) @@ -453,7 +526,7 @@ class TestFX(JitTestCase): return wrapped_via_decorator(y) m = symbolic_trace(to_trace) - self.assertIn('wrapped_via_decorator', m.code) + self.assertIn("wrapped_via_decorator", m.code) self.assertEqual(m(0), 1) self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) @@ -465,19 +538,18 @@ class TestFX(JitTestCase): return wrapped_via_decorator(y) m = symbolic_trace(to_trace) - self.assertIn('wrapped_via_decorator', m.code) + self.assertIn("wrapped_via_decorator", m.code) self.assertEqual(m(0), 1) self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) transformed = torch.fx.Transformer(m).transform() - self.assertIn('wrapped_via_decorator', transformed.code) + self.assertIn("wrapped_via_decorator", transformed.code) self.assertEqual(transformed(0), 1) self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) def test_wrap_with_submodule(self): - class M(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -499,11 +571,11 @@ class TestFX(JitTestCase): return wrapped_via_decorator(y) m = symbolic_trace(to_trace) - self.assertIn('wrapped_via_decorator', m.code) + self.assertIn("wrapped_via_decorator", m.code) self.assertEqual(m(0), 1) retraced = symbolic_trace(m) - self.assertIn('wrapped_via_decorator', retraced.code) + self.assertIn("wrapped_via_decorator", retraced.code) self.assertEqual(retraced(0), 1) def test_wrap_decorated_function(self): @@ -511,17 +583,18 @@ class TestFX(JitTestCase): return wrapped_decorated_fn(y) m = symbolic_trace(to_trace) - self.assertIn('wrapped_decorated_fn', m.code) + self.assertIn("wrapped_decorated_fn", m.code) self.assertEqual(m(1), 1) def test_graph_edit_with_proxy(self): class M(torch.nn.Module): def forward(self, a, b): return a + b + m = M() g = symbolic_trace(m).graph new_g = torch.fx.Graph() - val_map : dict[Node, Node] = {} + val_map: dict[Node, Node] = {} output_val = new_g.graph_copy(g, val_map) t = Proxy(output_val) # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. @@ -586,8 +659,10 @@ class TestFX(JitTestCase): return x if val is None else x + val f = Foo() - traced = torch.fx.symbolic_trace(f, concrete_args={'val' : None}) - with self.assertRaisesRegex(AssertionError, 'val has been specialized to have value None'): + traced = torch.fx.symbolic_trace(f, concrete_args={"val": None}) + with self.assertRaisesRegex( + AssertionError, "val has been specialized to have value None" + ): traced(torch.randn(5), torch.randn(5)) x = torch.randn(5) @@ -632,16 +707,17 @@ class TestFX(JitTestCase): class M(torch.nn.Module): def forward(self, a, b): return a + b + m = M() g = symbolic_trace(m).graph new_g = torch.fx.Graph() - val_map : dict[Node, Node] = {} + val_map: dict[Node, Node] = {} output_val = new_g.graph_copy(g, val_map) t = Proxy(output_val) # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. new_g.output((t + t).node) gm = GraphModule(m, new_g) - seen_names : set[str] = set() + seen_names: set[str] = set() for node in gm.graph.nodes: assert node.name not in seen_names seen_names.add(node.name) @@ -658,15 +734,15 @@ class TestFX(JitTestCase): # saving the original list because we will insert new nodes as a part of a test orig_graph_nodes = list(graph.nodes) for node in orig_graph_nodes: - if node.op == 'output': + if node.op == "output": continue self.assertTrue(node.stack_trace is not None) - assert 'test_fx.py' in node.stack_trace + assert "test_fx.py" in node.stack_trace # verify that copying the node does not lose the stack trace new_node = graph.node_copy(node) self.assertTrue(new_node.stack_trace is not None) - assert 'test_fx.py' in new_node.stack_trace + assert "test_fx.py" in new_node.stack_trace def test_stack_traces_with_transformer(self): class M(torch.nn.Module): @@ -682,10 +758,10 @@ class TestFX(JitTestCase): # nodes after Transformer should still preserve the original node's stack trace for node in new_gm.graph.nodes: - if node.op in {'placeholder', 'output'}: + if node.op in {"placeholder", "output"}: continue self.assertTrue(node.stack_trace is not None) - assert 'test_fx.py' in node.stack_trace + assert "test_fx.py" in node.stack_trace def test_lineno_map(self): class M(torch.nn.Module): @@ -703,22 +779,25 @@ class TestFX(JitTestCase): # test custom codegen def transform_code(code): return ["print('hello!')\n", *code] + gm.graph.on_generate_code(lambda _: transform_code) gm.recompile() expected = {2: 2, 3: 3, 4: 4, 5: 5} self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items()))) def test_graph_unique_names_manual(self): - graph : torch.fx.Graph = torch.fx.Graph() - a : torch.fx.Node = graph.create_node('placeholder', 'x') - b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1') - c : torch.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1') - d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) + graph: torch.fx.Graph = torch.fx.Graph() + a: torch.fx.Node = graph.create_node("placeholder", "x") + b: torch.fx.Node = graph.create_node( + "call_module", "linear_mod", args=(a,), name="foo_1_1" + ) + c: torch.fx.Node = graph.create_node("get_attr", "y_attr", name="foo_1") + d: torch.fx.Node = graph.create_node("call_function", operator.add, args=(b, c)) graph.output(d) graph2 = torch.fx.Graph() - val_map : dict[Node, Node] = {} + val_map: dict[Node, Node] = {} graph2.graph_copy(graph, val_map) - seen_names : set[str] = set() + seen_names: set[str] = set() for node in graph2.nodes: assert node.name not in seen_names seen_names.add(node.name) @@ -765,7 +844,9 @@ class TestFX(JitTestCase): # a valid nn.Module, symbolically traces it, lowers the Module to some # representation, and wraps that representation up into another # nn.Module instance that handles dispatch to the compiled/lowered code. - def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Module: + def lower_to_elementwise_interpreter( + orig_mod: torch.nn.Module, + ) -> torch.nn.Module: # ===== Stage 1: Symbolic trace the module ===== mod = symbolic_trace(orig_mod) @@ -776,12 +857,9 @@ class TestFX(JitTestCase): constants = {} fn_input_names = [] - target_to_name = { - operator.add : "add", - operator.mul : "mul" - } + target_to_name = {operator.add: "add", operator.mul: "mul"} - output_node : Optional[Node] = None + output_node: Optional[Node] = None # For each instruction, create a triple # (instruction_name : str, inputs : List[str], output : str) # to feed into the C++ interpreter @@ -789,31 +867,32 @@ class TestFX(JitTestCase): target, args, out_name = n.target, n.args, n.name assert len(n.kwargs) == 0, "kwargs currently not supported" - if n.op == 'placeholder': + if n.op == "placeholder": # Placeholders specify function argument names. Save these # for later when we generate the wrapper GraphModule fn_input_names.append(target) - elif n.op == 'call_function': + elif n.op == "call_function": assert target in target_to_name, "Unsupported call target " + target arg_names = [] for arg in args: if not isinstance(arg, Node): # Pull out constants. These constants will later be # fed to the interpreter C++ object via add_constant() - arg_name = f'constant_{constant_idx}' + arg_name = f"constant_{constant_idx}" constants[arg_name] = torch.tensor( - [arg] if isinstance(arg, numbers.Number) else arg) + [arg] if isinstance(arg, numbers.Number) else arg + ) arg_names.append(arg_name) constant_idx += 1 else: arg_names.append(arg.name) instructions.append((target_to_name[target], arg_names, out_name)) - elif n.op == 'output': + elif n.op == "output": if output_node is not None: - raise RuntimeError('Multiple output nodes!') + raise RuntimeError("Multiple output nodes!") output_node = n else: - raise RuntimeError('Unsupported opcode ' + n.op) + raise RuntimeError("Unsupported opcode " + n.op) interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter() # Load constants @@ -848,14 +927,17 @@ class TestFX(JitTestCase): # Add placeholders for fn inputs placeholder_nodes = [] for name in fn_input_names: - placeholder_nodes.append(graph.create_node('placeholder', name)) + placeholder_nodes.append(graph.create_node("placeholder", name)) # Get the interpreter object - interpreter_node = graph.create_node('get_attr', 'interpreter') + interpreter_node = graph.create_node("get_attr", "interpreter") # Add a node to call the interpreter instance output_node = graph.create_node( - op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes)) + op="call_method", + target="__call__", + args=(interpreter_node, placeholder_nodes), + ) # Register output graph.output(output_node) @@ -886,6 +968,7 @@ class TestFX(JitTestCase): def test_reserved_getattr(self): """Ensure that we do not name any nodes with a reserved builtin like `getattr`""" + class M(torch.nn.Module): def forward(self, a): return a.foo.bar.baz @@ -912,7 +995,7 @@ class TestFX(JitTestCase): x = torch.mm(x, self.mm_param) skip_connection = x x = torch.relu(x) - x = torch.mm(x, self.mm_param) + self.buffer[:x.shape[0]] + x = torch.mm(x, self.mm_param) + self.buffer[: x.shape[0]] x = self.lin(x) x = torch.relu(x) x = x + skip_connection @@ -929,11 +1012,17 @@ class TestFX(JitTestCase): def test_node_tagging(self): class TaggingTracer(Tracer): - def create_node(self, kind : str, target : Union[str, Callable], - args : tuple[Argument, ...], kwargs : dict[str, Any], name : Optional[str] = None, - type_expr : Optional[Any] = None) -> Node: + def create_node( + self, + kind: str, + target: Union[str, Callable], + args: tuple[Argument, ...], + kwargs: dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> Node: n = super().create_node(kind, target, args, kwargs, name) - n.tag = 'foo' + n.tag = "foo" return n class M(torch.nn.Module): @@ -944,8 +1033,8 @@ class TestFX(JitTestCase): g = TaggingTracer().trace(m) g.lint() for n in g.nodes: - self.assertTrue(hasattr(n, 'tag')) - self.assertEqual(n.tag, 'foo') + self.assertTrue(hasattr(n, "tag")) + self.assertEqual(n.tag, "foo") def test_tensor_attribute(self): class TensorAttribute(torch.nn.Module): @@ -974,11 +1063,10 @@ class TestFX(JitTestCase): traced2(torch.rand(4, 4)) def test_tensor_attribute_coalseced(self): - def count_attrs(fx_module): targets = set() for node in traced.graph.nodes: - if node.op == 'get_attr': + if node.op == "get_attr": targets.add(node.target) return len(targets) @@ -986,6 +1074,7 @@ class TestFX(JitTestCase): def f(x): return x + val + val + traced = symbolic_trace(f) traced.graph.lint() self.assertEqual(count_attrs(traced), 1) @@ -1005,11 +1094,7 @@ class TestFX(JitTestCase): def forward(self, x): return torch.neg(x) - seq = torch.nn.Sequential( - Simple(), - Simple(), - Simple() - ) + seq = torch.nn.Sequential(Simple(), Simple(), Simple()) traced = symbolic_trace(seq) traced.graph.lint() x = torch.rand(3, 4) @@ -1045,8 +1130,8 @@ class TestFX(JitTestCase): def test_pickle_custom_import(self): graph = torch.fx.Graph() - a = graph.placeholder('x') - b = graph.placeholder('y') + a = graph.placeholder("x") + b = graph.placeholder("y") c = graph.call_function(a_non_torch_leaf, (a, b)) d = graph.call_function(torch.sin, (c,)) graph.output(d) @@ -1058,12 +1143,12 @@ class TestFX(JitTestCase): self.assertEqual(loaded(x, y), gm(x, y)) def test_all_input_nodes(self): - graph : torch.fx.Graph = torch.fx.Graph() - a : torch.fx.Node = graph.placeholder('x') - b : torch.fx.Node = graph.call_module('linear_mod', args=(a,)) - c : torch.fx.Node = graph.get_attr('y_attr') - d : torch.fx.Node = graph.call_function(operator.add, args=(b, c)) - e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0)) + graph: torch.fx.Graph = torch.fx.Graph() + a: torch.fx.Node = graph.placeholder("x") + b: torch.fx.Node = graph.call_module("linear_mod", args=(a,)) + c: torch.fx.Node = graph.get_attr("y_attr") + d: torch.fx.Node = graph.call_function(operator.add, args=(b, c)) + e: torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0)) graph.output(e) graph.lint() @@ -1079,12 +1164,14 @@ class TestFX(JitTestCase): def transform(traced): new_graph = torch.fx.Graph() - val_map : dict[Node, Node] = {} + val_map: dict[Node, Node] = {} output_value = new_graph.graph_copy(traced.graph, val_map) relu_out = new_graph.create_node( - op='call_method', target='neg', args=(output_value,), kwargs={}) + op="call_method", target="neg", args=(output_value,), kwargs={} + ) new_graph.output(relu_out) return GraphModule(traced, new_graph) + transformed = transform(traced) transformed.graph.lint() copied = copy.deepcopy(transformed) @@ -1139,11 +1226,11 @@ class TestFX(JitTestCase): super().__init__() self.sa = SomeArgs() - def forward(self, x : list): + def forward(self, x: list): return self.sa(*x) ul = UnpacksList() - with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'): + with self.assertRaisesRegex(TraceError, "Proxy object cannot be iterated."): symbolic_trace(ul) def test_unpack_dict_better_error(self): @@ -1156,11 +1243,11 @@ class TestFX(JitTestCase): super().__init__() self.sk = SomeKwargs() - def forward(self, x : dict): + def forward(self, x: dict): return self.sk(**x) ud = UnpacksDict() - with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'): + with self.assertRaisesRegex(TraceError, "Proxy object cannot be iterated."): symbolic_trace(ud) def test_pretty_print_targets(self): @@ -1173,16 +1260,15 @@ class TestFX(JitTestCase): traced = symbolic_trace(SomeMod()) graph_str = str(traced.graph) - self.assertIn('builtins.getattr', graph_str) - self.assertIn('operator.add', graph_str) - self.assertIn('torch.add', graph_str) + self.assertIn("builtins.getattr", graph_str) + self.assertIn("operator.add", graph_str) + self.assertIn("torch.add", graph_str) def test_pretty_print_node(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.param: torch.nn.Parameter = torch.nn.Parameter( - torch.rand(3, 4)) + self.param: torch.nn.Parameter = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x: torch.Tensor, y: int = 2): @@ -1192,14 +1278,13 @@ class TestFX(JitTestCase): all_formatted = "\n".join([n.format_node() for n in traced.graph.nodes]) - FileCheck().check("x").check("placeholder") \ - .check("y").check("placeholder") \ - .check("getitem").check("call_function") \ - .check("param").check("get_attr") \ - .check("add").check("call_function") \ - .check("linear").check("call_module") \ - .check("clamp").check("call_method") \ - .run(all_formatted) + FileCheck().check("x").check("placeholder").check("y").check( + "placeholder" + ).check("getitem").check("call_function").check("param").check( + "get_attr" + ).check("add").check("call_function").check("linear").check( + "call_module" + ).check("clamp").check("call_method").run(all_formatted) def test_script_tensor_constant(self): # TorchScript seems to ignore attributes that start with `__`. @@ -1226,7 +1311,7 @@ class TestFX(JitTestCase): # `int` would normally throw a TypeError as argument can't be `Proxy` tracer = Tracer(autowrap_functions=(fx_int,)) graph = tracer.trace(AutowrapFnTest()) - traced = GraphModule(tracer.root, graph, 'test') + traced = GraphModule(tracer.root, graph, "test") tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2)) tracer_2.trace(AutowrapFnTest2()) @@ -1235,7 +1320,7 @@ class TestFX(JitTestCase): self.assertEqual(traced_scripted(torch.rand(4)), 2) def test_tuple_no_subscript(self): - def foo(x : tuple): + def foo(x: tuple): return x[0] traced = torch.fx.symbolic_trace(foo) @@ -1286,7 +1371,7 @@ class TestFX(JitTestCase): def test_torch_fx_getattr(self): class FXGetattrTest(torch.nn.Module): def forward(self, x): - return getattr(x, 'nonexistent_attr', torch.Tensor([2, 3])) + return getattr(x, "nonexistent_attr", torch.Tensor([2, 3])) traced = symbolic_trace(FXGetattrTest()) self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3])) @@ -1316,6 +1401,7 @@ class TestFX(JitTestCase): b = torch.ops.aten.sigmoid(a) c = torch.ops.aten.cat([a, b]) return torch.ops.aten.cat((c, c)) + m = M() input = torch.randn(3) ref_out = m(input) @@ -1329,6 +1415,7 @@ class TestFX(JitTestCase): def forward(self, a): b = torch.ops.aten.add.Tensor(a, a) return b + m = M() input = torch.randn(3) ref_out = m(input) @@ -1338,9 +1425,9 @@ class TestFX(JitTestCase): self.assertEqual(out, ref_out) for node in gm.graph.nodes: - if node.op == 'call_function': + if node.op == "call_function": assert isinstance(node.target, torch._ops.OpOverload) - assert node.target.__name__ == 'add.Tensor' + assert node.target.__name__ == "add.Tensor" def test_pickle_torch_custom_ops(self): class M(torch.nn.Module): @@ -1348,6 +1435,7 @@ class TestFX(JitTestCase): b = torch.ops.aten.sigmoid(a) c = torch.ops.aten.cat([a, b]) return torch.ops.aten.cat((c, c)) + m = M() input = torch.randn(3) ref_out = m(input) @@ -1362,18 +1450,19 @@ class TestFX(JitTestCase): traced = symbolic_trace(st) traced.graph.lint() printed = str(traced) - assert 'SimpleTest()' in printed - assert 'torch.relu' in printed + assert "SimpleTest()" in printed + assert "torch.relu" in printed def test_pretty_print_graph(self): class KwargPrintTest(torch.nn.Module): def forward(self, x): return torch.squeeze(x + 3.0, dim=2) + st = KwargPrintTest() traced = symbolic_trace(st) traced.graph.lint() stringed = str(traced.graph) - for s in ['args', 'kwargs', 'num_users']: + for s in ["args", "kwargs", "num_users"]: assert s in stringed def test_custom_proxy_type(self): @@ -1391,7 +1480,7 @@ class TestFX(JitTestCase): r = self.right * other.right return TensorPair(l, r) - def use_tensor_pair(x : TensorPair, y : TensorPair): + def use_tensor_pair(x: TensorPair, y: TensorPair): s = x.add(y) return s.mul(x) @@ -1421,7 +1510,7 @@ class TestFX(JitTestCase): r = self.right * other.right return TensorPair(l, r) - def use_tensor_pair_literal(x : TensorPair): + def use_tensor_pair_literal(x: TensorPair): s = x.add(TensorPair(torch.zeros(5, 3), torch.zeros(5, 3))) return s.mul(x) @@ -1450,7 +1539,7 @@ class TestFX(JitTestCase): r = self.right * other.right return TensorPair(l, r) - def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor): + def use_tensor_pair_ctor(x: TensorPair, y: torch.Tensor): s = x.add(TensorPair(y, y)) return s.mul(x) @@ -1480,7 +1569,7 @@ class TestFX(JitTestCase): elif other.is_zero: return self - def use_zero_tensor(x : torch.Tensor, y : torch.Tensor): + def use_zero_tensor(x: torch.Tensor, y: torch.Tensor): return ZeroTensor(x + y) x, y = torch.randn(5, 3), torch.randn(5, 3) @@ -1496,10 +1585,10 @@ class TestFX(JitTestCase): def test_graph_fns(self): g = Graph() - a = g.placeholder('a') - b = g.call_module('linear', (a,)) - c = g.get_attr('bias') - d = g.call_method('add', (b, c)) + a = g.placeholder("a") + b = g.call_module("linear", (a,)) + c = g.get_attr("bias") + d = g.call_method("add", (b, c)) e = g.call_function(torch.sin, (d,)) g.output(e) mod = torch.nn.Module() @@ -1513,10 +1602,10 @@ class TestFX(JitTestCase): self.assertEqual(r, ref) def test_remove_uses(self): - g : torch.fx.Graph = Graph() - x : torch.fx.Node = g.placeholder('x') - relu : torch.fx.Node = g.call_function(torch.relu, (x,)) - neg : torch.fx.Node = g.call_function(torch.neg, (relu,)) + g: torch.fx.Graph = Graph() + x: torch.fx.Node = g.placeholder("x") + relu: torch.fx.Node = g.call_function(torch.relu, (x,)) + neg: torch.fx.Node = g.call_function(torch.neg, (relu,)) g.output(neg) neg.replace_all_uses_with(relu) @@ -1525,10 +1614,10 @@ class TestFX(JitTestCase): self.assertTrue(neg not in relu.users) def test_remove_uses_with_custom_filter(self): - g : torch.fx.Graph = Graph() - x : torch.fx.Node = g.placeholder('x') - relu : torch.fx.Node = g.call_function(torch.relu, (x,)) - neg : torch.fx.Node = g.call_function(torch.neg, (relu,)) + g: torch.fx.Graph = Graph() + x: torch.fx.Node = g.placeholder("x") + relu: torch.fx.Node = g.call_function(torch.relu, (x,)) + neg: torch.fx.Node = g.call_function(torch.neg, (relu,)) g.output(neg) neg.replace_all_uses_with(relu, lambda x: x != neg) @@ -1540,7 +1629,7 @@ class TestFX(JitTestCase): symbolic_trace(eb) def test_pickle_nonetype_annotation(self): - eb = torch.nn.EmbeddingBag(10, 3, mode='sum') + eb = torch.nn.EmbeddingBag(10, 3, mode="sum") traced = symbolic_trace(eb) pickled = pickle.dumps(traced) loaded = pickle.loads(pickled) @@ -1559,28 +1648,28 @@ class TestFX(JitTestCase): self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1))) def test_construct_root_dict(self): - graph : torch.fx.Graph = torch.fx.Graph() - a : torch.fx.Node = graph.create_node('placeholder', 'x') - b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,)) - c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam') - d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) + graph: torch.fx.Graph = torch.fx.Graph() + a: torch.fx.Node = graph.create_node("placeholder", "x") + b: torch.fx.Node = graph.create_node("call_module", "foo.bar.baz", args=(a,)) + c: torch.fx.Node = graph.create_node("get_attr", "zip.zap.zam") + d: torch.fx.Node = graph.create_node("call_function", operator.add, args=(b, c)) graph.output(d) - linear_mod : torch.nn.Module = torch.nn.Linear(3, 4) - add_param : torch.Tensor = torch.rand(3, 4) - gm : torch.fx.GraphModule = torch.fx.GraphModule( - {'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph) + linear_mod: torch.nn.Module = torch.nn.Linear(3, 4) + add_param: torch.Tensor = torch.rand(3, 4) + gm: torch.fx.GraphModule = torch.fx.GraphModule( + {"foo.bar.baz": linear_mod, "zip.zap.zam": add_param}, graph + ) gm.graph.lint() - assert 'self.foo.bar.baz' in gm.code + assert "self.foo.bar.baz" in gm.code - x : torch.Tensor = torch.rand(3, 3) - out : torch.Tensor = gm(x) - ref_out : torch.Tensor = linear_mod(x) + add_param + x: torch.Tensor = torch.rand(3, 3) + out: torch.Tensor = gm(x) + ref_out: torch.Tensor = linear_mod(x) + add_param self.assertEqual(out, ref_out) def test_symbolic_trace_assert(self): - class AssertsTensorShape(torch.nn.Module): def forward(self, x): torch._assert(x.shape[1] > 4, "assert_foobar") @@ -1658,26 +1747,34 @@ class TestFX(JitTestCase): copied = torch.fx.Graph() for node in g.nodes: copied.node_copy(node) - with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'): + with self.assertRaisesRegex(RuntimeError, "does not belong to this Graph"): copied.lint() def test_wrong_topo(self): - graph : torch.fx.Graph = torch.fx.Graph() - a : torch.fx.Node = graph.create_node('placeholder', 'x') - b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,)) - c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam') - d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) + graph: torch.fx.Graph = torch.fx.Graph() + a: torch.fx.Node = graph.create_node("placeholder", "x") + b: torch.fx.Node = graph.create_node("call_module", "foo.bar.baz", args=(a,)) + c: torch.fx.Node = graph.create_node("get_attr", "zip.zap.zam") + d: torch.fx.Node = graph.create_node("call_function", operator.add, args=(b, c)) graph.output(d) nodes = list(graph.nodes) nodes[3].append(nodes[2]) - with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'): + with self.assertRaisesRegex( + RuntimeError, "was used before it has been defined" + ): graph.lint() def test_wrong_target_type(self): - graph : torch.fx.Graph = torch.fx.Graph() + graph: torch.fx.Graph = torch.fx.Graph() with self.assertRaises(ValueError): - n = torch.fx.Node(graph=graph, name='foo', op='call_function', target='foo', - args=(), kwargs={}) + n = torch.fx.Node( + graph=graph, + name="foo", + op="call_function", + target="foo", + args=(), + kwargs={}, + ) def test_example_shape_prop(self): class TestCase(torch.nn.Module): @@ -1688,6 +1785,7 @@ class TestFX(JitTestCase): def forward(self, x): return torch.neg(self.submod(x.relu() + self.attr)) + tc = TestCase() tc_traced = symbolic_trace(tc) ref_out = tc_traced(torch.rand(3, 4)) @@ -1695,15 +1793,24 @@ class TestFX(JitTestCase): # Make sure we're testing all opcodes opcodes = set() - output_shape : Optional[torch.Shape] = None - output_stride : Optional[tuple[int]] = None + output_shape: Optional[torch.Shape] = None + output_stride: Optional[tuple[int]] = None for node in tc_traced.graph.nodes: opcodes.add(node.op) - if node.op == 'output': - output_shape = node.args[0].meta['tensor_meta'].shape - output_stride = node.args[0].meta['tensor_meta'].stride - self.assertEqual(opcodes, {'placeholder', 'get_attr', 'call_function', 'call_method', - 'call_module', 'output'}) + if node.op == "output": + output_shape = node.args[0].meta["tensor_meta"].shape + output_stride = node.args[0].meta["tensor_meta"].stride + self.assertEqual( + opcodes, + { + "placeholder", + "get_attr", + "call_function", + "call_method", + "call_module", + "output", + }, + ) # Test shape propagation and make sure results match actual self.assertEqual(output_shape, ref_out.shape) @@ -1724,8 +1831,10 @@ class TestFX(JitTestCase): x = torch.randn(5, 5, 224, 224) shape_prop.ShapeProp(traced).propagate(x) - assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format - for node in traced.graph.nodes) + assert all( + node.meta["tensor_meta"].memory_format is torch.contiguous_format + for node in traced.graph.nodes + ) x_channels_last = x.contiguous(memory_format=torch.channels_last) traced.to(memory_format=torch.channels_last) @@ -1734,8 +1843,10 @@ class TestFX(JitTestCase): # NB: the implementation of conv may not preserve the memory format, # unfortunately. The best we can do is just check that the placeholder # node is channels-last - if node.op in {'placeholder'}: - self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last) + if node.op in {"placeholder"}: + self.assertEqual( + node.meta["tensor_meta"].memory_format, torch.channels_last + ) def test_shape_prop_aggregate(self): class ReturnTwo(torch.nn.Module): @@ -1762,9 +1873,9 @@ class TestFX(JitTestCase): shape_prop.ShapeProp(mod).propagate(torch.rand(3, 4)) for node in mod.graph.nodes: - if node.op == 'call_module': - assert 'tensor_meta' in node.meta - tensor_meta = node.meta['tensor_meta'] + if node.op == "call_module": + assert "tensor_meta" in node.meta + tensor_meta = node.meta["tensor_meta"] assert tensor_meta[0] == 3 assert tensor_meta[1].shape == torch.Size([]) @@ -1781,8 +1892,10 @@ class TestFX(JitTestCase): traced_3d = symbolic_trace(test_mod_3d) x_3d = torch.randn(5, 5, 224, 224, 15) shape_prop.ShapeProp(traced_3d).propagate(x_3d) - assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format - for node in traced_3d.graph.nodes) + assert all( + node.meta["tensor_meta"].memory_format is torch.contiguous_format + for node in traced_3d.graph.nodes + ) x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d) traced_3d.to(memory_format=torch.channels_last_3d) @@ -1791,8 +1904,10 @@ class TestFX(JitTestCase): # NB: the implementation of conv may not preserve the memory format, # unfortunately. The best we can do is just check that the placeholder # node is channels-last - if node.op in {'placeholder'}: - self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d) + if node.op in {"placeholder"}: + self.assertEqual( + node.meta["tensor_meta"].memory_format, torch.channels_last_3d + ) def test_shape_prop_unbacked_sym(self): from torch._dynamo.utils import detect_fake_mode @@ -1802,11 +1917,9 @@ class TestFX(JitTestCase): return torch.nonzero(x) inp = (torch.tensor([1, 0, 1, 0]),) - gm = torch.export.export(M(), inp).module() + gm = torch.export.export(M(), inp, strict=True).module() fake_inputs = [ - node.meta.get("val") - for node in gm.graph.nodes - if node.op == "placeholder" + node.meta.get("val") for node in gm.graph.nodes if node.op == "placeholder" ] inp = fake_inputs fake_mode = detect_fake_mode(inp) @@ -1834,10 +1947,12 @@ class TestFX(JitTestCase): gm = torch.fx.symbolic_trace(m) mod_stack = {} - expected_stack = [('sub_mod', ('sub_mod', type(m.sub_mod))), - ('sub_mod.conv_mod', ('sub_mod.conv_mod', type(m.sub_mod.conv_mod)))] + expected_stack = [ + ("sub_mod", ("sub_mod", type(m.sub_mod))), + ("sub_mod.conv_mod", ("sub_mod.conv_mod", type(m.sub_mod.conv_mod))), + ] for node in gm.graph.nodes: - mod_stack = node.meta.get('nn_module_stack', {}) + mod_stack = node.meta.get("nn_module_stack", {}) if mod_stack: break stack_list = list(mod_stack.items()) @@ -1856,13 +1971,13 @@ class TestFX(JitTestCase): graph = tracer.trace(M()) gm = GraphModule(tracer.root, graph) for node in gm.graph.nodes: - if node.op == 'get_attr': + if node.op == "get_attr": node.meta["nn_module_stack"] = "self" node.meta["stack_trace"] = "stack_trace" node.meta["source_fn_stack"] = "source_fn_stack" new_gm = Transformer(gm).transform() for node in new_gm.graph.nodes: - if node.op == 'get_attr': + if node.op == "get_attr": self.assertEqual(node.meta["nn_module_stack"], "self") self.assertEqual(node.meta["stack_trace"], "stack_trace") self.assertEqual(node.meta["source_fn_stack"], "source_fn_stack") @@ -1920,7 +2035,7 @@ class TestFX(JitTestCase): def __init__(self, module): super().__init__(module) - def run_node(self, n : Node) -> Any: + def run_node(self, n: Node) -> Any: result = super().run_node(n) n.cached_value = result return result @@ -1928,23 +2043,22 @@ class TestFX(JitTestCase): input = torch.randn(3, 4) RunNodeInterpreter(gm).run(input) for node in gm.graph.nodes: - assert hasattr(node, 'cached_value') + assert hasattr(node, "cached_value") def test_interpreter_onthefly_swap(self): - def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) class NegSigmSwapInterpreter(Interpreter): - def call_function(self, target : Target, args : tuple, kwargs : dict) -> Any: + def call_function(self, target: Target, args: tuple, kwargs: dict) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) # noqa: F821 - def call_method(self, target : Target, args : tuple, kwargs : dict) -> Any: - if target == 'neg': + def call_method(self, target: Target, args: tuple, kwargs: dict) -> Any: + if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) # noqa: F821 @@ -1967,13 +2081,15 @@ class TestFX(JitTestCase): interp = Interpreter(gm) env = {} for node in gm.graph.nodes: - if node.op == 'call_module' and node.target == 'linear': + if node.op == "call_module" and node.target == "linear": env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0 break assert len(env) == 1 x = torch.randn(3, 4) result = interp.run(x, initial_env=env) - self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0)) + self.assertEqual( + result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0) + ) def test_interpreter_star_args(self): def with_star_args(x, *args): @@ -1998,7 +2114,7 @@ class TestFX(JitTestCase): inp = torch.rand(5, 3, 224, 224) out = interp.run(inp) env_key_names = {n.name for n in interp.env.keys()} - self.assertEqual(env_key_names, {'output'}) + self.assertEqual(env_key_names, {"output"}) def test_interpreter_default_args(self): class Model(torch.nn.Module): @@ -2023,8 +2139,10 @@ class TestFX(JitTestCase): interp = Interpreter(gm) x = torch.randn(5, 3) - with self.assertRaisesRegex(RuntimeError, - 'Expected positional argument for parameter y, but one was not passed in'): + with self.assertRaisesRegex( + RuntimeError, + "Expected positional argument for parameter y, but one was not passed in", + ): out = interp.run(x) def test_transformer_noop(self): @@ -2046,20 +2164,19 @@ class TestFX(JitTestCase): self.assertEqual(new_gm(input), gm(input)) def test_transformer_op_swap(self): - def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) class NegSigmSwapXformer(Transformer): - def call_function(self, target : Target, args : tuple, kwargs : dict) -> Any: + def call_function(self, target: Target, args: tuple, kwargs: dict) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) # noqa: F821 - def call_method(self, target : Target, args : tuple, kwargs : dict) -> Any: - if target == 'neg': + def call_method(self, target: Target, args: tuple, kwargs: dict) -> Any: + if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) # noqa: F821 @@ -2090,8 +2207,10 @@ class TestFX(JitTestCase): def test_fn_type_annotations(self): class Foo(torch.nn.Module): - def forward(self, p : Pair, z : torch.Tensor, i : int) -> dict[str, torch.Tensor]: - return {'a': p.x + p.y + z + i} + def forward( + self, p: Pair, z: torch.Tensor, i: int + ) -> dict[str, torch.Tensor]: + return {"a": p.x + p.y + z + i} foo_scripted = torch.jit.script(Foo()) foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) @@ -2101,8 +2220,9 @@ class TestFX(JitTestCase): fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) def test_fn_type_annotation_empty(self): - def forward(a : list[torch.Tensor]): + def forward(a: list[torch.Tensor]): return a[0] + torch.jit.script(symbolic_trace(forward)) def test_wrapped_method(self): @@ -2110,6 +2230,7 @@ class TestFX(JitTestCase): @functools.wraps(fn) def wrapper(*args, **kwargs): return torch.relu(fn(*args, **kwargs)) + return wrapper class Foo(torch.nn.Module): @@ -2146,18 +2267,21 @@ class TestFX(JitTestCase): self.checkGraphModule(m, (torch.rand(3, 4),)) def test_typename_print(self): - graph : torch.fx.Graph = torch.fx.Graph() - x : torch.fx.Node = graph.create_node('placeholder', 'x') - b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,), - type_expr=List[float]) - output : torch.fx.Node = graph.output(b) + graph: torch.fx.Graph = torch.fx.Graph() + x: torch.fx.Node = graph.create_node("placeholder", "x") + b: torch.fx.Node = graph.create_node( + "call_function", target=torch.relu, args=(x,), type_expr=List[float] + ) + output: torch.fx.Node = graph.output(b) - self.assertTrue('typing.List[float]' in str(graph)) + self.assertTrue("typing.List[float]" in str(graph)) def test_layout(self): class M(torch.nn.Module): def forward(self, x): - return torch.empty_like(x, layout=torch.strided, pin_memory=False).fill_(0) + return torch.empty_like( + x, layout=torch.strided, pin_memory=False + ).fill_(0) traced = symbolic_trace(M()) x = torch.rand(5, 9, 3, 4) @@ -2175,27 +2299,31 @@ class TestFX(JitTestCase): def test_inf_nan(self): class FooMod(torch.nn.Module): def forward(self, x): - return x + float('inf'), x + float('-inf'), x + float('nan') + return x + float("inf"), x + float("-inf"), x + float("nan") fm = FooMod() self.checkGraphModule(fm, (torch.rand(3, 4),)) def test_inf_nan_kwds(self): - graph : torch.fx.Graph = torch.fx.Graph() - x : torch.fx.Node = graph.create_node('placeholder', 'x') - b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf') - c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan') + graph: torch.fx.Graph = torch.fx.Graph() + x: torch.fx.Node = graph.create_node("placeholder", "x") + b: torch.fx.Node = graph.create_node( + "call_function", operator.add, (x, float("inf")), {}, name="inf" + ) + c: torch.fx.Node = graph.create_node( + "call_function", operator.add, (x, float("nan")), {}, name="nan" + ) graph.output((b, c)) gm = torch.fx.GraphModule(torch.nn.Module(), graph) x = torch.rand(3, 4) - self.assertEqual(gm(x), (x + float('inf'), x + float('nan'))) + self.assertEqual(gm(x), (x + float("inf"), x + float("nan"))) def test_deepcopy_recursion_depth(self): depth = sys.getrecursionlimit() + 20 g = torch.fx.Graph() - x = g.placeholder('x') + x = g.placeholder("x") for i in range(depth): x = g.call_function(torch.relu, (x,)) g.output(x) @@ -2217,7 +2345,7 @@ class TestFX(JitTestCase): rn18 = torchvision_models.resnet18() class LowerReluTracer(torch.fx.Tracer): - def is_leaf_module(self, m : torch.nn.Module, qualname : str): + def is_leaf_module(self, m: torch.nn.Module, qualname: str): if isinstance(m, torch.nn.ReLU): return False return super().is_leaf_module(m, qualname) @@ -2226,13 +2354,17 @@ class TestFX(JitTestCase): to_erase = [] for node in rn18_traced.graph.nodes: - if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]: + if node.op == "call_function" and node.target in [ + torch.relu, + torch.nn.functional.relu, + ]: kwargs = node.kwargs.copy() # Neg doesn't have in-place - kwargs.pop('inplace') + kwargs.pop("inplace") with rn18_traced.graph.inserting_before(node): new_node = rn18_traced.graph.call_function( - the_function=torch.neg, args=node.args, kwargs=node.kwargs) + the_function=torch.neg, args=node.args, kwargs=node.kwargs + ) node.replace_all_uses_with(replace_with=new_node) to_erase.append(node) @@ -2240,11 +2372,13 @@ class TestFX(JitTestCase): rn18_traced.graph.erase_node(node) def test_replace_input(self): - graph : torch.fx.Graph = torch.fx.Graph() - x : torch.fx.Node = graph.create_node('placeholder', 'x') - y : torch.fx.Node = graph.create_node('placeholder', 'y') - b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) - output : torch.fx.Node = graph.output(b) + graph: torch.fx.Graph = torch.fx.Graph() + x: torch.fx.Node = graph.create_node("placeholder", "x") + y: torch.fx.Node = graph.create_node("placeholder", "y") + b: torch.fx.Node = graph.create_node( + "call_function", target=torch.relu, args=(x,) + ) + output: torch.fx.Node = graph.output(b) b.replace_input_with(x, y) @@ -2255,13 +2389,15 @@ class TestFX(JitTestCase): self.assertEqual(gm(input_x, input_y), torch.relu(input_y)) def test_insertion_point(self): - graph : torch.fx.Graph = torch.fx.Graph() - x : torch.fx.Node = graph.create_node('placeholder', 'x') - b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) - output : torch.fx.Node = graph.output(b) + graph: torch.fx.Graph = torch.fx.Graph() + x: torch.fx.Node = graph.create_node("placeholder", "x") + b: torch.fx.Node = graph.create_node( + "call_function", target=torch.relu, args=(x,) + ) + output: torch.fx.Node = graph.output(b) with graph.inserting_before(b): - neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) + neg: torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) _, *relu_args = b.args b.args = (neg, *relu_args) @@ -2271,11 +2407,13 @@ class TestFX(JitTestCase): self.assertEqual(gm(input), torch.relu(torch.neg(input))) def test_update_args_api(self): - graph : torch.fx.Graph = torch.fx.Graph() - x : torch.fx.Node = graph.create_node('placeholder', 'x') - y : torch.fx.Node = graph.create_node('placeholder', 'y') - b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) - output : torch.fx.Node = graph.output(b) + graph: torch.fx.Graph = torch.fx.Graph() + x: torch.fx.Node = graph.create_node("placeholder", "x") + y: torch.fx.Node = graph.create_node("placeholder", "y") + b: torch.fx.Node = graph.create_node( + "call_function", target=torch.relu, args=(x,) + ) + output: torch.fx.Node = graph.output(b) orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph) inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5) @@ -2286,17 +2424,19 @@ class TestFX(JitTestCase): self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y)) def test_update_kwargs_api(self): - graph : torch.fx.Graph = torch.fx.Graph() - x : torch.fx.Node = graph.create_node('placeholder', 'x') - y : torch.fx.Node = graph.create_node('placeholder', 'y') - b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, kwargs={'input': x}) - output : torch.fx.Node = graph.output(b) + graph: torch.fx.Graph = torch.fx.Graph() + x: torch.fx.Node = graph.create_node("placeholder", "x") + y: torch.fx.Node = graph.create_node("placeholder", "y") + b: torch.fx.Node = graph.create_node( + "call_function", target=torch.relu, kwargs={"input": x} + ) + output: torch.fx.Node = graph.output(b) orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph) inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5) self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x)) - b.update_kwarg('input', y) + b.update_kwarg("input", y) new_gm = torch.fx.GraphModule(torch.nn.Module(), graph) self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y)) @@ -2313,7 +2453,7 @@ class TestFX(JitTestCase): def test_immutable_dict_pytree_ops(self): rand_tensor = torch.randn(5, 3) - d = immutable_dict({'a': 3, 'b': [rand_tensor, 42]}) + d = immutable_dict({"a": 3, "b": [rand_tensor, 42]}) flattened, spec = pytree.tree_flatten(d) assert flattened == [3, rand_tensor, 42] @@ -2323,12 +2463,14 @@ class TestFX(JitTestCase): assert isinstance(unflattened, immutable_dict) def test_move_before(self): - graph : torch.fx.Graph = torch.fx.Graph() - x : torch.fx.Node = graph.create_node('placeholder', 'x') - b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) - output : torch.fx.Node = graph.output(b) + graph: torch.fx.Graph = torch.fx.Graph() + x: torch.fx.Node = graph.create_node("placeholder", "x") + b: torch.fx.Node = graph.create_node( + "call_function", target=torch.relu, args=(x,) + ) + output: torch.fx.Node = graph.output(b) - neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) + neg: torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) _, *relu_args = b.args b.args = (neg, *relu_args) b.prepend(neg) @@ -2339,10 +2481,12 @@ class TestFX(JitTestCase): self.assertEqual(gm(input), torch.relu(torch.neg(input))) def test_prepend_self(self): - graph : torch.fx.Graph = torch.fx.Graph() - x : torch.fx.Node = graph.create_node('placeholder', 'x') - b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) - output : torch.fx.Node = graph.output(b) + graph: torch.fx.Graph = torch.fx.Graph() + x: torch.fx.Node = graph.create_node("placeholder", "x") + b: torch.fx.Node = graph.create_node( + "call_function", target=torch.relu, args=(x,) + ) + output: torch.fx.Node = graph.output(b) b.prepend(b) x.append(b) @@ -2355,7 +2499,9 @@ class TestFX(JitTestCase): for node in traced.graph.nodes: # Test deleting with uses both in another Node and at the output if node.target in [operator.add, torch.relu]: - with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'): + with self.assertRaisesRegex( + RuntimeError, "but it still had .* users in the graph" + ): traced.graph.erase_node(node) def test_copy_it(self): @@ -2373,7 +2519,7 @@ class TestFX(JitTestCase): def test_find_uses(self): graph = torch.fx.Graph() - x = torch.fx.Proxy(graph.placeholder('x')) + x = torch.fx.Proxy(graph.placeholder("x")) y = torch.relu(x) z = x + x @@ -2383,7 +2529,7 @@ class TestFX(JitTestCase): users_of_x = x.node.users self.assertEqual(len(users_of_x), 3) - expected_ops = {'relu', 'add', 'neg'} + expected_ops = {"relu", "add", "neg"} for use in users_of_x: assert any(use.name.startswith(prefix) for prefix in expected_ops) @@ -2403,9 +2549,9 @@ class TestFX(JitTestCase): output_node = combined_graph.graph_copy(inline_into.graph, {}) input_node = next(iter(to_inline.graph.nodes)) - assert input_node and input_node.op == 'placeholder' + assert input_node and input_node.op == "placeholder" - val_map = {input_node : output_node} + val_map = {input_node: output_node} output = combined_graph.graph_copy(to_inline.graph, val_map) combined_graph.output(output) @@ -2416,7 +2562,7 @@ class TestFX(JitTestCase): def test_multi_insert_point(self): graph = torch.fx.Graph() - x = torch.fx.Proxy(graph.placeholder('x')) + x = torch.fx.Proxy(graph.placeholder("x")) relu = torch.relu(x) with graph.inserting_before(relu.node): @@ -2426,13 +2572,13 @@ class TestFX(JitTestCase): graph.output((relu.node, z.node)) graph.lint() - expected_ops = ['x', 'neg', 'tanh', 'relu'] + expected_ops = ["x", "neg", "tanh", "relu"] for node, expected in zip(graph.nodes, expected_ops): assert expected in node.name def test_reassign_args_kwargs_uses(self): graph = torch.fx.Graph() - x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y')) + x, y = Proxy(graph.placeholder("x")), Proxy(graph.placeholder("y")) z = x + y zed = z + z + z graph.output(zed.node) @@ -2465,7 +2611,7 @@ class TestFX(JitTestCase): bar: torch.Tensor class ModuleReturnDataclass(torch.nn.Module): - def forward(self, d : torch.Tensor): + def forward(self, d: torch.Tensor): return MyOutput(foo=d + d, bar=d * 3) module = ModuleReturnDataclass() @@ -2489,7 +2635,7 @@ class TestFX(JitTestCase): bar: torch.Tensor class ModuleReturnDataclass(torch.nn.Module): - def forward(self, d : torch.Tensor): + def forward(self, d: torch.Tensor): return MyOutput(foo=d + d, bar=d * 3) class CallsModule(torch.nn.Module): @@ -2514,12 +2660,13 @@ class TestFX(JitTestCase): """ Test case for Module that return namedtuple """ + class MyOutput(NamedTuple): foo: torch.Tensor bar: torch.Tensor class ModuleReturnNamedTuple(torch.nn.Module): - def forward(self, d : torch.Tensor): + def forward(self, d: torch.Tensor): return MyOutput(foo=d, bar=d) module = ModuleReturnNamedTuple() @@ -2534,7 +2681,7 @@ class TestFX(JitTestCase): def test_trace_dict_int_keys(self): class ModWithDictArg(torch.nn.Module): - def forward(self, d : dict[int, torch.Tensor]): + def forward(self, d: dict[int, torch.Tensor]): return d[42] class CallsModWithDict(torch.nn.Module): @@ -2546,14 +2693,16 @@ class TestFX(JitTestCase): return self.m({42: x}) class MyTracer(torch.fx.Tracer): - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: + def is_leaf_module( + self, m: torch.nn.Module, module_qualified_name: str + ) -> bool: return isinstance(m, ModWithDictArg) traced_graph = MyTracer().trace(CallsModWithDict()) def test_trace_dict_proxy_keys(self): class ModWithDictArg(torch.nn.Module): - def forward(self, d : dict[torch.Tensor, torch.Tensor]): + def forward(self, d: dict[torch.Tensor, torch.Tensor]): return d[42] class CallsModWithDict(torch.nn.Module): @@ -2565,10 +2714,12 @@ class TestFX(JitTestCase): return self.m({x: x}) class MyTracer(torch.fx.Tracer): - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: + def is_leaf_module( + self, m: torch.nn.Module, module_qualified_name: str + ) -> bool: return isinstance(m, ModWithDictArg) - with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'): + with self.assertRaisesRegex(RuntimeError, "cannot contain a Node"): traced_graph = MyTracer().trace(CallsModWithDict()) def test_module_deepcopy_edit_nodes(self): @@ -2608,7 +2759,7 @@ class TestFX(JitTestCase): return self.a.b, self.a.b.t(), self.a.b.view(12) traced = torch.fx.symbolic_trace(Foo()) - assert all('constant' not in node.target for node in traced.graph.nodes) + assert all("constant" not in node.target for node in traced.graph.nodes) def test_single_default_arg(self): class M(torch.nn.Module): @@ -2658,12 +2809,14 @@ class TestFX(JitTestCase): def test_update_args_kwargs_yells_at_you(self): symtraced = symbolic_trace(SimpleTest()) node = next(iter(symtraced.graph.nodes)) - with self.assertRaisesRegex(AttributeError, '__update_args_kwargs'): + with self.assertRaisesRegex(AttributeError, "__update_args_kwargs"): node.__update_args_kwargs((), {}) def test_torchbind_class_attribute_in_fx(self): if IS_FBCODE or IS_WINDOWS or IS_MACOS: - self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping") + self.skipTest( + "torch.classes._TorchScriptTesting._StackString is registered, skipping" + ) class FooBar1234(torch.nn.Module): def __init__(self) -> None: @@ -2678,7 +2831,9 @@ class TestFX(JitTestCase): def test_torchbind_class_attribute_in_fx_tensor_arg(self): if IS_FBCODE or IS_WINDOWS or IS_MACOS: - self.skipTest("torch.classes._TorchScriptTesting._ReLUClass is registered, skipping") + self.skipTest( + "torch.classes._TorchScriptTesting._ReLUClass is registered, skipping" + ) class FooBar2341(torch.nn.Module): def __init__(self) -> None: @@ -2694,7 +2849,7 @@ class TestFX(JitTestCase): input = torch.randn(3, 4) self.assertEqual(traced(input), m(input)) - self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes)) + self.assertTrue(any(n.op == "call_method" for n in traced.graph.nodes)) def test_script_method_trace(self): class Scripted(torch.nn.Module): @@ -2714,7 +2869,7 @@ class TestFX(JitTestCase): input = torch.randn(3, 4) self.assertEqual(traced(input), h(input)) - self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes)) + self.assertTrue(any(n.op == "call_method" for n in traced.graph.nodes)) def test_namedtuple_return_trace(self): class NamedTupReturn(torch.nn.Module): @@ -2780,7 +2935,7 @@ class TestFX(JitTestCase): class GetItem1(GetItemBase): def forward(self, x): - return self.pe[:, :x.size(0)] + return self.pe[:, : x.size(0)] class GetItem2(GetItemBase): def forward(self, x): @@ -2794,8 +2949,10 @@ class TestFX(JitTestCase): self.checkGraphModule(GetItem2(), [torch.zeros(4)]) self.checkGraphModule(GetItem3(), [torch.zeros(4)]) - @unittest.skipUnless(os.environ.get("FX_PATCH_GETITEM") == "1", - "Will be checked in test_getitem_subproc") + @unittest.skipUnless( + os.environ.get("FX_PATCH_GETITEM") == "1", + "Will be checked in test_getitem_subproc", + ) def test_getitem(self): self.getitem_inner() @@ -2813,9 +2970,12 @@ class TestFX(JitTestCase): traced = torch.fx.symbolic_trace(fn) - with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is " - "being compiled since it was called" - " from 'fn.forward'"): + with self.assertRaisesRegex( + RuntimeError, + "'wrapper_fn' is " + "being compiled since it was called" + " from 'fn.forward'", + ): scripted = torch.jit.script(traced) def test_user_friendly_call_provenance_with_module(self): @@ -2825,20 +2985,23 @@ class TestFX(JitTestCase): traced = torch.fx.symbolic_trace(M()) - with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is " - "being compiled since it was called" - " from 'M.forward'"): + with self.assertRaisesRegex( + RuntimeError, + "'wrapper_fn' is " "being compiled since it was called" " from 'M.forward'", + ): scripted = torch.jit.script(traced) def test_snake_case(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.activations = torch.nn.ModuleDict([ - ["snake_case", torch.nn.ReLU()], - ["PascalCase", torch.nn.LeakyReLU()], - ["ALL_CAPS", torch.nn.PReLU()] - ]) + self.activations = torch.nn.ModuleDict( + [ + ["snake_case", torch.nn.ReLU()], + ["PascalCase", torch.nn.LeakyReLU()], + ["ALL_CAPS", torch.nn.PReLU()], + ] + ) def forward(self, x): a = self.activations["snake_case"](x) @@ -2851,7 +3014,7 @@ class TestFX(JitTestCase): check = [ ("activations_snake_case", "activations.snake_case"), ("activations_pascal_case", "activations.PascalCase"), - ("activations_all_caps", "activations.ALL_CAPS") + ("activations_all_caps", "activations.ALL_CAPS"), ] i = 0 @@ -2867,6 +3030,7 @@ class TestFX(JitTestCase): def test_no_mutation(self): from torch.fx.immutable_collections import immutable_list + x = immutable_list([3, 4]) with self.assertRaisesRegex(NotImplementedError, "new_args"): x[0] = 4 @@ -2878,9 +3042,10 @@ class TestFX(JitTestCase): return 2 * x else: return x + mod = Foo() - mod_true = symbolic_trace(mod, concrete_args={'y': True}) - mod_false = symbolic_trace(mod, concrete_args={'y': False}) + mod_true = symbolic_trace(mod, concrete_args={"y": True}) + mod_false = symbolic_trace(mod, concrete_args={"y": False}) self.assertEqual(mod_true(3, True), 6) print(mod_true.code) assert any(i.target == torch._assert for i in mod_true.graph.nodes) @@ -2893,7 +3058,7 @@ class TestFX(JitTestCase): def f_higher(a, f): return f(a) - nf = symbolic_trace(f_higher, concrete_args={'f': lambda x: x * 2}) + nf = symbolic_trace(f_higher, concrete_args={"f": lambda x: x * 2}) self.assertEqual(nf(3, lambda x: x * 2), 6) def test_custom_traceback_raised_when_exception_source_is_graphmodule(self): @@ -2909,8 +3074,7 @@ class TestFX(JitTestCase): out = [n for n in traced.graph.nodes if n.op == "output"][-1] with traced.graph.inserting_before(out): - relu_out = traced.graph.call_method(method_name='relu', - args=(out.args[0],)) + relu_out = traced.graph.call_method(method_name="relu", args=(out.args[0],)) out.args = (relu_out,) traced.recompile() @@ -2919,9 +3083,11 @@ class TestFX(JitTestCase): with self.assertRaises(TypeError): traced(5) - self.assertRegex(captured[0], - r"Call using an FX-traced Module, line .* of the " - r"traced Module's generated forward function:") + self.assertRegex( + captured[0], + r"Call using an FX-traced Module, line .* of the " + r"traced Module's generated forward function:", + ) def test_custom_traceback_not_raised_when_exception_source_is_submodule(self): class M(torch.nn.Module): @@ -2941,9 +3107,11 @@ class TestFX(JitTestCase): except RuntimeError: captured = traceback.format_exc() - self.assertNotRegex(captured, - r"Call using an FX-traced Module, line .* of the " - r"traced Module's generated forward function:") + self.assertNotRegex( + captured, + r"Call using an FX-traced Module, line .* of the " + r"traced Module's generated forward function:", + ) def test_graph_module_replicate_for_dp(self): class Foo(torch.nn.Module): @@ -2994,7 +3162,9 @@ class TestFX(JitTestCase): check_mutable_operations = True tracer = MyTracer() - with self.assertRaisesRegex(RuntimeError, 'mutable operation aten::sigmoid.out'): + with self.assertRaisesRegex( + RuntimeError, "mutable operation aten::sigmoid.out" + ): traced_graph = tracer.trace(foo) def test_ast_rewriter_reassigns_submodules(self): @@ -3050,27 +3220,35 @@ class TestFX(JitTestCase): def test_profiler_ranges_side_effect(self): g = torch.fx.Graph() - handle = g.call_function(torch.ops.profiler._record_function_enter_new, ('test_range',)) + handle = g.call_function( + torch.ops.profiler._record_function_enter_new, ("test_range",) + ) g.call_function(torch.ops.profiler._record_function_exit, (handle,)) g.output(None) found_targets = {} for node in g.nodes: - if node.op == 'call_function': + if node.op == "call_function": found_targets.setdefault(node.target) self.assertEqual( list(found_targets.keys()), - [torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit] + [ + torch.ops.profiler._record_function_enter_new, + torch.ops.profiler._record_function_exit, + ], ) g.eliminate_dead_code() found_targets = {} for node in g.nodes: - if node.op == 'call_function': + if node.op == "call_function": found_targets.setdefault(node.target) self.assertEqual( list(found_targets.keys()), - [torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit] + [ + torch.ops.profiler._record_function_enter_new, + torch.ops.profiler._record_function_exit, + ], ) def test_ast_rewriter_wrapped_via_decorator(self): @@ -3163,8 +3341,9 @@ class TestFX(JitTestCase): conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"][-1] with a.graph.inserting_before(conv): with warnings.catch_warnings(record=True) as w: - dropout = a.graph.call_module(module_name="net_b.net_c.dropout", - args=conv.args) + dropout = a.graph.call_module( + module_name="net_b.net_c.dropout", args=conv.args + ) self.assertEqual(len(w), 0) conv.replace_all_uses_with(dropout) @@ -3175,12 +3354,14 @@ class TestFX(JitTestCase): return any(path == name for name, _ in gm.named_modules()) def parameter_exists(gm: GraphModule, path: str) -> bool: - return (any(path == name for name, _ in gm.named_parameters()) - and any(path == name for name in gm.state_dict().keys())) + return any(path == name for name, _ in gm.named_parameters()) and any( + path == name for name in gm.state_dict().keys() + ) def buffer_exists(gm: GraphModule, path: str) -> bool: - return (any(path == name for name, _ in gm.named_buffers()) - and any(path == name for name in gm.state_dict().keys())) + return any(path == name for name, _ in gm.named_buffers()) and any( + path == name for name in gm.state_dict().keys() + ) # Test that we added the "dropout" submodule self.assertTrue(module_exists(a, "net_b.net_c.dropout")) @@ -3204,23 +3385,24 @@ class TestFX(JitTestCase): self.assertFalse(module_exists(a, "net_b.net_c.conv")) # Test `get_submodule` with a deleted submodule - with self.assertRaisesRegex(AttributeError, "has no attribute " - "`conv`"): + with self.assertRaisesRegex(AttributeError, "has no attribute " "`conv`"): self.assertIsNone(a.get_submodule("net_b.net_c.conv")) # Test `get_attr` warnings cat = [n for n in a.graph.nodes if n.target == torch.cat][-1] with a.graph.inserting_before(cat): - with warnings.catch_warnings(record=True) as w: param = a.graph.get_attr(qualified_name="net_b.net_c.param") self.assertEqual(len(w), 0) - with self.assertWarnsRegex(UserWarning, "Attempted to " - "insert a get_attr Node with no " - "underlying reference in the " - "owning GraphModule"): + with self.assertWarnsRegex( + UserWarning, + "Attempted to " + "insert a get_attr Node with no " + "underlying reference in the " + "owning GraphModule", + ): bad_param = a.graph.get_attr(qualified_name="net_b.param") a.graph.erase_node(bad_param) @@ -3232,20 +3414,16 @@ class TestFX(JitTestCase): # Test `get_parameter` a.get_parameter("net_b.net_c.param") - with self.assertRaisesRegex(AttributeError, "is not an " - "nn.Parameter"): + with self.assertRaisesRegex(AttributeError, "is not an " "nn.Parameter"): a.get_parameter("net_b.buf") - with self.assertRaisesRegex(AttributeError, "has no attribute " - "`param`"): + with self.assertRaisesRegex(AttributeError, "has no attribute " "`param`"): a.get_parameter("net_b.param") # Test `get_buffer` a.get_buffer("net_b.buf") - with self.assertRaisesRegex(AttributeError, "is not a " - "buffer"): + with self.assertRaisesRegex(AttributeError, "is not a " "buffer"): a.get_buffer("net_b.net_c.param") - with self.assertRaisesRegex(AttributeError, "has no attribute " - "`buf`"): + with self.assertRaisesRegex(AttributeError, "has no attribute " "`buf`"): a.get_buffer("net_b.net_c.buf") # Test non-nested attributes @@ -3297,7 +3475,9 @@ class TestFX(JitTestCase): model = Model() class MyCustomTracer(torch.fx.Tracer): - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: + def is_leaf_module( + self, m: torch.nn.Module, module_qualified_name: str + ) -> bool: return module_qualified_name == "submod" inputs = torch.randn(1, 10) @@ -3321,9 +3501,7 @@ class TestFX(JitTestCase): weight = torch.tensor([[1.0]], requires_grad=True) bias = torch.tensor([0.0], requires_grad=True) buffer = torch.tensor([0.0]) - parameters = {'l1.weight': weight, - 'l1.bias': bias, - 'buffer': buffer} + parameters = {"l1.weight": weight, "l1.bias": bias, "buffer": buffer} fx_module = torch.fx.symbolic_trace(module) res = torch.func.functional_call(fx_module, parameters, x) res.backward() @@ -3519,12 +3697,14 @@ class TestFX(JitTestCase): return torch.add(x, x) class M(torch.nn.Module): - def forward(self, x: 'torch.Tensor', a: 'A') -> 'torch.Tensor': + def forward(self, x: "torch.Tensor", a: "A") -> "torch.Tensor": return a(x) self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) - def test_annotations_with_non_torch_reference_and_no_internal_forward_references(self): + def test_annotations_with_non_torch_reference_and_no_internal_forward_references( + self, + ): class A: def __call__(self, x: torch.Tensor): return torch.add(x, x) @@ -3541,14 +3721,14 @@ class TestFX(JitTestCase): return torch.add(x, x) class M(torch.nn.Module): - def forward(self, x: list['torch.Tensor'], a: A) -> 'torch.Tensor': + def forward(self, x: list["torch.Tensor"], a: A) -> "torch.Tensor": return a(x)[0] self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) def test_annotation_with_future(self): try: - import fx.test_future # noqa: F401 + import fx.test_future # noqa: F401 finally: del sys.modules["__future__"] @@ -3565,24 +3745,25 @@ class TestFX(JitTestCase): traced(x, y) - FileCheck().check("typing_Tuple[()]") \ - .check("typing_Tuple[str,typing_Tuple[()]]") \ - .run(traced.code) + FileCheck().check("typing_Tuple[()]").check( + "typing_Tuple[str,typing_Tuple[()]]" + ).run(traced.code) scripted = torch.jit.script(traced) scripted(x, y) - FileCheck().check("Tuple[()]") \ - .check("Tuple[str, Tuple[()]]") \ - .run(scripted.code) + FileCheck().check("Tuple[()]").check("Tuple[str, Tuple[()]]").run(scripted.code) - @unittest.skipIf(IS_WINDOWS, "Python Windows bug? https://bugs.python.org/issue45108") + @unittest.skipIf( + IS_WINDOWS, "Python Windows bug? https://bugs.python.org/issue45108" + ) @unittest.skipIf(sys.version_info >= (3, 10), "Does not work on Python-3.10") def test_assert(self): def f(x): assert x > 1 return x + 1 + try: torch.fx.proxy.TracerBase.trace_asserts = True traced = symbolic_trace(f) @@ -3614,7 +3795,7 @@ class TestFX(JitTestCase): return new_dict def f_dict_add(x): - return x['a'] + sum(x['z']) + return x["a"] + sum(x["z"]) def f_namedtuple_add(x): return x.x + x.y @@ -3639,42 +3820,47 @@ class TestFX(JitTestCase): (f_sum, [PH, PH, PH]), (f_sum, []), (f_sum, [PHTest(), PHTest(), PHTest()]), - (f_sum_dict, {'a': PH, 'b': PH, 'c': PH}), - (f_dict_list_map, {'a': (PH, PH), 'b': [PH], 'c': []}), + (f_sum_dict, {"a": PH, "b": PH, "c": PH}), + (f_dict_list_map, {"a": (PH, PH), "b": [PH], "c": []}), (f_dict_list_map, {5: (PH, PH, PH)}), - (f_dict_add, {'a': PH, 'z': (PH, PH, PH)}), - (f_dict_add, {'a': PH, 'z': []}), + (f_dict_add, {"a": PH, "z": (PH, PH, PH)}), + (f_dict_add, {"a": PH, "z": []}), (f_custom, Foo(PH, PH)), (f_custom, Foo(PH, 3)), - (f_custom_dict, Foo({'a': PH, 'b': PH}, PH)), + (f_custom_dict, Foo({"a": PH, "b": PH}, PH)), # (f_return_custom, Foo(PH, PH)), # Don't currently support output pytrees (f_namedtuple_add, Point(PH, PH)), ] def verify_pytree(f, inp): - val = pytree.tree_map(lambda x: torch.randn(3) if isinstance(x, PHBase) else x, inp) + val = pytree.tree_map( + lambda x: torch.randn(3) if isinstance(x, PHBase) else x, inp + ) num_flat_args = len(pytree.tree_leaves(inp)) orig_out = f(val) - nf = symbolic_trace(f, concrete_args={'x': inp}) + nf = symbolic_trace(f, concrete_args={"x": inp}) self.assertEqual(nf(val), orig_out) bare_fx = GraphModule({}, copy.deepcopy(nf.graph)) bare_fx.graph.set_codegen(CodeGen()) bare_fx.recompile() - self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out) + self.assertEqual( + nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), + orig_out, + ) assert num_flat_args == 0 or "tree_flatten_spec" in nf.code - assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args + assert sum(i.op == "placeholder" for i in nf.graph.nodes) == num_flat_args nf = symbolic_trace(nf) self.assertEqual(nf(val), orig_out) assert "tree_flatten_spec" not in nf.code - assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == 1 + assert sum(i.op == "placeholder" for i in nf.graph.nodes) == 1 - nf = symbolic_trace(nf, concrete_args={'x': inp}) + nf = symbolic_trace(nf, concrete_args={"x": inp}) self.assertEqual(nf(val), orig_out) assert num_flat_args == 0 or "tree_flatten_spec" in nf.code - assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args + assert sum(i.op == "placeholder" for i in nf.graph.nodes) == num_flat_args pickled = pickle.dumps(nf) nf = pickle.loads(pickled) @@ -3686,11 +3872,11 @@ class TestFX(JitTestCase): def test_pytree_concrete(self): def f(b, a): if b: - return a['a'] + return a["a"] else: - return a['z'] + return a["z"] - inp = {'a': {'a': PH, 'z': PH}, 'b': True} + inp = {"a": {"a": PH, "z": PH}, "b": True} nf = symbolic_trace(f, concrete_args=inp) val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp) self.assertEqual(nf(**val), f(**val)) @@ -3718,25 +3904,36 @@ class TestFX(JitTestCase): verify_metadata( gm=symbolic_trace( f_sum, - concrete_args={"a": PHWithMeta(ph_key="a"), "b": PHWithMeta(ph_key="b")} + concrete_args={ + "a": PHWithMeta(ph_key="a"), + "b": PHWithMeta(ph_key="b"), + }, ), arg_names=["a_1", "b_1"], - metadata=["a", "b"] + metadata=["a", "b"], ) verify_metadata( gm=symbolic_trace( f_dict, - concrete_args={"a": {"f1": PHWithMeta(ph_key="f1"), "f2": PHWithMeta(ph_key="f2")}} + concrete_args={ + "a": {"f1": PHWithMeta(ph_key="f1"), "f2": PHWithMeta(ph_key="f2")} + }, ), arg_names=["a_1", "a_2"], - metadata=["f1", "f2"] + metadata=["f1", "f2"], ) # Ensures that tags on nodes are NOT overwritten by PH attributes with same attr name (tag) class TaggingTracer(Tracer): - def create_node(self, kind : str, target : Union[str, Callable], - args : tuple[Argument, ...], kwargs : dict[str, Any], name : Optional[str] = None, - type_expr : Optional[Any] = None) -> Node: + def create_node( + self, + kind: str, + target: Union[str, Callable], + args: tuple[Argument, ...], + kwargs: dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> Node: n = super().create_node(kind, target, args, kwargs, name) n.tag = "foo" return n @@ -3747,7 +3944,9 @@ class TestFX(JitTestCase): self.tag = tag - g = TaggingTracer().trace(f_sum, concrete_args={"a": PHWithTag(tag="bar"), "b": PHWithTag(tag="bar")}) + g = TaggingTracer().trace( + f_sum, concrete_args={"a": PHWithTag(tag="bar"), "b": PHWithTag(tag="bar")} + ) for n in g.nodes: self.assertTrue(hasattr(n, "tag")) # Ensure that tag is still "foo" and not "bar" (from PHWithTag) @@ -3762,7 +3961,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: return lst_unpack def additional_globals(self): - return [('List', list)] + return [("List", list)] def process_inputs(self, *inputs): assert len(inputs) == 1 @@ -3783,7 +3982,9 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: bare_fx.recompile() self.assertEqual(nf(vals), f(*vals)) - self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(vals))), f(*vals)) + self.assertEqual( + nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(vals))), f(*vals) + ) ts_f = torch.jit.script(nf) self.assertEqual(nf(vals), ts_f(vals)) @@ -3797,7 +3998,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: return lst_unpack def additional_globals(self): - return [('List', list)] + return [("List", list)] def process_inputs(self, *inputs): assert len(inputs) == 1 @@ -3826,14 +4027,14 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: return lst_unpack def additional_globals(self): - return [('List', list)] + return [("List", list)] def process_inputs(self, *inputs): assert len(inputs) == 1 return inputs[0] def generate_output(self, output_args): - return f'return list({repr(output_args)})' + return f"return list({repr(output_args)})" def process_outputs(self, outputs): return list(outputs) @@ -3870,19 +4071,22 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: tracer_after = copy.deepcopy(tracer) self.assertEqual(str(tracer.graph), str(tracer_after.graph)) - self.assertTrue(not hasattr(tracer_before, 'graph') or str(tracer.graph) != str(tracer_before.graph)) + self.assertTrue( + not hasattr(tracer_before, "graph") + or str(tracer.graph) != str(tracer_before.graph) + ) def test_deepcopy_graphmodule(self): m = symbolic_trace(SimpleTest()) - m.meta['hello'] = 'world' + m.meta["hello"] = "world" copy_m = copy.deepcopy(m) - self.assertEqual(copy_m.meta['hello'], 'world') + self.assertEqual(copy_m.meta["hello"], "world") def test_deepcopy_no_recursion(self): m = symbolic_trace(SimpleTest()) - m.meta['hello'] = m # circular reference + m.meta["hello"] = m # circular reference copy_m = copy.deepcopy(m) # finishes - self.assertEqual(id(copy_m), id(copy_m.meta['hello'])) + self.assertEqual(id(copy_m), id(copy_m.meta["hello"])) def test_enum(self): from enum import Enum @@ -3948,8 +4152,10 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: # recorver mutable checking flag torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag + def run_getitem_target(): from torch.fx._symbolic_trace import _wrapped_methods_to_patch + _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) try: TestFX().getitem_inner() @@ -3961,11 +4167,15 @@ class TestOperatorSignatures(JitTestCase): def setUp(self): # Checking for mutable operations whil tracing is feature flagged # Enable it in testing but not by default - self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations + self.orig_tracer_mutable_flag = ( + torch.fx.proxy.TracerBase.check_mutable_operations + ) torch.fx.proxy.TracerBase.check_mutable_operations = True def tearDown(self): - torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag + torch.fx.proxy.TracerBase.check_mutable_operations = ( + self.orig_tracer_mutable_flag + ) @onlyCPU @ops(op_db, allowed_dtypes=(torch.float,)) @@ -3975,20 +4185,22 @@ class TestOperatorSignatures(JitTestCase): sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) schemas = get_signature_for_torch_op(op.op) if not schemas: - raise RuntimeError('No Schemas Returned') + raise RuntimeError("No Schemas Returned") for sample_input in sample_inputs_itr: # Iterate through overloads until we hit a match. If we exit this # loop via `else`, we haven't found a match for schema in schemas: try: - bound_args = schema.bind(sample_input.input, *sample_input.args, **sample_input.kwargs) + bound_args = schema.bind( + sample_input.input, *sample_input.args, **sample_input.kwargs + ) bound_args.apply_defaults() op(*bound_args.args, **bound_args.kwargs) break except TypeError as e: pass else: - raise RuntimeError(f'Did not match any schemas for op {op.name}!') + raise RuntimeError(f"Did not match any schemas for op {op.name}!") class TestFXAPIBackwardCompatibility(JitTestCase): @@ -3998,13 +4210,16 @@ class TestFXAPIBackwardCompatibility(JitTestCase): # Checking for mutable operations whil tracing is feature flagged # Enable it in testing but not by default - self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations + self.orig_tracer_mutable_flag = ( + torch.fx.proxy.TracerBase.check_mutable_operations + ) torch.fx.proxy.TracerBase.check_mutable_operations = True def tearDown(self): super().tearDown() - torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag - + torch.fx.proxy.TracerBase.check_mutable_operations = ( + self.orig_tracer_mutable_flag + ) def _fn_to_stable_annotation_str(self, obj): """ @@ -4016,77 +4231,87 @@ class TestFXAPIBackwardCompatibility(JitTestCase): signature = inspect.signature(obj) - sig_str = f'{fn_name}{signature}' + sig_str = f"{fn_name}{signature}" arg_strs = [] for k, v in signature.parameters.items(): - maybe_type_annotation = f': {self._annotation_type_to_stable_str(v.annotation, sig_str)}'\ - if v.annotation is not inspect.Signature.empty else '' + maybe_type_annotation = ( + f": {self._annotation_type_to_stable_str(v.annotation, sig_str)}" + if v.annotation is not inspect.Signature.empty + else "" + ) def default_val_str(val): if isinstance(val, (tuple, list)): - str_pieces = ['(' if isinstance(val, tuple) else '['] - str_pieces.append(', '.join(default_val_str(v) for v in val)) + str_pieces = ["(" if isinstance(val, tuple) else "["] + str_pieces.append(", ".join(default_val_str(v) for v in val)) if isinstance(val, tuple) and len(str_pieces) == 2: - str_pieces.append(',') - str_pieces.append(')' if isinstance(val, tuple) else ']') - return ''.join(str_pieces) + str_pieces.append(",") + str_pieces.append(")" if isinstance(val, tuple) else "]") + return "".join(str_pieces) # Need to fix up some default value strings. # First case: modules. Default module `repr` contains the FS path of the module. # Don't leak that if isinstance(val, types.ModuleType): - return f'' + return f"" # Second case: callables. Callables (such as lambdas) encode their address in # their string repr. Don't do that if callable(val): - return f'' + return f"" return str(val) if v.default is not inspect.Signature.empty: - default_val_str = default_val_str(v.default) if not isinstance(v.default, str) else f"'{v.default}'" - maybe_default = f' = {default_val_str}' + default_val_str = ( + default_val_str(v.default) + if not isinstance(v.default, str) + else f"'{v.default}'" + ) + maybe_default = f" = {default_val_str}" else: - maybe_default = '' - maybe_stars = '' + maybe_default = "" + maybe_stars = "" if v.kind == inspect.Parameter.VAR_POSITIONAL: - maybe_stars = '*' + maybe_stars = "*" elif v.kind == inspect.Parameter.VAR_KEYWORD: - maybe_stars = '**' - arg_strs.append(f'{maybe_stars}{k}{maybe_type_annotation}{maybe_default}') + maybe_stars = "**" + arg_strs.append(f"{maybe_stars}{k}{maybe_type_annotation}{maybe_default}") - return_annot = f' -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}'\ - if signature.return_annotation is not inspect.Signature.empty else '' + return_annot = ( + f" -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}" + if signature.return_annotation is not inspect.Signature.empty + else "" + ) return f'{fn_name}({", ".join(arg_strs)}){return_annot}' _trivial_mappings = { - str : 'str', - int : 'int', - float: 'float', - bool: 'bool', - torch.dtype: 'torch.dtype', - torch.Tensor: 'torch.Tensor', - torch.device: 'torch.device', - torch.memory_format: 'torch.memory_format', - slice: 'slice', - torch.nn.Module: 'torch.nn.modules.module.Module', - torch.fx.Graph : 'torch.fx.graph.Graph', - torch.fx.Node : 'torch.fx.node.Node', - torch.fx.Proxy : 'torch.fx.proxy.Proxy', - torch.fx.node.Target : 'torch.fx.node.Target', - torch.fx.node.Argument : 'torch.fx.node.Argument', - torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode', - torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule', - torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match', - Ellipsis : '...', - typing.Any: 'Any', - type(None): 'NoneType', - None: 'None', - typing.Iterator: 'Iterator', - collections.abc.Iterator: 'Iterator', + str: "str", + int: "int", + float: "float", + bool: "bool", + torch.dtype: "torch.dtype", + torch.Tensor: "torch.Tensor", + torch.device: "torch.device", + torch.memory_format: "torch.memory_format", + slice: "slice", + torch.nn.Module: "torch.nn.modules.module.Module", + torch.fx.Graph: "torch.fx.graph.Graph", + torch.fx.Node: "torch.fx.node.Node", + torch.fx.Proxy: "torch.fx.proxy.Proxy", + torch.fx.node.Target: "torch.fx.node.Target", + torch.fx.node.Argument: "torch.fx.node.Argument", + torch.fx.graph.PythonCode: "torch.fx.graph.PythonCode", + torch.fx.graph_module.GraphModule: "torch.fx.graph_module.GraphModule", + torch.fx.subgraph_rewriter.Match: "torch.fx.subgraph_rewriter.Match", + Ellipsis: "...", + typing.Any: "Any", + type(None): "NoneType", + None: "None", + typing.Iterator: "Iterator", + collections.abc.Iterator: "Iterator", } _UNBOUND_TYPES = { @@ -4104,7 +4329,7 @@ class TestFXAPIBackwardCompatibility(JitTestCase): def _annotation_type_to_stable_str(self, t, sig_str, recursive: bool = False): if t is inspect.Signature.empty: - return '' + return "" # Forward ref if isinstance(t, str): @@ -4112,9 +4337,9 @@ class TestFXAPIBackwardCompatibility(JitTestCase): return t else: return f"'{t}'" - if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef): + if hasattr(typing, "ForwardRef") and isinstance(t, typing.ForwardRef): return t.__forward_arg__ - if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef): + if hasattr(typing, "_ForwardRef") and isinstance(t, typing._ForwardRef): return t.__forward_arg__ mapping = self._trivial_mappings.get(t, None) @@ -4122,7 +4347,7 @@ class TestFXAPIBackwardCompatibility(JitTestCase): return mapping # Handle types with contained types - contained = getattr(t, '__args__', None) or [] + contained = getattr(t, "__args__", None) or [] # Callables contain a bare List for arguments contained = t if isinstance(t, list) else contained @@ -4131,39 +4356,49 @@ class TestFXAPIBackwardCompatibility(JitTestCase): if all(isinstance(ct, typing.TypeVar) for ct in contained): contained = [] - contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str, True) for ct in contained] - contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else '' + contained_type_annots = [ + self._annotation_type_to_stable_str(ct, sig_str, True) for ct in contained + ] + contained_type_str = ( + f'[{", ".join(contained_type_annots)}]' + if len(contained_type_annots) > 0 + else "" + ) - - origin = getattr(t, '__origin__', None) + origin = getattr(t, "__origin__", None) if origin is None: # Unbound types don't have `__origin__` in some Python versions, so fix that up here. origin = t if t in self._UNBOUND_TYPES else origin if origin in {tuple, tuple}: - return f'Tuple{contained_type_str}' + return f"Tuple{contained_type_str}" if origin in {typing.Union}: # Annoying hack to detect Optional - if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)): - not_none_param = contained[0] if contained[0] is not type(None) else contained[1] - return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str, True)}]' - return f'Union{contained_type_str}' + if len(contained) == 2 and (contained[0] is type(None)) ^ ( + contained[1] is type(None) + ): + not_none_param = ( + contained[0] if contained[0] is not type(None) else contained[1] + ) + return f"Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str, True)}]" + return f"Union{contained_type_str}" if origin in {dict, dict}: - return f'Dict{contained_type_str}' + return f"Dict{contained_type_str}" if origin in {list, list}: - return f'List{contained_type_str}' + return f"List{contained_type_str}" if origin in {type, type}: - return f'Type{contained_type_str}' + return f"Type{contained_type_str}" if isinstance(t, typing.Callable): if len(contained) > 0 and contained[0] is not Ellipsis: return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]' else: - return f'Callable{contained_type_str}' - - raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.' - f'Please add support for this type and confirm with the ' - f'FX team that your signature change is valid.') + return f"Callable{contained_type_str}" + raise RuntimeError( + f"Unrecognized type {t} used in BC-compatible type signature {sig_str}." + f"Please add support for this type and confirm with the " + f"FX team that your signature change is valid." + ) def test_function_back_compat(self): """ @@ -4183,14 +4418,18 @@ class TestFXAPIBackwardCompatibility(JitTestCase): signature_strs.sort() try: - self.assertExpected('\n'.join(signature_strs) + '\n', 'fx_backcompat_function_signatures') + self.assertExpected( + "\n".join(signature_strs) + "\n", "fx_backcompat_function_signatures" + ) except AssertionError as e: - msg = f"{e}\n****** ERROR ******\nAn FX function that has been marked " \ - f"as backwards-compatible has experienced a signature change. See the " \ - f"above exception context for more information. If this change was " \ - f"unintended, please revert it. If it was intended, check with the FX " \ - f"team to ensure that the proper deprecation protocols have been followed " \ - f"and subsequently --accept the change." + msg = ( + f"{e}\n****** ERROR ******\nAn FX function that has been marked " + f"as backwards-compatible has experienced a signature change. See the " + f"above exception context for more information. If this change was " + f"unintended, please revert it. If it was intended, check with the FX " + f"team to ensure that the proper deprecation protocols have been followed " + f"and subsequently --accept the change." + ) raise AssertionError(msg) # noqa: B904 def test_class_member_back_compat(self): @@ -4203,39 +4442,47 @@ class TestFXAPIBackwardCompatibility(JitTestCase): for obj in _BACK_COMPAT_OBJECTS: if isinstance(obj, type): - public_members = [name for name in obj.__dict__ if not name.startswith('_')] - class_method_strs.append(f'{torch.typename(obj)} {sorted(public_members)}') + public_members = [ + name for name in obj.__dict__ if not name.startswith("_") + ] + class_method_strs.append( + f"{torch.typename(obj)} {sorted(public_members)}" + ) class_method_strs.sort() try: - self.assertExpected('\n'.join(class_method_strs), 'fx_backcompat_class_members') + self.assertExpected( + "\n".join(class_method_strs), "fx_backcompat_class_members" + ) except AssertionError as e: - msg = f"{e}\n****** ERROR ******\nAn FX class that has been marked " \ - f"as backwards-compatible has experienced change in its public members. See the " \ - f"above exception context for more information. If this change was " \ - f"unintended, please revert it. If it was intended, check with the FX " \ - f"team to ensure that the proper deprecation protocols have been followed " \ - f"and subsequently --accept the change." + msg = ( + f"{e}\n****** ERROR ******\nAn FX class that has been marked " + f"as backwards-compatible has experienced change in its public members. See the " + f"above exception context for more information. If this change was " + f"unintended, please revert it. If it was intended, check with the FX " + f"team to ensure that the proper deprecation protocols have been followed " + f"and subsequently --accept the change." + ) raise AssertionError(msg) from e def test_public_api_surface(self): non_back_compat_objects = {} def check_symbols_have_bc_designation(m, seen): - if not m.__name__.startswith('torch.fx'): + if not m.__name__.startswith("torch.fx"): return - if m.__name__.startswith('torch.fx.experimental'): + if m.__name__.startswith("torch.fx.experimental"): return # It's really common for inner functions to point to random modules # - make sure we don't recurse into modules we've already checked. seen.add(m.__name__) for k, v in m.__dict__.items(): - if hasattr(v, '__name__') and v.__name__ in seen: + if hasattr(v, "__name__") and v.__name__ in seen: continue if v is m: continue - if k.startswith('_'): + if k.startswith("_"): continue if isinstance(v, types.ModuleType): check_symbols_have_bc_designation(v, seen) @@ -4246,20 +4493,30 @@ class TestFXAPIBackwardCompatibility(JitTestCase): check_symbols_have_bc_designation(torch.fx, set()) check_symbols_have_bc_designation(torch.fx.passes, set()) - non_back_compat_strs = [torch.typename(obj) for obj in non_back_compat_objects.keys()] + non_back_compat_strs = [ + torch.typename(obj) for obj in non_back_compat_objects.keys() + ] # Only want objects in torch.fx non_back_compat_strs = [ - s for s in non_back_compat_strs if s.startswith('torch.fx') and not s.startswith('torch.fx.experimental')] + s + for s in non_back_compat_strs + if s.startswith("torch.fx") and not s.startswith("torch.fx.experimental") + ] # Only want objects in public namespaces non_back_compat_strs = [ - s for s in non_back_compat_strs if all(not atom.startswith('_') for atom in s.split('.'))] + s + for s in non_back_compat_strs + if all(not atom.startswith("_") for atom in s.split(".")) + ] non_back_compat_strs.sort() if len(non_back_compat_strs) != 0: - raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a " - f"backwards-compatibility classification! Please decorate these " - f"API(s) with `@torch.fx._compatibility.compatibility` to specify " - f"BC guarantees.") + raise AssertionError( + f"Public FX API(s) {non_back_compat_strs} introduced but not given a " + f"backwards-compatibility classification! Please decorate these " + f"API(s) with `@torch.fx._compatibility.compatibility` to specify " + f"BC guarantees." + ) def test_adding_side_effect_function(self): class TestModule(torch.nn.Module): @@ -4274,7 +4531,7 @@ class TestFXAPIBackwardCompatibility(JitTestCase): self.assertEqual(len(gm.graph.nodes), 3) found = False for node in gm.graph.nodes: - if node.op == 'call_function' and node.target == side_effect_func: + if node.op == "call_function" and node.target == side_effect_func: found = True self.assertTrue(found) @@ -4293,36 +4550,51 @@ class TestFXAPIBackwardCompatibility(JitTestCase): self.assertTrue(hasattr(reload_gm, "dummy_buffer")) self.assertTrue(hasattr(reload_gm, "dummy_parameter")) + # This is failing on Python 3.12 : https://github.com/pytorch/pytorch/issues/119454 -@unittest.skipIf( - sys.version_info >= (3, 12), "Failing on python 3.12+" -) +@unittest.skipIf(sys.version_info >= (3, 12), "Failing on python 3.12+") class TestFunctionalTracing(JitTestCase): def setUp(self): super().setUp() # Checking for mutable operations whil tracing is feature flagged # Enable it in testing but not by default - self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations + self.orig_tracer_mutable_flag = ( + torch.fx.proxy.TracerBase.check_mutable_operations + ) torch.fx.proxy.TracerBase.check_mutable_operations = True def tearDown(self): super().tearDown() - torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag + torch.fx.proxy.TracerBase.check_mutable_operations = ( + self.orig_tracer_mutable_flag + ) - IGNORE_FUNCS = ("has_torch_function", "has_torch_function_unary", - "has_torch_function_variadic", "handle_torch_function", - "boolean_dispatch") - TO_PATCH = {"has_torch_function": None, - "has_torch_function_unary": None, - "has_torch_function_variadic": None} + IGNORE_FUNCS = ( + "has_torch_function", + "has_torch_function_unary", + "has_torch_function_variadic", + "handle_torch_function", + "boolean_dispatch", + ) + TO_PATCH = { + "has_torch_function": None, + "has_torch_function_unary": None, + "has_torch_function_variadic": None, + } BUILT_IN_FUNC = (AssertionError, "") PROXY_ITERABLE = (TypeError, r"argument of type 'Proxy' is not iterable") PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated") LEN_ERROR = (RuntimeError, r"'len' is not supported in symbolic tracing by default") ARG_TYPE_MISMATCH = (TypeError, r", not Proxy$") - CONTROL_FLOW = (TraceError, r"symbolically traced variables cannot be used as inputs to control flow") - INTERPOLATE_ARGS_CONFLICT = (ValueError, r"only one of size or scale_factor should be defined") + CONTROL_FLOW = ( + TraceError, + r"symbolically traced variables cannot be used as inputs to control flow", + ) + INTERPOLATE_ARGS_CONFLICT = ( + ValueError, + r"only one of size or scale_factor should be defined", + ) MUTABLE = (RuntimeError, r"Tried to trace mutable operation") UNTRACEABLE_FUNCTIONALS = { @@ -4363,13 +4635,11 @@ class TestFunctionalTracing(JitTestCase): "softplus": BUILT_IN_FUNC, "softshrink": BUILT_IN_FUNC, "threshold_": BUILT_IN_FUNC, - "adaptive_avg_pool2d": LEN_ERROR, "adaptive_avg_pool3d": LEN_ERROR, "adaptive_max_pool2d_with_indices": LEN_ERROR, "adaptive_max_pool3d_with_indices": LEN_ERROR, "instance_norm": CONTROL_FLOW, - "adaptive_max_pool1d": PROXY_ITERABLE, "adaptive_max_pool2d": PROXY_ITERABLE, "adaptive_max_pool3d": PROXY_ITERABLE, @@ -4378,7 +4648,6 @@ class TestFunctionalTracing(JitTestCase): "max_pool1d": PROXY_ITERABLE, "max_pool2d": PROXY_ITERABLE, "max_pool3d": PROXY_ITERABLE, - "lp_pool2d": PROXY_ITERATED, "lp_pool3d": PROXY_ITERATED, "max_unpool1d": PROXY_ITERATED, @@ -4386,14 +4655,12 @@ class TestFunctionalTracing(JitTestCase): "max_unpool3d": PROXY_ITERATED, "fold": PROXY_ITERATED, "unfold": PROXY_ITERATED, - "adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH, "fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH, "fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH, "layer_norm": ARG_TYPE_MISMATCH, "rms_norm": ARG_TYPE_MISMATCH, "lp_pool1d": ARG_TYPE_MISMATCH, - "affine_grid": CONTROL_FLOW, "alpha_dropout": CONTROL_FLOW, "batch_norm": CONTROL_FLOW, @@ -4449,7 +4716,6 @@ class TestFunctionalTracing(JitTestCase): "triplet_margin_loss": CONTROL_FLOW, "triplet_margin_with_distance_loss": CONTROL_FLOW, "upsample": CONTROL_FLOW, - "upsample_bilinear": INTERPOLATE_ARGS_CONFLICT, "upsample_nearest": INTERPOLATE_ARGS_CONFLICT, } @@ -4484,8 +4750,7 @@ class TestFunctionalTracing(JitTestCase): "max_pool1d": PROXY_ITERATED, "max_pool2d": PROXY_ITERATED, "max_pool3d": PROXY_ITERATED, - - "group_norm": CONTROL_FLOW + "group_norm": CONTROL_FLOW, } @classmethod @@ -4495,7 +4760,7 @@ class TestFunctionalTracing(JitTestCase): if not f.islower(): continue # Ignore internal functions - if f.startswith('_'): + if f.startswith("_"): continue # Ignore supporting functions if f in cls.IGNORE_FUNCS: @@ -4509,7 +4774,9 @@ class TestFunctionalTracing(JitTestCase): sig = inspect.signature(fn) has_tensor_arg = False for param in sig.parameters.values(): - if isinstance(param.annotation, type) and issubclass(param.annotation, torch.Tensor): + if isinstance(param.annotation, type) and issubclass( + param.annotation, torch.Tensor + ): has_tensor_arg = True if not has_tensor_arg: continue @@ -4521,10 +4788,12 @@ class TestFunctionalTracing(JitTestCase): @classmethod def generate_test_func(cls, func_name, fn): - def functional_test(self): - if func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 and \ - sys.version_info >= (3, 8) and sys.version_info < (3, 12): + if ( + func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 + and sys.version_info >= (3, 8) + and sys.version_info < (3, 12) + ): exc, err = self.UNTRACEABLE_FUNCTIONALS_PY38[func_name] with self.assertRaisesRegex(exc, err): symbolic_trace(fn) @@ -4534,6 +4803,7 @@ class TestFunctionalTracing(JitTestCase): symbolic_trace(fn) else: symbolic_trace(fn) + return functional_test @classmethod @@ -4546,7 +4816,6 @@ class TestFunctionalTracing(JitTestCase): @classmethod def setUpClass(cls): - def no(*args, **kwargs): return False @@ -4559,27 +4828,33 @@ class TestFunctionalTracing(JitTestCase): for name in cls.TO_PATCH.keys(): setattr(torch.nn.functional, name, cls.TO_PATCH[name]) + TestFunctionalTracing.generate_tests() instantiate_device_type_tests(TestOperatorSignatures, globals()) + @skipIfTorchDynamo("too slow") @skipIfNoTorchVision class TestVisionTracing(JitTestCase): def setUp(self): # Checking for mutable operations while tracing is feature flagged # Enable it in testing but not by default - self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations + self.orig_tracer_mutable_flag = ( + torch.fx.proxy.TracerBase.check_mutable_operations + ) torch.fx.proxy.TracerBase.check_mutable_operations = True def tearDown(self): - torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag + torch.fx.proxy.TracerBase.check_mutable_operations = ( + self.orig_tracer_mutable_flag + ) PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated") INCONSISTENT_TYPE = ( RuntimeError, - r"Return value was annotated as having type __torch__.torchvision.models[.\w]+ but is actually of type Tensor" + r"Return value was annotated as having type __torch__.torchvision.models[.\w]+ but is actually of type Tensor", ) UNTRACEABLE_MODELS = { @@ -4627,7 +4902,7 @@ class TestVisionTracing(JitTestCase): graph = symbolic_trace(model) else: out_transform = self.output_transform.get(name, lambda x: x) - graph : torch.fx.GraphModule = symbolic_trace(model) + graph: torch.fx.GraphModule = symbolic_trace(model) a = out_transform(model(x)) b = out_transform(graph(x)) self.assertEqual(a, b) @@ -4646,8 +4921,12 @@ class TestVisionTracing(JitTestCase): @classmethod def generate_classification_tests(cls): for k in torchvision_models.list_models(module=torchvision_models): - test_name = 'test_torchvision_models_' + k - x = torch.rand(1, 3, 299, 299) if k in ['inception_v3'] else torch.rand(1, 3, 224, 224) + test_name = "test_torchvision_models_" + k + x = ( + torch.rand(1, 3, 299, 299) + if k in ["inception_v3"] + else torch.rand(1, 3, 224, 224) + ) kwargs = dict(num_classes=50) model_test = cls.generate_test_fn(k, x, kwargs) setattr(cls, test_name, model_test) @@ -4655,7 +4934,7 @@ class TestVisionTracing(JitTestCase): @classmethod def generate_segmentation_tests(cls): for k in torchvision_models.list_models(module=torchvision_models.segmentation): - test_name = 'test_torchvision_models_segmentation_' + k + test_name = "test_torchvision_models_segmentation_" + k x = torch.rand(1, 3, 32, 32) kwargs = dict(num_classes=10, pretrained_backbone=False) model_test = cls.generate_test_fn(k, x, kwargs) @@ -4664,7 +4943,7 @@ class TestVisionTracing(JitTestCase): @classmethod def generate_detection_tests(cls): for k in torchvision_models.list_models(module=torchvision_models.detection): - test_name = 'test_torchvision_models_detection_' + k + test_name = "test_torchvision_models_detection_" + k x = [torch.rand(3, 300, 300)] kwargs = dict(num_classes=10, pretrained_backbone=False) model_test = cls.generate_test_fn(k, x, kwargs) @@ -4673,7 +4952,7 @@ class TestVisionTracing(JitTestCase): @classmethod def generate_video_tests(cls): for k in torchvision_models.list_models(module=torchvision_models.video): - test_name = 'test_torchvision_models_video_' + k + test_name = "test_torchvision_models_video_" + k x = ( torch.rand(1, 3, 4, 112, 112) if k not in {"mvit_v1_b", "mvit_v2_s", "s3d"} @@ -4690,8 +4969,9 @@ class TestVisionTracing(JitTestCase): cls.generate_segmentation_tests() cls.generate_video_tests() + if HAS_TORCHVISION: TestVisionTracing.generate_tests() -if __name__ == '__main__': +if __name__ == "__main__": run_tests()