mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/845 Test Plan: CI Differential Revision: D50191531 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111204 Approved by: https://github.com/angelayi
This commit is contained in:
parent
a3e9b80082
commit
ba7b9211ee
6 changed files with 269 additions and 147 deletions
|
|
@ -338,9 +338,10 @@ class TestDeserialize(TestCase):
|
|||
|
||||
def test_sym_bool(self):
|
||||
def f(x, y):
|
||||
return x.size(0) in y
|
||||
assert x.size(0) in y
|
||||
return x + y
|
||||
|
||||
self.check_graph(f, (torch.ones(2), torch.ones(3)))
|
||||
self.check_graph(f, (torch.ones(1), torch.ones(3)))
|
||||
|
||||
def test_shape(self):
|
||||
def f(x):
|
||||
|
|
|
|||
|
|
@ -203,26 +203,70 @@ class Graph:
|
|||
|
||||
|
||||
@dataclass
|
||||
class BackwardSignature:
|
||||
gradients_to_parameters: Dict[str, str]
|
||||
gradients_to_user_inputs: Dict[str, str]
|
||||
loss_output: str
|
||||
class UserInputSpec:
|
||||
arg: Argument
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputToParameterSpec:
|
||||
arg: TensorArgument
|
||||
parameter_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputToBufferSpec:
|
||||
arg: TensorArgument
|
||||
buffer_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputSpec(_Union):
|
||||
user_input: UserInputSpec
|
||||
parameter: InputToParameterSpec
|
||||
buffer: InputToBufferSpec
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserOutputSpec:
|
||||
arg: Argument
|
||||
|
||||
|
||||
@dataclass
|
||||
class LossOutputSpec:
|
||||
arg: TensorArgument
|
||||
|
||||
|
||||
@dataclass
|
||||
class BufferMutationSpec:
|
||||
arg: TensorArgument
|
||||
buffer_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GradientToParameterSpec:
|
||||
arg: TensorArgument
|
||||
parameter_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GradientToUserInputSpec:
|
||||
arg: TensorArgument
|
||||
user_input_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputSpec(_Union):
|
||||
user_output: UserOutputSpec
|
||||
loss_outout: LossOutputSpec
|
||||
buffer_mutation: BufferMutationSpec
|
||||
gradient_to_parameter: GradientToParameterSpec
|
||||
gradient_to_user_input: GradientToUserInputSpec
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphSignature:
|
||||
inputs_to_parameters: Dict[str, str]
|
||||
inputs_to_buffers: Dict[str, str]
|
||||
user_inputs: List[str]
|
||||
user_outputs: List[str]
|
||||
buffers_to_mutate: Dict[str, str]
|
||||
backward_signature: Optional[BackwardSignature]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallSpec:
|
||||
in_spec: str
|
||||
out_spec: str
|
||||
input_specs: List[InputSpec]
|
||||
output_specs: List[OutputSpec]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -249,8 +293,6 @@ class ModuleCallEntry:
|
|||
class GraphModule:
|
||||
graph: Graph
|
||||
signature: GraphSignature
|
||||
# TODO(zhxchen17) Merge call_spec into call graph.
|
||||
call_spec: CallSpec
|
||||
module_call_graph: List[ModuleCallEntry]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -25,15 +25,22 @@ from torch.utils._sympy.value_ranges import ValueRanges
|
|||
from .schema import ( # type: ignore[attr-defined]
|
||||
_Union,
|
||||
Argument,
|
||||
BackwardSignature,
|
||||
CallSpec,
|
||||
BufferMutationSpec,
|
||||
CustomObjArgument,
|
||||
Device,
|
||||
ExportedProgram,
|
||||
GradientToParameterSpec,
|
||||
GradientToUserInputSpec,
|
||||
Graph,
|
||||
GraphArgument,
|
||||
GraphModule,
|
||||
GraphSignature,
|
||||
InputSpec,
|
||||
InputToBufferSpec,
|
||||
InputToParameterSpec,
|
||||
UserInputSpec,
|
||||
UserOutputSpec,
|
||||
LossOutputSpec,
|
||||
Layout,
|
||||
MemoryFormat,
|
||||
ModuleCallEntry,
|
||||
|
|
@ -41,6 +48,7 @@ from .schema import ( # type: ignore[attr-defined]
|
|||
NamedArgument,
|
||||
Node,
|
||||
OptionalTensorArgument,
|
||||
OutputSpec,
|
||||
RangeConstraint,
|
||||
ScalarType,
|
||||
SCHEMA_VERSION,
|
||||
|
|
@ -197,41 +205,6 @@ def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta:
|
|||
layout=_TORCH_TO_SERIALIZE_LAYOUT[t.layout],
|
||||
)
|
||||
|
||||
# TODO(zhxchen17) Remove call spec.
|
||||
def serialize_call_spec(call_spec: torch._export.exported_program.CallSpec) -> CallSpec:
|
||||
return CallSpec(
|
||||
in_spec=treespec_dumps(call_spec.in_spec, TREESPEC_VERSION) if call_spec.in_spec else "",
|
||||
out_spec=treespec_dumps(call_spec.out_spec, TREESPEC_VERSION) if call_spec.out_spec else "",
|
||||
)
|
||||
|
||||
|
||||
def deserialize_call_spec(call_spec: CallSpec) -> torch._export.exported_program.CallSpec:
|
||||
return torch._export.exported_program.CallSpec(
|
||||
in_spec=treespec_loads(call_spec.in_spec) if call_spec.in_spec else None,
|
||||
out_spec=treespec_loads(call_spec.out_spec) if call_spec.out_spec else None,
|
||||
)
|
||||
|
||||
|
||||
def serialize_signature(sig: ep.ExportGraphSignature) -> GraphSignature:
|
||||
if bw_sig := sig.backward_signature:
|
||||
backward_signature = BackwardSignature(
|
||||
gradients_to_parameters=bw_sig.gradients_to_parameters,
|
||||
gradients_to_user_inputs=bw_sig.gradients_to_user_inputs,
|
||||
loss_output=bw_sig.loss_output,
|
||||
)
|
||||
else:
|
||||
backward_signature = None
|
||||
|
||||
graph_signature = GraphSignature(
|
||||
inputs_to_parameters=sig.inputs_to_parameters, # type: ignore[arg-type]
|
||||
inputs_to_buffers=sig.inputs_to_buffers, # type: ignore[arg-type]
|
||||
user_inputs=sig.user_inputs, # type: ignore[arg-type]
|
||||
user_outputs=sig.user_outputs, # type: ignore[arg-type]
|
||||
buffers_to_mutate=sig.buffers_to_mutate, # type: ignore[arg-type]
|
||||
backward_signature=backward_signature,
|
||||
)
|
||||
return graph_signature
|
||||
|
||||
|
||||
def serialize_torch_artifact(artifact) -> bytes:
|
||||
buffer = io.BytesIO()
|
||||
|
|
@ -327,12 +300,10 @@ class GraphModuleSerializer:
|
|||
def __init__(
|
||||
self,
|
||||
graph_signature: ep.ExportGraphSignature,
|
||||
call_spec: torch._export.exported_program.CallSpec,
|
||||
module_call_graph: List[ep.ModuleCallEntry]
|
||||
):
|
||||
self.graph_state = GraphState()
|
||||
self.graph_signature = graph_signature
|
||||
self.call_spec = call_spec
|
||||
self.module_call_graph = module_call_graph
|
||||
|
||||
@contextmanager
|
||||
|
|
@ -683,18 +654,98 @@ class GraphModuleSerializer:
|
|||
self.graph_state.sym_bool_values[name] = serialize_sym_bool(meta_val)
|
||||
return SymBoolArgument.create(as_name=name)
|
||||
|
||||
def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec:
|
||||
if spec.kind == ep.InputKind.USER_INPUT:
|
||||
return InputSpec.create(
|
||||
user_input=UserInputSpec(
|
||||
arg=self.serialize_argument_spec(spec.arg)
|
||||
)
|
||||
)
|
||||
elif spec.kind == ep.InputKind.PARAMETER:
|
||||
assert spec.target is not None
|
||||
assert isinstance(spec.arg, ep.TensorArgument)
|
||||
return InputSpec.create(
|
||||
parameter=InputToParameterSpec(
|
||||
arg=TensorArgument(name=spec.arg.name),
|
||||
parameter_name=spec.target,
|
||||
)
|
||||
)
|
||||
elif spec.kind == ep.InputKind.BUFFER:
|
||||
assert spec.target is not None
|
||||
assert isinstance(spec.arg, ep.TensorArgument)
|
||||
return InputSpec.create(
|
||||
buffer=InputToBufferSpec(
|
||||
arg=TensorArgument(name=spec.arg.name),
|
||||
buffer_name=spec.target,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unknown argument kind: {spec}")
|
||||
|
||||
def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec:
|
||||
if spec.kind == ep.OutputKind.USER_OUTPUT:
|
||||
return OutputSpec.create(
|
||||
user_output=UserOutputSpec(
|
||||
arg=self.serialize_argument_spec(spec.arg)
|
||||
)
|
||||
)
|
||||
elif spec.kind == ep.OutputKind.LOSS_OUTPUT:
|
||||
assert isinstance(spec.arg, ep.TensorArgument)
|
||||
return OutputSpec.create(
|
||||
loss_output=LossOutputSpec(
|
||||
arg=TensorArgument(name=spec.arg.name)
|
||||
)
|
||||
)
|
||||
elif spec.kind == ep.OutputKind.BUFFER_MUTATION:
|
||||
assert spec.target is not None
|
||||
assert isinstance(spec.arg, PyTensorArgument)
|
||||
return OutputSpec.create(
|
||||
buffer_mutation=BufferMutationSpec(
|
||||
arg=TensorArgument(name=spec.arg.name),
|
||||
buffer_name=spec.target,
|
||||
)
|
||||
)
|
||||
elif spec.kind == ep.OutputKind.GRADIENT_TO_PARAMETER:
|
||||
assert spec.target is not None
|
||||
assert isinstance(spec.arg, PyTensorArgument)
|
||||
return OutputSpec.create(
|
||||
gradient_to_parameter=GradientToParameterSpec(
|
||||
arg=TensorArgument(name=spec.arg.name),
|
||||
parameter_name=spec.target,
|
||||
)
|
||||
)
|
||||
elif spec.kind == ep.OutputKind.GRADIENT_TO_USER_INPUT:
|
||||
assert spec.target is not None
|
||||
assert isinstance(spec.arg, PyTensorArgument)
|
||||
return OutputSpec.create(
|
||||
gradient_to_user_input=GradientToUserInputSpec(
|
||||
arg=TensorArgument(name=spec.arg.name),
|
||||
user_input_name=spec.target,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unknown argument kind: {spec}")
|
||||
|
||||
def serialize_signature(self, sig: ep.ExportGraphSignature) -> GraphSignature:
|
||||
return GraphSignature(
|
||||
input_specs=[self.serialize_input_spec(s) for s in sig.input_specs],
|
||||
output_specs=[self.serialize_output_spec(s) for s in sig.output_specs],
|
||||
)
|
||||
|
||||
def serialize_argument_spec(self, x: ep.ArgumentSpec) -> Argument:
|
||||
if isinstance(x, PyTensorArgument):
|
||||
return Argument.create(as_tensor=TensorArgument(name=x.name))
|
||||
elif isinstance(x, PySymIntArgument):
|
||||
return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name))
|
||||
elif isinstance(x, PyConstantArgument):
|
||||
return self.serialize_input(x.value)
|
||||
else:
|
||||
raise AssertionError("TODO")
|
||||
|
||||
def serialize_module_call_signature(self, module_call_signature: ep.ModuleCallSignature) -> ModuleCallSignature:
|
||||
def serialize_argument(x: ep.ArgumentSpec) -> Argument:
|
||||
if isinstance(x, PyTensorArgument):
|
||||
return Argument.create(as_tensor=TensorArgument(name=x.name))
|
||||
elif isinstance(x, PySymIntArgument):
|
||||
return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name))
|
||||
else:
|
||||
assert isinstance(x, PyConstantArgument)
|
||||
return self.serialize_input(x.value)
|
||||
return ModuleCallSignature(
|
||||
inputs=[serialize_argument(x) for x in module_call_signature.inputs],
|
||||
outputs=[serialize_argument(x) for x in module_call_signature.outputs],
|
||||
inputs=[self.serialize_argument_spec(x) for x in module_call_signature.inputs],
|
||||
outputs=[self.serialize_argument_spec(x) for x in module_call_signature.outputs],
|
||||
in_spec=treespec_dumps(module_call_signature.in_spec, TREESPEC_VERSION),
|
||||
out_spec=treespec_dumps(module_call_signature.out_spec, TREESPEC_VERSION),
|
||||
)
|
||||
|
|
@ -821,8 +872,7 @@ class GraphModuleSerializer:
|
|||
|
||||
return GraphModule(
|
||||
graph=graph,
|
||||
signature=serialize_signature(self.graph_signature),
|
||||
call_spec=serialize_call_spec(self.call_spec),
|
||||
signature=self.serialize_signature(self.graph_signature),
|
||||
module_call_graph=self.serialize_module_call_graph(self.module_call_graph),
|
||||
)
|
||||
|
||||
|
|
@ -839,7 +889,6 @@ class ExportedProgramSerializer:
|
|||
serialized_graph_module = (
|
||||
GraphModuleSerializer(
|
||||
exported_program.graph_signature,
|
||||
exported_program.call_spec,
|
||||
exported_program.module_call_graph
|
||||
).serialize(exported_program.graph_module)
|
||||
)
|
||||
|
|
@ -860,6 +909,13 @@ class ExportedProgramSerializer:
|
|||
|
||||
|
||||
class GraphModuleDeserializer:
|
||||
@dataclasses.dataclass
|
||||
class Result:
|
||||
graph_module: torch.fx.GraphModule
|
||||
signature: ep.ExportGraphSignature
|
||||
module_call_graph: List[ep.ModuleCallEntry]
|
||||
names_to_symbols: Dict[str, sympy.Symbol]
|
||||
|
||||
def __init__(self):
|
||||
self.serialized_name_to_node: Dict[str, torch.fx.Node] = {}
|
||||
self.serialized_name_to_meta: Dict[str, MetaType] = {}
|
||||
|
|
@ -1069,47 +1125,73 @@ class GraphModuleDeserializer:
|
|||
|
||||
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
|
||||
|
||||
def deserialize_signature(
|
||||
self,
|
||||
sig: GraphSignature,
|
||||
inputs: List[Argument],
|
||||
outputs: List[Argument]
|
||||
) -> ep.ExportGraphSignature:
|
||||
# TODO(zhxchen17) Remove these after we change serialization schema.
|
||||
def make_argument_spec(arg: Argument) -> ep.ArgumentSpec:
|
||||
if arg.as_tensor is not None:
|
||||
return ep.TensorArgument(name=arg.as_tensor.name)
|
||||
elif arg.as_sym_int is not None:
|
||||
assert arg.as_sym_int.as_name is not None
|
||||
return ep.SymIntArgument(name=arg.as_sym_int.as_name)
|
||||
else:
|
||||
return ep.ConstantArgument(value=arg.value)
|
||||
def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec:
|
||||
if i.user_input is not None:
|
||||
return ep.InputSpec(
|
||||
kind=ep.InputKind.USER_INPUT,
|
||||
arg=self.deserialize_argument_spec(i.user_input.arg),
|
||||
target=None
|
||||
)
|
||||
elif i.parameter is not None:
|
||||
return ep.InputSpec(
|
||||
kind=ep.InputKind.PARAMETER,
|
||||
arg=PyTensorArgument(name=i.parameter.arg.name),
|
||||
target=i.parameter.parameter_name,
|
||||
)
|
||||
elif i.buffer is not None:
|
||||
return ep.InputSpec(
|
||||
kind=ep.InputKind.BUFFER,
|
||||
arg=PyTensorArgument(name=i.buffer.arg.name),
|
||||
target=i.buffer.buffer_name,
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unkown input spec {i}")
|
||||
|
||||
input_specs, output_specs = ep._sig_to_specs(
|
||||
user_inputs=set(sig.user_inputs),
|
||||
inputs_to_parameters=sig.inputs_to_parameters,
|
||||
inputs_to_buffers=sig.inputs_to_buffers,
|
||||
user_outputs=set(sig.user_outputs),
|
||||
buffer_mutations=sig.buffers_to_mutate,
|
||||
grad_params=sig.backward_signature.gradients_to_parameters if sig.backward_signature is not None else {},
|
||||
grad_user_inputs=sig.backward_signature.gradients_to_user_inputs if sig.backward_signature is not None else {},
|
||||
loss_output=sig.backward_signature.loss_output if sig.backward_signature is not None else None,
|
||||
inputs=[make_argument_spec(i) for i in inputs],
|
||||
outputs=[make_argument_spec(o) for o in outputs],
|
||||
def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec:
|
||||
if o.user_output is not None:
|
||||
return ep.OutputSpec(
|
||||
kind=ep.OutputKind.USER_OUTPUT,
|
||||
arg=self.deserialize_argument_spec(o.user_output.arg),
|
||||
target=None,
|
||||
)
|
||||
elif o.loss_output is not None:
|
||||
return ep.OutputSpec(
|
||||
kind=ep.OutputKind.LOSS_OUTPUT,
|
||||
arg=PyTensorArgument(name=o.loss_output.arg.name),
|
||||
target=None,
|
||||
)
|
||||
elif o.buffer_mutation is not None:
|
||||
return ep.OutputSpec(
|
||||
kind=ep.OutputKind.BUFFER_MUTATION,
|
||||
arg=PyTensorArgument(name=o.buffer_mutation.arg.name),
|
||||
target=o.buffer_mutation.buffer_name
|
||||
)
|
||||
elif o.gradient_to_parameter is not None:
|
||||
return ep.OutputSpec(
|
||||
kind=ep.OutputKind.GRADIENT_TO_PARAMETER,
|
||||
arg=PyTensorArgument(name=o.gradient_to_parameter.arg.name),
|
||||
target=o.gradient_to_parameter.parameter_name
|
||||
)
|
||||
elif o.gradient_to_user_input is not None:
|
||||
return ep.OutputSpec(
|
||||
kind=ep.OutputKind.GRADIENT_TO_USER_INPUT,
|
||||
arg=PyTensorArgument(name=o.gradient_to_user_input.arg.name),
|
||||
target=o.gradient_to_user_input.user_input_name
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unknown output spec {o}")
|
||||
|
||||
def deserialize_signature(self, sig: GraphSignature) -> ep.ExportGraphSignature:
|
||||
return ep.ExportGraphSignature(
|
||||
input_specs=[self.deserialize_input_spec(i) for i in sig.input_specs],
|
||||
output_specs=[self.deserialize_output_spec(o) for o in sig.output_specs]
|
||||
)
|
||||
return ep.ExportGraphSignature(input_specs=input_specs, output_specs=output_specs)
|
||||
|
||||
def deserialize(
|
||||
self,
|
||||
serialized_graph_module: GraphModule,
|
||||
symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None,
|
||||
) -> Tuple[
|
||||
torch.fx.GraphModule,
|
||||
ep.ExportGraphSignature,
|
||||
torch._export.exported_program.CallSpec,
|
||||
List[ep.ModuleCallEntry],
|
||||
Dict[str, sympy.Symbol]
|
||||
]:
|
||||
) -> Result:
|
||||
self.shape_env = symbolic_shapes.ShapeEnv(assume_static_by_default=True)
|
||||
self.fake_tensor_mode = FakeTensorMode(
|
||||
allow_fallback_kernels=False,
|
||||
|
|
@ -1121,19 +1203,13 @@ class GraphModuleDeserializer:
|
|||
|
||||
self.deserialize_graph(serialized_graph_module.graph)
|
||||
|
||||
sig = self.deserialize_signature(
|
||||
serialized_graph_module.signature,
|
||||
serialized_graph_module.graph.inputs,
|
||||
serialized_graph_module.graph.outputs
|
||||
)
|
||||
call_spec = deserialize_call_spec(serialized_graph_module.call_spec)
|
||||
sig = self.deserialize_signature(serialized_graph_module.signature)
|
||||
module_call_graph = self.deserialize_module_call_graph(serialized_graph_module.module_call_graph)
|
||||
return (
|
||||
torch._export.exported_program._create_graph_module_for_export(self.module, self.graph),
|
||||
sig,
|
||||
call_spec,
|
||||
module_call_graph,
|
||||
self.symbol_name_to_symbol,
|
||||
return GraphModuleDeserializer.Result(
|
||||
graph_module=torch._export.exported_program._create_graph_module_for_export(self.module, self.graph),
|
||||
signature=sig,
|
||||
module_call_graph=module_call_graph,
|
||||
names_to_symbols=self.symbol_name_to_symbol,
|
||||
)
|
||||
|
||||
def sync_fx_node(self, name: str, fx_node: torch.fx.Node):
|
||||
|
|
@ -1366,18 +1442,18 @@ class GraphModuleDeserializer:
|
|||
ret["source_fn_stack"] = source_fn_st
|
||||
return ret
|
||||
|
||||
def deserialize_module_call_signature(self, module_call_signature: ModuleCallSignature) -> ep.ModuleCallSignature:
|
||||
def deserialize_argument(x: Argument) -> ep.ArgumentSpec:
|
||||
if x.as_tensor is not None:
|
||||
return PyTensorArgument(name=x.as_tensor.name)
|
||||
elif x.as_symint is not None:
|
||||
return PySymIntArgument(name=x.as_symint.as_name)
|
||||
else:
|
||||
return PyConstantArgument(value=self.deserialize_input(x))
|
||||
def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec:
|
||||
if x.as_tensor is not None:
|
||||
return PyTensorArgument(name=x.as_tensor.name)
|
||||
elif x.as_sym_int is not None:
|
||||
return PySymIntArgument(name=x.as_sym_int.as_name)
|
||||
else:
|
||||
return PyConstantArgument(value=self.deserialize_input(x))
|
||||
|
||||
def deserialize_module_call_signature(self, module_call_signature: ModuleCallSignature) -> ep.ModuleCallSignature:
|
||||
return ep.ModuleCallSignature(
|
||||
inputs=[deserialize_argument(x) for x in module_call_signature.inputs],
|
||||
outputs=[deserialize_argument(x) for x in module_call_signature.outputs],
|
||||
inputs=[self.deserialize_argument_spec(x) for x in module_call_signature.inputs],
|
||||
outputs=[self.deserialize_argument_spec(x) for x in module_call_signature.outputs],
|
||||
in_spec=treespec_loads(module_call_signature.in_spec),
|
||||
out_spec=treespec_loads(module_call_signature.out_spec),
|
||||
)
|
||||
|
|
@ -1426,7 +1502,7 @@ class ExportedProgramDeserializer:
|
|||
for k, v in serialized_exported_program.range_constraints.items()
|
||||
}
|
||||
|
||||
graph_module, sig, call_spec, module_call_graph, symbol_name_to_symbol = (
|
||||
res = (
|
||||
GraphModuleDeserializer()
|
||||
.deserialize(
|
||||
serialized_exported_program.graph_module,
|
||||
|
|
@ -1434,7 +1510,7 @@ class ExportedProgramDeserializer:
|
|||
)
|
||||
)
|
||||
range_constraints = self.deserialize_range_constraints(
|
||||
symbol_name_to_range, symbol_name_to_symbol,
|
||||
symbol_name_to_range, res.names_to_symbols,
|
||||
)
|
||||
model_opset_version: Optional[Dict[str, int]] = serialized_exported_program.opset_version
|
||||
self._validate_model_opset_version(model_opset_version)
|
||||
|
|
@ -1445,14 +1521,14 @@ class ExportedProgramDeserializer:
|
|||
equality_constraints = deserialize_equality_constraints(serialized_exported_program.equality_constraints)
|
||||
|
||||
exported_program = ep.ExportedProgram(
|
||||
graph_module,
|
||||
graph_module.graph,
|
||||
sig,
|
||||
call_spec,
|
||||
res.graph_module,
|
||||
res.graph_module.graph,
|
||||
res.signature,
|
||||
None, # TODO(zhxchen17) Remove this.
|
||||
state_dict, # type: ignore[arg-type]
|
||||
range_constraints,
|
||||
equality_constraints,
|
||||
module_call_graph,
|
||||
res.module_call_graph,
|
||||
None, # type: ignore[arg-type]
|
||||
)
|
||||
return upgrader.upgrade(exported_program)
|
||||
|
|
|
|||
|
|
@ -196,6 +196,5 @@ class GraphModuleOpUpgrader:
|
|||
upgraded_program = exported_program._transform(_pass)
|
||||
# NB: we have to retrace the graph_module instead of ep because of some failure.
|
||||
exported_program = export(upgraded_program.module(), inputs, {})
|
||||
exported_program._call_spec = upgraded_program.call_spec
|
||||
|
||||
return exported_program
|
||||
|
|
|
|||
|
|
@ -3977,7 +3977,7 @@ class FallbackKernel(ExternKernelAlloc):
|
|||
kwargs.get(key, None) for key in self.ordered_kwargs_for_cpp_kernel
|
||||
]
|
||||
|
||||
serializer = GraphModuleSerializer(None, None, None)
|
||||
serializer = GraphModuleSerializer(None, None)
|
||||
named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs)
|
||||
|
||||
# serialize_outputs
|
||||
|
|
|
|||
|
|
@ -433,10 +433,7 @@ class ExportedProgram:
|
|||
example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None,
|
||||
dialect: Optional[str] = None,
|
||||
):
|
||||
from torch._export.exported_program import (
|
||||
_create_graph_module_for_export,
|
||||
CallSpec,
|
||||
)
|
||||
from torch._export.exported_program import _create_graph_module_for_export
|
||||
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
|
||||
InputDim,
|
||||
)
|
||||
|
|
@ -448,7 +445,6 @@ class ExportedProgram:
|
|||
self._graph_module.meta.update(root.meta)
|
||||
|
||||
self._graph_signature: ExportGraphSignature = graph_signature
|
||||
self._call_spec: CallSpec = call_spec
|
||||
self._state_dict: Dict[str, Any] = state_dict
|
||||
self._range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints
|
||||
self._equality_constraints: List[
|
||||
|
|
@ -512,11 +508,6 @@ class ExportedProgram:
|
|||
for buffer_name in self.graph_signature.buffers:
|
||||
yield buffer_name, self.state_dict[buffer_name]
|
||||
|
||||
@property
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def call_spec(self):
|
||||
return self._call_spec
|
||||
|
||||
@property
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def range_constraints(self):
|
||||
|
|
@ -537,6 +528,19 @@ class ExportedProgram:
|
|||
def example_inputs(self):
|
||||
return self._example_inputs
|
||||
|
||||
@property
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def call_spec(self):
|
||||
from torch._export.exported_program import CallSpec
|
||||
|
||||
if len(self.module_call_graph) == 0:
|
||||
return CallSpec(in_spec=None, out_spec=None)
|
||||
assert self.module_call_graph[0].fqn == ""
|
||||
return CallSpec(
|
||||
in_spec=self.module_call_graph[0].signature.in_spec,
|
||||
out_spec=self.module_call_graph[0].signature.out_spec,
|
||||
)
|
||||
|
||||
@property
|
||||
def dialect(self):
|
||||
return self._dialect
|
||||
|
|
|
|||
Loading…
Reference in a new issue