[quant][graphmode][fx] Standalone module support {input/output}_quantized_idxs (#49754)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49754

This PR adds the support for {input/output}_quantized_idxs for standalone module.

if input_quantized_idxs = [] and output_quantized_idxs = [], the standalone module will be expecting float
input and produce float output, and will quantize the input and dequantize output internally

if input_quantized_idxs = [0] and otuput_qiuantized_idxs = [0], the standalone module will be expecting quantized
input and produce quantized output, the input will be quantized in the parent module, and output will be dequantized
in the parent module as well, this is similar to current quantized modules like nn.quantized.Conv2d

For more details, please see the test case

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_standalone_module

Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D25684692

fbshipit-source-id: 900360e01c0e35b26fe85f4a887dc1fd6f7bfb66
This commit is contained in:
Jerry Zhang 2020-12-23 22:34:54 -08:00 committed by Facebook GitHub Bot
parent 69b1373587
commit 89b4899ea5
5 changed files with 221 additions and 74 deletions

View file

@ -570,7 +570,16 @@ class TestQuantizeFx(QuantizationTestCase):
m = convert_fx(m)
m(tensor_input)
def test_standalone_module(self):
def _test_standalone_module(
self,
interface_config,
prepare_count_check,
standalone_prepare_count_check,
convert_count_check,
standalone_convert_count_check):
""" Test standalone module with different quantized input/quantized output
configurations
"""
class StandaloneModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -610,45 +619,32 @@ class TestQuantizeFx(QuantizationTestCase):
original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach())
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach())
qconfig_dict = {"": default_qconfig}
config_name = {"standalone_module_name": [("standalone", None, None)]}
config_class = {"standalone_module_class": [(StandaloneModule, None, None)]}
for prepare_config in [config_name, config_class]:
for is_name in [True, False]:
if is_name:
prepare_config = {
"standalone_module_name": [("standalone", None, interface_config)]
}
else:
prepare_config = {
"standalone_module_class": [(StandaloneModule, None, interface_config)]
}
original_m_copy = copy.deepcopy(original_m)
original_ref_m_copy = copy.deepcopy(original_ref_m)
qconfig_dict = {"": default_qconfig}
# check prepared model
m = prepare_fx(
original_m_copy, qconfig_dict, prepare_custom_config_dict=prepare_config)
# calibration
m(data)
# input and output of first conv, observer for standalone module
# will be inserted in the standalone module itself
count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
# for input and output of conv in the standalone module
count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check)
self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check)
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check)
# check converted/quantized model
m = convert_fx(m)
count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d) : 1,
ns.call_method('dequantize') : 1,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
count_check = {
# standalone module will take float as input and output
# so we'll see quantize and dequantize in the modoule
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d): 1,
ns.call_method('dequantize') : 1,
}
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check)
self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check)
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check)
res = m(data)
# quantize the reference model
@ -658,6 +654,76 @@ class TestQuantizeFx(QuantizationTestCase):
ref_res = ref_m(data)
self.assertEqual(res, ref_res)
def test_standalone_module_float_interface(self):
float_interface_config = {
"input_quantized_idxs": [], # float input
"output_quantized_idxs": [], # float output
}
interface_config = float_interface_config
# input and output of first conv, observer for standalone module
# will be inserted in the standalone module itself
prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
# for input and output of conv in the standalone module
standalone_prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
convert_count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d) : 1,
ns.call_method("dequantize") : 1,
}
standalone_convert_count_check = {
# standalone module will take float as input and output
# so we'll see quantize and dequantize in the modoule
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d): 1,
ns.call_method("dequantize") : 1,
}
self._test_standalone_module(
interface_config,
prepare_count_check,
standalone_prepare_count_check,
convert_count_check,
standalone_convert_count_check)
def test_standalone_module_quantized_interface(self):
quantized_interface_config = {
"input_quantized_idxs": [0], # quantized input
"output_quantized_idxs": [0], # quantized output
}
interface_config = quantized_interface_config
# observer for input and output of first conv
prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
# for output of conv in the standalone module
standalone_prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 1
}
convert_count_check = {
# quantizing input for conv
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d) : 1,
# dequantizing output of standalone module
ns.call_method("dequantize") : 1,
}
standalone_convert_count_check = {
# quantization of input happens in parent module
# quantization of output happens in the quantized conv module
ns.call_function(torch.quantize_per_tensor) : 0,
ns.call_module(nnq.Conv2d): 1,
# dequantization for output happens in parent module
ns.call_method("dequantize") : 0,
}
self._test_standalone_module(
interface_config,
prepare_count_check,
standalone_prepare_count_check,
convert_count_check,
standalone_convert_count_check)
@skipIfNoFBGEMM
def test_qconfig_none(self):
class M(torch.nn.Module):

