mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[export] Enable verifier [2/n] (#113075)
Summary: Turn on verifier check for exportec program ctor. Note that this effectively detect a large surface of spec violations, so we also spend some time fixing them one by one in this diff. Test Plan: CI Differential Revision: D51014944 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113075 Approved by: https://github.com/angelayi
This commit is contained in:
parent
f2963642c2
commit
aa376e31fd
4 changed files with 61 additions and 30 deletions
|
|
@ -303,4 +303,4 @@ class ExportedProgram:
|
|||
range_constraints: Dict[str, RangeConstraint]
|
||||
equality_constraints: List[Tuple[Tuple[str, int], Tuple[str, int]]]
|
||||
schema_version: int
|
||||
example_inputs: Optional[Tuple[List[bytes], Dict[str, bytes]]]
|
||||
dialect: str
|
||||
|
|
|
|||
|
|
@ -11,22 +11,13 @@ import typing
|
|||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
import torch.export.exported_program as ep
|
||||
from torch._export.verifier import load_verifier
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch.fx.experimental import symbolic_shapes
|
||||
from torch.utils._pytree import tree_map_only, treespec_dumps, treespec_loads
|
||||
|
|
@ -984,7 +975,7 @@ class ExportedProgramSerializer:
|
|||
range_constraints=serialized_range_constraints,
|
||||
equality_constraints=serialized_equality_constraints,
|
||||
schema_version=SCHEMA_VERSION,
|
||||
example_inputs=None,
|
||||
dialect=exported_program.dialect,
|
||||
),
|
||||
serialize_torch_artifact(exported_program.state_dict),
|
||||
)
|
||||
|
|
@ -1615,7 +1606,8 @@ class ExportedProgramDeserializer:
|
|||
range_constraints,
|
||||
equality_constraints,
|
||||
res.module_call_graph,
|
||||
None, # type: ignore[arg-type]
|
||||
None,
|
||||
load_verifier(serialized_exported_program.dialect),
|
||||
)
|
||||
return upgrader.upgrade(exported_program)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import inspect
|
||||
import operator
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Dict, final, List, Tuple, Type
|
||||
from typing import Any, Dict, final, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch._ops import HigherOrderOperator, OpOverload
|
||||
|
|
@ -41,7 +41,14 @@ def _check_val(node: torch.fx.Node) -> None:
|
|||
return all(_check_correct_val(x) for x in val)
|
||||
return False
|
||||
|
||||
def _no_returns(op):
|
||||
if not isinstance(op, OpOverload):
|
||||
return False
|
||||
return len(op._schema.returns) == 0
|
||||
|
||||
if "val" not in node.meta:
|
||||
if node.op == "call_function" and _no_returns(node.target):
|
||||
return
|
||||
raise SpecViolationError(f"Node.meta {node.name} is missing val field.")
|
||||
|
||||
val = node.meta["val"]
|
||||
|
|
@ -50,7 +57,7 @@ def _check_val(node: torch.fx.Node) -> None:
|
|||
|
||||
|
||||
class _VerifierMeta(type):
|
||||
__registry: Dict[str, Type['Verifier']] = {}
|
||||
_registry: Dict[str, Type['Verifier']] = {}
|
||||
|
||||
def __new__(metacls, name, bases, attrs):
|
||||
if bases:
|
||||
|
|
@ -64,7 +71,7 @@ class _VerifierMeta(type):
|
|||
|
||||
assert isinstance(attrs["dialect"], str)
|
||||
ret = type.__new__(metacls, name, bases, attrs)
|
||||
metacls.__registry[attrs["dialect"]] = ret # type: ignore[assignment]
|
||||
metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment]
|
||||
return ret
|
||||
|
||||
|
||||
|
|
@ -72,7 +79,21 @@ class Verifier(metaclass=_VerifierMeta):
|
|||
dialect = "ATEN"
|
||||
|
||||
def allowed_builtin_ops(self) -> List:
|
||||
return [operator.getitem, operator.add, operator.mul, operator.sub]
|
||||
return [
|
||||
operator.getitem,
|
||||
operator.add,
|
||||
operator.mul,
|
||||
operator.sub,
|
||||
operator.truediv,
|
||||
operator.ge,
|
||||
operator.le,
|
||||
operator.gt,
|
||||
operator.lt,
|
||||
operator.eq,
|
||||
operator.ne,
|
||||
operator.floordiv,
|
||||
operator.mod,
|
||||
]
|
||||
|
||||
def allowed_op_types(self) -> Tuple[Type[Any], ...]:
|
||||
return (OpOverload, HigherOrderOperator)
|
||||
|
|
@ -95,7 +116,14 @@ class Verifier(metaclass=_VerifierMeta):
|
|||
# TODO Enforce type checking in the constructor.
|
||||
return
|
||||
self._check_graph_module(ep.graph_module)
|
||||
_verify_exported_program_signature(ep)
|
||||
try:
|
||||
_verify_exported_program_signature(ep)
|
||||
except SpecViolationError as e:
|
||||
# TODO Remove this branch.
|
||||
if ep.dialect == "EDGE": # !!! Don't change this allowlist. !!!
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
|
||||
@final
|
||||
def _check_graph_module(self, gm: torch.fx.GraphModule) -> None:
|
||||
|
|
@ -128,7 +156,7 @@ class Verifier(metaclass=_VerifierMeta):
|
|||
raise SpecViolationError(
|
||||
f"operator '{op}' is not functional"
|
||||
)
|
||||
self.check_valid_op(op)
|
||||
self.check_valid_op(op)
|
||||
|
||||
for mod in gm.modules():
|
||||
if not isinstance(mod, torch.fx.GraphModule):
|
||||
|
|
@ -154,6 +182,16 @@ class Verifier(metaclass=_VerifierMeta):
|
|||
)
|
||||
|
||||
attr = getattr(mod, node.target)
|
||||
if isinstance(attr, torch.nn.Module):
|
||||
def _is_type(name, ty):
|
||||
return isinstance(getattr(attr, name, None), ty)
|
||||
if type(attr).__name__ == "LoweredBackendModule" \
|
||||
and _is_type("backend_id", str) \
|
||||
and _is_type("processed_bytes", bytes) \
|
||||
and _is_type("compile_specs", list) \
|
||||
and hasattr(attr, "original_module"):
|
||||
continue
|
||||
|
||||
if not isinstance(attr, _allowed_getattr_types()):
|
||||
raise SpecViolationError(
|
||||
f"Invalid get_attr type {type(attr)}. \n"
|
||||
|
|
@ -163,6 +201,9 @@ class Verifier(metaclass=_VerifierMeta):
|
|||
|
||||
elif node.op == "placeholder":
|
||||
_check_val(node)
|
||||
# TODO(zhxchen17)
|
||||
# elif node.op == "output":
|
||||
# _check_flattened_outputs()
|
||||
|
||||
self.check_additional(gm)
|
||||
|
||||
|
|
@ -281,13 +322,7 @@ def _verify_exported_program_signature(exported_program) -> None:
|
|||
assert output_node.op == "output"
|
||||
output_nodes = [arg.name for arg in output_node.args[0]]
|
||||
|
||||
total_gs_outputs = (
|
||||
len(gs.buffers_to_mutate) +
|
||||
len(gs.user_outputs) +
|
||||
len(bs_grad_to_param) +
|
||||
len(bs_grad_to_user_inputs)
|
||||
)
|
||||
if len(output_nodes) != total_gs_outputs:
|
||||
if len(output_nodes) != len(gs.output_specs):
|
||||
raise SpecViolationError(
|
||||
f"Number of output nodes {len(output_nodes)} is different "
|
||||
"Than the number of outputs specified by the graph signature: \n"
|
||||
|
|
@ -317,3 +352,9 @@ def _verify_exported_program_signature(exported_program) -> None:
|
|||
"order or is not found in the "
|
||||
f"exported program's user_output list: {gs.user_output}. "
|
||||
)
|
||||
|
||||
|
||||
def load_verifier(dialect: str) -> Optional[Type[Verifier]]:
|
||||
if dialect == "ATEN":
|
||||
return _VerifierMeta._registry.get(dialect)
|
||||
return _VerifierMeta._registry[dialect]
|
||||
|
|
|
|||
|
|
@ -122,10 +122,8 @@ class ExportedProgram:
|
|||
verifier = Verifier
|
||||
assert issubclass(verifier, Verifier)
|
||||
self._verifier = verifier
|
||||
|
||||
# Validate should be always the last step of the constructor.
|
||||
# TODO(zhxchen17) Uncomment the following line.
|
||||
# self.verifier().check(self)
|
||||
self.verifier().check(self)
|
||||
|
||||
@property
|
||||
@compatibility(is_backward_compatible=False)
|
||||
|
|
|
|||
Loading…
Reference in a new issue