mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[quant][graphmode][fx] Standalone module support {input/output}_quantized_idxs (#49754)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49754 This PR adds the support for {input/output}_quantized_idxs for standalone module. if input_quantized_idxs = [] and output_quantized_idxs = [], the standalone module will be expecting float input and produce float output, and will quantize the input and dequantize output internally if input_quantized_idxs = [0] and otuput_qiuantized_idxs = [0], the standalone module will be expecting quantized input and produce quantized output, the input will be quantized in the parent module, and output will be dequantized in the parent module as well, this is similar to current quantized modules like nn.quantized.Conv2d For more details, please see the test case Test Plan: python test/test_quantization.py TestQuantizeFx.test_standalone_module Imported from OSS Reviewed By: raghuramank100 Differential Revision: D25684692 fbshipit-source-id: 900360e01c0e35b26fe85f4a887dc1fd6f7bfb66
This commit is contained in:
parent
69b1373587
commit
89b4899ea5
5 changed files with 221 additions and 74 deletions
|
|
@ -570,7 +570,16 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
m = convert_fx(m)
|
||||
m(tensor_input)
|
||||
|
||||
def test_standalone_module(self):
|
||||
def _test_standalone_module(
|
||||
self,
|
||||
interface_config,
|
||||
prepare_count_check,
|
||||
standalone_prepare_count_check,
|
||||
convert_count_check,
|
||||
standalone_convert_count_check):
|
||||
""" Test standalone module with different quantized input/quantized output
|
||||
configurations
|
||||
"""
|
||||
class StandaloneModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -610,45 +619,32 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach())
|
||||
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach())
|
||||
|
||||
qconfig_dict = {"": default_qconfig}
|
||||
config_name = {"standalone_module_name": [("standalone", None, None)]}
|
||||
config_class = {"standalone_module_class": [(StandaloneModule, None, None)]}
|
||||
for prepare_config in [config_name, config_class]:
|
||||
for is_name in [True, False]:
|
||||
if is_name:
|
||||
prepare_config = {
|
||||
"standalone_module_name": [("standalone", None, interface_config)]
|
||||
}
|
||||
else:
|
||||
prepare_config = {
|
||||
"standalone_module_class": [(StandaloneModule, None, interface_config)]
|
||||
}
|
||||
|
||||
original_m_copy = copy.deepcopy(original_m)
|
||||
original_ref_m_copy = copy.deepcopy(original_ref_m)
|
||||
|
||||
qconfig_dict = {"": default_qconfig}
|
||||
# check prepared model
|
||||
m = prepare_fx(
|
||||
original_m_copy, qconfig_dict, prepare_custom_config_dict=prepare_config)
|
||||
# calibration
|
||||
m(data)
|
||||
# input and output of first conv, observer for standalone module
|
||||
# will be inserted in the standalone module itself
|
||||
count_check = {
|
||||
ns.call_module(torch.quantization.MinMaxObserver): 2
|
||||
}
|
||||
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
|
||||
# for input and output of conv in the standalone module
|
||||
count_check = {
|
||||
ns.call_module(torch.quantization.MinMaxObserver): 2
|
||||
}
|
||||
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check)
|
||||
self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check)
|
||||
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check)
|
||||
|
||||
# check converted/quantized model
|
||||
m = convert_fx(m)
|
||||
count_check = {
|
||||
ns.call_function(torch.quantize_per_tensor) : 1,
|
||||
ns.call_module(nnq.Conv2d) : 1,
|
||||
ns.call_method('dequantize') : 1,
|
||||
}
|
||||
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
|
||||
count_check = {
|
||||
# standalone module will take float as input and output
|
||||
# so we'll see quantize and dequantize in the modoule
|
||||
ns.call_function(torch.quantize_per_tensor) : 1,
|
||||
ns.call_module(nnq.Conv2d): 1,
|
||||
ns.call_method('dequantize') : 1,
|
||||
}
|
||||
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check)
|
||||
self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check)
|
||||
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check)
|
||||
res = m(data)
|
||||
|
||||
# quantize the reference model
|
||||
|
|
@ -658,6 +654,76 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
ref_res = ref_m(data)
|
||||
self.assertEqual(res, ref_res)
|
||||
|
||||
def test_standalone_module_float_interface(self):
|
||||
float_interface_config = {
|
||||
"input_quantized_idxs": [], # float input
|
||||
"output_quantized_idxs": [], # float output
|
||||
}
|
||||
interface_config = float_interface_config
|
||||
# input and output of first conv, observer for standalone module
|
||||
# will be inserted in the standalone module itself
|
||||
prepare_count_check = {
|
||||
ns.call_module(torch.quantization.MinMaxObserver): 2
|
||||
}
|
||||
# for input and output of conv in the standalone module
|
||||
standalone_prepare_count_check = {
|
||||
ns.call_module(torch.quantization.MinMaxObserver): 2
|
||||
}
|
||||
convert_count_check = {
|
||||
ns.call_function(torch.quantize_per_tensor) : 1,
|
||||
ns.call_module(nnq.Conv2d) : 1,
|
||||
ns.call_method("dequantize") : 1,
|
||||
}
|
||||
standalone_convert_count_check = {
|
||||
# standalone module will take float as input and output
|
||||
# so we'll see quantize and dequantize in the modoule
|
||||
ns.call_function(torch.quantize_per_tensor) : 1,
|
||||
ns.call_module(nnq.Conv2d): 1,
|
||||
ns.call_method("dequantize") : 1,
|
||||
}
|
||||
self._test_standalone_module(
|
||||
interface_config,
|
||||
prepare_count_check,
|
||||
standalone_prepare_count_check,
|
||||
convert_count_check,
|
||||
standalone_convert_count_check)
|
||||
|
||||
def test_standalone_module_quantized_interface(self):
|
||||
quantized_interface_config = {
|
||||
"input_quantized_idxs": [0], # quantized input
|
||||
"output_quantized_idxs": [0], # quantized output
|
||||
}
|
||||
interface_config = quantized_interface_config
|
||||
# observer for input and output of first conv
|
||||
prepare_count_check = {
|
||||
ns.call_module(torch.quantization.MinMaxObserver): 2
|
||||
}
|
||||
# for output of conv in the standalone module
|
||||
standalone_prepare_count_check = {
|
||||
ns.call_module(torch.quantization.MinMaxObserver): 1
|
||||
}
|
||||
convert_count_check = {
|
||||
# quantizing input for conv
|
||||
ns.call_function(torch.quantize_per_tensor) : 1,
|
||||
ns.call_module(nnq.Conv2d) : 1,
|
||||
# dequantizing output of standalone module
|
||||
ns.call_method("dequantize") : 1,
|
||||
}
|
||||
standalone_convert_count_check = {
|
||||
# quantization of input happens in parent module
|
||||
# quantization of output happens in the quantized conv module
|
||||
ns.call_function(torch.quantize_per_tensor) : 0,
|
||||
ns.call_module(nnq.Conv2d): 1,
|
||||
# dequantization for output happens in parent module
|
||||
ns.call_method("dequantize") : 0,
|
||||
}
|
||||
self._test_standalone_module(
|
||||
interface_config,
|
||||
prepare_count_check,
|
||||
standalone_prepare_count_check,
|
||||
convert_count_check,
|
||||
standalone_convert_count_check)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_qconfig_none(self):
|
||||
class M(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ import torch
|
|||
import copy
|
||||
from torch.fx import GraphModule # type: ignore
|
||||
from torch.fx.graph import Graph
|
||||
from typing import Union, Dict, Any
|
||||
from typing import Union, Dict, Any, List
|
||||
|
||||
class ObservedGraphModule(GraphModule):
|
||||
|
||||
def get_preserved_attr_names(self):
|
||||
def get_preserved_attr_names(self) -> List[str]:
|
||||
return ['_activation_post_process_map',
|
||||
'_patterns',
|
||||
'_qconfig_map',
|
||||
|
|
@ -35,6 +35,12 @@ def is_observed_module(module: Any) -> bool:
|
|||
return isinstance(module, ObservedGraphModule)
|
||||
|
||||
class ObservedStandaloneGraphModule(ObservedGraphModule):
|
||||
def get_preserved_attr_names(self) -> List[str] :
|
||||
return super().get_preserved_attr_names() + [
|
||||
"_standalone_module_input_quantized_idxs",
|
||||
"_standalone_module_output_quantized_idxs"
|
||||
]
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
fake_mod = torch.nn.Module()
|
||||
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
|
||||
|
|
|
|||
|
|
@ -753,10 +753,10 @@ class StandaloneModuleQuantizeHandler(QuantizeHandler):
|
|||
qconfig = quantizer.qconfig_map[node.name]
|
||||
convert = torch.quantization.quantize_fx._convert_standalone_module_fx # type: ignore
|
||||
observed_standalone_module = quantizer.modules[node.target]
|
||||
input_quantized_idxs = observed_standalone_module._standalone_module_input_quantized_idxs
|
||||
quantized_standalone_module = convert(observed_standalone_module, debug=debug)
|
||||
parent_name, name = _parent_name(node.target)
|
||||
# update the modules dict
|
||||
setattr(quantizer.modules[parent_name], name, quantized_standalone_module)
|
||||
quantizer.modules[node.target] = quantized_standalone_module
|
||||
# standalone module takes float input
|
||||
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
||||
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=input_quantized_idxs))
|
||||
|
|
|
|||
|
|
@ -102,14 +102,15 @@ def insert_observer(
|
|||
'call_module', observer_name, (load_arg(node),), {})
|
||||
observed_node_names_set.add(node.name)
|
||||
|
||||
def insert_observer_for_special_module(
|
||||
def maybe_insert_observer_for_special_module(
|
||||
quantize_handler: QuantizeHandler, modules: Dict[str, torch.nn.Module],
|
||||
prepare_custom_config_dict: Any, qconfig: Any, node: Node):
|
||||
prepare_custom_config_dict: Any, qconfig: Any, node: Node) -> Optional[List[int]]:
|
||||
""" Insert observer for custom module and standalone module
|
||||
Returns: standalone_module_input_idxs: the indexs for inputs that
|
||||
needs to be observed by parent module
|
||||
"""
|
||||
assert modules is not None
|
||||
standalone_module_input_idxs = None
|
||||
if isinstance(quantize_handler, CustomModuleQuantizeHandler):
|
||||
custom_module = modules[node.target] # type: ignore
|
||||
custom_module_class_mapping = prepare_custom_config_dict.get(
|
||||
|
|
@ -129,19 +130,22 @@ def insert_observer_for_special_module(
|
|||
class_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_class_configs}
|
||||
name_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_name_configs}
|
||||
config = class_config_map.get(type(standalone_module), (None, None))
|
||||
config = name_config_map.get(node.target, (None, None))
|
||||
standalone_module_qconfig_dict = {"": qconfig} if config[0] is None else config[0]
|
||||
standalone_prepare_config_dict = {} if config[1] is None else config[1]
|
||||
config = name_config_map.get(node.target, config)
|
||||
sm_qconfig_dict = {"": qconfig} if config[0] is None else config[0]
|
||||
sm_prepare_config_dict = {} if config[1] is None else config[1]
|
||||
prepare = \
|
||||
torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore
|
||||
observed_standalone_module = \
|
||||
prepare(standalone_module, standalone_module_qconfig_dict, standalone_prepare_config_dict)
|
||||
prepare(standalone_module, sm_qconfig_dict, sm_prepare_config_dict)
|
||||
standalone_module_input_idxs = observed_standalone_module.\
|
||||
_standalone_module_input_quantized_idxs
|
||||
observed_standalone_module = mark_observed_standalone_module(
|
||||
observed_standalone_module)
|
||||
parent_name, name = _parent_name(node.target)
|
||||
setattr(modules[parent_name], name,
|
||||
observed_standalone_module)
|
||||
modules[node.target] = observed_standalone_module # type: ignore
|
||||
return standalone_module_input_idxs
|
||||
|
||||
def insert_observer_for_output_of_the_node(
|
||||
node: Node,
|
||||
|
|
@ -155,7 +159,8 @@ def insert_observer_for_output_of_the_node(
|
|||
observed_graph: Graph,
|
||||
load_arg: Callable,
|
||||
observed_node_names_set: Set[str],
|
||||
matched_nodes: Optional[List[Node]]):
|
||||
matched_nodes: Optional[List[Node]],
|
||||
standalone_module_input_idxs: Optional[List[int]]):
|
||||
""" Insert observer/fake_quantize module for output of the observed
|
||||
module if needed
|
||||
"""
|
||||
|
|
@ -215,8 +220,11 @@ def insert_observer_for_output_of_the_node(
|
|||
observed_node_names_set.add(node.name)
|
||||
elif isinstance(quantize_handler,
|
||||
StandaloneModuleQuantizeHandler):
|
||||
# output is observed in the standalone module
|
||||
return
|
||||
assert node.op == "call_module"
|
||||
output_is_quantized = 0 in \
|
||||
modules[node.target]._standalone_module_output_quantized_idxs # type: ignore
|
||||
if output_is_quantized:
|
||||
observed_node_names_set.add(node.name)
|
||||
elif (quantize_handler.all_node_args and
|
||||
input_output_observed(quantize_handler)):
|
||||
# observer for outputs
|
||||
|
|
@ -226,6 +234,16 @@ def insert_observer_for_output_of_the_node(
|
|||
activation_post_process_map, env, observed_graph,
|
||||
load_arg, observed_node_names_set)
|
||||
|
||||
# insert observer for input of standalone module
|
||||
if standalone_module_input_idxs is not None:
|
||||
for idx in standalone_module_input_idxs:
|
||||
if node.args[idx].name not in observed_node_names_set: # type: ignore
|
||||
new_observer = qconfig.activation()
|
||||
insert_observer(
|
||||
node, new_observer, model,
|
||||
activation_post_process_map, env, observed_graph,
|
||||
load_arg, observed_node_names_set)
|
||||
|
||||
def insert_observer_for_input_arg_of_observed_node(
|
||||
node: Node, observed_node_names_set: Set[str],
|
||||
quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]],
|
||||
|
|
@ -373,10 +391,19 @@ class Quantizer:
|
|||
""" standalone_module means it a submodule that is not inlined in
|
||||
parent module, and will be quantized separately as one unit.
|
||||
|
||||
When we are preparing a standalone module:
|
||||
both input and output are observed in prepared standalone module
|
||||
How the standalone module is observed is specified by `input_quantized_idxs` and
|
||||
`output_quantized_idxs` in the prepare_custom_config for the standalone module
|
||||
Returns:
|
||||
model(GraphModule): prepared standalone module
|
||||
attributes:
|
||||
_standalone_module_input_quantized_idxs(List[Int]): a list of
|
||||
indexes for the graph input that is expected to be quantized,
|
||||
same as input_quantized_idxs configuration provided
|
||||
for the standalone module
|
||||
_standalone_module_output_quantized_idxs(List[Int]): a list of
|
||||
indexs for the graph output that is quantized
|
||||
same as input_quantized_idxs configuration provided
|
||||
for the standalone module
|
||||
"""
|
||||
if prepare_custom_config_dict is None:
|
||||
prepare_custom_config_dict = {}
|
||||
|
|
@ -430,8 +457,6 @@ class Quantizer:
|
|||
def load_arg(a):
|
||||
return map_arg(a, lambda node: env[node.name])
|
||||
|
||||
# indexes for the inputs that needs to be observed
|
||||
standalone_module_observed_input_idxs: List[int] = []
|
||||
graph_inputs = []
|
||||
for node in model.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
|
|
@ -487,14 +512,15 @@ class Quantizer:
|
|||
# parent
|
||||
if qconfig is not None:
|
||||
assert obj is not None
|
||||
insert_observer_for_special_module(
|
||||
obj, self.modules, prepare_custom_config_dict, qconfig,
|
||||
node)
|
||||
standalone_module_input_idxs = \
|
||||
maybe_insert_observer_for_special_module(
|
||||
obj, self.modules, prepare_custom_config_dict, qconfig,
|
||||
node)
|
||||
insert_observer_for_output_of_the_node(
|
||||
node, obj, qconfig, self.modules, model, pattern,
|
||||
self.activation_post_process_map, env,
|
||||
observed_graph, load_arg, observed_node_names_set,
|
||||
matched_nodes)
|
||||
matched_nodes, standalone_module_input_idxs)
|
||||
else:
|
||||
env[node.name] = observed_graph.node_copy(node, load_arg)
|
||||
|
||||
|
|
@ -516,6 +542,19 @@ class Quantizer:
|
|||
model = GraphModule(model, observed_graph)
|
||||
self.save_state(model)
|
||||
model = mark_observed_module(model)
|
||||
if is_standalone_module:
|
||||
assert result_node is not None
|
||||
assert isinstance(result_node.args[0], Node), \
|
||||
"standalone module only supports returning simple value currently"\
|
||||
"(not tuple, dict etc.)"
|
||||
# indicator for whether output is observed or not.
|
||||
# This used for correctly quantize standalone modules
|
||||
output_is_observed = \
|
||||
result_node.args[0].name in observed_node_names_set
|
||||
# these inputs are observed in parent
|
||||
model._standalone_module_input_quantized_idxs = \
|
||||
input_quantized_idxs
|
||||
model._standalone_module_output_quantized_idxs = output_quantized_idxs
|
||||
return model
|
||||
|
||||
def save_state(self, observed: GraphModule) -> None:
|
||||
|
|
@ -569,8 +608,10 @@ class Quantizer:
|
|||
""" standalone_module means it a submodule that is not inlined in
|
||||
parent module, and will be quantized separately as one unit.
|
||||
|
||||
Returns a quantized standalone module which accepts float input
|
||||
and produces float output.
|
||||
Returns a quantized standalone module, whether input/output is quantized is
|
||||
specified by prepare_custom_config_dict, with
|
||||
input_quantized_idxs, output_quantized_idxs, please
|
||||
see docs for prepare_fx for details
|
||||
"""
|
||||
if convert_custom_config_dict is None:
|
||||
convert_custom_config_dict = {}
|
||||
|
|
@ -627,36 +668,50 @@ class Quantizer:
|
|||
else:
|
||||
return env[n.name]
|
||||
|
||||
def load_arg(quantized: Optional[Union[List[Any], bool, Tuple[Any, ...]]]
|
||||
def load_arg(quantized: Optional[Union[List[int], bool, Tuple[int, ...]]]
|
||||
) -> Callable[[Node], Argument]:
|
||||
"""
|
||||
Input: quantized, which can be None, list, boolean or tuple
|
||||
- if quantized is a list or tuple, then arg should be a list and
|
||||
the args with corresponding indexes will be quantized
|
||||
- if quantized is a boolean, then all args will be
|
||||
quantized/not quantized
|
||||
- if quantized is None, then we'll load the node as long as it
|
||||
exists
|
||||
- if quantized is a boolean, then all args will be
|
||||
quantized/not quantized
|
||||
- if quantized is an empty list or tuple, then it is the same as load_arg(quantized=False)
|
||||
- if quantized is a list or tuple, then arg should be a list and
|
||||
the args with corresponding indexes will be quantized
|
||||
|
||||
Output: fn which takes arg_or_args, and loads them from the
|
||||
corresponding environment depending on the value of quantized.
|
||||
"""
|
||||
assert quantized is None or \
|
||||
isinstance(quantized, (tuple, list, bool)), type(quantized)
|
||||
if isinstance(quantized, (tuple, list)) and len(quantized) == 0:
|
||||
# empty tuple or list means nothing is quantized
|
||||
quantized = False
|
||||
|
||||
def load_arg_impl(arg_or_args):
|
||||
if quantized is None:
|
||||
# we'll update the format of `quantized`
|
||||
# to better match arg_or_args
|
||||
updated_quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] = quantized
|
||||
|
||||
if isinstance(quantized, (tuple, list)) and \
|
||||
len(quantized) == 1 and isinstance(arg_or_args, Node):
|
||||
# when argument is one Node instead of tuple, we just need to check
|
||||
# 0 is in the quantized list
|
||||
updated_quantized = 0 in quantized
|
||||
|
||||
if updated_quantized is None:
|
||||
return map_arg(arg_or_args, load_x)
|
||||
if isinstance(quantized, bool):
|
||||
if isinstance(updated_quantized, bool):
|
||||
return map_arg(
|
||||
arg_or_args,
|
||||
load_quantized if quantized else load_non_quantized)
|
||||
elif isinstance(quantized, (tuple, list)):
|
||||
load_quantized if updated_quantized else load_non_quantized)
|
||||
elif isinstance(updated_quantized, (tuple, list)):
|
||||
assert isinstance(arg_or_args, (tuple, list)), arg_or_args
|
||||
loaded_args = []
|
||||
# for now, we only support quantizing positional arguments
|
||||
for i, a in enumerate(arg_or_args):
|
||||
if i in quantized:
|
||||
if i in updated_quantized:
|
||||
loaded_args.append(map_arg(a, load_quantized))
|
||||
else:
|
||||
loaded_args.append(map_arg(a, load_non_quantized))
|
||||
|
|
@ -690,10 +745,10 @@ class Quantizer:
|
|||
def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool:
|
||||
""" Check if output node is quantized or not """
|
||||
assert self.modules is not None
|
||||
# by default the output is expected to be quantized
|
||||
# by default the output for a quantizable node is expected to be quantized
|
||||
quantized = True
|
||||
|
||||
# Need to get correct quantized/non-quantized state for the output
|
||||
# Need to get correct quantized/non-quantized state forn the output
|
||||
# of CopyNode
|
||||
if type(obj) in [
|
||||
CopyNode,
|
||||
|
|
@ -750,7 +805,7 @@ class Quantizer:
|
|||
"output_quantized_idxs", [])
|
||||
|
||||
for node in model.graph.nodes:
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
cur_output_node_idx = output_node_seen_cnt
|
||||
output_node_seen_cnt += 1
|
||||
if cur_output_node_idx in output_quantized_idxs:
|
||||
|
|
@ -775,12 +830,19 @@ class Quantizer:
|
|||
quantized = False
|
||||
else:
|
||||
assert obj is not None
|
||||
# We will get whether the output is quantized or not before
|
||||
# convert for standalone module and after convert
|
||||
# for non-standalone module, since _standalone_module_output_quantized_idxs
|
||||
# is only available in observed standalone module
|
||||
if is_observed_standalone_module_node:
|
||||
out_quant_idxs = self.modules[node.target]._standalone_module_output_quantized_idxs
|
||||
assert len(out_quant_idxs) <= 1, "Currently standalone only support one output"
|
||||
quantized = 0 in out_quant_idxs
|
||||
|
||||
result = obj.convert(
|
||||
self, node, load_arg, debug=debug,
|
||||
convert_custom_config_dict=convert_custom_config_dict)
|
||||
if is_observed_standalone_module_node:
|
||||
quantized = False
|
||||
else:
|
||||
if not is_observed_standalone_module_node:
|
||||
quantized = is_output_quantized(node, obj)
|
||||
|
||||
if quantized:
|
||||
|
|
|
|||
|
|
@ -107,8 +107,20 @@ def _prepare_standalone_module_fx(
|
|||
standalone_module means it a submodule that is not inlined in parent module,
|
||||
and will be quantized separately as one unit.
|
||||
|
||||
Both input and output of the module are observed in the
|
||||
standalone module.
|
||||
How the standalone module is observed is specified by `input_quantized_idxs` and
|
||||
`output_quantized_idxs` in the prepare_custom_config for the standalone module
|
||||
|
||||
Returns:
|
||||
model(GraphModule): prepared standalone module
|
||||
attributes:
|
||||
_standalone_module_input_quantized_idxs(List[Int]): a list of
|
||||
indexes for the graph input that is expected to be quantized,
|
||||
same as input_quantized_idxs configuration provided
|
||||
for the standalone module
|
||||
_standalone_module_output_quantized_idxs(List[Int]): a list of
|
||||
indexs for the graph output that is quantized
|
||||
same as input_quantized_idxs configuration provided
|
||||
for the standalone module
|
||||
"""
|
||||
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True)
|
||||
|
||||
|
|
@ -378,8 +390,9 @@ def _convert_standalone_module_fx(
|
|||
r""" [Internal use only] Convert a model produced by :func:`~torch.quantization.prepare_standalone_module_fx`
|
||||
and convert it to a quantized model
|
||||
|
||||
Return:
|
||||
A quantized standalone module which accepts float input
|
||||
and produces float output.
|
||||
Returns a quantized standalone module, whether input/output is quantized is
|
||||
specified by prepare_custom_config_dict, with
|
||||
input_quantized_idxs, output_quantized_idxs, please
|
||||
see docs for prepare_fx for details
|
||||
"""
|
||||
return _convert_fx(graph_module, debug, convert_custom_config_dict, is_standalone_module=True)
|
||||
|
|
|
|||
Loading…
Reference in a new issue