mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
add node mapping processing (#146103)
Summary: Add `node_mapping = create_node_mapping(pre_grad_graph_id, inductor_post_to_pre_grad_nodes, debug_info)`, to produce a `inductor_provenance_tracking_node_mappings.json` file. This file will be used by the provenance tracking highlighter tool to create provenance visualization. `inductor_triton_kernel_to_post_grad_nodes.json` and `inductor_provenance_tracking_node_mappings.json` files are not dumped if they are both empty. So it's removed from some of the `test_structured_trace` tests. Test Plan: CI ``` buck run mode/dev-nosan fbcode//caffe2/test:fx -- -r graph_provenance buck run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing python test/dynamo/test_structured_trace.py ``` Differential Revision: D68190173 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146103 Approved by: https://github.com/chenyang78
This commit is contained in:
parent
f38d5b4a74
commit
a4e4368157
7 changed files with 330 additions and 33 deletions
|
|
@ -250,7 +250,6 @@ class StructuredTraceTest(TestCase):
|
|||
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
|
|
@ -278,7 +277,6 @@ class StructuredTraceTest(TestCase):
|
|||
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
|
|
@ -314,7 +312,6 @@ class StructuredTraceTest(TestCase):
|
|||
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
|
|
@ -332,7 +329,6 @@ class StructuredTraceTest(TestCase):
|
|||
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
|
||||
|
|
@ -360,7 +356,6 @@ class StructuredTraceTest(TestCase):
|
|||
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
|
|
@ -415,7 +410,6 @@ class StructuredTraceTest(TestCase):
|
|||
{"aot_backward_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
|
||||
{"inductor_post_grad_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
|
||||
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
|
||||
|
|
@ -425,7 +419,6 @@ class StructuredTraceTest(TestCase):
|
|||
{"compilation_metrics": "METRICS", "frame_id": 3, "frame_compile_id": 0, "attempt": 0}
|
||||
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
|
||||
{"inductor_post_grad_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
|
||||
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
|
||||
{"bwd_compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1}
|
||||
|
|
@ -625,7 +618,6 @@ class StructuredTraceTest(TestCase):
|
|||
{"aot_backward_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_post_grad_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
|
|
@ -642,7 +634,6 @@ class StructuredTraceTest(TestCase):
|
|||
{"aot_backward_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_post_grad_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
|
|
@ -679,7 +670,6 @@ class StructuredTraceTest(TestCase):
|
|||
{"aot_inference_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_post_grad_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
|
|
@ -836,7 +826,6 @@ def forward(self, x, y):
|
|||
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "inductor_triton_kernel_to_post_grad_nodes", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# Owner(s): ["module: fx"]
|
||||
import json
|
||||
|
||||
import torch
|
||||
from torch._inductor.compile_fx import aot_export_module
|
||||
|
|
@ -98,7 +97,6 @@ class TestFXNodeSource(TestCase):
|
|||
ep = torch.export.export(model, example_inputs, strict=True)
|
||||
gm = ep.module()
|
||||
provenance = get_graph_provenance_json(gm.graph)
|
||||
provenance = json.loads(provenance)
|
||||
self.assertEqual(
|
||||
set(provenance.keys()), {"relu", "linear", "sigmoid", "linear_1"}
|
||||
)
|
||||
|
|
@ -130,7 +128,6 @@ class TestFXNodeSource(TestCase):
|
|||
)
|
||||
|
||||
provenance = get_graph_provenance_json(gm.graph)
|
||||
provenance = json.loads(provenance)
|
||||
|
||||
self.assertEqual(
|
||||
set(provenance.keys()), {"t", "addmm", "relu", "t_1", "addmm_1", "sigmoid"}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from pathlib import Path
|
|||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.debug import create_node_mapping
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.testing._internal.triton_utils import requires_cuda
|
||||
|
||||
|
|
@ -48,14 +49,69 @@ class TestProvenanceTracingArtifact(TestCase):
|
|||
"triton_poi_fused_mul_0": ["mul"],
|
||||
"triton_poi_fused_addmm_gelu_1": [
|
||||
"mul_3",
|
||||
"erf",
|
||||
"add_tensor",
|
||||
"mul_1",
|
||||
"add_tensor",
|
||||
"add",
|
||||
"erf",
|
||||
"mul_2",
|
||||
],
|
||||
}
|
||||
self.assertEqual(sorted(actual_data), sorted(expected_data))
|
||||
self.assertEqual(sorted(actual_data.items()), sorted(expected_data.items()))
|
||||
|
||||
filename = Path(filepath) / "inductor_provenance_tracking_node_mappings.json"
|
||||
with open(filename) as f:
|
||||
actual_data = json.load(f)
|
||||
# check that the generated provenance tracing artifact is expected
|
||||
expected_data = [
|
||||
(
|
||||
"cppCodeToPost",
|
||||
{
|
||||
"triton_poi_fused_mul_0": ["mul"],
|
||||
"triton_poi_fused_addmm_gelu_1": [
|
||||
"mul_3",
|
||||
"mul_1",
|
||||
"add_tensor",
|
||||
"add",
|
||||
"erf",
|
||||
"mul_2",
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
"postToCppCode",
|
||||
{
|
||||
"mul": ["triton_poi_fused_mul_0"],
|
||||
"mul_3": ["triton_poi_fused_addmm_gelu_1"],
|
||||
"mul_1": ["triton_poi_fused_addmm_gelu_1"],
|
||||
"add_tensor": ["triton_poi_fused_addmm_gelu_1"],
|
||||
"add": ["triton_poi_fused_addmm_gelu_1"],
|
||||
"erf": ["triton_poi_fused_addmm_gelu_1"],
|
||||
"mul_2": ["triton_poi_fused_addmm_gelu_1"],
|
||||
},
|
||||
),
|
||||
(
|
||||
"postToPre",
|
||||
{
|
||||
"mul": ["mul"],
|
||||
"mm_default": ["addmm"],
|
||||
"add_tensor": ["addmm"],
|
||||
"mul_1": ["gelu"],
|
||||
"mul_2": ["gelu"],
|
||||
"erf": ["gelu"],
|
||||
"add": ["gelu"],
|
||||
"mul_3": ["gelu"],
|
||||
},
|
||||
),
|
||||
(
|
||||
"preToPost",
|
||||
{
|
||||
"mul": ["mul"],
|
||||
"addmm": ["mm_default", "add_tensor"],
|
||||
"gelu": ["mul_1", "mul_2", "erf", "add", "mul_3"],
|
||||
},
|
||||
),
|
||||
]
|
||||
self.assertEqual(sorted(actual_data.items()), sorted(expected_data))
|
||||
|
||||
def test_triton_kernel_to_post_grad_tracing(self):
|
||||
a = torch.randn(10, 20, device="cuda")
|
||||
|
|
@ -95,5 +151,103 @@ class TestProvenanceTracingArtifact(TestCase):
|
|||
shutil.rmtree(filepath)
|
||||
|
||||
|
||||
class TestProvenanceTracingNodeMapping(TestCase):
|
||||
def test_create_node_mapping(self):
|
||||
pre_grad_graph_id = 140156815043952
|
||||
post_to_pre_grad_nodes_json = {
|
||||
"add_tensor": [
|
||||
{
|
||||
"from_node": [
|
||||
{
|
||||
"from_node": [
|
||||
{
|
||||
"from_node": [],
|
||||
"graph_id": 140156815043952,
|
||||
"name": "linear",
|
||||
}
|
||||
],
|
||||
"graph_id": 140152856025632,
|
||||
"name": "addmm",
|
||||
}
|
||||
],
|
||||
"graph_id": 140151961816272,
|
||||
"name": "add",
|
||||
},
|
||||
],
|
||||
"mm_default": [
|
||||
{
|
||||
"from_node": [],
|
||||
"graph_id": -1,
|
||||
"name": "",
|
||||
},
|
||||
{
|
||||
"from_node": [
|
||||
{
|
||||
"from_node": [
|
||||
{
|
||||
"from_node": [],
|
||||
"graph_id": 140156815043952,
|
||||
"name": "linear",
|
||||
}
|
||||
],
|
||||
"graph_id": 140152856025632,
|
||||
"name": "addmm",
|
||||
}
|
||||
],
|
||||
"graph_id": 140151961816272,
|
||||
"name": "mm",
|
||||
},
|
||||
],
|
||||
"permute": [
|
||||
{
|
||||
"from_node": [],
|
||||
"graph_id": 140156815043952,
|
||||
"name": "linear",
|
||||
}
|
||||
],
|
||||
"relu": [
|
||||
{
|
||||
"from_node": [],
|
||||
"graph_id": 140156815043952,
|
||||
"name": "relu",
|
||||
}
|
||||
],
|
||||
}
|
||||
triton_kernel_to_post_grad_json = {
|
||||
"triton_poi_fused_addmm_relu_sigmoid_0": ["relu", "add_tensor"]
|
||||
}
|
||||
|
||||
result = create_node_mapping(
|
||||
pre_grad_graph_id,
|
||||
post_to_pre_grad_nodes_json,
|
||||
triton_kernel_to_post_grad_json,
|
||||
)
|
||||
self.assertEqual(
|
||||
result,
|
||||
{
|
||||
"cppCodeToPost": {
|
||||
"triton_poi_fused_addmm_relu_sigmoid_0": [
|
||||
"relu",
|
||||
"add_tensor",
|
||||
]
|
||||
},
|
||||
"postToCppCode": {
|
||||
"add_tensor": ["triton_poi_fused_addmm_relu_sigmoid_0"],
|
||||
"relu": ["triton_poi_fused_addmm_relu_sigmoid_0"],
|
||||
},
|
||||
"postToPre": {
|
||||
"add_tensor": ["linear"],
|
||||
"mm_default": ["linear"],
|
||||
"permute": ["linear"],
|
||||
"relu": ["relu"],
|
||||
},
|
||||
"preToPost": {
|
||||
"linear": ["add_tensor", "mm_default", "permute"],
|
||||
"relu": ["relu"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -996,7 +996,10 @@ class _InProcessFxCompile(FxCompile):
|
|||
"name": "inductor_post_to_pre_grad_nodes",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: provenance_tracking_json,
|
||||
payload_fn=lambda: json.dumps(provenance_tracking_json),
|
||||
)
|
||||
torch._inductor.debug._inductor_post_to_pre_grad_nodes = (
|
||||
provenance_tracking_json
|
||||
)
|
||||
if config.is_fbcode():
|
||||
log_optimus_to_scuba(
|
||||
|
|
@ -1925,6 +1928,7 @@ def compile_fx(
|
|||
colored=True,
|
||||
),
|
||||
)
|
||||
torch._inductor.debug._pre_grad_graph_id = id(model_.graph)
|
||||
|
||||
model_ = _recursive_pre_grad_passes(model_, example_inputs_)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import pickle
|
|||
import pstats
|
||||
import shutil
|
||||
import subprocess
|
||||
import traceback
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, Callable, IO, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
|
@ -309,8 +310,17 @@ def enable_aot_logging() -> Iterator[None]:
|
|||
stack.close()
|
||||
|
||||
|
||||
# Used for provenance tracking
|
||||
# They are not stored in DebugContext because they are not set in
|
||||
# _inductor_triton_kernel_to_post_grad_node_info's Debug Context
|
||||
_inductor_post_to_pre_grad_nodes: dict[str, Any] = {}
|
||||
_pre_grad_graph_id: Optional[int] = None
|
||||
|
||||
|
||||
class DebugContext:
|
||||
_counter = itertools.count()
|
||||
|
||||
# Used for provenance tracking
|
||||
_inductor_triton_kernel_to_post_grad_node_info: dict[str, list[str]] = {}
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -551,12 +561,22 @@ class DebugFormatter:
|
|||
|
||||
def log_inductor_triton_kernel_to_post_grad_node_info(
|
||||
self, filename: str = "inductor_triton_kernel_to_post_grad_nodes.json"
|
||||
) -> dict[str, list[str]]:
|
||||
) -> tuple[dict[str, list[str]], dict[str, Any]]:
|
||||
debug_info = {}
|
||||
with self.fopen(filename, "w") as fd:
|
||||
log.info("Writing provenance tracing debugging info to %s", fd.name)
|
||||
debug_info = DebugContext._inductor_triton_kernel_to_post_grad_node_info
|
||||
json.dump(debug_info, fd)
|
||||
return debug_info
|
||||
node_mapping = {}
|
||||
if _pre_grad_graph_id:
|
||||
with self.fopen(
|
||||
"inductor_provenance_tracking_node_mappings.json", "w"
|
||||
) as fd:
|
||||
node_mapping = create_node_mapping(
|
||||
_pre_grad_graph_id, _inductor_post_to_pre_grad_nodes, debug_info
|
||||
)
|
||||
json.dump(node_mapping, fd)
|
||||
return debug_info, node_mapping
|
||||
|
||||
def log_autotuning_results(
|
||||
self,
|
||||
|
|
@ -656,6 +676,124 @@ class TensorMetadataHolder:
|
|||
save_args_cnt = itertools.count()
|
||||
|
||||
|
||||
def create_node_mapping(
|
||||
pre_grad_graph_id: int,
|
||||
post_to_pre_grad_nodes_json: dict[str, Any],
|
||||
triton_kernel_to_post_grad_json: dict[str, Any],
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Create bidirectional mappings between:
|
||||
|
||||
- pre_grad graph nodes and post_grad graph code nodes, and vice versa
|
||||
- triton kernel name and post_grad graph code nodes, and vice versa
|
||||
"""
|
||||
|
||||
# return a dummy dict if there's any error
|
||||
empty_return: dict[str, dict[str, Any]] = {
|
||||
"preToPost": {},
|
||||
"postToPre": {},
|
||||
"cppCodeToPost": {},
|
||||
"postToCppCode": {},
|
||||
}
|
||||
|
||||
log.info("Creating node mappings for provenance tracking")
|
||||
|
||||
if not isinstance(post_to_pre_grad_nodes_json, dict):
|
||||
log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict")
|
||||
return empty_return
|
||||
|
||||
if not isinstance(triton_kernel_to_post_grad_json, dict):
|
||||
log.error(
|
||||
"Provenance tacking error: triton_kernel_to_post_grad_json is not a dict"
|
||||
)
|
||||
return empty_return
|
||||
|
||||
if not isinstance(pre_grad_graph_id, int):
|
||||
log.error("Provenance tacking error: pre_grad_graph_id is not an int")
|
||||
return empty_return
|
||||
|
||||
pre_to_post: dict[str, Any] = collections.defaultdict(OrderedSet)
|
||||
post_to_pre: dict[str, Any] = collections.defaultdict(OrderedSet)
|
||||
|
||||
post_to_cpp_code: dict[str, Any] = collections.defaultdict(OrderedSet)
|
||||
|
||||
try:
|
||||
for outer_key, node_array in triton_kernel_to_post_grad_json.items():
|
||||
if not isinstance(node_array, list):
|
||||
log.error(
|
||||
"Provenance tacking error: triton_kernel_to_post_grad_json value is not a list"
|
||||
)
|
||||
return empty_return
|
||||
for curr_node in node_array:
|
||||
post_to_cpp_code[curr_node].add(outer_key)
|
||||
|
||||
def check_format(node: dict[str, Any]) -> bool:
|
||||
if not isinstance(node, dict):
|
||||
log.error(
|
||||
"Provenance tacking error: node provenance in post_to_pre_grad_nodes_json is not a dict"
|
||||
)
|
||||
return False
|
||||
if "graph_id" not in node or "name" not in node or "from_node" not in node:
|
||||
log.error(
|
||||
"Provenance tacking error: node provenance in post_to_pre_grad_nodes_json has wrong format"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
for outer_key, node_array in post_to_pre_grad_nodes_json.items():
|
||||
if not isinstance(node_array, list):
|
||||
log.error(
|
||||
"Provenance tacking error: post_to_pre_grad_nodes_json value is not a list"
|
||||
)
|
||||
return empty_return
|
||||
for node in node_array:
|
||||
if not check_format(node):
|
||||
return empty_return
|
||||
# Check the current node first
|
||||
if node.get("graph_id") == pre_grad_graph_id:
|
||||
pre_to_post[node["name"]].add(outer_key)
|
||||
post_to_pre[outer_key].add(node["name"])
|
||||
|
||||
# Check nested from_node array recursively, add node with the right graph_id to the map
|
||||
stack = [(n, outer_key) for n in node.get("from_node", [])]
|
||||
while stack:
|
||||
current_node, parent_key = stack.pop()
|
||||
if not check_format(current_node):
|
||||
return empty_return
|
||||
if current_node.get("graph_id") == pre_grad_graph_id:
|
||||
pre_to_post[current_node["name"]].add(parent_key)
|
||||
post_to_pre[parent_key].add(current_node["name"])
|
||||
stack.extend(
|
||||
(n, parent_key) for n in current_node.get("from_node", [])
|
||||
)
|
||||
|
||||
def convert_sets_to_lists(d: dict[str, Any]) -> None:
|
||||
for key in d:
|
||||
d[key] = list(d[key])
|
||||
d = dict(d)
|
||||
|
||||
# convert to list because set is not JSON serializable
|
||||
convert_sets_to_lists(pre_to_post)
|
||||
convert_sets_to_lists(post_to_pre)
|
||||
convert_sets_to_lists(post_to_cpp_code)
|
||||
return {
|
||||
"preToPost": pre_to_post,
|
||||
"postToPre": post_to_pre,
|
||||
"cppCodeToPost": triton_kernel_to_post_grad_json,
|
||||
"postToCppCode": post_to_cpp_code,
|
||||
}
|
||||
except Exception as e:
|
||||
# Since this is just logging code, it should never interfere with regular
|
||||
# program execution, so we use this try-except to guard against any error
|
||||
log.error("Unexpected error in create_node_mapping: %s", e)
|
||||
log.error("post_to_pre_grad_nodes_json: %s", post_to_pre_grad_nodes_json)
|
||||
log.error(
|
||||
"triton_kernel_to_post_grad_json: %s", triton_kernel_to_post_grad_json
|
||||
)
|
||||
log.error("pre_grad_graph_id: %s", pre_grad_graph_id)
|
||||
log.error(traceback.format_exc())
|
||||
return empty_return
|
||||
|
||||
|
||||
def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
This function is used to save arguments for a compile_fx_inner function call
|
||||
|
|
|
|||
|
|
@ -1938,16 +1938,32 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
"Finished codegen for all nodes. The list of kernel names available: %s",
|
||||
V.graph.all_codegen_kernel_names,
|
||||
)
|
||||
# Dump the inductor_triton_kernel_to_post_grad_node_info to a json file for debugging trace
|
||||
debug_info = V.debug.log_inductor_triton_kernel_to_post_grad_node_info()
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "inductor_triton_kernel_to_post_grad_nodes",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: json.dumps(debug_info),
|
||||
# Dump provenance artifacts for debugging trace
|
||||
provenance_info = (
|
||||
V.debug.log_inductor_triton_kernel_to_post_grad_node_info()
|
||||
)
|
||||
# provenance_info might be None if config.trace.enabled is not set
|
||||
if provenance_info:
|
||||
(
|
||||
debug_info,
|
||||
node_mappings,
|
||||
) = provenance_info
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "inductor_triton_kernel_to_post_grad_nodes",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: json.dumps(debug_info),
|
||||
)
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "inductor_provenance_tracking_node_mappings",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: json.dumps(node_mappings),
|
||||
)
|
||||
|
||||
result = self.wrapper_code.generate(self.is_inference)
|
||||
self.wrapper_code.pop_codegened_graph()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
import json
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
|
@ -224,9 +223,9 @@ def get_current_meta() -> dict[str, Any]:
|
|||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def get_graph_provenance_json(graph: Graph) -> str:
|
||||
def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
|
||||
"""
|
||||
Given an fx.Graph, return a json string that contains the provenance information of each node.
|
||||
Given an fx.Graph, return a json that contains the provenance information of each node.
|
||||
"""
|
||||
provenance_tracking_json = {}
|
||||
for node in graph.nodes:
|
||||
|
|
@ -236,4 +235,4 @@ def get_graph_provenance_json(graph: Graph) -> str:
|
|||
if "from_node" in node.meta
|
||||
else []
|
||||
)
|
||||
return json.dumps(provenance_tracking_json)
|
||||
return provenance_tracking_json
|
||||
|
|
|
|||
Loading…
Reference in a new issue