mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Revert D26403094: ns for fx - stubs of the three APIs (compare weights, activations, activations with shadow)
Test Plan: revert-hammer
Differential Revision:
D26403094 (37622db76a)
Original commit changeset: 9752331d4ae0
fbshipit-source-id: f0a32d443a29b25af33d90420dfd1bada40c917c
This commit is contained in:
parent
4949eea0ff
commit
eaddadd4f7
8 changed files with 31 additions and 971 deletions
|
|
@ -32,20 +32,11 @@ from torch.testing._internal.common_quantization import (
|
|||
skip_if_no_torchvision,
|
||||
test_only_eval_fn,
|
||||
)
|
||||
from torch.testing._internal.common_quantization import NodeSpec as ns
|
||||
from torch.testing._internal.common_quantized import override_qengines
|
||||
from torch.quantization.ns.graph_matcher import (
|
||||
get_matching_node_pairs,
|
||||
GraphMatchingException,
|
||||
)
|
||||
from torch.quantization.ns.numeric_suite_core_apis_fx import (
|
||||
compare_weights,
|
||||
prepare_model_outputs,
|
||||
OutputLogger,
|
||||
prepare_model_with_stubs,
|
||||
get_matching_activations,
|
||||
get_matching_activations_a_shadows_b,
|
||||
)
|
||||
|
||||
|
||||
class TestGraphModeNumericSuite(QuantizationTestCase):
|
||||
|
|
@ -577,197 +568,3 @@ class TestFXGraphMatcherModels(QuantizationTestCase):
|
|||
mq = convert_fx(mp_copy)
|
||||
# assume success if no exceptions
|
||||
results = get_matching_node_pairs(mp, mq)
|
||||
|
||||
class TestFXNumericSuiteCoreAPIs(QuantizationTestCase):
|
||||
|
||||
@override_qengines
|
||||
def test_compare_weights_mod(self):
|
||||
m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)).eval()
|
||||
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
|
||||
# TODO(future PR): prevent the need for copying here, we can copy the
|
||||
# modules but should reuse the underlying tensors
|
||||
mp_copy = copy.deepcopy(mp)
|
||||
mq = convert_fx(mp_copy)
|
||||
results = compare_weights('fp32_prepared', mp, 'int8', mq)
|
||||
self.assertTrue(len(results) == 2)
|
||||
self.assert_ns_weight_compare_dict_valid(results)
|
||||
|
||||
@override_qengines
|
||||
def test_compare_weights_fun(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.w = nn.Parameter(torch.Tensor(4, 1))
|
||||
self.b = nn.Parameter(torch.Tensor(4))
|
||||
torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5))
|
||||
|
||||
def forward(self, x):
|
||||
return F.linear(x, self.w, self.b)
|
||||
|
||||
m = M().eval()
|
||||
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
|
||||
mp(torch.randn(1, 1))
|
||||
# TODO(future PR): prevent the need for copying here, we can copy the
|
||||
# modules but should reuse the underlying tensors
|
||||
mp_copy = copy.deepcopy(mp)
|
||||
mq = convert_fx(mp_copy)
|
||||
results = compare_weights('fp32_prepared', mp, 'int8', mq)
|
||||
self.assertTrue(len(results) == 1)
|
||||
self.assert_ns_weight_compare_dict_valid(results)
|
||||
|
||||
@override_qengines
|
||||
def test_match_activations_mod(self):
|
||||
m = nn.Sequential(
|
||||
torch.quantization.QuantStub(),
|
||||
nn.Conv2d(1, 1, 1),
|
||||
nn.Conv2d(1, 1, 1),
|
||||
).eval()
|
||||
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
|
||||
mp(torch.randn(2, 1, 2, 2))
|
||||
# TODO(future PR): prevent the need for copying here, we can copy the
|
||||
# modules but should reuse the underlying tensors
|
||||
mp_copy = copy.deepcopy(mp)
|
||||
mq = convert_fx(mp_copy)
|
||||
|
||||
mp_ns, mq_ns = prepare_model_outputs(
|
||||
'fp32_prepared', mp, 'int8', mq, OutputLogger)
|
||||
|
||||
expected_occurrence = {
|
||||
ns.call_module(OutputLogger): 2,
|
||||
}
|
||||
self.checkGraphModuleNodes(
|
||||
mp_ns, expected_node_occurrence=expected_occurrence)
|
||||
self.checkGraphModuleNodes(
|
||||
mq_ns, expected_node_occurrence=expected_occurrence)
|
||||
|
||||
# TODO(before land): test both scripted and non-scripted
|
||||
mp_ns = torch.jit.script(mp_ns)
|
||||
mq_ns = torch.jit.script(mq_ns)
|
||||
|
||||
# calibrate
|
||||
input_fp32 = torch.randn(2, 1, 2, 2)
|
||||
mp_ns(input_fp32)
|
||||
mq_ns(input_fp32)
|
||||
|
||||
# check activation result correctness
|
||||
act_compare_dict = get_matching_activations(mp_ns, mq_ns, OutputLogger)
|
||||
self.assertTrue(len(act_compare_dict) == 2)
|
||||
self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)
|
||||
|
||||
@override_qengines
|
||||
def test_match_activations_fun(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.w1 = nn.Parameter(torch.Tensor(4, 4))
|
||||
self.b1 = nn.Parameter(torch.Tensor(4))
|
||||
self.w2 = nn.Parameter(torch.Tensor(4, 4))
|
||||
self.b2 = nn.Parameter(torch.Tensor(4))
|
||||
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
|
||||
torch.nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5))
|
||||
|
||||
def forward(self, x):
|
||||
x = F.linear(x, self.w1, self.b1)
|
||||
x = F.linear(x, self.w2, self.b2)
|
||||
return x
|
||||
|
||||
m = M().eval()
|
||||
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
|
||||
mp(torch.randn(4, 4))
|
||||
# TODO(future PR): prevent the need for copying here, we can copy the
|
||||
# modules but should reuse the underlying tensors
|
||||
mp_copy = copy.deepcopy(mp)
|
||||
mq = convert_fx(mp_copy)
|
||||
|
||||
mp_ns, mq_ns = prepare_model_outputs(
|
||||
'fp32_prepared', mp, 'int8', mq, OutputLogger)
|
||||
|
||||
expected_occurrence = {
|
||||
ns.call_module(OutputLogger): 2,
|
||||
}
|
||||
self.checkGraphModuleNodes(
|
||||
mp_ns, expected_node_occurrence=expected_occurrence)
|
||||
self.checkGraphModuleNodes(
|
||||
mq_ns, expected_node_occurrence=expected_occurrence)
|
||||
|
||||
# TODO(before land): test both scripted and non-scripted
|
||||
mp_ns = torch.jit.script(mp_ns)
|
||||
mq_ns = torch.jit.script(mq_ns)
|
||||
|
||||
# calibrate
|
||||
input_fp32 = torch.randn(4, 4)
|
||||
mp_ns(input_fp32)
|
||||
mq_ns(input_fp32)
|
||||
|
||||
# check activation result correctness
|
||||
act_compare_dict = get_matching_activations(mp_ns, mq_ns, OutputLogger)
|
||||
self.assertTrue(len(act_compare_dict) == 2)
|
||||
self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)
|
||||
|
||||
@override_qengines
|
||||
def test_prepare_model_with_stubs_mod(self):
|
||||
m = nn.Sequential(
|
||||
nn.Conv2d(1, 1, 1),
|
||||
nn.Conv2d(1, 1, 1),
|
||||
).eval()
|
||||
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
|
||||
mp(torch.randn(1, 1, 4, 4))
|
||||
# TODO(future PR): prevent the need for copying here, we can copy the
|
||||
# modules but should reuse the underlying tensors
|
||||
mp_copy = copy.deepcopy(mp)
|
||||
mq = convert_fx(mp_copy)
|
||||
|
||||
mp_shadows_mq = prepare_model_with_stubs('fp32_prepared', mp, 'int8', mq, OutputLogger)
|
||||
|
||||
# TODO(before land): test both scripted and non-scripted
|
||||
mp_shadows_mq = torch.jit.script(mp_shadows_mq)
|
||||
|
||||
# calibrate
|
||||
input_fp32 = torch.randn(1, 1, 4, 4)
|
||||
mp_shadows_mq(input_fp32)
|
||||
|
||||
# check activation result correctness
|
||||
act_compare_dict = get_matching_activations_a_shadows_b(
|
||||
mp_shadows_mq, OutputLogger)
|
||||
self.assertTrue(len(act_compare_dict) == 2)
|
||||
self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)
|
||||
|
||||
@override_qengines
|
||||
def test_prepare_model_with_stubs_fun(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.w1 = nn.Parameter(torch.Tensor(4, 4))
|
||||
self.b1 = nn.Parameter(torch.Tensor(4))
|
||||
self.w2 = nn.Parameter(torch.Tensor(4, 4))
|
||||
self.b2 = nn.Parameter(torch.Tensor(4))
|
||||
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
|
||||
torch.nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5))
|
||||
|
||||
def forward(self, x):
|
||||
x = F.linear(x, self.w1, self.b1)
|
||||
x = F.linear(x, self.w2, self.b2)
|
||||
return x
|
||||
|
||||
m = M().eval()
|
||||
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
|
||||
mp(torch.randn(4, 4))
|
||||
# TODO(future PR): prevent the need for copying here, we can copy the
|
||||
# modules but should reuse the underlying tensors
|
||||
mp_copy = copy.deepcopy(mp)
|
||||
mq = convert_fx(mp_copy)
|
||||
|
||||
mp_shadows_mq = prepare_model_with_stubs('fp32_prepared', mp, 'int8', mq, OutputLogger)
|
||||
|
||||
# TODO(before land): test both scripted and non-scripted
|
||||
mp_shadows_mq = torch.jit.script(mp_shadows_mq)
|
||||
|
||||
# calibrate
|
||||
input_fp32 = torch.randn(4, 4)
|
||||
mp_shadows_mq(input_fp32)
|
||||
|
||||
# check activation result correctness
|
||||
act_compare_dict = get_matching_activations_a_shadows_b(
|
||||
mp_shadows_mq, OutputLogger)
|
||||
self.assertTrue(len(act_compare_dict) == 2)
|
||||
self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)
|
||||
|
|
|
|||
|
|
@ -80,7 +80,6 @@ try:
|
|||
from quantization.test_numeric_suite_fx import TestGraphModeNumericSuite # noqa: F401
|
||||
from quantization.test_numeric_suite_fx import TestFXGraphMatcher # noqa: F401
|
||||
from quantization.test_numeric_suite_fx import TestFXGraphMatcherModels # noqa: F401
|
||||
from quantization.test_numeric_suite_fx import TestFXNumericSuiteCoreAPIs # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -11,9 +11,16 @@ toq = torch.ops.quantized
|
|||
from torch.fx import GraphModule
|
||||
from torch.fx.graph import Graph, Node
|
||||
|
||||
from .utils import getattr_from_fqn
|
||||
from typing import Dict, Tuple, List, Optional, Set, Callable, Any
|
||||
|
||||
from typing import Dict, Tuple, List, Optional, Set, Callable
|
||||
# TODO(before land): delete this
|
||||
def _print_node(node: Optional[Node]) -> None:
|
||||
if node is None:
|
||||
print(None)
|
||||
else:
|
||||
print(
|
||||
node, ', target:', node.target, ', op:', node.op,
|
||||
', args:', node.args, ', kwargs:', node.kwargs)
|
||||
|
||||
def _get_output_nodes(g: Graph) -> List[Node]:
|
||||
return [n for n in g.nodes if n.op == 'output']
|
||||
|
|
@ -82,6 +89,16 @@ def get_non_matchable_modules() -> Set[Callable]:
|
|||
torch.quantization.FakeQuantizeBase,
|
||||
])
|
||||
|
||||
def _getattr_from_fqn(gm: GraphModule, fqn: str) -> Any:
|
||||
"""
|
||||
Given a gm and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz.
|
||||
"""
|
||||
fqn_parts = fqn.split(".")
|
||||
cur_val = gm
|
||||
for part in fqn_parts:
|
||||
cur_val = getattr(cur_val, part)
|
||||
return cur_val
|
||||
|
||||
class _NSGraphMatchableNodesIterator:
|
||||
"""
|
||||
Iterates through the graph of gm, starting with the output nodes
|
||||
|
|
@ -136,7 +153,7 @@ class _NSGraphMatchableNodesIterator:
|
|||
elif node.op == 'call_module':
|
||||
assert isinstance(node.target, str)
|
||||
# target_mod = getattr(self.gm, node.target)
|
||||
target_mod = getattr_from_fqn(self.gm, node.target)
|
||||
target_mod = _getattr_from_fqn(self.gm, node.target)
|
||||
return not \
|
||||
any(isinstance(target_mod, t) # type: ignore
|
||||
for t in self.non_matchable_modules)
|
||||
|
|
@ -170,9 +187,9 @@ def _node_a_related_to_b(
|
|||
elif node_a.op == 'call_module':
|
||||
# for call_module, we need to look up the modules to do the type check
|
||||
assert isinstance(node_a.target, str)
|
||||
mod_a = getattr_from_fqn(gm_a, node_a.target)
|
||||
mod_a = _getattr_from_fqn(gm_a, node_a.target)
|
||||
assert isinstance(node_b.target, str)
|
||||
mod_b = getattr_from_fqn(gm_b, node_b.target)
|
||||
mod_b = _getattr_from_fqn(gm_b, node_b.target)
|
||||
# modules with equivalent types always match (i.e. nn.Conv2d and nn.Conv2d)
|
||||
if type(mod_a) == type(mod_b):
|
||||
return True
|
||||
|
|
@ -196,7 +213,7 @@ def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[Callable]:
|
|||
return node.target # type: ignore
|
||||
elif node.op == 'call_module':
|
||||
assert isinstance(node.target, str)
|
||||
mod = getattr_from_fqn(gm, node.target)
|
||||
mod = _getattr_from_fqn(gm, node.target)
|
||||
return type(mod)
|
||||
return None
|
||||
|
||||
|
|
@ -263,6 +280,13 @@ def get_matching_node_pairs(
|
|||
except StopIteration:
|
||||
pass
|
||||
|
||||
# TODO(before land): remove
|
||||
if False:
|
||||
print('a')
|
||||
_print_node(cur_node_a)
|
||||
print('b')
|
||||
_print_node(cur_node_b)
|
||||
|
||||
# look up types of a and b for useful error messages
|
||||
type_a, type_b = None, None
|
||||
if cur_node_a is not None:
|
||||
|
|
|
|||
|
|
@ -1,346 +0,0 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule, map_arg
|
||||
from torch.fx.graph import Graph, Node
|
||||
from torch.quantization.fx.quantize import is_activation_post_process
|
||||
from torch.quantization.fx.utils import get_new_attr_name_with_prefix
|
||||
|
||||
from .utils import (
|
||||
get_node_io_type,
|
||||
getattr_from_fqn,
|
||||
print_node,
|
||||
NodeIOType,
|
||||
return_first_non_observer_node,
|
||||
)
|
||||
|
||||
from typing import Dict, Tuple, Callable, List, Any, Optional
|
||||
|
||||
def _insert_logger_after_node(
|
||||
node: Node,
|
||||
gm: GraphModule,
|
||||
logger_cls: Callable,
|
||||
logger_node_name_suffix: str,
|
||||
model_name: str,
|
||||
other_node_name: Optional[str] = None,
|
||||
) -> Node:
|
||||
"""
|
||||
Given a starting graph of
|
||||
|
||||
prev_node -> node -> next_node
|
||||
|
||||
This function creates a new logger_cls obj and adds it
|
||||
after node, resulting in
|
||||
|
||||
prev_node -> node -> logger_obj -> next_node
|
||||
"""
|
||||
# create new name
|
||||
logger_node_name = \
|
||||
get_new_attr_name_with_prefix(node.name + logger_node_name_suffix)(gm)
|
||||
# create the logger object
|
||||
logger_obj = logger_cls(node.name, model_name, other_node_name)
|
||||
# attach the logger object to the parent module
|
||||
setattr(gm, logger_node_name, logger_obj)
|
||||
logger_node = node.graph.create_node(
|
||||
'call_module', logger_node_name, (node,), {})
|
||||
return logger_node
|
||||
|
||||
def remove_observers_add_loggers(
|
||||
gm: GraphModule,
|
||||
nodes_to_instrument: List[Node],
|
||||
logger_cls: Callable,
|
||||
model_name: str,
|
||||
) -> GraphModule:
|
||||
"""
|
||||
Takes the graph of gm, removes all observers, adds loggers to the output
|
||||
of each node in nodes_to_instrument. Returns a GraphModule with the new
|
||||
graph.
|
||||
"""
|
||||
|
||||
new_graph = Graph()
|
||||
env: Dict[str, Any] = {}
|
||||
modules = dict(gm.named_modules())
|
||||
|
||||
def load_arg(a):
|
||||
return map_arg(a, lambda node: env[node.name])
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'output':
|
||||
new_graph.output(map_arg(node.args[0], load_arg))
|
||||
continue
|
||||
|
||||
if node.op == 'call_module' and is_activation_post_process(modules[node.target]):
|
||||
# remove activation post process node
|
||||
env[node.name] = env[node.args[0].name]
|
||||
|
||||
elif node in nodes_to_instrument:
|
||||
# ensure env is populated with base node
|
||||
env[node.name] = new_graph.node_copy(node, load_arg)
|
||||
# add the logger after the base node
|
||||
env[node.name] = _insert_logger_after_node(
|
||||
env[node.name], gm, logger_cls, '_ns_logger_', model_name)
|
||||
|
||||
else:
|
||||
env[node.name] = new_graph.node_copy(node, load_arg)
|
||||
|
||||
new_gm = GraphModule(gm, new_graph)
|
||||
return new_gm
|
||||
|
||||
def _insert_dtype_cast_after_node(
|
||||
node_a: Node,
|
||||
node_c: Node,
|
||||
prev_node_c: Node,
|
||||
gm_a: GraphModule,
|
||||
gm_b: GraphModule,
|
||||
node_name_prefix: str,
|
||||
) -> Node:
|
||||
"""
|
||||
Given a starting graph C (derived from graph B) of
|
||||
|
||||
... -> prev_node_c -> node_c -> ...
|
||||
|
||||
And a corresponding related node_a, inserts the correct dtype
|
||||
cast node after prev_node_c to cast into the dtype expected
|
||||
by node_a, resulting in:
|
||||
|
||||
dtype_cast
|
||||
/
|
||||
... -> prev_node_c -> node_c -> ...
|
||||
|
||||
For example, if node_c is an int8 op and node_a is an fp32 op, this function
|
||||
will insert a dequant.
|
||||
"""
|
||||
dtype_cast_op = None
|
||||
node_io_type_a = get_node_io_type(node_a, gm_a)
|
||||
node_io_type_c = get_node_io_type(node_c, gm_b)
|
||||
|
||||
if node_io_type_a == NodeIOType.FP32 and node_io_type_c == NodeIOType.INT8:
|
||||
dtype_cast_op = torch.dequantize
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"dtype cast from {node_io_type_c} to {node_io_type_a} needs to be implemented")
|
||||
|
||||
new_dtype_cast_name = \
|
||||
get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
|
||||
return prev_node_c.graph.create_node(
|
||||
'call_function', dtype_cast_op, (prev_node_c,), {},
|
||||
new_dtype_cast_name)
|
||||
|
||||
def _insert_copy_of_node_a_after_input_node_c(
|
||||
input_node_c: Node,
|
||||
node_a: Node,
|
||||
gm_a: GraphModule,
|
||||
gm_b: GraphModule,
|
||||
node_name_prefix: str,
|
||||
) -> Node:
|
||||
"""
|
||||
Assume that node_a from graph_a has
|
||||
args (input, arg1, ...), and
|
||||
kwargs {kw0: kwarg0, ...}
|
||||
|
||||
Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b,
|
||||
and creates the corresponding nodes in graph_c. Note: observers are ignored,
|
||||
so if an arg is an observer we navigate up until we find a non-observer parent.
|
||||
|
||||
If node_a is a call_module, points the module pointed to by node_a to gm_b.
|
||||
|
||||
Creates the copy of node_a in graph_c, with input as the first arg,
|
||||
and all other args and kwargs pointing to the copies of the objects
|
||||
in gm_b created above.
|
||||
|
||||
An example in pictures:
|
||||
|
||||
graph A:
|
||||
========
|
||||
|
||||
input -------------> node_a
|
||||
/ /
|
||||
weight -> weight_obs /
|
||||
/
|
||||
bias ----------------
|
||||
|
||||
graph C (derived from B):
|
||||
=========================
|
||||
|
||||
input_node_c --> node_a_copy
|
||||
/ /
|
||||
weight_copy ----/ /
|
||||
/
|
||||
bias_copy ------/
|
||||
"""
|
||||
graph_c = input_node_c.graph
|
||||
|
||||
# generically handle all args and kwargs except for the input
|
||||
# Note: this hasn't been tested with many ops, logic may change.
|
||||
new_args = []
|
||||
# assumes that the first arg is the input
|
||||
for node_a_arg in node_a.args[1:]:
|
||||
if isinstance(node_a_arg, Node):
|
||||
arg_a = return_first_non_observer_node(node_a_arg, gm_a)
|
||||
arg_a_copy_name = \
|
||||
get_new_attr_name_with_prefix(arg_a.name + '_shadow_copy_')(gm_b) # type: ignore
|
||||
arg_a_obj = getattr_from_fqn(gm_a, arg_a.target) # type: ignore
|
||||
setattr(gm_b, arg_a_copy_name, arg_a_obj.detach())
|
||||
node_a_arg_copy = graph_c.create_node(
|
||||
'get_attr', arg_a_copy_name, (), {}, arg_a_copy_name)
|
||||
new_args.append(node_a_arg_copy)
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"handling for arg of type {type(node_a_arg)} is not implemented")
|
||||
|
||||
new_kwargs = {}
|
||||
for node_a_k, node_a_kwarg in node_a.kwargs.items():
|
||||
kwarg_a_copy_name = \
|
||||
get_new_attr_name_with_prefix(node_a_kwarg.name + '_shadow_copy_')(gm_b) # type: ignore
|
||||
kwarg_a_obj = getattr_from_fqn(gm_a, node_a_kwarg.target) # type: ignore
|
||||
setattr(gm_b, kwarg_a_copy_name, kwarg_a_obj.detach())
|
||||
node_a_kwarg_copy = graph_c.create_node(
|
||||
'get_attr', kwarg_a_copy_name, (), {}, kwarg_a_copy_name)
|
||||
new_kwargs[node_a_k] = node_a_kwarg_copy
|
||||
|
||||
node_a_shadows_c_name = \
|
||||
get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
|
||||
|
||||
if node_a.op == 'call_module':
|
||||
# if target is a module, we point to the module from gm_b
|
||||
new_mod_copy_name = \
|
||||
get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
|
||||
# fetch the corresponding module from gm_a
|
||||
assert isinstance(node_a.target, str)
|
||||
mod_a = getattr_from_fqn(gm_a, node_a.target)
|
||||
setattr(gm_b, new_mod_copy_name, mod_a)
|
||||
node_a_shadows_c = graph_c.create_node(
|
||||
node_a.op, new_mod_copy_name, (input_node_c, *new_args),
|
||||
new_kwargs, node_a_shadows_c_name) # type: ignore
|
||||
return node_a_shadows_c
|
||||
else:
|
||||
assert node_a.op == 'call_function'
|
||||
node_a_shadows_c = graph_c.create_node(
|
||||
node_a.op, node_a.target, (input_node_c, *new_args),
|
||||
new_kwargs, node_a_shadows_c_name) # type: ignore
|
||||
return node_a_shadows_c
|
||||
|
||||
def create_a_shadows_b(
|
||||
name_a: str,
|
||||
gm_a: GraphModule,
|
||||
name_b: str,
|
||||
gm_b: GraphModule,
|
||||
matched_node_pairs: Dict[str, Tuple[Node, Node]],
|
||||
logger_cls: Callable,
|
||||
) -> GraphModule:
|
||||
"""
|
||||
Creates a new GraphModule consisting of the graph of C, with the meaningful
|
||||
nodes of A shadowing the corresponding nodes of B. For example,
|
||||
|
||||
Graph A:
|
||||
a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2
|
||||
|
||||
Graph B:
|
||||
b0 -> op0_int8 -> b1 -> op1_int8 -> b2
|
||||
|
||||
matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}
|
||||
|
||||
Graph C (A shadows B):
|
||||
|
||||
/ dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1
|
||||
/ /
|
||||
b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1
|
||||
|
||||
In a nutshell, this function does the following for each node pair:
|
||||
* copies the necessary attributes and modules from gm_a to gm_b,
|
||||
keeping names unique
|
||||
* adds a dtype cast op (dequant, quant, etc)
|
||||
* adds a copy of node_a in gm_b's graph
|
||||
* adds loggers to the outputs of node_a and node_b
|
||||
"""
|
||||
|
||||
# graph_c is the graph created from copying the nodes of graph_b and inserting
|
||||
# the shadows with the nodes copied from graph_a
|
||||
graph_c = Graph()
|
||||
env_c: Dict[str, Any] = {}
|
||||
modules = dict(gm_b.named_modules())
|
||||
|
||||
def load_arg(a):
|
||||
return map_arg(a, lambda node: env_c[node.name])
|
||||
|
||||
nodes_to_instrument_b_to_a = {}
|
||||
for match_name, (node_a, node_b) in matched_node_pairs.items():
|
||||
nodes_to_instrument_b_to_a[node_b] = node_a
|
||||
|
||||
for node_b in gm_b.graph.nodes:
|
||||
if node_b.op == 'output':
|
||||
graph_c.output(map_arg(node_b.args[0], load_arg))
|
||||
continue
|
||||
|
||||
if node_b.op == 'call_module' and is_activation_post_process(modules[node_b.target]):
|
||||
# remove activation post process node
|
||||
env_c[node_b.name] = env_c[node_b.args[0].name] # type: ignore
|
||||
|
||||
elif node_b in nodes_to_instrument_b_to_a:
|
||||
node_a = nodes_to_instrument_b_to_a[node_b]
|
||||
if False:
|
||||
print('b')
|
||||
print_node(node_b)
|
||||
print('a')
|
||||
print_node(node_a)
|
||||
|
||||
# ensure env_c is populated with base node
|
||||
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
|
||||
node_c = env_c[node_b.name]
|
||||
|
||||
# after this point,
|
||||
#
|
||||
# node_a is the original node from graph_a, with parent module gm_a
|
||||
# node_b is the original node from graph_b, with parent module gm_b
|
||||
# node_c is the copy of node_b in graph_c
|
||||
#
|
||||
# subgraph so far:
|
||||
#
|
||||
# node_c
|
||||
|
||||
# cast dtype from the dtype of node_c's input to the dtype of
|
||||
# node_a's input (dequant, etc)
|
||||
dtype_cast_node = _insert_dtype_cast_after_node(
|
||||
node_a, node_c, node_c.args[0], gm_a, gm_b, node_b.name + '_dtype_cast_')
|
||||
env_c[dtype_cast_node.name] = dtype_cast_node
|
||||
# subgraph so far:
|
||||
#
|
||||
# dtype_cast_node
|
||||
# /
|
||||
# node_c
|
||||
|
||||
# hook up the new mod_a copy to be in the graph, receiving the
|
||||
# same inputs as mod_b does, with dtype cast to match a
|
||||
node_a_shadows_c = _insert_copy_of_node_a_after_input_node_c(
|
||||
env_c[dtype_cast_node.name],
|
||||
node_a, gm_a, gm_b, node_c.name + '_shadow_copy_')
|
||||
env_c[node_a_shadows_c.name] = node_a_shadows_c
|
||||
# subgraph so far:
|
||||
#
|
||||
# dtype_cast_node --> node_a_copy(args/kwargs not shown)
|
||||
# /
|
||||
# node_c
|
||||
|
||||
# hook up a logger to the mod_b copy
|
||||
env_c[node_b.name] = _insert_logger_after_node(
|
||||
env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_', name_b)
|
||||
# subgraph so far:
|
||||
#
|
||||
# dtype_cast_node --> node_a_copy
|
||||
# /
|
||||
# node_c --> logger_c
|
||||
|
||||
# hook up a logger to the mod_a copy
|
||||
# Note: we pass node_b.name to this logger, for easy matching later
|
||||
env_c[node_a_shadows_c.name] = _insert_logger_after_node(
|
||||
env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_', name_a,
|
||||
node_b.name)
|
||||
# subgraph so far:
|
||||
#
|
||||
# dtype_cast_node --> node_a_copy --> logger_a
|
||||
# /
|
||||
# node_c --> logger_c
|
||||
|
||||
else:
|
||||
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
|
||||
|
||||
gm_c = GraphModule(gm_b, graph_c)
|
||||
return gm_c
|
||||
|
|
@ -1,241 +0,0 @@
|
|||
import collections
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
toq = torch.ops.quantized
|
||||
from torch.fx import GraphModule
|
||||
from torch.quantization.ns.graph_matcher import (
|
||||
get_matching_node_pairs,
|
||||
get_type_a_related_to_b,
|
||||
)
|
||||
|
||||
from .utils import (
|
||||
getattr_from_fqn,
|
||||
)
|
||||
|
||||
from .weight_utils import (
|
||||
get_conv_mod_weight,
|
||||
get_linear_fun_weight,
|
||||
)
|
||||
|
||||
from .graph_passes import (
|
||||
remove_observers_add_loggers,
|
||||
create_a_shadows_b,
|
||||
)
|
||||
|
||||
from typing import Dict, Tuple, Callable, List, Optional
|
||||
|
||||
|
||||
# Note: this is not a user facing API
|
||||
# TODO(future PR): wrap this in a user facing API which does not
|
||||
# expose FX types.
|
||||
def compare_weights(
|
||||
name_a: str,
|
||||
gm_a: GraphModule,
|
||||
name_b: str,
|
||||
gm_b: GraphModule,
|
||||
) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
type_a_related_to_b = get_type_a_related_to_b()
|
||||
matched_node_pairs = get_matching_node_pairs(gm_a, gm_b)
|
||||
|
||||
results = {}
|
||||
|
||||
for match_name, match in matched_node_pairs.items():
|
||||
|
||||
node_a, node_b = match
|
||||
assert node_a.op == node_b.op and \
|
||||
node_a.op in ('call_function', 'call_module')
|
||||
|
||||
if node_a.op == 'call_function':
|
||||
|
||||
# linear
|
||||
# TODO(future PR): other function types
|
||||
a_related_to_linear = node_a.target in (F.linear,) or \
|
||||
(node_a.target, F.linear) in type_a_related_to_b
|
||||
|
||||
if a_related_to_linear:
|
||||
weight_a = get_linear_fun_weight(node_a, gm_a)
|
||||
weight_b = get_linear_fun_weight(node_b, gm_b)
|
||||
|
||||
results[match_name] = {
|
||||
name_a: weight_a,
|
||||
name_b: weight_b,
|
||||
}
|
||||
|
||||
else: # call_module
|
||||
# for call_module, we need to look up the modules to do the type check
|
||||
assert isinstance(node_a.target, str)
|
||||
mod_a = getattr_from_fqn(gm_a, node_a.target)
|
||||
assert isinstance(node_b.target, str)
|
||||
mod_b = getattr_from_fqn(gm_b, node_b.target)
|
||||
|
||||
# check that A is one the modules we need
|
||||
# assume B is related (this is done by graph matcher)
|
||||
a_related_to_conv2d_mod = isinstance(mod_a, nn.Conv2d) or \
|
||||
(type(mod_a), nn.Conv2d) in type_a_related_to_b
|
||||
|
||||
# TODO(future PR): other module types
|
||||
if a_related_to_conv2d_mod:
|
||||
weight_a = get_conv_mod_weight(mod_a)
|
||||
weight_b = get_conv_mod_weight(mod_b)
|
||||
results[match_name] = {
|
||||
name_a: weight_a,
|
||||
name_b: weight_b,
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class OutputLogger(nn.Module):
|
||||
stats: List[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_name: str,
|
||||
model_name: str,
|
||||
other_node_name: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.stats: List[torch.Tensor] = []
|
||||
# name of the node whose output this Logger is capturing
|
||||
self.node_name = node_name
|
||||
# name of the model from which the node originated from
|
||||
self.model_name = model_name
|
||||
# name of the other node with a matching Logger
|
||||
# used to link node_a_copy -> logger_a to node_c -> logger_c
|
||||
# in a_shadows_b
|
||||
self.other_node_name = other_node_name
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
self.stats.append(x.detach())
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"OutputLogger(node_name={self.node_name}, model_name={self.model_name}, other_node_name={self.other_node_name})"
|
||||
|
||||
# Note: this is not a user facing API
|
||||
# TODO(future PR): wrap this in a user facing API which does not
|
||||
# expose FX types.
|
||||
def prepare_model_outputs(
|
||||
name_a: str,
|
||||
gm_a: GraphModule,
|
||||
name_b: str,
|
||||
gm_b: GraphModule,
|
||||
logger_cls: Callable,
|
||||
) -> Tuple[GraphModule, GraphModule]:
|
||||
|
||||
matched_node_pairs = get_matching_node_pairs(gm_a, gm_b)
|
||||
|
||||
nodes_to_instrument_a = []
|
||||
nodes_to_instrument_b = []
|
||||
for match_name, (node_a, node_b,) in matched_node_pairs.items():
|
||||
# TODO(future PR): do not observe pairs of nodes we do not care
|
||||
# about (both fp32, denylist, etc)
|
||||
nodes_to_instrument_a.append(node_a)
|
||||
nodes_to_instrument_b.append(node_b)
|
||||
|
||||
gm_a = remove_observers_add_loggers(gm_a, nodes_to_instrument_a, logger_cls, name_a)
|
||||
gm_b = remove_observers_add_loggers(gm_b, nodes_to_instrument_b, logger_cls, name_b)
|
||||
return (gm_a, gm_b)
|
||||
|
||||
# Note: this is not a user facing API
|
||||
# TODO(future PR): wrap this in a user facing API which does not
|
||||
# expose FX types.
|
||||
# TODO(future PR): align on naming
|
||||
# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
|
||||
def get_matching_activations(
|
||||
gm_a: GraphModule,
|
||||
gm_b: GraphModule,
|
||||
logger_cls: Callable,
|
||||
) -> Dict[str, Dict[str, List[torch.Tensor]]]:
|
||||
"""
|
||||
Same thing as ns.get_matching_activations, but for FX models prepared with
|
||||
this module.
|
||||
|
||||
TODO(future PR): real docblock
|
||||
|
||||
Output format:
|
||||
|
||||
{
|
||||
'layer1.stats': {
|
||||
'name_a': [torch.Tensor(...), ...],
|
||||
'name_b': [torch.Tensor(...), ...],
|
||||
},
|
||||
...
|
||||
}
|
||||
|
||||
Note, there are three differences from the output format of Eager NS:
|
||||
1. `name_a` and `name_b` are used instead of hardcoding names
|
||||
to `float` and `quantized`.
|
||||
2. Lists of Tensors are returned instead of individual Tensors, to unify
|
||||
the return type for calibrating with 1 input vs N inputs.
|
||||
3. `logger_cls` is included in the API for easy result extraction
|
||||
"""
|
||||
results: Dict[str, Dict[str, List[torch.Tensor]]] = \
|
||||
collections.defaultdict(dict)
|
||||
for gm in (gm_a, gm_b):
|
||||
for gm_name, mod in gm.named_modules():
|
||||
# TODO(future PR): better check when scripted
|
||||
is_logger = (
|
||||
isinstance(mod, logger_cls) # type: ignore
|
||||
or (
|
||||
isinstance(mod, torch.jit.RecursiveScriptModule)
|
||||
and mod.original_name == 'OutputLogger'
|
||||
)
|
||||
)
|
||||
if is_logger:
|
||||
results[mod.node_name + '.stats'][mod.model_name] = mod.stats
|
||||
return dict(results)
|
||||
|
||||
# Note: this is not a user facing API
|
||||
# TODO(future PR): wrap this in a user facing API which does not
|
||||
# expose FX types.
|
||||
def prepare_model_with_stubs(
|
||||
name_a: str,
|
||||
gm_a: GraphModule,
|
||||
name_b: str,
|
||||
gm_b: GraphModule,
|
||||
logger_cls: Callable,
|
||||
) -> GraphModule:
|
||||
"""
|
||||
Same thing as prepare_model_outputs, but for an `a_shadows_b` model.
|
||||
TODO(future PR): real docblock
|
||||
"""
|
||||
matched_node_pairs = get_matching_node_pairs(gm_a, gm_b)
|
||||
gm_a_shadows_b = create_a_shadows_b(
|
||||
name_a, gm_a, name_b, gm_b, matched_node_pairs, logger_cls)
|
||||
return gm_a_shadows_b
|
||||
|
||||
# Note: this is not a user facing API
|
||||
# TODO(future PR): wrap this in a user facing API which does not
|
||||
# expose FX types.
|
||||
# TODO(future PR): align on naming
|
||||
# this is equivalent of just the comparison extraction part of `ns.compare_model_stub`
|
||||
def get_matching_activations_a_shadows_b(
|
||||
gm_a_shadows_b: GraphModule,
|
||||
logger_cls: Callable,
|
||||
) -> Dict[str, Dict[str, List[torch.Tensor]]]:
|
||||
"""
|
||||
Same thing as get_matching_activations, but for an `a_shadows_b` model.
|
||||
TODO(future PR): real docblock
|
||||
"""
|
||||
results: Dict[str, Dict[str, List[torch.Tensor]]] = \
|
||||
collections.defaultdict(dict)
|
||||
for name, mod in gm_a_shadows_b.named_modules():
|
||||
# TODO(future PR): better check when scripted
|
||||
is_logger = (
|
||||
isinstance(mod, logger_cls) # type: ignore
|
||||
or (
|
||||
isinstance(mod, torch.jit.RecursiveScriptModule)
|
||||
and mod.original_name == 'OutputLogger'
|
||||
)
|
||||
)
|
||||
if is_logger:
|
||||
# If logger_obj.other_node_name is populated, then this logger
|
||||
# is from model A, and other_node_name is the name from model B.
|
||||
if mod.other_node_name is None:
|
||||
results[mod.node_name + '.stats'][mod.model_name] = mod.stats
|
||||
else:
|
||||
results[mod.other_node_name + '.stats'][mod.model_name] = mod.stats
|
||||
return dict(results)
|
||||
|
|
@ -1,84 +0,0 @@
|
|||
import enum
|
||||
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.graph import Node
|
||||
from torch.quantization.fx.quantize import is_activation_post_process
|
||||
|
||||
from typing import Optional, Any
|
||||
|
||||
# TODO(future PR): delete this after FX has a util for it
|
||||
def print_node(node: Optional[Node]) -> None:
|
||||
if node is None:
|
||||
print(None)
|
||||
else:
|
||||
print(
|
||||
node, ', target:', node.target, ', op:', node.op,
|
||||
', args:', node.args, ', kwargs:', node.kwargs)
|
||||
|
||||
def getattr_from_fqn(gm: GraphModule, fqn: str) -> Any:
|
||||
"""
|
||||
Given a gm and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz.
|
||||
"""
|
||||
fqn_parts = fqn.split(".")
|
||||
cur_val = gm
|
||||
for part in fqn_parts:
|
||||
cur_val = getattr(cur_val, part)
|
||||
return cur_val
|
||||
|
||||
class NodeIOType(enum.Enum):
|
||||
FP32 = enum.auto() # all inputs and outputs fp32
|
||||
INT8 = enum.auto() # all inputs and outputs int8
|
||||
# TODO(future PRs): dynamic quant, fake quant, etc
|
||||
|
||||
|
||||
def get_node_io_type(node: Node, gm: GraphModule) -> NodeIOType:
|
||||
if node.op == 'call_function':
|
||||
fp32_fun_target_names = ('torch.nn.functional', 'torch.nn')
|
||||
int8_fun_target_names = ('torch._ops.quantized',)
|
||||
# For now, hacky check to see which op is in which namespace
|
||||
# TODO(future PR): use a real mapping
|
||||
if node.target.__module__ in fp32_fun_target_names:
|
||||
return NodeIOType.FP32
|
||||
else:
|
||||
assert node.target.__module__ in int8_fun_target_names, \
|
||||
'unknown node target %s' % node.target
|
||||
return NodeIOType.INT8
|
||||
else:
|
||||
assert node.op == 'call_module'
|
||||
assert isinstance(node.target, str)
|
||||
mod = getattr_from_fqn(gm, node.target)
|
||||
# For now, hacky check to see which mod is in which namespace
|
||||
# TODO(future PR): use a real mapping
|
||||
if mod.__module__.startswith('torch.nn.modules'):
|
||||
return NodeIOType.FP32
|
||||
else:
|
||||
assert mod.__module__.startswith('torch.nn.q'), \
|
||||
'unknown node target %s' % mod
|
||||
return NodeIOType.INT8
|
||||
|
||||
def return_first_non_observer_node(
|
||||
node: Node,
|
||||
gm: GraphModule,
|
||||
) -> Node:
|
||||
"""
|
||||
If node is not an observer, returns it. If node is an observer,
|
||||
navigates up the graph and returns the first parent which is not an
|
||||
observer. For example,
|
||||
|
||||
graph: (node_non_obs), node = node_non_obs : returns node_non_obs
|
||||
graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs
|
||||
graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs
|
||||
"""
|
||||
node_obj = getattr_from_fqn(gm, node.target) # type: ignore
|
||||
if is_activation_post_process(node_obj):
|
||||
assert len(node.args) == 1
|
||||
assert isinstance(node.args[0], Node)
|
||||
node = node.args[0]
|
||||
# code duplication intended, not worth refactoring
|
||||
assert isinstance(node.target, str)
|
||||
node_obj = getattr_from_fqn(gm, node.target)
|
||||
if is_activation_post_process(node_obj):
|
||||
assert len(node.args) == 1
|
||||
assert isinstance(node.args[0], Node)
|
||||
node = node.args[0]
|
||||
return node
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
toq = torch.ops.quantized
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.graph import Node
|
||||
|
||||
from .utils import getattr_from_fqn
|
||||
|
||||
def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor:
|
||||
# TODO(future PR): make more generic, handle everything
|
||||
if isinstance(mod, nn.Conv2d):
|
||||
return mod.weight.detach()
|
||||
else:
|
||||
return mod._weight_bias()[0] # type: ignore
|
||||
|
||||
def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
||||
# TODO(future PR): better docblock, with example FX IR
|
||||
if node.target in (F.linear,):
|
||||
# traverse backwards from the weight arg, accounting for
|
||||
# any observers
|
||||
weight_arg_node = node.args[1]
|
||||
# print_node(weight_arg_node)
|
||||
assert isinstance(weight_arg_node, Node)
|
||||
weight_node = weight_arg_node.args[0]
|
||||
# print_node(weight_node)
|
||||
# TODO(future PR): currently assumes 1 observer, handle arbitrary
|
||||
# levels of observation, from 0 to N
|
||||
assert isinstance(weight_node, Node)
|
||||
assert weight_node.op == 'get_attr'
|
||||
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore
|
||||
return weight.detach()
|
||||
|
||||
else:
|
||||
assert node.target in (toq.linear,)
|
||||
# packed weight is arg 1
|
||||
packed_weight_node = node.args[1]
|
||||
assert isinstance(packed_weight_node, Node)
|
||||
assert packed_weight_node.op == 'get_attr'
|
||||
packed_weight = getattr_from_fqn(gm, packed_weight_node.target) # type: ignore
|
||||
# TODO(future PR): why does packed_weight.unpack() not work?
|
||||
# TODO(future PR): discuss if we even need to unpack, or if the
|
||||
# caller can handle the unpacking
|
||||
(weight, _bias), _name = packed_weight.__getstate__()
|
||||
return weight
|
||||
|
|
@ -42,7 +42,7 @@ import os
|
|||
import unittest
|
||||
import numpy as np
|
||||
from torch.testing import FileCheck
|
||||
from typing import Callable, Tuple, Dict, List
|
||||
from typing import Callable, Tuple, Dict
|
||||
|
||||
class NodeSpec:
|
||||
''' Used for checking GraphModule Node
|
||||
|
|
@ -558,7 +558,6 @@ class QuantizationTestCase(TestCase):
|
|||
nodes_in_graph[n] = 1
|
||||
|
||||
if expected_node is not None:
|
||||
print('expected_node', expected_node)
|
||||
self.assertTrue(expected_node in nodes_in_graph, 'node:' + str(expected_node) +
|
||||
' not found in the graph module')
|
||||
|
||||
|
|
@ -655,49 +654,6 @@ class QuantizationTestCase(TestCase):
|
|||
(k, (expected_type_a, expected_type_b), (actual_type_a, actual_type_b))
|
||||
)
|
||||
|
||||
def assert_ns_weight_compare_dict_valid(
|
||||
self,
|
||||
weight_compare_dict: Dict[str, Dict[str, torch.Tensor]],
|
||||
) -> None:
|
||||
"""
|
||||
Verifieds that the weight_compare dict (output of Numeric Suite
|
||||
weight matching APIs) is valid:
|
||||
1. for each layer, results are recorded for two models
|
||||
2. shapes of each pair of weights match
|
||||
"""
|
||||
for layer_name, layer_data in weight_compare_dict.items():
|
||||
self.assertTrue(
|
||||
len(layer_data) == 2,
|
||||
f"Layer {layer_name} does not have exactly two model results.")
|
||||
k0, k1 = layer_data.keys()
|
||||
self.assertTrue(
|
||||
layer_data[k0].shape == layer_data[k1].shape,
|
||||
f"Layer {layer_name}, {k0} and {k1} have a shape mismatch.")
|
||||
|
||||
def assert_ns_logger_act_compare_dict_valid(
|
||||
self,
|
||||
act_compare_dict: Dict[str, Dict[str, List[torch.Tensor]]],
|
||||
) -> None:
|
||||
"""
|
||||
Verifies that the act_compare_dict (output of Numeric Suite
|
||||
activation matching APIs) is valid:
|
||||
1. for each layer, results are recorded for two models
|
||||
2. number of seen tensors match
|
||||
3. shapes of each pair of seen tensors match
|
||||
"""
|
||||
for layer_name, layer_data in act_compare_dict.items():
|
||||
self.assertTrue(
|
||||
len(layer_data) == 2,
|
||||
f"Layer {layer_name} does not have exactly two model results.")
|
||||
k0, k1 = layer_data.keys()
|
||||
self.assertTrue(
|
||||
len(layer_data[k0]) == len(layer_data[k1]),
|
||||
f"Layer {layer_name}, {k0} and {k1} do not have the same number of seen Tensors.")
|
||||
for idx in range(len(layer_data[k0])):
|
||||
self.assertTrue(
|
||||
layer_data[k0][idx].shape == layer_data[k1][idx].shape,
|
||||
f"Layer {layer_name}, {k0} and {k1} have a shape mismatch at idx {idx}.")
|
||||
|
||||
def checkGraphModeFxOp(self, model, inputs, quant_type,
|
||||
expected_node=None,
|
||||
expected_node_occurrence=None,
|
||||
|
|
|
|||
Loading…
Reference in a new issue