[export] Update serialization schema to input/output specs. (#845) (#111204)

Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/845

Test Plan: CI

Differential Revision: D50191531

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111204
Approved by: https://github.com/angelayi
This commit is contained in:
Zhengxu Chen 2023-10-13 22:19:56 +00:00 committed by PyTorch MergeBot
parent a3e9b80082
commit ba7b9211ee
6 changed files with 269 additions and 147 deletions

View file

@ -338,9 +338,10 @@ class TestDeserialize(TestCase):
def test_sym_bool(self):
def f(x, y):
return x.size(0) in y
assert x.size(0) in y
return x + y
self.check_graph(f, (torch.ones(2), torch.ones(3)))
self.check_graph(f, (torch.ones(1), torch.ones(3)))
def test_shape(self):
def f(x):

View file

@ -203,26 +203,70 @@ class Graph:
@dataclass
class BackwardSignature:
gradients_to_parameters: Dict[str, str]
gradients_to_user_inputs: Dict[str, str]
loss_output: str
class UserInputSpec:
arg: Argument
@dataclass
class InputToParameterSpec:
arg: TensorArgument
parameter_name: str
@dataclass
class InputToBufferSpec:
arg: TensorArgument
buffer_name: str
@dataclass
class InputSpec(_Union):
user_input: UserInputSpec
parameter: InputToParameterSpec
buffer: InputToBufferSpec
@dataclass
class UserOutputSpec:
arg: Argument
@dataclass
class LossOutputSpec:
arg: TensorArgument
@dataclass
class BufferMutationSpec:
arg: TensorArgument
buffer_name: str
@dataclass
class GradientToParameterSpec:
arg: TensorArgument
parameter_name: str
@dataclass
class GradientToUserInputSpec:
arg: TensorArgument
user_input_name: str
@dataclass
class OutputSpec(_Union):
user_output: UserOutputSpec
loss_outout: LossOutputSpec
buffer_mutation: BufferMutationSpec
gradient_to_parameter: GradientToParameterSpec
gradient_to_user_input: GradientToUserInputSpec
@dataclass
class GraphSignature:
inputs_to_parameters: Dict[str, str]
inputs_to_buffers: Dict[str, str]
user_inputs: List[str]
user_outputs: List[str]
buffers_to_mutate: Dict[str, str]
backward_signature: Optional[BackwardSignature]
@dataclass
class CallSpec:
in_spec: str
out_spec: str
input_specs: List[InputSpec]
output_specs: List[OutputSpec]
@dataclass
@ -249,8 +293,6 @@ class ModuleCallEntry:
class GraphModule:
graph: Graph
signature: GraphSignature
# TODO(zhxchen17) Merge call_spec into call graph.
call_spec: CallSpec
module_call_graph: List[ModuleCallEntry]

View file

@ -25,15 +25,22 @@ from torch.utils._sympy.value_ranges import ValueRanges
from .schema import ( # type: ignore[attr-defined]
_Union,
Argument,
BackwardSignature,
CallSpec,
BufferMutationSpec,
CustomObjArgument,
Device,
ExportedProgram,
GradientToParameterSpec,
GradientToUserInputSpec,
Graph,
GraphArgument,
GraphModule,
GraphSignature,
InputSpec,
InputToBufferSpec,
InputToParameterSpec,
UserInputSpec,
UserOutputSpec,
LossOutputSpec,
Layout,
MemoryFormat,
ModuleCallEntry,
@ -41,6 +48,7 @@ from .schema import ( # type: ignore[attr-defined]
NamedArgument,
Node,
OptionalTensorArgument,
OutputSpec,
RangeConstraint,
ScalarType,
SCHEMA_VERSION,
@ -197,41 +205,6 @@ def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta:
layout=_TORCH_TO_SERIALIZE_LAYOUT[t.layout],
)
# TODO(zhxchen17) Remove call spec.
def serialize_call_spec(call_spec: torch._export.exported_program.CallSpec) -> CallSpec:
return CallSpec(
in_spec=treespec_dumps(call_spec.in_spec, TREESPEC_VERSION) if call_spec.in_spec else "",
out_spec=treespec_dumps(call_spec.out_spec, TREESPEC_VERSION) if call_spec.out_spec else "",
)
def deserialize_call_spec(call_spec: CallSpec) -> torch._export.exported_program.CallSpec:
return torch._export.exported_program.CallSpec(
in_spec=treespec_loads(call_spec.in_spec) if call_spec.in_spec else None,
out_spec=treespec_loads(call_spec.out_spec) if call_spec.out_spec else None,
)
def serialize_signature(sig: ep.ExportGraphSignature) -> GraphSignature:
if bw_sig := sig.backward_signature:
backward_signature = BackwardSignature(
gradients_to_parameters=bw_sig.gradients_to_parameters,
gradients_to_user_inputs=bw_sig.gradients_to_user_inputs,
loss_output=bw_sig.loss_output,
)
else:
backward_signature = None
graph_signature = GraphSignature(
inputs_to_parameters=sig.inputs_to_parameters, # type: ignore[arg-type]
inputs_to_buffers=sig.inputs_to_buffers, # type: ignore[arg-type]
user_inputs=sig.user_inputs, # type: ignore[arg-type]
user_outputs=sig.user_outputs, # type: ignore[arg-type]
buffers_to_mutate=sig.buffers_to_mutate, # type: ignore[arg-type]
backward_signature=backward_signature,
)
return graph_signature
def serialize_torch_artifact(artifact) -> bytes:
buffer = io.BytesIO()
@ -327,12 +300,10 @@ class GraphModuleSerializer:
def __init__(
self,
graph_signature: ep.ExportGraphSignature,
call_spec: torch._export.exported_program.CallSpec,
module_call_graph: List[ep.ModuleCallEntry]
):
self.graph_state = GraphState()
self.graph_signature = graph_signature
self.call_spec = call_spec
self.module_call_graph = module_call_graph
@contextmanager
@ -683,18 +654,98 @@ class GraphModuleSerializer:
self.graph_state.sym_bool_values[name] = serialize_sym_bool(meta_val)
return SymBoolArgument.create(as_name=name)
def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec:
if spec.kind == ep.InputKind.USER_INPUT:
return InputSpec.create(
user_input=UserInputSpec(
arg=self.serialize_argument_spec(spec.arg)
)
)
elif spec.kind == ep.InputKind.PARAMETER:
assert spec.target is not None
assert isinstance(spec.arg, ep.TensorArgument)
return InputSpec.create(
parameter=InputToParameterSpec(
arg=TensorArgument(name=spec.arg.name),
parameter_name=spec.target,
)
)
elif spec.kind == ep.InputKind.BUFFER:
assert spec.target is not None
assert isinstance(spec.arg, ep.TensorArgument)
return InputSpec.create(
buffer=InputToBufferSpec(
arg=TensorArgument(name=spec.arg.name),
buffer_name=spec.target,
)
)
else:
raise AssertionError(f"Unknown argument kind: {spec}")
def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec:
if spec.kind == ep.OutputKind.USER_OUTPUT:
return OutputSpec.create(
user_output=UserOutputSpec(
arg=self.serialize_argument_spec(spec.arg)
)
)
elif spec.kind == ep.OutputKind.LOSS_OUTPUT:
assert isinstance(spec.arg, ep.TensorArgument)
return OutputSpec.create(
loss_output=LossOutputSpec(
arg=TensorArgument(name=spec.arg.name)
)
)
elif spec.kind == ep.OutputKind.BUFFER_MUTATION:
assert spec.target is not None
assert isinstance(spec.arg, PyTensorArgument)
return OutputSpec.create(
buffer_mutation=BufferMutationSpec(
arg=TensorArgument(name=spec.arg.name),
buffer_name=spec.target,
)
)
elif spec.kind == ep.OutputKind.GRADIENT_TO_PARAMETER:
assert spec.target is not None
assert isinstance(spec.arg, PyTensorArgument)
return OutputSpec.create(
gradient_to_parameter=GradientToParameterSpec(
arg=TensorArgument(name=spec.arg.name),
parameter_name=spec.target,
)
)
elif spec.kind == ep.OutputKind.GRADIENT_TO_USER_INPUT:
assert spec.target is not None
assert isinstance(spec.arg, PyTensorArgument)
return OutputSpec.create(
gradient_to_user_input=GradientToUserInputSpec(
arg=TensorArgument(name=spec.arg.name),
user_input_name=spec.target,
)
)
else:
raise AssertionError(f"Unknown argument kind: {spec}")
def serialize_signature(self, sig: ep.ExportGraphSignature) -> GraphSignature:
return GraphSignature(
input_specs=[self.serialize_input_spec(s) for s in sig.input_specs],
output_specs=[self.serialize_output_spec(s) for s in sig.output_specs],
)
def serialize_argument_spec(self, x: ep.ArgumentSpec) -> Argument:
if isinstance(x, PyTensorArgument):
return Argument.create(as_tensor=TensorArgument(name=x.name))
elif isinstance(x, PySymIntArgument):
return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name))
elif isinstance(x, PyConstantArgument):
return self.serialize_input(x.value)
else:
raise AssertionError("TODO")
def serialize_module_call_signature(self, module_call_signature: ep.ModuleCallSignature) -> ModuleCallSignature:
def serialize_argument(x: ep.ArgumentSpec) -> Argument:
if isinstance(x, PyTensorArgument):
return Argument.create(as_tensor=TensorArgument(name=x.name))
elif isinstance(x, PySymIntArgument):
return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name))
else:
assert isinstance(x, PyConstantArgument)
return self.serialize_input(x.value)
return ModuleCallSignature(
inputs=[serialize_argument(x) for x in module_call_signature.inputs],
outputs=[serialize_argument(x) for x in module_call_signature.outputs],
inputs=[self.serialize_argument_spec(x) for x in module_call_signature.inputs],
outputs=[self.serialize_argument_spec(x) for x in module_call_signature.outputs],
in_spec=treespec_dumps(module_call_signature.in_spec, TREESPEC_VERSION),
out_spec=treespec_dumps(module_call_signature.out_spec, TREESPEC_VERSION),
)
@ -821,8 +872,7 @@ class GraphModuleSerializer:
return GraphModule(
graph=graph,
signature=serialize_signature(self.graph_signature),
call_spec=serialize_call_spec(self.call_spec),
signature=self.serialize_signature(self.graph_signature),
module_call_graph=self.serialize_module_call_graph(self.module_call_graph),
)
@ -839,7 +889,6 @@ class ExportedProgramSerializer:
serialized_graph_module = (
GraphModuleSerializer(
exported_program.graph_signature,
exported_program.call_spec,
exported_program.module_call_graph
).serialize(exported_program.graph_module)
)
@ -860,6 +909,13 @@ class ExportedProgramSerializer:
class GraphModuleDeserializer:
@dataclasses.dataclass
class Result:
graph_module: torch.fx.GraphModule
signature: ep.ExportGraphSignature
module_call_graph: List[ep.ModuleCallEntry]
names_to_symbols: Dict[str, sympy.Symbol]
def __init__(self):
self.serialized_name_to_node: Dict[str, torch.fx.Node] = {}
self.serialized_name_to_meta: Dict[str, MetaType] = {}
@ -1069,47 +1125,73 @@ class GraphModuleDeserializer:
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
def deserialize_signature(
self,
sig: GraphSignature,
inputs: List[Argument],
outputs: List[Argument]
) -> ep.ExportGraphSignature:
# TODO(zhxchen17) Remove these after we change serialization schema.
def make_argument_spec(arg: Argument) -> ep.ArgumentSpec:
if arg.as_tensor is not None:
return ep.TensorArgument(name=arg.as_tensor.name)
elif arg.as_sym_int is not None:
assert arg.as_sym_int.as_name is not None
return ep.SymIntArgument(name=arg.as_sym_int.as_name)
else:
return ep.ConstantArgument(value=arg.value)
def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec:
if i.user_input is not None:
return ep.InputSpec(
kind=ep.InputKind.USER_INPUT,
arg=self.deserialize_argument_spec(i.user_input.arg),
target=None
)
elif i.parameter is not None:
return ep.InputSpec(
kind=ep.InputKind.PARAMETER,
arg=PyTensorArgument(name=i.parameter.arg.name),
target=i.parameter.parameter_name,
)
elif i.buffer is not None:
return ep.InputSpec(
kind=ep.InputKind.BUFFER,
arg=PyTensorArgument(name=i.buffer.arg.name),
target=i.buffer.buffer_name,
)
else:
raise AssertionError(f"Unkown input spec {i}")
input_specs, output_specs = ep._sig_to_specs(
user_inputs=set(sig.user_inputs),
inputs_to_parameters=sig.inputs_to_parameters,
inputs_to_buffers=sig.inputs_to_buffers,
user_outputs=set(sig.user_outputs),
buffer_mutations=sig.buffers_to_mutate,
grad_params=sig.backward_signature.gradients_to_parameters if sig.backward_signature is not None else {},
grad_user_inputs=sig.backward_signature.gradients_to_user_inputs if sig.backward_signature is not None else {},
loss_output=sig.backward_signature.loss_output if sig.backward_signature is not None else None,
inputs=[make_argument_spec(i) for i in inputs],
outputs=[make_argument_spec(o) for o in outputs],
def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec:
if o.user_output is not None:
return ep.OutputSpec(
kind=ep.OutputKind.USER_OUTPUT,
arg=self.deserialize_argument_spec(o.user_output.arg),
target=None,
)
elif o.loss_output is not None:
return ep.OutputSpec(
kind=ep.OutputKind.LOSS_OUTPUT,
arg=PyTensorArgument(name=o.loss_output.arg.name),
target=None,
)
elif o.buffer_mutation is not None:
return ep.OutputSpec(
kind=ep.OutputKind.BUFFER_MUTATION,
arg=PyTensorArgument(name=o.buffer_mutation.arg.name),
target=o.buffer_mutation.buffer_name
)
elif o.gradient_to_parameter is not None:
return ep.OutputSpec(
kind=ep.OutputKind.GRADIENT_TO_PARAMETER,
arg=PyTensorArgument(name=o.gradient_to_parameter.arg.name),
target=o.gradient_to_parameter.parameter_name
)
elif o.gradient_to_user_input is not None:
return ep.OutputSpec(
kind=ep.OutputKind.GRADIENT_TO_USER_INPUT,
arg=PyTensorArgument(name=o.gradient_to_user_input.arg.name),
target=o.gradient_to_user_input.user_input_name
)
else:
raise AssertionError(f"Unknown output spec {o}")
def deserialize_signature(self, sig: GraphSignature) -> ep.ExportGraphSignature:
return ep.ExportGraphSignature(
input_specs=[self.deserialize_input_spec(i) for i in sig.input_specs],
output_specs=[self.deserialize_output_spec(o) for o in sig.output_specs]
)
return ep.ExportGraphSignature(input_specs=input_specs, output_specs=output_specs)
def deserialize(
self,
serialized_graph_module: GraphModule,
symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None,
) -> Tuple[
torch.fx.GraphModule,
ep.ExportGraphSignature,
torch._export.exported_program.CallSpec,
List[ep.ModuleCallEntry],
Dict[str, sympy.Symbol]
]:
) -> Result:
self.shape_env = symbolic_shapes.ShapeEnv(assume_static_by_default=True)
self.fake_tensor_mode = FakeTensorMode(
allow_fallback_kernels=False,
@ -1121,19 +1203,13 @@ class GraphModuleDeserializer:
self.deserialize_graph(serialized_graph_module.graph)
sig = self.deserialize_signature(
serialized_graph_module.signature,
serialized_graph_module.graph.inputs,
serialized_graph_module.graph.outputs
)
call_spec = deserialize_call_spec(serialized_graph_module.call_spec)
sig = self.deserialize_signature(serialized_graph_module.signature)
module_call_graph = self.deserialize_module_call_graph(serialized_graph_module.module_call_graph)
return (
torch._export.exported_program._create_graph_module_for_export(self.module, self.graph),
sig,
call_spec,
module_call_graph,
self.symbol_name_to_symbol,
return GraphModuleDeserializer.Result(
graph_module=torch._export.exported_program._create_graph_module_for_export(self.module, self.graph),
signature=sig,
module_call_graph=module_call_graph,
names_to_symbols=self.symbol_name_to_symbol,
)
def sync_fx_node(self, name: str, fx_node: torch.fx.Node):
@ -1366,18 +1442,18 @@ class GraphModuleDeserializer:
ret["source_fn_stack"] = source_fn_st
return ret
def deserialize_module_call_signature(self, module_call_signature: ModuleCallSignature) -> ep.ModuleCallSignature:
def deserialize_argument(x: Argument) -> ep.ArgumentSpec:
if x.as_tensor is not None:
return PyTensorArgument(name=x.as_tensor.name)
elif x.as_symint is not None:
return PySymIntArgument(name=x.as_symint.as_name)
else:
return PyConstantArgument(value=self.deserialize_input(x))
def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec:
if x.as_tensor is not None:
return PyTensorArgument(name=x.as_tensor.name)
elif x.as_sym_int is not None:
return PySymIntArgument(name=x.as_sym_int.as_name)
else:
return PyConstantArgument(value=self.deserialize_input(x))
def deserialize_module_call_signature(self, module_call_signature: ModuleCallSignature) -> ep.ModuleCallSignature:
return ep.ModuleCallSignature(
inputs=[deserialize_argument(x) for x in module_call_signature.inputs],
outputs=[deserialize_argument(x) for x in module_call_signature.outputs],
inputs=[self.deserialize_argument_spec(x) for x in module_call_signature.inputs],
outputs=[self.deserialize_argument_spec(x) for x in module_call_signature.outputs],
in_spec=treespec_loads(module_call_signature.in_spec),
out_spec=treespec_loads(module_call_signature.out_spec),
)
@ -1426,7 +1502,7 @@ class ExportedProgramDeserializer:
for k, v in serialized_exported_program.range_constraints.items()
}
graph_module, sig, call_spec, module_call_graph, symbol_name_to_symbol = (
res = (
GraphModuleDeserializer()
.deserialize(
serialized_exported_program.graph_module,
@ -1434,7 +1510,7 @@ class ExportedProgramDeserializer:
)
)
range_constraints = self.deserialize_range_constraints(
symbol_name_to_range, symbol_name_to_symbol,
symbol_name_to_range, res.names_to_symbols,
)
model_opset_version: Optional[Dict[str, int]] = serialized_exported_program.opset_version
self._validate_model_opset_version(model_opset_version)
@ -1445,14 +1521,14 @@ class ExportedProgramDeserializer:
equality_constraints = deserialize_equality_constraints(serialized_exported_program.equality_constraints)
exported_program = ep.ExportedProgram(
graph_module,
graph_module.graph,
sig,
call_spec,
res.graph_module,
res.graph_module.graph,
res.signature,
None, # TODO(zhxchen17) Remove this.
state_dict, # type: ignore[arg-type]
range_constraints,
equality_constraints,
module_call_graph,
res.module_call_graph,
None, # type: ignore[arg-type]
)
return upgrader.upgrade(exported_program)

View file

@ -196,6 +196,5 @@ class GraphModuleOpUpgrader:
upgraded_program = exported_program._transform(_pass)
# NB: we have to retrace the graph_module instead of ep because of some failure.
exported_program = export(upgraded_program.module(), inputs, {})
exported_program._call_spec = upgraded_program.call_spec
return exported_program

View file

@ -3977,7 +3977,7 @@ class FallbackKernel(ExternKernelAlloc):
kwargs.get(key, None) for key in self.ordered_kwargs_for_cpp_kernel
]
serializer = GraphModuleSerializer(None, None, None)
serializer = GraphModuleSerializer(None, None)
named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs)
# serialize_outputs

