[export] Enable verifier [2/n] (#113075)

Summary: Turn on verifier check for exportec program ctor. Note that this effectively detect a large surface of spec violations, so we also spend some time fixing them one by one in this diff.

Test Plan: CI

Differential Revision: D51014944

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113075
Approved by: https://github.com/angelayi
This commit is contained in:
Zhengxu Chen 2023-11-08 03:32:11 +00:00 committed by PyTorch MergeBot
parent f2963642c2
commit aa376e31fd
4 changed files with 61 additions and 30 deletions

View file

@ -303,4 +303,4 @@ class ExportedProgram:
range_constraints: Dict[str, RangeConstraint]
equality_constraints: List[Tuple[Tuple[str, int], Tuple[str, int]]]
schema_version: int
example_inputs: Optional[Tuple[List[bytes], Dict[str, bytes]]]
dialect: str

View file

@ -11,22 +11,13 @@ import typing
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import (
Any,
Callable,
cast,
Dict,
Iterator,
List,
Optional,
Tuple,
Union,
)
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union
import sympy
import torch
import torch.export.exported_program as ep
from torch._export.verifier import load_verifier
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.fx.experimental import symbolic_shapes
from torch.utils._pytree import tree_map_only, treespec_dumps, treespec_loads
@ -984,7 +975,7 @@ class ExportedProgramSerializer:
range_constraints=serialized_range_constraints,
equality_constraints=serialized_equality_constraints,
schema_version=SCHEMA_VERSION,
example_inputs=None,
dialect=exported_program.dialect,
),
serialize_torch_artifact(exported_program.state_dict),
)
@ -1615,7 +1606,8 @@ class ExportedProgramDeserializer:
range_constraints,
equality_constraints,
res.module_call_graph,
None, # type: ignore[arg-type]
None,
load_verifier(serialized_exported_program.dialect),
)
return upgrader.upgrade(exported_program)

View file

@ -1,7 +1,7 @@
import inspect
import operator
from collections.abc import Iterable
from typing import Any, Dict, final, List, Tuple, Type
from typing import Any, Dict, final, List, Optional, Tuple, Type
import torch
from torch._ops import HigherOrderOperator, OpOverload
@ -41,7 +41,14 @@ def _check_val(node: torch.fx.Node) -> None:
return all(_check_correct_val(x) for x in val)
return False
def _no_returns(op):
if not isinstance(op, OpOverload):
return False
return len(op._schema.returns) == 0
if "val" not in node.meta:
if node.op == "call_function" and _no_returns(node.target):
return
raise SpecViolationError(f"Node.meta {node.name} is missing val field.")
val = node.meta["val"]
@ -50,7 +57,7 @@ def _check_val(node: torch.fx.Node) -> None:
class _VerifierMeta(type):
__registry: Dict[str, Type['Verifier']] = {}
_registry: Dict[str, Type['Verifier']] = {}
def __new__(metacls, name, bases, attrs):
if bases:
@ -64,7 +71,7 @@ class _VerifierMeta(type):
assert isinstance(attrs["dialect"], str)
ret = type.__new__(metacls, name, bases, attrs)
metacls.__registry[attrs["dialect"]] = ret # type: ignore[assignment]
metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment]
return ret
@ -72,7 +79,21 @@ class Verifier(metaclass=_VerifierMeta):
dialect = "ATEN"
def allowed_builtin_ops(self) -> List:
return [operator.getitem, operator.add, operator.mul, operator.sub]
return [
operator.getitem,
operator.add,
operator.mul,
operator.sub,
operator.truediv,
operator.ge,
operator.le,
operator.gt,
operator.lt,
operator.eq,
operator.ne,
operator.floordiv,
operator.mod,
]
def allowed_op_types(self) -> Tuple[Type[Any], ...]:
return (OpOverload, HigherOrderOperator)
@ -95,7 +116,14 @@ class Verifier(metaclass=_VerifierMeta):
# TODO Enforce type checking in the constructor.
return
self._check_graph_module(ep.graph_module)
_verify_exported_program_signature(ep)
try:
_verify_exported_program_signature(ep)
except SpecViolationError as e:
# TODO Remove this branch.
if ep.dialect == "EDGE": # !!! Don't change this allowlist. !!!
pass
else:
raise e
@final
def _check_graph_module(self, gm: torch.fx.GraphModule) -> None:
@ -128,7 +156,7 @@ class Verifier(metaclass=_VerifierMeta):
raise SpecViolationError(
f"operator '{op}' is not functional"
)
self.check_valid_op(op)
self.check_valid_op(op)
for mod in gm.modules():
if not isinstance(mod, torch.fx.GraphModule):
@ -154,6 +182,16 @@ class Verifier(metaclass=_VerifierMeta):
)
attr = getattr(mod, node.target)
if isinstance(attr, torch.nn.Module):
def _is_type(name, ty):
return isinstance(getattr(attr, name, None), ty)
if type(attr).__name__ == "LoweredBackendModule" \
and _is_type("backend_id", str) \
and _is_type("processed_bytes", bytes) \
and _is_type("compile_specs", list) \
and hasattr(attr, "original_module"):
continue
if not isinstance(attr, _allowed_getattr_types()):
raise SpecViolationError(
f"Invalid get_attr type {type(attr)}. \n"
@ -163,6 +201,9 @@ class Verifier(metaclass=_VerifierMeta):
elif node.op == "placeholder":
_check_val(node)
# TODO(zhxchen17)
# elif node.op == "output":
# _check_flattened_outputs()
self.check_additional(gm)
@ -281,13 +322,7 @@ def _verify_exported_program_signature(exported_program) -> None:
assert output_node.op == "output"
output_nodes = [arg.name for arg in output_node.args[0]]
total_gs_outputs = (
len(gs.buffers_to_mutate) +
len(gs.user_outputs) +
len(bs_grad_to_param) +
len(bs_grad_to_user_inputs)
)
if len(output_nodes) != total_gs_outputs:
if len(output_nodes) != len(gs.output_specs):
raise SpecViolationError(
f"Number of output nodes {len(output_nodes)} is different "
"Than the number of outputs specified by the graph signature: \n"
@ -317,3 +352,9 @@ def _verify_exported_program_signature(exported_program) -> None:
"order or is not found in the "
f"exported program's user_output list: {gs.user_output}. "
)
def load_verifier(dialect: str) -> Optional[Type[Verifier]]:
if dialect == "ATEN":
return _VerifierMeta._registry.get(dialect)
return _VerifierMeta._registry[dialect]

View file

@ -122,10 +122,8 @@ class ExportedProgram:
verifier = Verifier
assert issubclass(verifier, Verifier)
self._verifier = verifier
# Validate should be always the last step of the constructor.
# TODO(zhxchen17) Uncomment the following line.
# self.verifier().check(self)
self.verifier().check(self)
@property
@compatibility(is_backward_compatible=False)