View file

@ -2,11 +2,11 @@ import torch
import copy
from torch.fx import GraphModule # type: ignore
from torch.fx.graph import Graph
from typing import Union, Dict, Any
from typing import Union, Dict, Any, List
class ObservedGraphModule(GraphModule):
def get_preserved_attr_names(self):
def get_preserved_attr_names(self) -> List[str]:
return ['_activation_post_process_map',
'_patterns',
'_qconfig_map',
@ -35,6 +35,12 @@ def is_observed_module(module: Any) -> bool:
return isinstance(module, ObservedGraphModule)
class ObservedStandaloneGraphModule(ObservedGraphModule):
def get_preserved_attr_names(self) -> List[str] :
return super().get_preserved_attr_names() + [
"_standalone_module_input_quantized_idxs",
"_standalone_module_output_quantized_idxs"
]
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)

View file

@ -753,10 +753,10 @@ class StandaloneModuleQuantizeHandler(QuantizeHandler):
qconfig = quantizer.qconfig_map[node.name]
convert = torch.quantization.quantize_fx._convert_standalone_module_fx # type: ignore
observed_standalone_module = quantizer.modules[node.target]
input_quantized_idxs = observed_standalone_module._standalone_module_input_quantized_idxs
quantized_standalone_module = convert(observed_standalone_module, debug=debug)
parent_name, name = _parent_name(node.target)
# update the modules dict
setattr(quantizer.modules[parent_name], name, quantized_standalone_module)
quantizer.modules[node.target] = quantized_standalone_module
# standalone module takes float input
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=input_quantized_idxs))

View file