View file

@ -433,10 +433,7 @@ class ExportedProgram:
example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None,
dialect: Optional[str] = None,
):
from torch._export.exported_program import (
_create_graph_module_for_export,
CallSpec,
)
from torch._export.exported_program import _create_graph_module_for_export
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
InputDim,
)
@ -448,7 +445,6 @@ class ExportedProgram:
self._graph_module.meta.update(root.meta)
self._graph_signature: ExportGraphSignature = graph_signature
self._call_spec: CallSpec = call_spec
self._state_dict: Dict[str, Any] = state_dict
self._range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints
self._equality_constraints: List[
@ -512,11 +508,6 @@ class ExportedProgram:
for buffer_name in self.graph_signature.buffers:
yield buffer_name, self.state_dict[buffer_name]
@property
@compatibility(is_backward_compatible=False)
def call_spec(self):
return self._call_spec
@property
@compatibility(is_backward_compatible=False)
def range_constraints(self):
@ -537,6 +528,19 @@ class ExportedProgram:
def example_inputs(self):
return self._example_inputs
@property
@compatibility(is_backward_compatible=False)
def call_spec(self):
from torch._export.exported_program import CallSpec
if len(self.module_call_graph) == 0:
return CallSpec(in_spec=None, out_spec=None)
assert self.module_call_graph[0].fqn == ""
return CallSpec(
in_spec=self.module_call_graph[0].signature.in_spec,
out_spec=self.module_call_graph[0].signature.out_spec,
)
@property
def dialect(self):
return self._dialect