From 23fffb54d591c7b5ca6d19728d628dbe1e79d91c Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Mon, 3 Feb 2025 16:02:20 -0800 Subject: [PATCH] Use OrderedSet in _functorch/partitioners (#146102) In an attempt to make partitioning more deterministic, change all sets in partitioners.py to OrderedSets. Note that this change does not fix the non-determinism we're seeing in the internal model. But let's at least eliminate this potential source of non-determinism before investigating any changes to the mincut approach? Pull Request resolved: https://github.com/pytorch/pytorch/pull/146102 Approved by: https://github.com/oulgen --- .lintrunner.toml | 1 + torch/_functorch/partitioners.py | 104 +++++++++++++++++-------------- 2 files changed, 57 insertions(+), 48 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 512f470aaa2..347502fc07d 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1703,6 +1703,7 @@ command = [ ] include_patterns = [ "torch/_inductor/**/*.py", + "torch/_functorch/partitioners.py", ] is_formatter = true diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 4cae6c6e93c..c85db3b780c 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -25,6 +25,7 @@ from torch.fx.experimental.symbolic_shapes import ( is_symbol_binding_fx_node, ) from torch.fx.passes import graph_drawer +from torch.utils._ordered_set import OrderedSet from torch.utils.checkpoint import CheckpointPolicy from . import config @@ -55,11 +56,11 @@ prims = torch.ops.prims class OpTypes: """Class for keeping track of different operator categories""" - fusible_ops: set[Callable] - compute_intensive_ops: set[Callable] - random_ops: set[Callable] - view_ops: set[Callable] - recomputable_ops: set[Callable] + fusible_ops: OrderedSet[Callable] + compute_intensive_ops: OrderedSet[Callable] + random_ops: OrderedSet[Callable] + view_ops: OrderedSet[Callable] + recomputable_ops: OrderedSet[Callable] def is_fusible(self, node: fx.Node): return get_aten_target(node) in self.fusible_ops @@ -82,9 +83,9 @@ class NodeInfo: # Be careful about iterating over these explicitly, as their order may not # be deterministic inputs: list[fx.Node] - _required_fw_nodes: set[fx.Node] - required_bw_nodes: set[fx.Node] - unclaimed_nodes: set[fx.Node] + _required_fw_nodes: OrderedSet[fx.Node] + required_bw_nodes: OrderedSet[fx.Node] + unclaimed_nodes: OrderedSet[fx.Node] fw_order: dict[fx.Node, int] @functools.cached_property @@ -326,7 +327,7 @@ def _extract_fwd_bwd_modules( # we propagate all symbols which are referenced by backwards inputs. # These are not directly used in the graph but are required for downstream # sizevar assignment - saved_symbols: set[sympy.Symbol] = set() + saved_symbols: OrderedSet[sympy.Symbol] = OrderedSet() saved_sym_nodes_binding = [] saved_sym_nodes_derived = [] @@ -426,9 +427,9 @@ def default_partition( forward_only_graph = _extract_graph_with_inputs_outputs( joint_module.graph, inputs, fwd_outputs, "forward" ) - forward_node_names = { + forward_node_names = OrderedSet( node.name for node in forward_only_graph.nodes if node.op != "output" - } + ) saved_values = [] saved_sym_nodes = [] @@ -580,7 +581,7 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: def insert_node_in_graph(node): cur_nodes = [node] - insertable_nodes = set() + insertable_nodes: OrderedSet[fx.Node] = OrderedSet() while len(cur_nodes) > 0: node = cur_nodes.pop() if node in insertable_nodes or node in env: @@ -817,19 +818,21 @@ def solve_min_cut( joint_graph: fx.Graph, node_info: NodeInfo, min_cut_options: MinCutOptions, - dont_ban=None, + dont_ban: Optional[OrderedSet[fx.Node]] = None, ): if dont_ban is None: - dont_ban = set() + dont_ban = OrderedSet() op_types = get_default_op_list() if AOT_PARTITIONER_DEBUG: - joint_module_ops = { + joint_module_ops = OrderedSet( str(node.target._overloadpacket) for node in joint_graph.nodes if node.op == "call_function" and hasattr(node.target, "_overloadpacket") - } - ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops} + ) + ops_ignored = joint_module_ops - OrderedSet( + str(i) for i in op_types.recomputable_ops + ) log.info("Ops banned from re-materialization: %s", ops_ignored) def can_fuse_into_auto_functionalized(a, b): @@ -888,7 +891,7 @@ def solve_min_cut( def is_materialized_backwards(node): if op_types.is_view(node): return False - cur_nodes = {node} + cur_nodes = OrderedSet([node]) while len(cur_nodes) > 0: cur = cur_nodes.pop() for user in cur.users: @@ -981,7 +984,7 @@ def solve_min_cut( return mem_sz * 2 nx_graph = nx.DiGraph() - banned_nodes = set() + banned_nodes: OrderedSet[fx.Node] = OrderedSet() def ban_recomputation_if_allowed(node): if op_types.is_view(node): @@ -1091,12 +1094,13 @@ def solve_min_cut( if node_info.is_required_fw(user): if node_info.get_fw_order(user) > max_range: continue - val = (node_info.get_fw_order(user), user, is_fusible(node, user)) + val: tuple[int, fx.Node, bool] = ( + node_info.get_fw_order(user), + user, + is_fusible(node, user), + ) if val not in sorted_nodes: - heapq.heappush( - sorted_nodes, - val, - ) + heapq.heappush(sorted_nodes, val) return max_range if min_cut_options.ban_if_used_far_apart: @@ -1141,11 +1145,13 @@ def solve_min_cut( # Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36 if min_cut_options.ban_if_long_fusible_chains: - visited = set() + visited: OrderedSet[fx.Node] = OrderedSet() for start_node in joint_graph.nodes: if not node_info.is_required_fw(start_node): continue - fusible = [(node_info.get_fw_order(start_node), start_node)] + fusible: list[tuple[int, fx.Node]] = [ + (node_info.get_fw_order(start_node), start_node) + ] start_order = node_info.get_fw_order(start_node) while len(fusible) > 0: _, cur = heapq.heappop(fusible) @@ -1184,11 +1190,11 @@ def solve_min_cut( raise reachable, non_reachable = partition - cutset: set[tuple[str, str]] = set() + cutset: OrderedSet[tuple[str, str]] = OrderedSet() for u, nbrs in ((n, nx_graph[n]) for n in reachable): cutset.update((u, v) for v in nbrs if v in non_reachable) - cut_nodes = set() + cut_nodes: OrderedSet[str] = OrderedSet() for node_in, node_out in cutset: assert node_in[:-3] == node_out[:-4] node_name = node_in[:-3] @@ -1358,9 +1364,9 @@ def get_default_op_list() -> OpTypes: ] default_recomputable_ops += [method_to_operator(m) for m in magic_methods] - recomputable_ops = set(default_recomputable_ops) + recomputable_ops = OrderedSet(default_recomputable_ops) - random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] + random_ops = OrderedSet([aten.native_dropout, aten.rand_like, aten.randn_like]) compute_intensive_ops = [ aten.mm, aten.convolution, @@ -1375,13 +1381,13 @@ def get_default_op_list() -> OpTypes: aten._scaled_mm, ] # noqa: E501,B950 - fusible_ops = recomputable_ops | set(random_ops) + fusible_ops = recomputable_ops | random_ops return OpTypes( - set(fusible_ops), - set(compute_intensive_ops), - set(random_ops), - set(view_ops), - set(recomputable_ops), + fusible_ops, + OrderedSet(compute_intensive_ops), + random_ops, + OrderedSet(view_ops), + recomputable_ops, ) @@ -1567,9 +1573,11 @@ def choose_saved_values_set( from torch._inductor.fx_utils import get_node_storage - input_storages = {get_node_storage(node) for node in node_info.inputs} + input_storages = OrderedSet(get_node_storage(node) for node in node_info.inputs) - def get_recomputable_banned_nodes(banned_nodes: set[fx.Node]) -> list[fx.Node]: + def get_recomputable_banned_nodes( + banned_nodes: OrderedSet[fx.Node], + ) -> list[fx.Node]: return [ i for i in banned_nodes @@ -1653,7 +1661,7 @@ Activation Checkpointing - Knapsack Problem Summary: payload_fn=lambda: knapsack_summary, ) log.info(knapsack_summary) - dont_ban = set() + dont_ban: OrderedSet[fx.Node] = OrderedSet() for idx in recomputable_node_idxs: # if idx in all_recomputable_banned_nodes: try: @@ -1776,7 +1784,7 @@ def min_cut_rematerialization_partition( def classify_nodes(joint_module): name_to_node = get_name_to_node(joint_module.graph) - required_bw_nodes = set() + required_bw_nodes: OrderedSet[fx.Node] = OrderedSet() for node in joint_module.graph.nodes: if node.op == "placeholder" and "tangents" in node.target: required_bw_nodes.add(node) @@ -1800,16 +1808,16 @@ def min_cut_rematerialization_partition( forward_only_graph = _extract_graph_with_inputs_outputs( joint_module.graph, inputs, fwd_outputs, "forward" ) - required_fw_nodes: set[fx.Node] = { + required_fw_nodes: OrderedSet[fx.Node] = OrderedSet( name_to_node[node.name] for node in forward_only_graph.nodes if node.op != "output" - } - unclaimed_nodes = { + ) + unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet( node for node in joint_module.graph.nodes if node not in required_fw_nodes and node not in required_bw_nodes - } + ) fw_cnt = 0 fw_order = {} for node in joint_module.graph.nodes: @@ -1879,12 +1887,12 @@ def min_cut_rematerialization_partition( # Log theoretical per activation storage sizes log.info("Theoretical Per Activation Storage Sizes: %s", sorted_sizes) - fw_module_nodes = { + fw_module_nodes = OrderedSet( node.name for node in fw_module.graph.nodes if node.op == "call_function" - } - bw_module_nodes = { + ) + bw_module_nodes = OrderedSet( node.name for node in bw_module.graph.nodes if node.op == "call_function" - } + ) remat_nodes = fw_module_nodes & bw_module_nodes counts: dict[str, int] = defaultdict(int)