[BE][export] add "+export" logging to de/serialization (#145283)

adds de/serialization debug logging to `TORCH_LOGS="+dynamic"`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145283
Approved by: https://github.com/ydwu4, https://github.com/angelayi
This commit is contained in:
Pian Pawakapan 2025-01-23 19:47:48 +00:00 committed by PyTorch MergeBot
parent ce4a097bf7
commit d53f2067fe
2 changed files with 60 additions and 11 deletions

View file

@ -466,6 +466,7 @@ class GraphModuleSerializer(metaclass=Final):
def handle_placeholder(self, node: torch.fx.Node):
assert node.op == "placeholder"
val = node.meta["val"]
log.debug("[handle_placeholder] %s: %s", node.name, val)
if isinstance(val, torch.Tensor):
graph_input = Argument.create(as_tensor=self.serialize_tensor_output(node.name, val))
elif isinstance(val, torch.SymInt):
@ -490,6 +491,7 @@ class GraphModuleSerializer(metaclass=Final):
assert node.op == "output"
assert len(node.args) == 1, "FX.Node's args should have one arg"
node_args = node.args[0]
log.debug("[handle_output] %s: %s", node.name, node_args)
if isinstance(node_args, torch.fx.Node):
# For singleton tensor returns
self.graph_state.is_single_tensor_return = True
@ -511,12 +513,13 @@ class GraphModuleSerializer(metaclass=Final):
def handle_call_function(self, node: torch.fx.Node):
assert node.op == "call_function"
meta_val = node.meta.get("val")
log.debug("[handle_call_function] %s: %s(%s, {%s}) -> %s", node.name, node.target, node.args, node.kwargs, meta_val)
# getitem has been handled in the producer node, skip it here
if node.target is operator.getitem:
return
meta_val = node.meta.get("val")
if (
node.target in _SYM_OPS
or (meta_val is not None and isinstance(meta_val, (torch.SymInt, torch.SymBool, torch.SymFloat)))
@ -571,7 +574,7 @@ class GraphModuleSerializer(metaclass=Final):
self.graph_state.nodes.append(ex_node)
def handle_get_attr(self, node):
pass
log.debug("[handle_get_attr] %s", node.name)
def _output_node_at_index(self, node, index) -> Optional[torch.fx.Node]:
user_node = None
@ -638,6 +641,7 @@ class GraphModuleSerializer(metaclass=Final):
def serialize_script_obj_meta(
self, script_obj_meta: ep.CustomObjArgument
) -> CustomObjArgument:
log.debug("[serialize_script_obj_meta] %s", script_obj_meta)
return CustomObjArgument(
name=script_obj_meta.name,
class_fqn=script_obj_meta.class_fqn,
@ -1007,6 +1011,7 @@ class GraphModuleSerializer(metaclass=Final):
return SymBoolArgument.create(as_name=name)
def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec:
log.debug("[serialize_input_spec] %s", spec)
if spec.kind == ep.InputKind.USER_INPUT:
if isinstance(spec.arg, ep.ConstantArgument):
if type(spec.arg.value) is int:
@ -1083,6 +1088,7 @@ class GraphModuleSerializer(metaclass=Final):
raise AssertionError(f"Unknown argument kind: {spec}")
def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec:
log.debug("[serialize_output_spec] %s", spec)
if spec.kind == ep.OutputKind.USER_OUTPUT:
return OutputSpec.create(
user_output=UserOutputSpec(arg=self.serialize_argument_spec(spec.arg))
@ -1139,6 +1145,7 @@ class GraphModuleSerializer(metaclass=Final):
raise AssertionError(f"Unknown argument kind: {spec}")
def serialize_signature(self, sig: ep.ExportGraphSignature) -> GraphSignature:
log.debug("\n[serialize_signature]")
return GraphSignature(
input_specs=[self.serialize_input_spec(s) for s in sig.input_specs],
output_specs=[self.serialize_output_spec(s) for s in sig.output_specs],
@ -1163,6 +1170,7 @@ class GraphModuleSerializer(metaclass=Final):
def serialize_module_call_signature(
self, module_call_signature: ep.ModuleCallSignature
) -> ModuleCallSignature:
log.debug("[serialize_module_call_signature] %s", module_call_signature)
return ModuleCallSignature(
inputs=[
self.serialize_argument_spec(x) for x in module_call_signature.inputs
@ -1178,6 +1186,7 @@ class GraphModuleSerializer(metaclass=Final):
def serialize_module_call_graph(
self, module_call_graph: list[ep.ModuleCallEntry]
) -> list[ModuleCallEntry]:
log.debug("\n[serialize_module_call_graph]")
return [
ModuleCallEntry(
fqn=entry.fqn,
@ -1382,6 +1391,8 @@ class GraphModuleSerializer(metaclass=Final):
def serialize_graph(self, graph_module: torch.fx.GraphModule) -> Graph:
assert isinstance(graph_module, torch.fx.GraphModule)
log.debug("[serialize_graph]\n\n%s", graph_module.print_readable())
for node in graph_module.graph.nodes:
try:
getattr(self, f"handle_{node.op}")(node)
@ -1405,6 +1416,7 @@ class GraphModuleSerializer(metaclass=Final):
def serialize_graph_module_metadata(self, meta: dict[str, Any]):
ret = {}
if custom := meta.get("custom"):
log.debug("\n[serialize_graph_module_metadata] %s", custom)
try:
ret["custom"] = json.dumps(custom)
except Exception as e:
@ -1415,6 +1427,7 @@ class GraphModuleSerializer(metaclass=Final):
return ret
def serialize(self, graph_module: torch.fx.GraphModule) -> GraphModule:
log.debug("\n[serialize]")
graph = self.serialize_graph(graph_module)
return GraphModule(
@ -1686,29 +1699,43 @@ class GraphModuleDeserializer(metaclass=Final):
raise SerializeError(f"Unable to deserialize output node {output}")
def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
log.debug("\n[deserialize_graph]")
# Handle the tensor metas.
for name, tensor_value in serialized_graph.tensor_values.items():
log.debug("[deserialize_tensor_meta] %s (input): %s", name, tensor_value)
meta_val = self.deserialize_tensor_meta(tensor_value)
log.debug("[deserialize_tensor_meta] %s (output): %s", name, meta_val)
self.serialized_name_to_meta[name] = meta_val
for name, sym_int_value in serialized_graph.sym_int_values.items():
self.serialized_name_to_meta[name] = self.deserialize_sym_int(sym_int_value)
log.debug("[deserialize_sym_int] %s (input): %s", name, sym_int_value)
int_val = self.deserialize_sym_int(sym_int_value)
log.debug("[deserialize_sym_int] %s (output): %s", name, int_val)
self.serialized_name_to_meta[name] = int_val
for name, sym_int_value in serialized_graph.sym_float_values.items():
self.serialized_name_to_meta[name] = self.deserialize_sym_float(sym_int_value)
for name, sym_float_value in serialized_graph.sym_float_values.items():
log.debug("[deserialize_sym_float] %s (input): %s", name, sym_float_value)
float_val = self.deserialize_sym_float(sym_float_value)
log.debug("[deserialize_sym_float] %s (output): %s", name, float_val)
self.serialized_name_to_meta[name] = float_val
for name, sym_bool_value in serialized_graph.sym_bool_values.items():
self.serialized_name_to_meta[name] = self.deserialize_sym_bool(
sym_bool_value
)
log.debug("[deserialize_sym_bool] %s (input): %s", name, sym_bool_value)
bool_val = self.deserialize_sym_bool(sym_bool_value)
log.debug("[deserialize_sym_bool] %s (output): %s", name, bool_val)
self.serialized_name_to_meta[name] = bool_val
for name, script_obj_meta in serialized_graph.custom_obj_values.items():
log.debug("[deserialize_script_obj_meta] %s", script_obj_meta)
self.serialized_name_to_meta[name] = self.deserialize_script_obj_meta(
script_obj_meta
)
log.debug("\n[deserialize graph nodes]")
# Inputs: convert to placeholder nodes in FX.
for i, input_ in enumerate(serialized_graph.inputs):
log.debug("[deserialize input] %s", input_)
if input_.type in ("as_tensor", "as_custom_obj"):
node_name = input_.value.name
placeholder_node = self.graph.placeholder(node_name)
@ -1751,7 +1778,10 @@ class GraphModuleDeserializer(metaclass=Final):
) from e
# Outputs: convert to a single `output` node.
outputs = [self.deserialize_graph_output(output) for output in serialized_graph.outputs]
outputs = []
for output in serialized_graph.outputs:
log.debug("[deserialize output] %s", output)
outputs.append(self.deserialize_graph_output(output))
if serialized_graph.is_single_tensor_return:
assert len(outputs) == 1
@ -1839,10 +1869,19 @@ class GraphModuleDeserializer(metaclass=Final):
)
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
log.debug(
"[deserialize_node] %s: %s(%s, {%s}) -> %s",
fx_node.name,
fx_node.target,
fx_node.args,
fx_node.kwargs,
fx_node.meta.get("val"),
)
if fx_node.op not in ["placeholder", "output"] and "nn_module_stack" not in fx_node.meta:
fx_node.meta["nn_module_stack"] = {} # serialization throws away empty dicts
def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec:
log.debug("[deserialize_input_spec] %s", i)
if i.type == "user_input":
return ep.InputSpec(
kind=ep.InputKind.USER_INPUT,
@ -1895,6 +1934,7 @@ class GraphModuleDeserializer(metaclass=Final):
raise AssertionError(f"Unknown input spec {i}")
def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec:
log.debug("[deserialize_output_spec] %s", o)
if o.type == "user_output":
return ep.OutputSpec(
kind=ep.OutputKind.USER_OUTPUT,
@ -1941,6 +1981,7 @@ class GraphModuleDeserializer(metaclass=Final):
raise AssertionError(f"Unknown output spec {o}")
def deserialize_signature(self, sig: GraphSignature) -> ep.ExportGraphSignature:
log.debug("\n[deserialize_signature]")
return ep.ExportGraphSignature(
input_specs=[self.deserialize_input_spec(i) for i in sig.input_specs],
output_specs=[self.deserialize_output_spec(o) for o in sig.output_specs],
@ -1958,6 +1999,7 @@ class GraphModuleDeserializer(metaclass=Final):
assert _CURRENT_DESERIALIZER is None
_CURRENT_DESERIALIZER = self
try:
log.debug("\n[deserialize]")
self.shape_env = symbolic_shapes.ShapeEnv(assume_static_by_default=True)
self.fake_tensor_mode = FakeTensorMode(
allow_fallback_kernels=False,
@ -2415,6 +2457,7 @@ class GraphModuleDeserializer(metaclass=Final):
return ret
def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec:
log.debug("[deserialize_argument_spec] %s", x)
if x.type == "as_tensor":
return ep.TensorArgument(name=x.as_tensor.name)
elif x.type == "as_sym_int":
@ -2444,6 +2487,7 @@ class GraphModuleDeserializer(metaclass=Final):
def deserialize_module_call_graph(
self, module_call_graph: list[ModuleCallEntry]
) -> list[ep.ModuleCallEntry]:
log.debug("\n[deserialize_module_call_graph]")
return [
ep.ModuleCallEntry(
fqn=entry.fqn,
@ -2471,12 +2515,14 @@ class ExportedProgramDeserializer(metaclass=Final):
symbol_name_to_range: dict[str, symbolic_shapes.ValueRanges],
symbol_name_to_symbol: dict[str, sympy.Symbol],
) -> dict[sympy.Symbol, ValueRanges]:
log.debug("\n[deserialize_range_constraints]")
range_constraints = {}
for k, v in symbol_name_to_range.items():
if symbol := symbol_name_to_symbol.get(k):
log.debug("[deserialize_range_constraints] %s -> %s", k, v)
range_constraints[symbol] = v # type: ignore[arg-type]
else:
log.warning(f"Symbol {k} did not appear in the graph that was deserialized") # noqa: G004
log.warning("Symbol %s did not appear in the graph that was deserialized", k)
return range_constraints
def deserialize(
@ -2520,7 +2566,7 @@ class ExportedProgramDeserializer(metaclass=Final):
res.names_to_symbols,
)
return ep.ExportedProgram(
result = ep.ExportedProgram(
root=res.graph_module,
graph=res.graph_module.graph,
graph_signature=res.signature,
@ -2531,6 +2577,8 @@ class ExportedProgramDeserializer(metaclass=Final):
constants=res.constants,
verifiers=[load_verifier(v) for v in exported_program.verifiers],
)
log.debug("\n[deserialize]: %s", result)
return result
class EnumEncoder(json.JSONEncoder):

View file

@ -49,6 +49,7 @@ register_log(
*DYNAMIC,
"torch._export.converter",
"torch._export.non_strict_utils",
"torch._export.serde.serialize",
],
)