@ -102,14 +102,15 @@ def insert_observer(
'call_module', observer_name, (load_arg(node),), {})
observed_node_names_set.add(node.name)
def insert_observer_for_special_module(
def maybe_insert_observer_for_special_module(
quantize_handler: QuantizeHandler, modules: Dict[str, torch.nn.Module],
prepare_custom_config_dict: Any, qconfig: Any, node: Node):
prepare_custom_config_dict: Any, qconfig: Any, node: Node) -> Optional[List[int]]:
""" Insert observer for custom module and standalone module
Returns: standalone_module_input_idxs: the indexs for inputs that
needs to be observed by parent module
"""
assert modules is not None
standalone_module_input_idxs = None
if isinstance(quantize_handler, CustomModuleQuantizeHandler):
custom_module = modules[node.target] # type: ignore
custom_module_class_mapping = prepare_custom_config_dict.get(
@ -129,19 +130,22 @@ def insert_observer_for_special_module(
class_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_class_configs}
name_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_name_configs}
config = class_config_map.get(type(standalone_module), (None, None))
config = name_config_map.get(node.target, (None, None))
standalone_module_qconfig_dict = {"": qconfig} if config[0] is None else config[0]
standalone_prepare_config_dict = {} if config[1] is None else config[1]
config = name_config_map.get(node.target, config)
sm_qconfig_dict = {"": qconfig} if config[0] is None else config[0]
sm_prepare_config_dict = {} if config[1] is None else config[1]
prepare = \
torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore
observed_standalone_module = \
prepare(standalone_module, standalone_module_qconfig_dict, standalone_prepare_config_dict)
prepare(standalone_module, sm_qconfig_dict, sm_prepare_config_dict)
standalone_module_input_idxs = observed_standalone_module.\
_standalone_module_input_quantized_idxs
observed_standalone_module = mark_observed_standalone_module(
observed_standalone_module)
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name,
observed_standalone_module)
modules[node.target] = observed_standalone_module # type: ignore
return standalone_module_input_idxs
def insert_observer_for_output_of_the_node(
node: Node,
@ -155,7 +159,8 @@ def insert_observer_for_output_of_the_node(
observed_graph: Graph,
load_arg: Callable,
observed_node_names_set: Set[str],
matched_nodes: Optional[List[Node]]):
matched_nodes: Optional[List[Node]],
standalone_module_input_idxs: Optional[List[int]]):
""" Insert observer/fake_quantize module for output of the observed
module if needed
"""
@ -215,8 +220,11 @@ def insert_observer_for_output_of_the_node(
observed_node_names_set.add(node.name)
elif isinstance(quantize_handler,
StandaloneModuleQuantizeHandler):
# output is observed in the standalone module
return
assert node.op == "call_module"
output_is_quantized = 0 in \
modules[node.target]._standalone_module_output_quantized_idxs # type: ignore
if output_is_quantized:
observed_node_names_set.add(node.name)
elif (quantize_handler.all_node_args and
input_output_observed(quantize_handler)):
# observer for outputs
@ -226,6 +234,16 @@ def insert_observer_for_output_of_the_node(
activation_post_process_map, env, observed_graph,
load_arg, observed_node_names_set)
# insert observer for input of standalone module
if standalone_module_input_idxs is not None:
for idx in standalone_module_input_idxs:
if node.args[idx].name not in observed_node_names_set: # type: ignore
new_observer = qconfig.activation()
insert_observer(
node, new_observer, model,
activation_post_process_map, env, observed_graph,
load_arg, observed_node_names_set)
def insert_observer_for_input_arg_of_observed_node(
node: Node, observed_node_names_set: Set[str],
quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]],
@ -373,10 +391,19 @@ class Quantizer:
""" standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.
When we are preparing a standalone module:
both input and output are observed in prepared standalone module
How the standalone module is observed is specified by `input_quantized_idxs` and
`output_quantized_idxs` in the prepare_custom_config for the standalone module
Returns:
model(GraphModule): prepared standalone module
attributes:
_standalone_module_input_quantized_idxs(List[Int]): a list of
indexes for the graph input that is expected to be quantized,
same as input_quantized_idxs configuration provided
for the standalone module
_standalone_module_output_quantized_idxs(List[Int]): a list of
indexs for the graph output that is quantized
same as input_quantized_idxs configuration provided
for the standalone module
"""
if prepare_custom_config_dict is None:
prepare_custom_config_dict = {}
@ -430,8 +457,6 @@ class Quantizer:
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
# indexes for the inputs that needs to be observed
standalone_module_observed_input_idxs: List[int] = []
graph_inputs = []
for node in model.graph.nodes:
if node.op == 'placeholder':
@ -487,14 +512,15 @@ class Quantizer:
# parent
if qconfig is not None:
assert obj is not None
insert_observer_for_special_module(
obj, self.modules, prepare_custom_config_dict, qconfig,
node)
standalone_module_input_idxs = \
maybe_insert_observer_for_special_module(
obj, self.modules, prepare_custom_config_dict, qconfig,
node)
insert_observer_for_output_of_the_node(
node, obj, qconfig, self.modules, model, pattern,
self.activation_post_process_map, env,
observed_graph, load_arg, observed_node_names_set,
matched_nodes)
matched_nodes, standalone_module_input_idxs)
else:
env[node.name] = observed_graph.node_copy(node, load_arg)
@ -516,6 +542,19 @@ class Quantizer:
model = GraphModule(model, observed_graph)
self.save_state(model)
model = mark_observed_module(model)
if is_standalone_module:
assert result_node is not None
assert isinstance(result_node.args[0], Node), \
"standalone module only supports returning simple value currently"\
"(not tuple, dict etc.)"
# indicator for whether output is observed or not.
# This used for correctly quantize standalone modules
output_is_observed = \
result_node.args[0].name in observed_node_names_set
# these inputs are observed in parent
model._standalone_module_input_quantized_idxs = \
input_quantized_idxs
model._standalone_module_output_quantized_idxs = output_quantized_idxs
return model
def save_state(self, observed: GraphModule) -> None:
@ -569,8 +608,10 @@ class Quantizer:
""" standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.
Returns a quantized standalone module which accepts float input
and produces float output.
Returns a quantized standalone module, whether input/output is quantized is
specified by prepare_custom_config_dict, with
input_quantized_idxs, output_quantized_idxs, please
see docs for prepare_fx for details
"""
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
@ -627,36 +668,50 @@ class Quantizer:
else:
return env[n.name]
def load_arg(quantized: Optional[Union[List[Any], bool, Tuple[Any, ...]]]
def load_arg(quantized: Optional[Union[List[int], bool, Tuple[int, ...]]]
) -> Callable[[Node], Argument]:
"""
Input: quantized, which can be None, list, boolean or tuple
- if quantized is a list or tuple, then arg should be a list and
the args with corresponding indexes will be quantized
- if quantized is a boolean, then all args will be
quantized/not quantized
- if quantized is None, then we'll load the node as long as it
exists
- if quantized is a boolean, then all args will be
quantized/not quantized
- if quantized is an empty list or tuple, then it is the same as load_arg(quantized=False)
- if quantized is a list or tuple, then arg should be a list and
the args with corresponding indexes will be quantized
Output: fn which takes arg_or_args, and loads them from the
corresponding environment depending on the value of quantized.
"""
assert quantized is None or \
isinstance(quantized, (tuple, list, bool)), type(quantized)
if isinstance(quantized, (tuple, list)) and len(quantized) == 0:
# empty tuple or list means nothing is quantized
quantized = False
def load_arg_impl(arg_or_args):
if quantized is None:
# we'll update the format of `quantized`
# to better match arg_or_args
updated_quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] = quantized
if isinstance(quantized, (tuple, list)) and \
len(quantized) == 1 and isinstance(arg_or_args, Node):
# when argument is one Node instead of tuple, we just need to check
# 0 is in the quantized list
updated_quantized = 0 in quantized
if updated_quantized is None:
return map_arg(arg_or_args, load_x)
if isinstance(quantized, bool):
if isinstance(updated_quantized, bool):
return map_arg(
arg_or_args,
load_quantized if quantized else load_non_quantized)
elif isinstance(quantized, (tuple, list)):
load_quantized if updated_quantized else load_non_quantized)
elif isinstance(updated_quantized, (tuple, list)):
assert isinstance(arg_or_args, (tuple, list)), arg_or_args
loaded_args = []
# for now, we only support quantizing positional arguments
for i, a in enumerate(arg_or_args):
if i in quantized:
if i in updated_quantized:
loaded_args.append(map_arg(a, load_quantized))
else:
loaded_args.append(map_arg(a, load_non_quantized))
@ -690,10 +745,10 @@ class Quantizer:
def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool:
""" Check if output node is quantized or not """
assert self.modules is not None
# by default the output is expected to be quantized
# by default the output for a quantizable node is expected to be quantized
quantized = True
# Need to get correct quantized/non-quantized state for the output
# Need to get correct quantized/non-quantized state forn the output
# of CopyNode
if type(obj) in [
CopyNode,
@ -750,7 +805,7 @@ class Quantizer:
"output_quantized_idxs", [])
for node in model.graph.nodes:
if node.op == 'output':
if node.op == "output":
cur_output_node_idx = output_node_seen_cnt
output_node_seen_cnt += 1
if cur_output_node_idx in output_quantized_idxs:
@ -775,12 +830,19 @@ class Quantizer:
quantized = False
else:
assert obj is not None
# We will get whether the output is quantized or not before
# convert for standalone module and after convert
# for non-standalone module, since _standalone_module_output_quantized_idxs
# is only available in observed standalone module
if is_observed_standalone_module_node:
out_quant_idxs = self.modules[node.target]._standalone_module_output_quantized_idxs
assert len(out_quant_idxs) <= 1, "Currently standalone only support one output"
quantized = 0 in out_quant_idxs
result = obj.convert(
self, node, load_arg, debug=debug,
convert_custom_config_dict=convert_custom_config_dict)
if is_observed_standalone_module_node:
quantized = False
else:
if not is_observed_standalone_module_node:
quantized = is_output_quantized(node, obj)
if quantized:

View file

@ -107,8 +107,20 @@ def _prepare_standalone_module_fx(
standalone_module means it a submodule that is not inlined in parent module,
and will be quantized separately as one unit.
Both input and output of the module are observed in the
standalone module.
How the standalone module is observed is specified by `input_quantized_idxs` and
`output_quantized_idxs` in the prepare_custom_config for the standalone module
Returns:
model(GraphModule): prepared standalone module
attributes:
_standalone_module_input_quantized_idxs(List[Int]): a list of
indexes for the graph input that is expected to be quantized,
same as input_quantized_idxs configuration provided
for the standalone module
_standalone_module_output_quantized_idxs(List[Int]): a list of
indexs for the graph output that is quantized
same as input_quantized_idxs configuration provided
for the standalone module
"""
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True)
@ -378,8 +390,9 @@ def _convert_standalone_module_fx(
r""" [Internal use only] Convert a model produced by :func:`~torch.quantization.prepare_standalone_module_fx`
and convert it to a quantized model
Return:
A quantized standalone module which accepts float input
and produces float output.
Returns a quantized standalone module, whether input/output is quantized is
specified by prepare_custom_config_dict, with
input_quantized_idxs, output_quantized_idxs, please
see docs for prepare_fx for details
"""
return _convert_fx(graph_module, debug, convert_custom_config_dict, is_standalone_module=True)