mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[BE][export] add "+export" logging to de/serialization (#145283)
adds de/serialization debug logging to `TORCH_LOGS="+dynamic"` Pull Request resolved: https://github.com/pytorch/pytorch/pull/145283 Approved by: https://github.com/ydwu4, https://github.com/angelayi
This commit is contained in:
parent
ce4a097bf7
commit
d53f2067fe
2 changed files with 60 additions and 11 deletions
|
|
@ -466,6 +466,7 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
def handle_placeholder(self, node: torch.fx.Node):
|
||||
assert node.op == "placeholder"
|
||||
val = node.meta["val"]
|
||||
log.debug("[handle_placeholder] %s: %s", node.name, val)
|
||||
if isinstance(val, torch.Tensor):
|
||||
graph_input = Argument.create(as_tensor=self.serialize_tensor_output(node.name, val))
|
||||
elif isinstance(val, torch.SymInt):
|
||||
|
|
@ -490,6 +491,7 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
assert node.op == "output"
|
||||
assert len(node.args) == 1, "FX.Node's args should have one arg"
|
||||
node_args = node.args[0]
|
||||
log.debug("[handle_output] %s: %s", node.name, node_args)
|
||||
if isinstance(node_args, torch.fx.Node):
|
||||
# For singleton tensor returns
|
||||
self.graph_state.is_single_tensor_return = True
|
||||
|
|
@ -511,12 +513,13 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
|
||||
def handle_call_function(self, node: torch.fx.Node):
|
||||
assert node.op == "call_function"
|
||||
meta_val = node.meta.get("val")
|
||||
log.debug("[handle_call_function] %s: %s(%s, {%s}) -> %s", node.name, node.target, node.args, node.kwargs, meta_val)
|
||||
|
||||
# getitem has been handled in the producer node, skip it here
|
||||
if node.target is operator.getitem:
|
||||
return
|
||||
|
||||
meta_val = node.meta.get("val")
|
||||
if (
|
||||
node.target in _SYM_OPS
|
||||
or (meta_val is not None and isinstance(meta_val, (torch.SymInt, torch.SymBool, torch.SymFloat)))
|
||||
|
|
@ -571,7 +574,7 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
self.graph_state.nodes.append(ex_node)
|
||||
|
||||
def handle_get_attr(self, node):
|
||||
pass
|
||||
log.debug("[handle_get_attr] %s", node.name)
|
||||
|
||||
def _output_node_at_index(self, node, index) -> Optional[torch.fx.Node]:
|
||||
user_node = None
|
||||
|
|
@ -638,6 +641,7 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
def serialize_script_obj_meta(
|
||||
self, script_obj_meta: ep.CustomObjArgument
|
||||
) -> CustomObjArgument:
|
||||
log.debug("[serialize_script_obj_meta] %s", script_obj_meta)
|
||||
return CustomObjArgument(
|
||||
name=script_obj_meta.name,
|
||||
class_fqn=script_obj_meta.class_fqn,
|
||||
|
|
@ -1007,6 +1011,7 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
return SymBoolArgument.create(as_name=name)
|
||||
|
||||
def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec:
|
||||
log.debug("[serialize_input_spec] %s", spec)
|
||||
if spec.kind == ep.InputKind.USER_INPUT:
|
||||
if isinstance(spec.arg, ep.ConstantArgument):
|
||||
if type(spec.arg.value) is int:
|
||||
|
|
@ -1083,6 +1088,7 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
raise AssertionError(f"Unknown argument kind: {spec}")
|
||||
|
||||
def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec:
|
||||
log.debug("[serialize_output_spec] %s", spec)
|
||||
if spec.kind == ep.OutputKind.USER_OUTPUT:
|
||||
return OutputSpec.create(
|
||||
user_output=UserOutputSpec(arg=self.serialize_argument_spec(spec.arg))
|
||||
|
|
@ -1139,6 +1145,7 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
raise AssertionError(f"Unknown argument kind: {spec}")
|
||||
|
||||
def serialize_signature(self, sig: ep.ExportGraphSignature) -> GraphSignature:
|
||||
log.debug("\n[serialize_signature]")
|
||||
return GraphSignature(
|
||||
input_specs=[self.serialize_input_spec(s) for s in sig.input_specs],
|
||||
output_specs=[self.serialize_output_spec(s) for s in sig.output_specs],
|
||||
|
|
@ -1163,6 +1170,7 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
def serialize_module_call_signature(
|
||||
self, module_call_signature: ep.ModuleCallSignature
|
||||
) -> ModuleCallSignature:
|
||||
log.debug("[serialize_module_call_signature] %s", module_call_signature)
|
||||
return ModuleCallSignature(
|
||||
inputs=[
|
||||
self.serialize_argument_spec(x) for x in module_call_signature.inputs
|
||||
|
|
@ -1178,6 +1186,7 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
def serialize_module_call_graph(
|
||||
self, module_call_graph: list[ep.ModuleCallEntry]
|
||||
) -> list[ModuleCallEntry]:
|
||||
log.debug("\n[serialize_module_call_graph]")
|
||||
return [
|
||||
ModuleCallEntry(
|
||||
fqn=entry.fqn,
|
||||
|
|
@ -1382,6 +1391,8 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
|
||||
def serialize_graph(self, graph_module: torch.fx.GraphModule) -> Graph:
|
||||
assert isinstance(graph_module, torch.fx.GraphModule)
|
||||
log.debug("[serialize_graph]\n\n%s", graph_module.print_readable())
|
||||
|
||||
for node in graph_module.graph.nodes:
|
||||
try:
|
||||
getattr(self, f"handle_{node.op}")(node)
|
||||
|
|
@ -1405,6 +1416,7 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
def serialize_graph_module_metadata(self, meta: dict[str, Any]):
|
||||
ret = {}
|
||||
if custom := meta.get("custom"):
|
||||
log.debug("\n[serialize_graph_module_metadata] %s", custom)
|
||||
try:
|
||||
ret["custom"] = json.dumps(custom)
|
||||
except Exception as e:
|
||||
|
|
@ -1415,6 +1427,7 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
return ret
|
||||
|
||||
def serialize(self, graph_module: torch.fx.GraphModule) -> GraphModule:
|
||||
log.debug("\n[serialize]")
|
||||
graph = self.serialize_graph(graph_module)
|
||||
|
||||
return GraphModule(
|
||||
|
|
@ -1686,29 +1699,43 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||
raise SerializeError(f"Unable to deserialize output node {output}")
|
||||
|
||||
def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
|
||||
log.debug("\n[deserialize_graph]")
|
||||
|
||||
# Handle the tensor metas.
|
||||
for name, tensor_value in serialized_graph.tensor_values.items():
|
||||
log.debug("[deserialize_tensor_meta] %s (input): %s", name, tensor_value)
|
||||
meta_val = self.deserialize_tensor_meta(tensor_value)
|
||||
log.debug("[deserialize_tensor_meta] %s (output): %s", name, meta_val)
|
||||
self.serialized_name_to_meta[name] = meta_val
|
||||
|
||||
for name, sym_int_value in serialized_graph.sym_int_values.items():
|
||||
self.serialized_name_to_meta[name] = self.deserialize_sym_int(sym_int_value)
|
||||
log.debug("[deserialize_sym_int] %s (input): %s", name, sym_int_value)
|
||||
int_val = self.deserialize_sym_int(sym_int_value)
|
||||
log.debug("[deserialize_sym_int] %s (output): %s", name, int_val)
|
||||
self.serialized_name_to_meta[name] = int_val
|
||||
|
||||
for name, sym_int_value in serialized_graph.sym_float_values.items():
|
||||
self.serialized_name_to_meta[name] = self.deserialize_sym_float(sym_int_value)
|
||||
for name, sym_float_value in serialized_graph.sym_float_values.items():
|
||||
log.debug("[deserialize_sym_float] %s (input): %s", name, sym_float_value)
|
||||
float_val = self.deserialize_sym_float(sym_float_value)
|
||||
log.debug("[deserialize_sym_float] %s (output): %s", name, float_val)
|
||||
self.serialized_name_to_meta[name] = float_val
|
||||
|
||||
for name, sym_bool_value in serialized_graph.sym_bool_values.items():
|
||||
self.serialized_name_to_meta[name] = self.deserialize_sym_bool(
|
||||
sym_bool_value
|
||||
)
|
||||
log.debug("[deserialize_sym_bool] %s (input): %s", name, sym_bool_value)
|
||||
bool_val = self.deserialize_sym_bool(sym_bool_value)
|
||||
log.debug("[deserialize_sym_bool] %s (output): %s", name, bool_val)
|
||||
self.serialized_name_to_meta[name] = bool_val
|
||||
|
||||
for name, script_obj_meta in serialized_graph.custom_obj_values.items():
|
||||
log.debug("[deserialize_script_obj_meta] %s", script_obj_meta)
|
||||
self.serialized_name_to_meta[name] = self.deserialize_script_obj_meta(
|
||||
script_obj_meta
|
||||
)
|
||||
|
||||
log.debug("\n[deserialize graph nodes]")
|
||||
# Inputs: convert to placeholder nodes in FX.
|
||||
for i, input_ in enumerate(serialized_graph.inputs):
|
||||
log.debug("[deserialize input] %s", input_)
|
||||
if input_.type in ("as_tensor", "as_custom_obj"):
|
||||
node_name = input_.value.name
|
||||
placeholder_node = self.graph.placeholder(node_name)
|
||||
|
|
@ -1751,7 +1778,10 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||
) from e
|
||||
|
||||
# Outputs: convert to a single `output` node.
|
||||
outputs = [self.deserialize_graph_output(output) for output in serialized_graph.outputs]
|
||||
outputs = []
|
||||
for output in serialized_graph.outputs:
|
||||
log.debug("[deserialize output] %s", output)
|
||||
outputs.append(self.deserialize_graph_output(output))
|
||||
|
||||
if serialized_graph.is_single_tensor_return:
|
||||
assert len(outputs) == 1
|
||||
|
|
@ -1839,10 +1869,19 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||
)
|
||||
|
||||
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
|
||||
log.debug(
|
||||
"[deserialize_node] %s: %s(%s, {%s}) -> %s",
|
||||
fx_node.name,
|
||||
fx_node.target,
|
||||
fx_node.args,
|
||||
fx_node.kwargs,
|
||||
fx_node.meta.get("val"),
|
||||
)
|
||||
if fx_node.op not in ["placeholder", "output"] and "nn_module_stack" not in fx_node.meta:
|
||||
fx_node.meta["nn_module_stack"] = {} # serialization throws away empty dicts
|
||||
|
||||
def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec:
|
||||
log.debug("[deserialize_input_spec] %s", i)
|
||||
if i.type == "user_input":
|
||||
return ep.InputSpec(
|
||||
kind=ep.InputKind.USER_INPUT,
|
||||
|
|
@ -1895,6 +1934,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||
raise AssertionError(f"Unknown input spec {i}")
|
||||
|
||||
def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec:
|
||||
log.debug("[deserialize_output_spec] %s", o)
|
||||
if o.type == "user_output":
|
||||
return ep.OutputSpec(
|
||||
kind=ep.OutputKind.USER_OUTPUT,
|
||||
|
|
@ -1941,6 +1981,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||
raise AssertionError(f"Unknown output spec {o}")
|
||||
|
||||
def deserialize_signature(self, sig: GraphSignature) -> ep.ExportGraphSignature:
|
||||
log.debug("\n[deserialize_signature]")
|
||||
return ep.ExportGraphSignature(
|
||||
input_specs=[self.deserialize_input_spec(i) for i in sig.input_specs],
|
||||
output_specs=[self.deserialize_output_spec(o) for o in sig.output_specs],
|
||||
|
|
@ -1958,6 +1999,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||
assert _CURRENT_DESERIALIZER is None
|
||||
_CURRENT_DESERIALIZER = self
|
||||
try:
|
||||
log.debug("\n[deserialize]")
|
||||
self.shape_env = symbolic_shapes.ShapeEnv(assume_static_by_default=True)
|
||||
self.fake_tensor_mode = FakeTensorMode(
|
||||
allow_fallback_kernels=False,
|
||||
|
|
@ -2415,6 +2457,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||
return ret
|
||||
|
||||
def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec:
|
||||
log.debug("[deserialize_argument_spec] %s", x)
|
||||
if x.type == "as_tensor":
|
||||
return ep.TensorArgument(name=x.as_tensor.name)
|
||||
elif x.type == "as_sym_int":
|
||||
|
|
@ -2444,6 +2487,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||
def deserialize_module_call_graph(
|
||||
self, module_call_graph: list[ModuleCallEntry]
|
||||
) -> list[ep.ModuleCallEntry]:
|
||||
log.debug("\n[deserialize_module_call_graph]")
|
||||
return [
|
||||
ep.ModuleCallEntry(
|
||||
fqn=entry.fqn,
|
||||
|
|
@ -2471,12 +2515,14 @@ class ExportedProgramDeserializer(metaclass=Final):
|
|||
symbol_name_to_range: dict[str, symbolic_shapes.ValueRanges],
|
||||
symbol_name_to_symbol: dict[str, sympy.Symbol],
|
||||
) -> dict[sympy.Symbol, ValueRanges]:
|
||||
log.debug("\n[deserialize_range_constraints]")
|
||||
range_constraints = {}
|
||||
for k, v in symbol_name_to_range.items():
|
||||
if symbol := symbol_name_to_symbol.get(k):
|
||||
log.debug("[deserialize_range_constraints] %s -> %s", k, v)
|
||||
range_constraints[symbol] = v # type: ignore[arg-type]
|
||||
else:
|
||||
log.warning(f"Symbol {k} did not appear in the graph that was deserialized") # noqa: G004
|
||||
log.warning("Symbol %s did not appear in the graph that was deserialized", k)
|
||||
return range_constraints
|
||||
|
||||
def deserialize(
|
||||
|
|
@ -2520,7 +2566,7 @@ class ExportedProgramDeserializer(metaclass=Final):
|
|||
res.names_to_symbols,
|
||||
)
|
||||
|
||||
return ep.ExportedProgram(
|
||||
result = ep.ExportedProgram(
|
||||
root=res.graph_module,
|
||||
graph=res.graph_module.graph,
|
||||
graph_signature=res.signature,
|
||||
|
|
@ -2531,6 +2577,8 @@ class ExportedProgramDeserializer(metaclass=Final):
|
|||
constants=res.constants,
|
||||
verifiers=[load_verifier(v) for v in exported_program.verifiers],
|
||||
)
|
||||
log.debug("\n[deserialize]: %s", result)
|
||||
return result
|
||||
|
||||
|
||||
class EnumEncoder(json.JSONEncoder):
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ register_log(
|
|||
*DYNAMIC,
|
||||
"torch._export.converter",
|
||||
"torch._export.non_strict_utils",
|
||||
"torch._export.serde.serialize",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue