From ba7b9211eed3c29b183fa622eee696cd1d3c8287 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Fri, 13 Oct 2023 22:19:56 +0000 Subject: [PATCH] [export] Update serialization schema to input/output specs. (#845) (#111204) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/845 Test Plan: CI Differential Revision: D50191531 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111204 Approved by: https://github.com/angelayi --- test/export/test_serialize.py | 5 +- torch/_export/serde/schema.py | 78 ++++++-- torch/_export/serde/serialize.py | 306 +++++++++++++++++++------------ torch/_export/serde/upgrade.py | 1 - torch/_inductor/ir.py | 2 +- torch/export/exported_program.py | 24 ++- 6 files changed, 269 insertions(+), 147 deletions(-) diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 895b9be298a..c09894f3c3e 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -338,9 +338,10 @@ class TestDeserialize(TestCase): def test_sym_bool(self): def f(x, y): - return x.size(0) in y + assert x.size(0) in y + return x + y - self.check_graph(f, (torch.ones(2), torch.ones(3))) + self.check_graph(f, (torch.ones(1), torch.ones(3))) def test_shape(self): def f(x): diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index 94ce2f2ddd6..041e48143eb 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -203,26 +203,70 @@ class Graph: @dataclass -class BackwardSignature: - gradients_to_parameters: Dict[str, str] - gradients_to_user_inputs: Dict[str, str] - loss_output: str +class UserInputSpec: + arg: Argument + + +@dataclass +class InputToParameterSpec: + arg: TensorArgument + parameter_name: str + + +@dataclass +class InputToBufferSpec: + arg: TensorArgument + buffer_name: str + + +@dataclass +class InputSpec(_Union): + user_input: UserInputSpec + parameter: InputToParameterSpec + buffer: InputToBufferSpec + + +@dataclass +class UserOutputSpec: + arg: Argument + + +@dataclass +class LossOutputSpec: + arg: TensorArgument + + +@dataclass +class BufferMutationSpec: + arg: TensorArgument + buffer_name: str + + +@dataclass +class GradientToParameterSpec: + arg: TensorArgument + parameter_name: str + + +@dataclass +class GradientToUserInputSpec: + arg: TensorArgument + user_input_name: str + + +@dataclass +class OutputSpec(_Union): + user_output: UserOutputSpec + loss_outout: LossOutputSpec + buffer_mutation: BufferMutationSpec + gradient_to_parameter: GradientToParameterSpec + gradient_to_user_input: GradientToUserInputSpec @dataclass class GraphSignature: - inputs_to_parameters: Dict[str, str] - inputs_to_buffers: Dict[str, str] - user_inputs: List[str] - user_outputs: List[str] - buffers_to_mutate: Dict[str, str] - backward_signature: Optional[BackwardSignature] - - -@dataclass -class CallSpec: - in_spec: str - out_spec: str + input_specs: List[InputSpec] + output_specs: List[OutputSpec] @dataclass @@ -249,8 +293,6 @@ class ModuleCallEntry: class GraphModule: graph: Graph signature: GraphSignature - # TODO(zhxchen17) Merge call_spec into call graph. - call_spec: CallSpec module_call_graph: List[ModuleCallEntry] diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 3f559fece1e..98a5ddafd38 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -25,15 +25,22 @@ from torch.utils._sympy.value_ranges import ValueRanges from .schema import ( # type: ignore[attr-defined] _Union, Argument, - BackwardSignature, - CallSpec, + BufferMutationSpec, CustomObjArgument, Device, ExportedProgram, + GradientToParameterSpec, + GradientToUserInputSpec, Graph, GraphArgument, GraphModule, GraphSignature, + InputSpec, + InputToBufferSpec, + InputToParameterSpec, + UserInputSpec, + UserOutputSpec, + LossOutputSpec, Layout, MemoryFormat, ModuleCallEntry, @@ -41,6 +48,7 @@ from .schema import ( # type: ignore[attr-defined] NamedArgument, Node, OptionalTensorArgument, + OutputSpec, RangeConstraint, ScalarType, SCHEMA_VERSION, @@ -197,41 +205,6 @@ def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta: layout=_TORCH_TO_SERIALIZE_LAYOUT[t.layout], ) -# TODO(zhxchen17) Remove call spec. -def serialize_call_spec(call_spec: torch._export.exported_program.CallSpec) -> CallSpec: - return CallSpec( - in_spec=treespec_dumps(call_spec.in_spec, TREESPEC_VERSION) if call_spec.in_spec else "", - out_spec=treespec_dumps(call_spec.out_spec, TREESPEC_VERSION) if call_spec.out_spec else "", - ) - - -def deserialize_call_spec(call_spec: CallSpec) -> torch._export.exported_program.CallSpec: - return torch._export.exported_program.CallSpec( - in_spec=treespec_loads(call_spec.in_spec) if call_spec.in_spec else None, - out_spec=treespec_loads(call_spec.out_spec) if call_spec.out_spec else None, - ) - - -def serialize_signature(sig: ep.ExportGraphSignature) -> GraphSignature: - if bw_sig := sig.backward_signature: - backward_signature = BackwardSignature( - gradients_to_parameters=bw_sig.gradients_to_parameters, - gradients_to_user_inputs=bw_sig.gradients_to_user_inputs, - loss_output=bw_sig.loss_output, - ) - else: - backward_signature = None - - graph_signature = GraphSignature( - inputs_to_parameters=sig.inputs_to_parameters, # type: ignore[arg-type] - inputs_to_buffers=sig.inputs_to_buffers, # type: ignore[arg-type] - user_inputs=sig.user_inputs, # type: ignore[arg-type] - user_outputs=sig.user_outputs, # type: ignore[arg-type] - buffers_to_mutate=sig.buffers_to_mutate, # type: ignore[arg-type] - backward_signature=backward_signature, - ) - return graph_signature - def serialize_torch_artifact(artifact) -> bytes: buffer = io.BytesIO() @@ -327,12 +300,10 @@ class GraphModuleSerializer: def __init__( self, graph_signature: ep.ExportGraphSignature, - call_spec: torch._export.exported_program.CallSpec, module_call_graph: List[ep.ModuleCallEntry] ): self.graph_state = GraphState() self.graph_signature = graph_signature - self.call_spec = call_spec self.module_call_graph = module_call_graph @contextmanager @@ -683,18 +654,98 @@ class GraphModuleSerializer: self.graph_state.sym_bool_values[name] = serialize_sym_bool(meta_val) return SymBoolArgument.create(as_name=name) + def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec: + if spec.kind == ep.InputKind.USER_INPUT: + return InputSpec.create( + user_input=UserInputSpec( + arg=self.serialize_argument_spec(spec.arg) + ) + ) + elif spec.kind == ep.InputKind.PARAMETER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return InputSpec.create( + parameter=InputToParameterSpec( + arg=TensorArgument(name=spec.arg.name), + parameter_name=spec.target, + ) + ) + elif spec.kind == ep.InputKind.BUFFER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return InputSpec.create( + buffer=InputToBufferSpec( + arg=TensorArgument(name=spec.arg.name), + buffer_name=spec.target, + ) + ) + else: + raise AssertionError(f"Unknown argument kind: {spec}") + + def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec: + if spec.kind == ep.OutputKind.USER_OUTPUT: + return OutputSpec.create( + user_output=UserOutputSpec( + arg=self.serialize_argument_spec(spec.arg) + ) + ) + elif spec.kind == ep.OutputKind.LOSS_OUTPUT: + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + loss_output=LossOutputSpec( + arg=TensorArgument(name=spec.arg.name) + ) + ) + elif spec.kind == ep.OutputKind.BUFFER_MUTATION: + assert spec.target is not None + assert isinstance(spec.arg, PyTensorArgument) + return OutputSpec.create( + buffer_mutation=BufferMutationSpec( + arg=TensorArgument(name=spec.arg.name), + buffer_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.GRADIENT_TO_PARAMETER: + assert spec.target is not None + assert isinstance(spec.arg, PyTensorArgument) + return OutputSpec.create( + gradient_to_parameter=GradientToParameterSpec( + arg=TensorArgument(name=spec.arg.name), + parameter_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.GRADIENT_TO_USER_INPUT: + assert spec.target is not None + assert isinstance(spec.arg, PyTensorArgument) + return OutputSpec.create( + gradient_to_user_input=GradientToUserInputSpec( + arg=TensorArgument(name=spec.arg.name), + user_input_name=spec.target, + ) + ) + else: + raise AssertionError(f"Unknown argument kind: {spec}") + + def serialize_signature(self, sig: ep.ExportGraphSignature) -> GraphSignature: + return GraphSignature( + input_specs=[self.serialize_input_spec(s) for s in sig.input_specs], + output_specs=[self.serialize_output_spec(s) for s in sig.output_specs], + ) + + def serialize_argument_spec(self, x: ep.ArgumentSpec) -> Argument: + if isinstance(x, PyTensorArgument): + return Argument.create(as_tensor=TensorArgument(name=x.name)) + elif isinstance(x, PySymIntArgument): + return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name)) + elif isinstance(x, PyConstantArgument): + return self.serialize_input(x.value) + else: + raise AssertionError("TODO") + def serialize_module_call_signature(self, module_call_signature: ep.ModuleCallSignature) -> ModuleCallSignature: - def serialize_argument(x: ep.ArgumentSpec) -> Argument: - if isinstance(x, PyTensorArgument): - return Argument.create(as_tensor=TensorArgument(name=x.name)) - elif isinstance(x, PySymIntArgument): - return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name)) - else: - assert isinstance(x, PyConstantArgument) - return self.serialize_input(x.value) return ModuleCallSignature( - inputs=[serialize_argument(x) for x in module_call_signature.inputs], - outputs=[serialize_argument(x) for x in module_call_signature.outputs], + inputs=[self.serialize_argument_spec(x) for x in module_call_signature.inputs], + outputs=[self.serialize_argument_spec(x) for x in module_call_signature.outputs], in_spec=treespec_dumps(module_call_signature.in_spec, TREESPEC_VERSION), out_spec=treespec_dumps(module_call_signature.out_spec, TREESPEC_VERSION), ) @@ -821,8 +872,7 @@ class GraphModuleSerializer: return GraphModule( graph=graph, - signature=serialize_signature(self.graph_signature), - call_spec=serialize_call_spec(self.call_spec), + signature=self.serialize_signature(self.graph_signature), module_call_graph=self.serialize_module_call_graph(self.module_call_graph), ) @@ -839,7 +889,6 @@ class ExportedProgramSerializer: serialized_graph_module = ( GraphModuleSerializer( exported_program.graph_signature, - exported_program.call_spec, exported_program.module_call_graph ).serialize(exported_program.graph_module) ) @@ -860,6 +909,13 @@ class ExportedProgramSerializer: class GraphModuleDeserializer: + @dataclasses.dataclass + class Result: + graph_module: torch.fx.GraphModule + signature: ep.ExportGraphSignature + module_call_graph: List[ep.ModuleCallEntry] + names_to_symbols: Dict[str, sympy.Symbol] + def __init__(self): self.serialized_name_to_node: Dict[str, torch.fx.Node] = {} self.serialized_name_to_meta: Dict[str, MetaType] = {} @@ -1069,47 +1125,73 @@ class GraphModuleDeserializer: fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) - def deserialize_signature( - self, - sig: GraphSignature, - inputs: List[Argument], - outputs: List[Argument] - ) -> ep.ExportGraphSignature: - # TODO(zhxchen17) Remove these after we change serialization schema. - def make_argument_spec(arg: Argument) -> ep.ArgumentSpec: - if arg.as_tensor is not None: - return ep.TensorArgument(name=arg.as_tensor.name) - elif arg.as_sym_int is not None: - assert arg.as_sym_int.as_name is not None - return ep.SymIntArgument(name=arg.as_sym_int.as_name) - else: - return ep.ConstantArgument(value=arg.value) + def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec: + if i.user_input is not None: + return ep.InputSpec( + kind=ep.InputKind.USER_INPUT, + arg=self.deserialize_argument_spec(i.user_input.arg), + target=None + ) + elif i.parameter is not None: + return ep.InputSpec( + kind=ep.InputKind.PARAMETER, + arg=PyTensorArgument(name=i.parameter.arg.name), + target=i.parameter.parameter_name, + ) + elif i.buffer is not None: + return ep.InputSpec( + kind=ep.InputKind.BUFFER, + arg=PyTensorArgument(name=i.buffer.arg.name), + target=i.buffer.buffer_name, + ) + else: + raise AssertionError(f"Unkown input spec {i}") - input_specs, output_specs = ep._sig_to_specs( - user_inputs=set(sig.user_inputs), - inputs_to_parameters=sig.inputs_to_parameters, - inputs_to_buffers=sig.inputs_to_buffers, - user_outputs=set(sig.user_outputs), - buffer_mutations=sig.buffers_to_mutate, - grad_params=sig.backward_signature.gradients_to_parameters if sig.backward_signature is not None else {}, - grad_user_inputs=sig.backward_signature.gradients_to_user_inputs if sig.backward_signature is not None else {}, - loss_output=sig.backward_signature.loss_output if sig.backward_signature is not None else None, - inputs=[make_argument_spec(i) for i in inputs], - outputs=[make_argument_spec(o) for o in outputs], + def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec: + if o.user_output is not None: + return ep.OutputSpec( + kind=ep.OutputKind.USER_OUTPUT, + arg=self.deserialize_argument_spec(o.user_output.arg), + target=None, + ) + elif o.loss_output is not None: + return ep.OutputSpec( + kind=ep.OutputKind.LOSS_OUTPUT, + arg=PyTensorArgument(name=o.loss_output.arg.name), + target=None, + ) + elif o.buffer_mutation is not None: + return ep.OutputSpec( + kind=ep.OutputKind.BUFFER_MUTATION, + arg=PyTensorArgument(name=o.buffer_mutation.arg.name), + target=o.buffer_mutation.buffer_name + ) + elif o.gradient_to_parameter is not None: + return ep.OutputSpec( + kind=ep.OutputKind.GRADIENT_TO_PARAMETER, + arg=PyTensorArgument(name=o.gradient_to_parameter.arg.name), + target=o.gradient_to_parameter.parameter_name + ) + elif o.gradient_to_user_input is not None: + return ep.OutputSpec( + kind=ep.OutputKind.GRADIENT_TO_USER_INPUT, + arg=PyTensorArgument(name=o.gradient_to_user_input.arg.name), + target=o.gradient_to_user_input.user_input_name + ) + else: + raise AssertionError(f"Unknown output spec {o}") + + def deserialize_signature(self, sig: GraphSignature) -> ep.ExportGraphSignature: + return ep.ExportGraphSignature( + input_specs=[self.deserialize_input_spec(i) for i in sig.input_specs], + output_specs=[self.deserialize_output_spec(o) for o in sig.output_specs] ) - return ep.ExportGraphSignature(input_specs=input_specs, output_specs=output_specs) def deserialize( self, serialized_graph_module: GraphModule, symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None, - ) -> Tuple[ - torch.fx.GraphModule, - ep.ExportGraphSignature, - torch._export.exported_program.CallSpec, - List[ep.ModuleCallEntry], - Dict[str, sympy.Symbol] - ]: + ) -> Result: self.shape_env = symbolic_shapes.ShapeEnv(assume_static_by_default=True) self.fake_tensor_mode = FakeTensorMode( allow_fallback_kernels=False, @@ -1121,19 +1203,13 @@ class GraphModuleDeserializer: self.deserialize_graph(serialized_graph_module.graph) - sig = self.deserialize_signature( - serialized_graph_module.signature, - serialized_graph_module.graph.inputs, - serialized_graph_module.graph.outputs - ) - call_spec = deserialize_call_spec(serialized_graph_module.call_spec) + sig = self.deserialize_signature(serialized_graph_module.signature) module_call_graph = self.deserialize_module_call_graph(serialized_graph_module.module_call_graph) - return ( - torch._export.exported_program._create_graph_module_for_export(self.module, self.graph), - sig, - call_spec, - module_call_graph, - self.symbol_name_to_symbol, + return GraphModuleDeserializer.Result( + graph_module=torch._export.exported_program._create_graph_module_for_export(self.module, self.graph), + signature=sig, + module_call_graph=module_call_graph, + names_to_symbols=self.symbol_name_to_symbol, ) def sync_fx_node(self, name: str, fx_node: torch.fx.Node): @@ -1366,18 +1442,18 @@ class GraphModuleDeserializer: ret["source_fn_stack"] = source_fn_st return ret - def deserialize_module_call_signature(self, module_call_signature: ModuleCallSignature) -> ep.ModuleCallSignature: - def deserialize_argument(x: Argument) -> ep.ArgumentSpec: - if x.as_tensor is not None: - return PyTensorArgument(name=x.as_tensor.name) - elif x.as_symint is not None: - return PySymIntArgument(name=x.as_symint.as_name) - else: - return PyConstantArgument(value=self.deserialize_input(x)) + def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec: + if x.as_tensor is not None: + return PyTensorArgument(name=x.as_tensor.name) + elif x.as_sym_int is not None: + return PySymIntArgument(name=x.as_sym_int.as_name) + else: + return PyConstantArgument(value=self.deserialize_input(x)) + def deserialize_module_call_signature(self, module_call_signature: ModuleCallSignature) -> ep.ModuleCallSignature: return ep.ModuleCallSignature( - inputs=[deserialize_argument(x) for x in module_call_signature.inputs], - outputs=[deserialize_argument(x) for x in module_call_signature.outputs], + inputs=[self.deserialize_argument_spec(x) for x in module_call_signature.inputs], + outputs=[self.deserialize_argument_spec(x) for x in module_call_signature.outputs], in_spec=treespec_loads(module_call_signature.in_spec), out_spec=treespec_loads(module_call_signature.out_spec), ) @@ -1426,7 +1502,7 @@ class ExportedProgramDeserializer: for k, v in serialized_exported_program.range_constraints.items() } - graph_module, sig, call_spec, module_call_graph, symbol_name_to_symbol = ( + res = ( GraphModuleDeserializer() .deserialize( serialized_exported_program.graph_module, @@ -1434,7 +1510,7 @@ class ExportedProgramDeserializer: ) ) range_constraints = self.deserialize_range_constraints( - symbol_name_to_range, symbol_name_to_symbol, + symbol_name_to_range, res.names_to_symbols, ) model_opset_version: Optional[Dict[str, int]] = serialized_exported_program.opset_version self._validate_model_opset_version(model_opset_version) @@ -1445,14 +1521,14 @@ class ExportedProgramDeserializer: equality_constraints = deserialize_equality_constraints(serialized_exported_program.equality_constraints) exported_program = ep.ExportedProgram( - graph_module, - graph_module.graph, - sig, - call_spec, + res.graph_module, + res.graph_module.graph, + res.signature, + None, # TODO(zhxchen17) Remove this. state_dict, # type: ignore[arg-type] range_constraints, equality_constraints, - module_call_graph, + res.module_call_graph, None, # type: ignore[arg-type] ) return upgrader.upgrade(exported_program) diff --git a/torch/_export/serde/upgrade.py b/torch/_export/serde/upgrade.py index 0f6b55ab363..b4c0e2b3eff 100644 --- a/torch/_export/serde/upgrade.py +++ b/torch/_export/serde/upgrade.py @@ -196,6 +196,5 @@ class GraphModuleOpUpgrader: upgraded_program = exported_program._transform(_pass) # NB: we have to retrace the graph_module instead of ep because of some failure. exported_program = export(upgraded_program.module(), inputs, {}) - exported_program._call_spec = upgraded_program.call_spec return exported_program diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 86067e09b56..3932b156fbf 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3977,7 +3977,7 @@ class FallbackKernel(ExternKernelAlloc): kwargs.get(key, None) for key in self.ordered_kwargs_for_cpp_kernel ] - serializer = GraphModuleSerializer(None, None, None) + serializer = GraphModuleSerializer(None, None) named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs) # serialize_outputs diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index b61583171b1..fa6711379fc 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -433,10 +433,7 @@ class ExportedProgram: example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None, dialect: Optional[str] = None, ): - from torch._export.exported_program import ( - _create_graph_module_for_export, - CallSpec, - ) + from torch._export.exported_program import _create_graph_module_for_export from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( InputDim, ) @@ -448,7 +445,6 @@ class ExportedProgram: self._graph_module.meta.update(root.meta) self._graph_signature: ExportGraphSignature = graph_signature - self._call_spec: CallSpec = call_spec self._state_dict: Dict[str, Any] = state_dict self._range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints self._equality_constraints: List[ @@ -512,11 +508,6 @@ class ExportedProgram: for buffer_name in self.graph_signature.buffers: yield buffer_name, self.state_dict[buffer_name] - @property - @compatibility(is_backward_compatible=False) - def call_spec(self): - return self._call_spec - @property @compatibility(is_backward_compatible=False) def range_constraints(self): @@ -537,6 +528,19 @@ class ExportedProgram: def example_inputs(self): return self._example_inputs + @property + @compatibility(is_backward_compatible=False) + def call_spec(self): + from torch._export.exported_program import CallSpec + + if len(self.module_call_graph) == 0: + return CallSpec(in_spec=None, out_spec=None) + assert self.module_call_graph[0].fqn == "" + return CallSpec( + in_spec=self.module_call_graph[0].signature.in_spec, + out_spec=self.module_call_graph[0].signature.out_spec, + ) + @property def dialect(self): return self._dialect