add node mapping processing (#146103)

Summary:
Add `node_mapping = create_node_mapping(pre_grad_graph_id, inductor_post_to_pre_grad_nodes, debug_info)`, to produce a `inductor_provenance_tracking_node_mappings.json` file. This file will be used by the provenance tracking highlighter tool to create provenance visualization.

`inductor_triton_kernel_to_post_grad_nodes.json` and `inductor_provenance_tracking_node_mappings.json` files are not dumped if they are both empty. So it's removed from some of the `test_structured_trace` tests.

Test Plan:
CI
```
buck run mode/dev-nosan  fbcode//caffe2/test:fx -- -r graph_provenance

buck run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing

python test/dynamo/test_structured_trace.py
```

Differential Revision: D68190173

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146103
Approved by: https://github.com/chenyang78
This commit is contained in:
Shangdi Yu 2025-02-01 08:29:29 +00:00 committed by PyTorch MergeBot
parent f38d5b4a74
commit a4e4368157
7 changed files with 330 additions and 33 deletions

View file

@ -250,7 +250,6 @@ class StructuredTraceTest(TestCase):
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
@ -278,7 +277,6 @@ class StructuredTraceTest(TestCase):
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
@ -314,7 +312,6 @@ class StructuredTraceTest(TestCase):
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
@ -332,7 +329,6 @@ class StructuredTraceTest(TestCase):
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
@ -360,7 +356,6 @@ class StructuredTraceTest(TestCase):
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
@ -415,7 +410,6 @@ class StructuredTraceTest(TestCase):
{"aot_backward_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
@ -425,7 +419,6 @@ class StructuredTraceTest(TestCase):
{"compilation_metrics": "METRICS", "frame_id": 3, "frame_compile_id": 0, "attempt": 0}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"bwd_compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1}
@ -625,7 +618,6 @@ class StructuredTraceTest(TestCase):
{"aot_backward_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
@ -642,7 +634,6 @@ class StructuredTraceTest(TestCase):
{"aot_backward_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
@ -679,7 +670,6 @@ class StructuredTraceTest(TestCase):
{"aot_inference_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
@ -836,7 +826,6 @@ def forward(self, x, y):
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}

View file

@ -1,5 +1,4 @@
# Owner(s): ["module: fx"]
import json
import torch
from torch._inductor.compile_fx import aot_export_module
@ -98,7 +97,6 @@ class TestFXNodeSource(TestCase):
ep = torch.export.export(model, example_inputs, strict=True)
gm = ep.module()
provenance = get_graph_provenance_json(gm.graph)
provenance = json.loads(provenance)
self.assertEqual(
set(provenance.keys()), {"relu", "linear", "sigmoid", "linear_1"}
)
@ -130,7 +128,6 @@ class TestFXNodeSource(TestCase):
)
provenance = get_graph_provenance_json(gm.graph)
provenance = json.loads(provenance)
self.assertEqual(
set(provenance.keys()), {"t", "addmm", "relu", "t_1", "addmm_1", "sigmoid"}

View file

@ -9,6 +9,7 @@ from pathlib import Path
import torch
from torch._inductor import config
from torch._inductor.debug import create_node_mapping
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.triton_utils import requires_cuda
@ -48,14 +49,69 @@ class TestProvenanceTracingArtifact(TestCase):
"triton_poi_fused_mul_0": ["mul"],
"triton_poi_fused_addmm_gelu_1": [
"mul_3",
"erf",
"add_tensor",
"mul_1",
"add_tensor",
"add",
"erf",
"mul_2",
],
}
self.assertEqual(sorted(actual_data), sorted(expected_data))
self.assertEqual(sorted(actual_data.items()), sorted(expected_data.items()))
filename = Path(filepath) / "inductor_provenance_tracking_node_mappings.json"
with open(filename) as f:
actual_data = json.load(f)
# check that the generated provenance tracing artifact is expected
expected_data = [
(
"cppCodeToPost",
{
"triton_poi_fused_mul_0": ["mul"],
"triton_poi_fused_addmm_gelu_1": [
"mul_3",
"mul_1",
"add_tensor",
"add",
"erf",
"mul_2",
],
},
),
(
"postToCppCode",
{
"mul": ["triton_poi_fused_mul_0"],
"mul_3": ["triton_poi_fused_addmm_gelu_1"],
"mul_1": ["triton_poi_fused_addmm_gelu_1"],
"add_tensor": ["triton_poi_fused_addmm_gelu_1"],
"add": ["triton_poi_fused_addmm_gelu_1"],
"erf": ["triton_poi_fused_addmm_gelu_1"],
"mul_2": ["triton_poi_fused_addmm_gelu_1"],
},
),
(
"postToPre",
{
"mul": ["mul"],
"mm_default": ["addmm"],
"add_tensor": ["addmm"],
"mul_1": ["gelu"],
"mul_2": ["gelu"],
"erf": ["gelu"],
"add": ["gelu"],
"mul_3": ["gelu"],
},
),
(
"preToPost",
{
"mul": ["mul"],
"addmm": ["mm_default", "add_tensor"],
"gelu": ["mul_1", "mul_2", "erf", "add", "mul_3"],
},
),
]
self.assertEqual(sorted(actual_data.items()), sorted(expected_data))
def test_triton_kernel_to_post_grad_tracing(self):
a = torch.randn(10, 20, device="cuda")
@ -95,5 +151,103 @@ class TestProvenanceTracingArtifact(TestCase):
shutil.rmtree(filepath)
class TestProvenanceTracingNodeMapping(TestCase):
def test_create_node_mapping(self):
pre_grad_graph_id = 140156815043952
post_to_pre_grad_nodes_json = {
"add_tensor": [
{
"from_node": [
{
"from_node": [
{
"from_node": [],
"graph_id": 140156815043952,
"name": "linear",
}
],
"graph_id": 140152856025632,
"name": "addmm",
}
],
"graph_id": 140151961816272,
"name": "add",
},
],
"mm_default": [
{
"from_node": [],
"graph_id": -1,
"name": "",
},
{
"from_node": [
{
"from_node": [
{
"from_node": [],
"graph_id": 140156815043952,
"name": "linear",
}
],
"graph_id": 140152856025632,
"name": "addmm",
}
],
"graph_id": 140151961816272,
"name": "mm",
},
],
"permute": [
{
"from_node": [],
"graph_id": 140156815043952,
"name": "linear",
}
],
"relu": [
{
"from_node": [],
"graph_id": 140156815043952,
"name": "relu",
}
],
}
triton_kernel_to_post_grad_json = {
"triton_poi_fused_addmm_relu_sigmoid_0": ["relu", "add_tensor"]
}
result = create_node_mapping(
pre_grad_graph_id,
post_to_pre_grad_nodes_json,
triton_kernel_to_post_grad_json,
)
self.assertEqual(
result,
{
"cppCodeToPost": {
"triton_poi_fused_addmm_relu_sigmoid_0": [
"relu",
"add_tensor",
]
},
"postToCppCode": {
"add_tensor": ["triton_poi_fused_addmm_relu_sigmoid_0"],
"relu": ["triton_poi_fused_addmm_relu_sigmoid_0"],
},
"postToPre": {
"add_tensor": ["linear"],
"mm_default": ["linear"],
"permute": ["linear"],
"relu": ["relu"],
},
"preToPost": {
"linear": ["add_tensor", "mm_default", "permute"],
"relu": ["relu"],
},
},
)
if __name__ == "__main__":
run_tests()

View file

@ -996,7 +996,10 @@ class _InProcessFxCompile(FxCompile):
"name": "inductor_post_to_pre_grad_nodes",
"encoding": "json",
},
payload_fn=lambda: provenance_tracking_json,
payload_fn=lambda: json.dumps(provenance_tracking_json),
)
torch._inductor.debug._inductor_post_to_pre_grad_nodes = (
provenance_tracking_json
)
if config.is_fbcode():
log_optimus_to_scuba(
@ -1925,6 +1928,7 @@ def compile_fx(
colored=True,
),
)
torch._inductor.debug._pre_grad_graph_id = id(model_.graph)
model_ = _recursive_pre_grad_passes(model_, example_inputs_)

View file

@ -12,6 +12,7 @@ import pickle
import pstats
import shutil
import subprocess
import traceback
from collections.abc import Iterator
from typing import Any, Callable, IO, Optional, Union
from unittest.mock import patch
@ -309,8 +310,17 @@ def enable_aot_logging() -> Iterator[None]:
stack.close()
# Used for provenance tracking
# They are not stored in DebugContext because they are not set in
# _inductor_triton_kernel_to_post_grad_node_info's Debug Context
_inductor_post_to_pre_grad_nodes: dict[str, Any] = {}
_pre_grad_graph_id: Optional[int] = None
class DebugContext:
_counter = itertools.count()
# Used for provenance tracking
_inductor_triton_kernel_to_post_grad_node_info: dict[str, list[str]] = {}
@staticmethod
@ -551,12 +561,22 @@ class DebugFormatter:
def log_inductor_triton_kernel_to_post_grad_node_info(
self, filename: str = "inductor_triton_kernel_to_post_grad_nodes.json"
) -> dict[str, list[str]]:
) -> tuple[dict[str, list[str]], dict[str, Any]]:
debug_info = {}
with self.fopen(filename, "w") as fd:
log.info("Writing provenance tracing debugging info to %s", fd.name)
debug_info = DebugContext._inductor_triton_kernel_to_post_grad_node_info
json.dump(debug_info, fd)
return debug_info
node_mapping = {}
if _pre_grad_graph_id:
with self.fopen(
"inductor_provenance_tracking_node_mappings.json", "w"
) as fd:
node_mapping = create_node_mapping(
_pre_grad_graph_id, _inductor_post_to_pre_grad_nodes, debug_info
)
json.dump(node_mapping, fd)
return debug_info, node_mapping
def log_autotuning_results(
self,
@ -656,6 +676,124 @@ class TensorMetadataHolder:
save_args_cnt = itertools.count()
def create_node_mapping(
pre_grad_graph_id: int,
post_to_pre_grad_nodes_json: dict[str, Any],
triton_kernel_to_post_grad_json: dict[str, Any],
) -> dict[str, dict[str, Any]]:
"""Create bidirectional mappings between:
- pre_grad graph nodes and post_grad graph code nodes, and vice versa
- triton kernel name and post_grad graph code nodes, and vice versa
"""
# return a dummy dict if there's any error
empty_return: dict[str, dict[str, Any]] = {
"preToPost": {},
"postToPre": {},
"cppCodeToPost": {},
"postToCppCode": {},
}
log.info("Creating node mappings for provenance tracking")
if not isinstance(post_to_pre_grad_nodes_json, dict):
log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict")
return empty_return
if not isinstance(triton_kernel_to_post_grad_json, dict):
log.error(
"Provenance tacking error: triton_kernel_to_post_grad_json is not a dict"
)
return empty_return
if not isinstance(pre_grad_graph_id, int):
log.error("Provenance tacking error: pre_grad_graph_id is not an int")
return empty_return
pre_to_post: dict[str, Any] = collections.defaultdict(OrderedSet)
post_to_pre: dict[str, Any] = collections.defaultdict(OrderedSet)
post_to_cpp_code: dict[str, Any] = collections.defaultdict(OrderedSet)
try:
for outer_key, node_array in triton_kernel_to_post_grad_json.items():
if not isinstance(node_array, list):
log.error(
"Provenance tacking error: triton_kernel_to_post_grad_json value is not a list"
)
return empty_return
for curr_node in node_array:
post_to_cpp_code[curr_node].add(outer_key)
def check_format(node: dict[str, Any]) -> bool:
if not isinstance(node, dict):
log.error(
"Provenance tacking error: node provenance in post_to_pre_grad_nodes_json is not a dict"
)
return False
if "graph_id" not in node or "name" not in node or "from_node" not in node:
log.error(
"Provenance tacking error: node provenance in post_to_pre_grad_nodes_json has wrong format"
)
return False
return True
for outer_key, node_array in post_to_pre_grad_nodes_json.items():
if not isinstance(node_array, list):
log.error(
"Provenance tacking error: post_to_pre_grad_nodes_json value is not a list"
)
return empty_return
for node in node_array:
if not check_format(node):
return empty_return
# Check the current node first
if node.get("graph_id") == pre_grad_graph_id:
pre_to_post[node["name"]].add(outer_key)
post_to_pre[outer_key].add(node["name"])
# Check nested from_node array recursively, add node with the right graph_id to the map
stack = [(n, outer_key) for n in node.get("from_node", [])]
while stack:
current_node, parent_key = stack.pop()
if not check_format(current_node):
return empty_return
if current_node.get("graph_id") == pre_grad_graph_id:
pre_to_post[current_node["name"]].add(parent_key)
post_to_pre[parent_key].add(current_node["name"])
stack.extend(
(n, parent_key) for n in current_node.get("from_node", [])
)
def convert_sets_to_lists(d: dict[str, Any]) -> None:
for key in d:
d[key] = list(d[key])
d = dict(d)
# convert to list because set is not JSON serializable
convert_sets_to_lists(pre_to_post)
convert_sets_to_lists(post_to_pre)
convert_sets_to_lists(post_to_cpp_code)
return {
"preToPost": pre_to_post,
"postToPre": post_to_pre,
"cppCodeToPost": triton_kernel_to_post_grad_json,
"postToCppCode": post_to_cpp_code,
}
except Exception as e:
# Since this is just logging code, it should never interfere with regular
# program execution, so we use this try-except to guard against any error
log.error("Unexpected error in create_node_mapping: %s", e)
log.error("post_to_pre_grad_nodes_json: %s", post_to_pre_grad_nodes_json)
log.error(
"triton_kernel_to_post_grad_json: %s", triton_kernel_to_post_grad_json
)
log.error("pre_grad_graph_id: %s", pre_grad_graph_id)
log.error(traceback.format_exc())
return empty_return
def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None:
"""
This function is used to save arguments for a compile_fx_inner function call

View file

@ -1938,16 +1938,32 @@ class GraphLowering(torch.fx.Interpreter):
"Finished codegen for all nodes. The list of kernel names available: %s",
V.graph.all_codegen_kernel_names,
)
# Dump the inductor_triton_kernel_to_post_grad_node_info to a json file for debugging trace
debug_info = V.debug.log_inductor_triton_kernel_to_post_grad_node_info()
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "inductor_triton_kernel_to_post_grad_nodes",
"encoding": "json",
},
payload_fn=lambda: json.dumps(debug_info),
# Dump provenance artifacts for debugging trace
provenance_info = (
V.debug.log_inductor_triton_kernel_to_post_grad_node_info()
)
# provenance_info might be None if config.trace.enabled is not set
if provenance_info:
(
debug_info,
node_mappings,
) = provenance_info
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "inductor_triton_kernel_to_post_grad_nodes",
"encoding": "json",
},
payload_fn=lambda: json.dumps(debug_info),
)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "inductor_provenance_tracking_node_mappings",
"encoding": "json",
},
payload_fn=lambda: json.dumps(node_mappings),
)
result = self.wrapper_code.generate(self.is_inference)
self.wrapper_code.pop_codegened_graph()

View file

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import copy
import json
import traceback
from contextlib import contextmanager
from enum import Enum
@ -224,9 +223,9 @@ def get_current_meta() -> dict[str, Any]:
@compatibility(is_backward_compatible=False)
def get_graph_provenance_json(graph: Graph) -> str:
def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
"""
Given an fx.Graph, return a json string that contains the provenance information of each node.
Given an fx.Graph, return a json that contains the provenance information of each node.
"""
provenance_tracking_json = {}
for node in graph.nodes:
@ -236,4 +235,4 @@ def get_graph_provenance_json(graph: Graph) -> str:
if "from_node" in node.meta
else []
)
return json.dumps(provenance_tracking_json)
return provenance_tracking_json