[quant][graphmode][fx] Merge quant_env and env (#59028)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59028

Previously we have an env and a quant_env in convert, which is a bit confusing,
in this PR we merged them and have a Dict[str, Tuple[Node, torch.dtype]]

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D28724863

fbshipit-source-id: 722a682c70d300a6ccd2b988786a1ac2d45e880e
This commit is contained in:
Jerry Zhang 2021-06-01 09:20:19 -07:00 committed by Facebook GitHub Bot
parent afdfd2288a
commit 10fc42eacc
4 changed files with 98 additions and 53 deletions

View file

@ -2351,6 +2351,34 @@ class TestQuantizeFx(QuantizationTestCase):
mp = prepare_fx(m, qconfig_dict)
mc = convert_fx(mp)
def test_shape_followed_by_quantized_op(self):
""" Make sure that shape does not dequantize
the Tensor before the next operator
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2)
self.conv2 = torch.nn.Conv2d(2, 2, 2)
def forward(self, x):
x = self.conv1(x)
s = x.shape
torch._assert(s == x.shape, "")
x = self.conv2(x)
return x
# make sure quantization runs
m = M().eval()
m = prepare_fx(m, {"": default_qconfig})
m = convert_fx(m)
m(torch.randn(2, 2, 4, 4))
node_occurrence = {
ns.call_function(torch.quantize_per_tensor): 1,
ns.call_method("dequantize"): 1
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):

View file

@ -114,6 +114,7 @@ class QuantizeHandler(ABC):
def should_mark_output_quantized_from_input_quantized_status(
self,
qconfig: QConfigAny
) -> bool:
"""
Returns true if after convert, the output of the matched pattern is
@ -1087,8 +1088,10 @@ class FixedQParamsOpQuantizeHandler(QuantizeHandler):
def should_mark_output_quantized_from_input_quantized_status(
self,
qconfig: QConfigAny
) -> bool:
return True
# FixQParamOps are the same as CopyNode in int8 quantization
return activation_dtype(qconfig) in [torch.quint8, torch.qint8]
# some qhandlers override the activations constructor
def get_activation_ctr(self, qconfig, pattern) -> Optional[Callable]:
@ -1180,6 +1183,7 @@ class FixedQParamsOpQuantizeHandler(QuantizeHandler):
class CopyNodeQuantizeHandler(QuantizeHandler):
def should_mark_output_quantized_from_input_quantized_status(
self,
qconfig: QConfigAny
) -> bool:
return True

View file

@ -65,6 +65,7 @@ from .quantization_patterns import (
from .utils import (
_parent_name,
all_node_args_have_no_tensors,
is_get_tensor_info_node,
quantize_node,
get_custom_module_class_keys,
get_new_attr_name_with_prefix,
@ -1129,9 +1130,7 @@ class Quantizer:
custom_module_classes=custom_module_classes)
self.quantized_graph = Graph()
env: Dict[str, Node] = {}
# TODO: merge quant_env with env
quant_env: Dict[str, Tuple[Node, torch.dtype]] = {}
env: Dict[str, Tuple[Node, Optional[torch.dtype]]] = {}
graph_inputs: List[str] = []
for node in model.graph.nodes:
@ -1139,29 +1138,29 @@ class Quantizer:
graph_inputs.append(node.name)
def load_non_quantized(n: Node) -> Node:
if n.name not in env:
assert n.name in quant_env, \
'trying to load float node but did not find ' + \
'node:' + n.name + \
' in quantized or non quantized environment, env: ' + \
str(env) + ' quant_env:' + str(quant_env)
quantized_node, _ = quant_env[n.name]
env[n.name] = Proxy(quantized_node).dequantize().node
return env[n.name]
assert n.name in env, \
'trying to load float node but did not find ' + \
'node:' + n.name + \
' in env: ' + \
str(env)
quantized_node, dtype = env[n.name]
if dtype and dtype != torch.float:
env[n.name] = Proxy(quantized_node).dequantize().node, torch.float
return env[n.name][0]
def load_quantized(n: Node) -> Node:
assert n.name in quant_env, \
assert n.name in env, \
'trying to load quantized node but did not find node:' + \
n.name + ' in quant environment:' + str(quant_env)
return quant_env[n.name][0]
n.name + ' in environment:' + str(env)
quantized_node, dtype = env[n.name]
assert dtype in [torch.quint8, torch.qint8, torch.float16], \
f'Expecting node {quantized_node} to be quantized but got dtype: {dtype}'
return quantized_node
def load_x(n: Node) -> Node:
assert n.name in env or n.name in quant_env, \
'node ' + n.name + ' does not exist in either environment'
if n.name in quant_env:
return quant_env[n.name][0]
else:
return env[n.name]
assert n.name in env, \
'node ' + n.name + ' does not exist in environment'
return env[n.name][0]
def load_arg(quantized: Optional[Union[List[int], bool, Tuple[int, ...]]]
) -> Callable[[Node], Argument]:
@ -1216,14 +1215,11 @@ class Quantizer:
def node_arg_is_quantized(node_arg: Any) -> bool:
if isinstance(node_arg, Node):
assert node_arg.name in env or node_arg.name in quant_env, \
assert node_arg.name in env, \
'Expecting node_arg to be in the environment'
# there might be nodes appearing in both environemnts, but
# quant_env will take precedence
if node_arg.name in quant_env:
return True
elif node_arg.name in env:
return False
if node_arg.name in env:
_, dtype = env[node_arg.name]
return dtype != torch.float
else:
return False
elif isinstance(node_arg, list):
@ -1238,7 +1234,7 @@ class Quantizer:
else:
return False
def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool:
def is_output_quantized(node: Node, obj: QuantizeHandler, qconfig: QConfigAny) -> bool:
""" Check if output node is quantized or not """
assert self.modules is not None
# by default the output for a quantizable node is expected to be quantized
@ -1248,7 +1244,7 @@ class Quantizer:
# of FixedQParamsQuantizeHandler
# TODO: we may want to try to remove the special case here
# as well
if obj.should_mark_output_quantized_from_input_quantized_status():
if obj.should_mark_output_quantized_from_input_quantized_status(qconfig):
assert node.op in [
'call_module',
'call_function',
@ -1277,19 +1273,19 @@ class Quantizer:
if observer_module.dtype == torch.float32:
# copy the observer for fp32 dtype
env[node.name] = self.quantized_graph.node_copy(
node, load_non_quantized)
elif isinstance(prev_node, Node) and prev_node.name in quant_env:
node, load_non_quantized), torch.float
elif isinstance(prev_node, Node) and prev_node.name in env:
# if previous node is already quantized, we'll just remove the
# activation_post_process
_, prev_dtype = quant_env[prev_node.name]
_, prev_dtype = env[prev_node.name]
current_dtype = observer_module.dtype
if prev_dtype == current_dtype:
quant_env[node.name] = quant_env[prev_node.name]
env[node.name] = env[prev_node.name]
else:
root_module = self.modules[""]
assert isinstance(prev_node, Node)
observer_dtype: torch.dtype = observer_module.dtype # type: ignore[assignment]
quant_env[node.name] = (
env[node.name] = (
quantize_node(self, load_non_quantized(prev_node),
observer_module, node, is_input=True),
observer_dtype)
@ -1298,7 +1294,7 @@ class Quantizer:
root_module = self.modules[""]
assert isinstance(node.args[0], Node)
dtype: torch.dtype = observer_module.dtype # type: ignore[assignment]
quant_env[node.name] = (
env[node.name] = (
quantize_node(self, load_non_quantized(node.args[0]),
observer_module, node, is_input=True),
dtype)
@ -1352,12 +1348,12 @@ class Quantizer:
self, node, qconfig, load_arg, is_reference=is_reference,
convert_custom_config_dict=convert_custom_config_dict)
if not is_observed_standalone_module_node:
quantized = is_output_quantized(node, obj)
quantized = is_output_quantized(node, obj, qconfig)
if quantized:
quant_env[node.name] = result, activation_dtype(qconfig)
env[node.name] = result, activation_dtype(qconfig)
else:
env[node.name] = result
env[node.name] = result, torch.float
continue
elif root_node is not None:
if qconfig is None:
@ -1371,7 +1367,7 @@ class Quantizer:
# function will not be called.
result = self.quantized_graph.node_copy(
node, load_non_quantized)
env[node.name] = result
env[node.name] = result, torch.float
continue
# handle activation post process calls
@ -1382,34 +1378,43 @@ class Quantizer:
cur_placeholder_node_idx = placeholder_node_seen_cnt
placeholder_node_seen_cnt += 1
if cur_placeholder_node_idx in input_quantized_idxs:
quant_env[node.name] = \
self.quantized_graph.node_copy(node, load_non_quantized), activation_dtype(qconfig) if qconfig else None
env[node.name] = \
self.quantized_graph.node_copy(
node, load_non_quantized), torch.quint8
else:
env[node.name] = \
self.quantized_graph.node_copy(node, load_non_quantized)
self.quantized_graph.node_copy(node, load_non_quantized), torch.float
else:
# copy quantized or non-quantized node
env[node.name] = \
self.quantized_graph.node_copy(node, load_non_quantized)
# get_tensor_info_node like shape works for both
# quantized and non-quantized input and output a non-Tensor
# (we use None for dtype currently for non-Tensors)
if is_get_tensor_info_node(node):
env[node.name] = \
self.quantized_graph.node_copy(node, load_x), None
else:
env[node.name] = \
self.quantized_graph.node_copy(node, load_non_quantized), torch.float
# remove activation post process
act_post_process_removed_graph = Graph()
env = {}
remove_env: Dict[str, Node] = {}
def load_arg_remove(a: Argument) -> Argument:
return map_arg(a, lambda node: remove_env[node.name])
def load_arg_simple(a: Argument) -> Argument:
return map_arg(a, lambda node: env[node.name])
for node in self.quantized_graph.nodes:
if node.op == 'output':
act_post_process_removed_graph.output(
map_arg(node.args[0], load_arg_simple))
map_arg(node.args[0], load_arg_remove))
continue
if node.op == 'call_module' and \
is_activation_post_process(self.modules[node.target]):
# remove activation post process node
env[node.name] = env[node.args[0].name]
remove_env[node.name] = remove_env[node.args[0].name]
else:
env[node.name] = act_post_process_removed_graph.node_copy(
node, load_arg_simple)
remove_env[node.name] = act_post_process_removed_graph.node_copy(
node, load_arg_remove)
# removes qconfig and activation_post_process modules
if _remove_qconfig_flag:

View file

@ -437,3 +437,11 @@ def node_bool_tensor_arg_indexes(node: Node) -> List[int]:
if node.op == "call_method" and node.target == "masked_fill":
return [1]
return []
def is_get_tensor_info_node(node: Node) -> bool:
""" Returns True if this node is a node that takes a Tensor as input and output some
meta information about the Tensor, e.g. shape, size etc.
"""
result: bool = \
node.op == "call_function" and node.target == getattr and node.args[1] == "shape" # type: ignore[assignment]
return result