diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index c129049c3f0..1f31346c96e 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -250,7 +250,6 @@ class StructuredTraceTest(TestCase): {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -278,7 +277,6 @@ class StructuredTraceTest(TestCase): {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -314,7 +312,6 @@ class StructuredTraceTest(TestCase): {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -332,7 +329,6 @@ class StructuredTraceTest(TestCase): {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} @@ -360,7 +356,6 @@ class StructuredTraceTest(TestCase): {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -415,7 +410,6 @@ class StructuredTraceTest(TestCase): {"aot_backward_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} -{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} @@ -425,7 +419,6 @@ class StructuredTraceTest(TestCase): {"compilation_metrics": "METRICS", "frame_id": 3, "frame_compile_id": 0, "attempt": 0} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} -{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"bwd_compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1} @@ -625,7 +618,6 @@ class StructuredTraceTest(TestCase): {"aot_backward_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -642,7 +634,6 @@ class StructuredTraceTest(TestCase): {"aot_backward_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -679,7 +670,6 @@ class StructuredTraceTest(TestCase): {"aot_inference_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -836,7 +826,6 @@ def forward(self, x, y): {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} diff --git a/test/fx/test_fx_traceback.py b/test/fx/test_fx_traceback.py index c65720fdf65..9fb98a525fe 100644 --- a/test/fx/test_fx_traceback.py +++ b/test/fx/test_fx_traceback.py @@ -1,5 +1,4 @@ # Owner(s): ["module: fx"] -import json import torch from torch._inductor.compile_fx import aot_export_module @@ -98,7 +97,6 @@ class TestFXNodeSource(TestCase): ep = torch.export.export(model, example_inputs, strict=True) gm = ep.module() provenance = get_graph_provenance_json(gm.graph) - provenance = json.loads(provenance) self.assertEqual( set(provenance.keys()), {"relu", "linear", "sigmoid", "linear_1"} ) @@ -130,7 +128,6 @@ class TestFXNodeSource(TestCase): ) provenance = get_graph_provenance_json(gm.graph) - provenance = json.loads(provenance) self.assertEqual( set(provenance.keys()), {"t", "addmm", "relu", "t_1", "addmm_1", "sigmoid"} diff --git a/test/inductor/test_provenance_tracing.py b/test/inductor/test_provenance_tracing.py index e6b0a25b149..58bb1026d65 100644 --- a/test/inductor/test_provenance_tracing.py +++ b/test/inductor/test_provenance_tracing.py @@ -9,6 +9,7 @@ from pathlib import Path import torch from torch._inductor import config +from torch._inductor.debug import create_node_mapping from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.triton_utils import requires_cuda @@ -48,14 +49,69 @@ class TestProvenanceTracingArtifact(TestCase): "triton_poi_fused_mul_0": ["mul"], "triton_poi_fused_addmm_gelu_1": [ "mul_3", - "erf", - "add_tensor", "mul_1", + "add_tensor", "add", + "erf", "mul_2", ], } - self.assertEqual(sorted(actual_data), sorted(expected_data)) + self.assertEqual(sorted(actual_data.items()), sorted(expected_data.items())) + + filename = Path(filepath) / "inductor_provenance_tracking_node_mappings.json" + with open(filename) as f: + actual_data = json.load(f) + # check that the generated provenance tracing artifact is expected + expected_data = [ + ( + "cppCodeToPost", + { + "triton_poi_fused_mul_0": ["mul"], + "triton_poi_fused_addmm_gelu_1": [ + "mul_3", + "mul_1", + "add_tensor", + "add", + "erf", + "mul_2", + ], + }, + ), + ( + "postToCppCode", + { + "mul": ["triton_poi_fused_mul_0"], + "mul_3": ["triton_poi_fused_addmm_gelu_1"], + "mul_1": ["triton_poi_fused_addmm_gelu_1"], + "add_tensor": ["triton_poi_fused_addmm_gelu_1"], + "add": ["triton_poi_fused_addmm_gelu_1"], + "erf": ["triton_poi_fused_addmm_gelu_1"], + "mul_2": ["triton_poi_fused_addmm_gelu_1"], + }, + ), + ( + "postToPre", + { + "mul": ["mul"], + "mm_default": ["addmm"], + "add_tensor": ["addmm"], + "mul_1": ["gelu"], + "mul_2": ["gelu"], + "erf": ["gelu"], + "add": ["gelu"], + "mul_3": ["gelu"], + }, + ), + ( + "preToPost", + { + "mul": ["mul"], + "addmm": ["mm_default", "add_tensor"], + "gelu": ["mul_1", "mul_2", "erf", "add", "mul_3"], + }, + ), + ] + self.assertEqual(sorted(actual_data.items()), sorted(expected_data)) def test_triton_kernel_to_post_grad_tracing(self): a = torch.randn(10, 20, device="cuda") @@ -95,5 +151,103 @@ class TestProvenanceTracingArtifact(TestCase): shutil.rmtree(filepath) +class TestProvenanceTracingNodeMapping(TestCase): + def test_create_node_mapping(self): + pre_grad_graph_id = 140156815043952 + post_to_pre_grad_nodes_json = { + "add_tensor": [ + { + "from_node": [ + { + "from_node": [ + { + "from_node": [], + "graph_id": 140156815043952, + "name": "linear", + } + ], + "graph_id": 140152856025632, + "name": "addmm", + } + ], + "graph_id": 140151961816272, + "name": "add", + }, + ], + "mm_default": [ + { + "from_node": [], + "graph_id": -1, + "name": "", + }, + { + "from_node": [ + { + "from_node": [ + { + "from_node": [], + "graph_id": 140156815043952, + "name": "linear", + } + ], + "graph_id": 140152856025632, + "name": "addmm", + } + ], + "graph_id": 140151961816272, + "name": "mm", + }, + ], + "permute": [ + { + "from_node": [], + "graph_id": 140156815043952, + "name": "linear", + } + ], + "relu": [ + { + "from_node": [], + "graph_id": 140156815043952, + "name": "relu", + } + ], + } + triton_kernel_to_post_grad_json = { + "triton_poi_fused_addmm_relu_sigmoid_0": ["relu", "add_tensor"] + } + + result = create_node_mapping( + pre_grad_graph_id, + post_to_pre_grad_nodes_json, + triton_kernel_to_post_grad_json, + ) + self.assertEqual( + result, + { + "cppCodeToPost": { + "triton_poi_fused_addmm_relu_sigmoid_0": [ + "relu", + "add_tensor", + ] + }, + "postToCppCode": { + "add_tensor": ["triton_poi_fused_addmm_relu_sigmoid_0"], + "relu": ["triton_poi_fused_addmm_relu_sigmoid_0"], + }, + "postToPre": { + "add_tensor": ["linear"], + "mm_default": ["linear"], + "permute": ["linear"], + "relu": ["relu"], + }, + "preToPost": { + "linear": ["add_tensor", "mm_default", "permute"], + "relu": ["relu"], + }, + }, + ) + + if __name__ == "__main__": run_tests() diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index b7ad9f1fde8..2da15dc0a8b 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -996,7 +996,10 @@ class _InProcessFxCompile(FxCompile): "name": "inductor_post_to_pre_grad_nodes", "encoding": "json", }, - payload_fn=lambda: provenance_tracking_json, + payload_fn=lambda: json.dumps(provenance_tracking_json), + ) + torch._inductor.debug._inductor_post_to_pre_grad_nodes = ( + provenance_tracking_json ) if config.is_fbcode(): log_optimus_to_scuba( @@ -1925,6 +1928,7 @@ def compile_fx( colored=True, ), ) + torch._inductor.debug._pre_grad_graph_id = id(model_.graph) model_ = _recursive_pre_grad_passes(model_, example_inputs_) diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index f46dae63fcd..3098448894a 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -12,6 +12,7 @@ import pickle import pstats import shutil import subprocess +import traceback from collections.abc import Iterator from typing import Any, Callable, IO, Optional, Union from unittest.mock import patch @@ -309,8 +310,17 @@ def enable_aot_logging() -> Iterator[None]: stack.close() +# Used for provenance tracking +# They are not stored in DebugContext because they are not set in +# _inductor_triton_kernel_to_post_grad_node_info's Debug Context +_inductor_post_to_pre_grad_nodes: dict[str, Any] = {} +_pre_grad_graph_id: Optional[int] = None + + class DebugContext: _counter = itertools.count() + + # Used for provenance tracking _inductor_triton_kernel_to_post_grad_node_info: dict[str, list[str]] = {} @staticmethod @@ -551,12 +561,22 @@ class DebugFormatter: def log_inductor_triton_kernel_to_post_grad_node_info( self, filename: str = "inductor_triton_kernel_to_post_grad_nodes.json" - ) -> dict[str, list[str]]: + ) -> tuple[dict[str, list[str]], dict[str, Any]]: + debug_info = {} with self.fopen(filename, "w") as fd: log.info("Writing provenance tracing debugging info to %s", fd.name) debug_info = DebugContext._inductor_triton_kernel_to_post_grad_node_info json.dump(debug_info, fd) - return debug_info + node_mapping = {} + if _pre_grad_graph_id: + with self.fopen( + "inductor_provenance_tracking_node_mappings.json", "w" + ) as fd: + node_mapping = create_node_mapping( + _pre_grad_graph_id, _inductor_post_to_pre_grad_nodes, debug_info + ) + json.dump(node_mapping, fd) + return debug_info, node_mapping def log_autotuning_results( self, @@ -656,6 +676,124 @@ class TensorMetadataHolder: save_args_cnt = itertools.count() +def create_node_mapping( + pre_grad_graph_id: int, + post_to_pre_grad_nodes_json: dict[str, Any], + triton_kernel_to_post_grad_json: dict[str, Any], +) -> dict[str, dict[str, Any]]: + """Create bidirectional mappings between: + + - pre_grad graph nodes and post_grad graph code nodes, and vice versa + - triton kernel name and post_grad graph code nodes, and vice versa + """ + + # return a dummy dict if there's any error + empty_return: dict[str, dict[str, Any]] = { + "preToPost": {}, + "postToPre": {}, + "cppCodeToPost": {}, + "postToCppCode": {}, + } + + log.info("Creating node mappings for provenance tracking") + + if not isinstance(post_to_pre_grad_nodes_json, dict): + log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict") + return empty_return + + if not isinstance(triton_kernel_to_post_grad_json, dict): + log.error( + "Provenance tacking error: triton_kernel_to_post_grad_json is not a dict" + ) + return empty_return + + if not isinstance(pre_grad_graph_id, int): + log.error("Provenance tacking error: pre_grad_graph_id is not an int") + return empty_return + + pre_to_post: dict[str, Any] = collections.defaultdict(OrderedSet) + post_to_pre: dict[str, Any] = collections.defaultdict(OrderedSet) + + post_to_cpp_code: dict[str, Any] = collections.defaultdict(OrderedSet) + + try: + for outer_key, node_array in triton_kernel_to_post_grad_json.items(): + if not isinstance(node_array, list): + log.error( + "Provenance tacking error: triton_kernel_to_post_grad_json value is not a list" + ) + return empty_return + for curr_node in node_array: + post_to_cpp_code[curr_node].add(outer_key) + + def check_format(node: dict[str, Any]) -> bool: + if not isinstance(node, dict): + log.error( + "Provenance tacking error: node provenance in post_to_pre_grad_nodes_json is not a dict" + ) + return False + if "graph_id" not in node or "name" not in node or "from_node" not in node: + log.error( + "Provenance tacking error: node provenance in post_to_pre_grad_nodes_json has wrong format" + ) + return False + return True + + for outer_key, node_array in post_to_pre_grad_nodes_json.items(): + if not isinstance(node_array, list): + log.error( + "Provenance tacking error: post_to_pre_grad_nodes_json value is not a list" + ) + return empty_return + for node in node_array: + if not check_format(node): + return empty_return + # Check the current node first + if node.get("graph_id") == pre_grad_graph_id: + pre_to_post[node["name"]].add(outer_key) + post_to_pre[outer_key].add(node["name"]) + + # Check nested from_node array recursively, add node with the right graph_id to the map + stack = [(n, outer_key) for n in node.get("from_node", [])] + while stack: + current_node, parent_key = stack.pop() + if not check_format(current_node): + return empty_return + if current_node.get("graph_id") == pre_grad_graph_id: + pre_to_post[current_node["name"]].add(parent_key) + post_to_pre[parent_key].add(current_node["name"]) + stack.extend( + (n, parent_key) for n in current_node.get("from_node", []) + ) + + def convert_sets_to_lists(d: dict[str, Any]) -> None: + for key in d: + d[key] = list(d[key]) + d = dict(d) + + # convert to list because set is not JSON serializable + convert_sets_to_lists(pre_to_post) + convert_sets_to_lists(post_to_pre) + convert_sets_to_lists(post_to_cpp_code) + return { + "preToPost": pre_to_post, + "postToPre": post_to_pre, + "cppCodeToPost": triton_kernel_to_post_grad_json, + "postToCppCode": post_to_cpp_code, + } + except Exception as e: + # Since this is just logging code, it should never interfere with regular + # program execution, so we use this try-except to guard against any error + log.error("Unexpected error in create_node_mapping: %s", e) + log.error("post_to_pre_grad_nodes_json: %s", post_to_pre_grad_nodes_json) + log.error( + "triton_kernel_to_post_grad_json: %s", triton_kernel_to_post_grad_json + ) + log.error("pre_grad_graph_id: %s", pre_grad_graph_id) + log.error(traceback.format_exc()) + return empty_return + + def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None: """ This function is used to save arguments for a compile_fx_inner function call diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 76a3fee4860..c8ea9db64ba 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1938,16 +1938,32 @@ class GraphLowering(torch.fx.Interpreter): "Finished codegen for all nodes. The list of kernel names available: %s", V.graph.all_codegen_kernel_names, ) - # Dump the inductor_triton_kernel_to_post_grad_node_info to a json file for debugging trace - debug_info = V.debug.log_inductor_triton_kernel_to_post_grad_node_info() - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "inductor_triton_kernel_to_post_grad_nodes", - "encoding": "json", - }, - payload_fn=lambda: json.dumps(debug_info), + # Dump provenance artifacts for debugging trace + provenance_info = ( + V.debug.log_inductor_triton_kernel_to_post_grad_node_info() ) + # provenance_info might be None if config.trace.enabled is not set + if provenance_info: + ( + debug_info, + node_mappings, + ) = provenance_info + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_triton_kernel_to_post_grad_nodes", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(debug_info), + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_provenance_tracking_node_mappings", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(node_mappings), + ) result = self.wrapper_code.generate(self.is_inference) self.wrapper_code.pop_codegened_graph() diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index 095b3f1b27b..3ec156005a0 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import copy -import json import traceback from contextlib import contextmanager from enum import Enum @@ -224,9 +223,9 @@ def get_current_meta() -> dict[str, Any]: @compatibility(is_backward_compatible=False) -def get_graph_provenance_json(graph: Graph) -> str: +def get_graph_provenance_json(graph: Graph) -> dict[str, Any]: """ - Given an fx.Graph, return a json string that contains the provenance information of each node. + Given an fx.Graph, return a json that contains the provenance information of each node. """ provenance_tracking_json = {} for node in graph.nodes: @@ -236,4 +235,4 @@ def get_graph_provenance_json(graph: Graph) -> str: if "from_node" in node.meta else [] ) - return json.dumps(provenance_tracking_json) + return provenance_tracking_json