mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[quant][graphmode][fx] Merge quant_env and env (#59028)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59028 Previously we have an env and a quant_env in convert, which is a bit confusing, in this PR we merged them and have a Dict[str, Tuple[Node, torch.dtype]] Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps Imported from OSS Reviewed By: vkuzo Differential Revision: D28724863 fbshipit-source-id: 722a682c70d300a6ccd2b988786a1ac2d45e880e
This commit is contained in:
parent
afdfd2288a
commit
10fc42eacc
4 changed files with 98 additions and 53 deletions
|
|
@ -2351,6 +2351,34 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
mp = prepare_fx(m, qconfig_dict)
|
||||
mc = convert_fx(mp)
|
||||
|
||||
def test_shape_followed_by_quantized_op(self):
|
||||
""" Make sure that shape does not dequantize
|
||||
the Tensor before the next operator
|
||||
"""
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(2, 2, 2)
|
||||
self.conv2 = torch.nn.Conv2d(2, 2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
s = x.shape
|
||||
torch._assert(s == x.shape, "")
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
# make sure quantization runs
|
||||
m = M().eval()
|
||||
m = prepare_fx(m, {"": default_qconfig})
|
||||
m = convert_fx(m)
|
||||
m(torch.randn(2, 2, 4, 4))
|
||||
node_occurrence = {
|
||||
ns.call_function(torch.quantize_per_tensor): 1,
|
||||
ns.call_method("dequantize"): 1
|
||||
}
|
||||
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
||||
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
class TestQuantizeFxOps(QuantizationTestCase):
|
||||
|
|
|
|||
|
|
@ -114,6 +114,7 @@ class QuantizeHandler(ABC):
|
|||
|
||||
def should_mark_output_quantized_from_input_quantized_status(
|
||||
self,
|
||||
qconfig: QConfigAny
|
||||
) -> bool:
|
||||
"""
|
||||
Returns true if after convert, the output of the matched pattern is
|
||||
|
|
@ -1087,8 +1088,10 @@ class FixedQParamsOpQuantizeHandler(QuantizeHandler):
|
|||
|
||||
def should_mark_output_quantized_from_input_quantized_status(
|
||||
self,
|
||||
qconfig: QConfigAny
|
||||
) -> bool:
|
||||
return True
|
||||
# FixQParamOps are the same as CopyNode in int8 quantization
|
||||
return activation_dtype(qconfig) in [torch.quint8, torch.qint8]
|
||||
|
||||
# some qhandlers override the activations constructor
|
||||
def get_activation_ctr(self, qconfig, pattern) -> Optional[Callable]:
|
||||
|
|
@ -1180,6 +1183,7 @@ class FixedQParamsOpQuantizeHandler(QuantizeHandler):
|
|||
class CopyNodeQuantizeHandler(QuantizeHandler):
|
||||
def should_mark_output_quantized_from_input_quantized_status(
|
||||
self,
|
||||
qconfig: QConfigAny
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ from .quantization_patterns import (
|
|||
from .utils import (
|
||||
_parent_name,
|
||||
all_node_args_have_no_tensors,
|
||||
is_get_tensor_info_node,
|
||||
quantize_node,
|
||||
get_custom_module_class_keys,
|
||||
get_new_attr_name_with_prefix,
|
||||
|
|
@ -1129,9 +1130,7 @@ class Quantizer:
|
|||
custom_module_classes=custom_module_classes)
|
||||
|
||||
self.quantized_graph = Graph()
|
||||
env: Dict[str, Node] = {}
|
||||
# TODO: merge quant_env with env
|
||||
quant_env: Dict[str, Tuple[Node, torch.dtype]] = {}
|
||||
env: Dict[str, Tuple[Node, Optional[torch.dtype]]] = {}
|
||||
|
||||
graph_inputs: List[str] = []
|
||||
for node in model.graph.nodes:
|
||||
|
|
@ -1139,29 +1138,29 @@ class Quantizer:
|
|||
graph_inputs.append(node.name)
|
||||
|
||||
def load_non_quantized(n: Node) -> Node:
|
||||
if n.name not in env:
|
||||
assert n.name in quant_env, \
|
||||
'trying to load float node but did not find ' + \
|
||||
'node:' + n.name + \
|
||||
' in quantized or non quantized environment, env: ' + \
|
||||
str(env) + ' quant_env:' + str(quant_env)
|
||||
quantized_node, _ = quant_env[n.name]
|
||||
env[n.name] = Proxy(quantized_node).dequantize().node
|
||||
return env[n.name]
|
||||
assert n.name in env, \
|
||||
'trying to load float node but did not find ' + \
|
||||
'node:' + n.name + \
|
||||
' in env: ' + \
|
||||
str(env)
|
||||
quantized_node, dtype = env[n.name]
|
||||
if dtype and dtype != torch.float:
|
||||
env[n.name] = Proxy(quantized_node).dequantize().node, torch.float
|
||||
return env[n.name][0]
|
||||
|
||||
def load_quantized(n: Node) -> Node:
|
||||
assert n.name in quant_env, \
|
||||
assert n.name in env, \
|
||||
'trying to load quantized node but did not find node:' + \
|
||||
n.name + ' in quant environment:' + str(quant_env)
|
||||
return quant_env[n.name][0]
|
||||
n.name + ' in environment:' + str(env)
|
||||
quantized_node, dtype = env[n.name]
|
||||
assert dtype in [torch.quint8, torch.qint8, torch.float16], \
|
||||
f'Expecting node {quantized_node} to be quantized but got dtype: {dtype}'
|
||||
return quantized_node
|
||||
|
||||
def load_x(n: Node) -> Node:
|
||||
assert n.name in env or n.name in quant_env, \
|
||||
'node ' + n.name + ' does not exist in either environment'
|
||||
if n.name in quant_env:
|
||||
return quant_env[n.name][0]
|
||||
else:
|
||||
return env[n.name]
|
||||
assert n.name in env, \
|
||||
'node ' + n.name + ' does not exist in environment'
|
||||
return env[n.name][0]
|
||||
|
||||
def load_arg(quantized: Optional[Union[List[int], bool, Tuple[int, ...]]]
|
||||
) -> Callable[[Node], Argument]:
|
||||
|
|
@ -1216,14 +1215,11 @@ class Quantizer:
|
|||
|
||||
def node_arg_is_quantized(node_arg: Any) -> bool:
|
||||
if isinstance(node_arg, Node):
|
||||
assert node_arg.name in env or node_arg.name in quant_env, \
|
||||
assert node_arg.name in env, \
|
||||
'Expecting node_arg to be in the environment'
|
||||
# there might be nodes appearing in both environemnts, but
|
||||
# quant_env will take precedence
|
||||
if node_arg.name in quant_env:
|
||||
return True
|
||||
elif node_arg.name in env:
|
||||
return False
|
||||
if node_arg.name in env:
|
||||
_, dtype = env[node_arg.name]
|
||||
return dtype != torch.float
|
||||
else:
|
||||
return False
|
||||
elif isinstance(node_arg, list):
|
||||
|
|
@ -1238,7 +1234,7 @@ class Quantizer:
|
|||
else:
|
||||
return False
|
||||
|
||||
def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool:
|
||||
def is_output_quantized(node: Node, obj: QuantizeHandler, qconfig: QConfigAny) -> bool:
|
||||
""" Check if output node is quantized or not """
|
||||
assert self.modules is not None
|
||||
# by default the output for a quantizable node is expected to be quantized
|
||||
|
|
@ -1248,7 +1244,7 @@ class Quantizer:
|
|||
# of FixedQParamsQuantizeHandler
|
||||
# TODO: we may want to try to remove the special case here
|
||||
# as well
|
||||
if obj.should_mark_output_quantized_from_input_quantized_status():
|
||||
if obj.should_mark_output_quantized_from_input_quantized_status(qconfig):
|
||||
assert node.op in [
|
||||
'call_module',
|
||||
'call_function',
|
||||
|
|
@ -1277,19 +1273,19 @@ class Quantizer:
|
|||
if observer_module.dtype == torch.float32:
|
||||
# copy the observer for fp32 dtype
|
||||
env[node.name] = self.quantized_graph.node_copy(
|
||||
node, load_non_quantized)
|
||||
elif isinstance(prev_node, Node) and prev_node.name in quant_env:
|
||||
node, load_non_quantized), torch.float
|
||||
elif isinstance(prev_node, Node) and prev_node.name in env:
|
||||
# if previous node is already quantized, we'll just remove the
|
||||
# activation_post_process
|
||||
_, prev_dtype = quant_env[prev_node.name]
|
||||
_, prev_dtype = env[prev_node.name]
|
||||
current_dtype = observer_module.dtype
|
||||
if prev_dtype == current_dtype:
|
||||
quant_env[node.name] = quant_env[prev_node.name]
|
||||
env[node.name] = env[prev_node.name]
|
||||
else:
|
||||
root_module = self.modules[""]
|
||||
assert isinstance(prev_node, Node)
|
||||
observer_dtype: torch.dtype = observer_module.dtype # type: ignore[assignment]
|
||||
quant_env[node.name] = (
|
||||
env[node.name] = (
|
||||
quantize_node(self, load_non_quantized(prev_node),
|
||||
observer_module, node, is_input=True),
|
||||
observer_dtype)
|
||||
|
|
@ -1298,7 +1294,7 @@ class Quantizer:
|
|||
root_module = self.modules[""]
|
||||
assert isinstance(node.args[0], Node)
|
||||
dtype: torch.dtype = observer_module.dtype # type: ignore[assignment]
|
||||
quant_env[node.name] = (
|
||||
env[node.name] = (
|
||||
quantize_node(self, load_non_quantized(node.args[0]),
|
||||
observer_module, node, is_input=True),
|
||||
dtype)
|
||||
|
|
@ -1352,12 +1348,12 @@ class Quantizer:
|
|||
self, node, qconfig, load_arg, is_reference=is_reference,
|
||||
convert_custom_config_dict=convert_custom_config_dict)
|
||||
if not is_observed_standalone_module_node:
|
||||
quantized = is_output_quantized(node, obj)
|
||||
quantized = is_output_quantized(node, obj, qconfig)
|
||||
|
||||
if quantized:
|
||||
quant_env[node.name] = result, activation_dtype(qconfig)
|
||||
env[node.name] = result, activation_dtype(qconfig)
|
||||
else:
|
||||
env[node.name] = result
|
||||
env[node.name] = result, torch.float
|
||||
continue
|
||||
elif root_node is not None:
|
||||
if qconfig is None:
|
||||
|
|
@ -1371,7 +1367,7 @@ class Quantizer:
|
|||
# function will not be called.
|
||||
result = self.quantized_graph.node_copy(
|
||||
node, load_non_quantized)
|
||||
env[node.name] = result
|
||||
env[node.name] = result, torch.float
|
||||
continue
|
||||
|
||||
# handle activation post process calls
|
||||
|
|
@ -1382,34 +1378,43 @@ class Quantizer:
|
|||
cur_placeholder_node_idx = placeholder_node_seen_cnt
|
||||
placeholder_node_seen_cnt += 1
|
||||
if cur_placeholder_node_idx in input_quantized_idxs:
|
||||
quant_env[node.name] = \
|
||||
self.quantized_graph.node_copy(node, load_non_quantized), activation_dtype(qconfig) if qconfig else None
|
||||
env[node.name] = \
|
||||
self.quantized_graph.node_copy(
|
||||
node, load_non_quantized), torch.quint8
|
||||
else:
|
||||
env[node.name] = \
|
||||
self.quantized_graph.node_copy(node, load_non_quantized)
|
||||
self.quantized_graph.node_copy(node, load_non_quantized), torch.float
|
||||
else:
|
||||
# copy quantized or non-quantized node
|
||||
env[node.name] = \
|
||||
self.quantized_graph.node_copy(node, load_non_quantized)
|
||||
# get_tensor_info_node like shape works for both
|
||||
# quantized and non-quantized input and output a non-Tensor
|
||||
# (we use None for dtype currently for non-Tensors)
|
||||
if is_get_tensor_info_node(node):
|
||||
env[node.name] = \
|
||||
self.quantized_graph.node_copy(node, load_x), None
|
||||
else:
|
||||
env[node.name] = \
|
||||
self.quantized_graph.node_copy(node, load_non_quantized), torch.float
|
||||
|
||||
# remove activation post process
|
||||
act_post_process_removed_graph = Graph()
|
||||
env = {}
|
||||
remove_env: Dict[str, Node] = {}
|
||||
|
||||
def load_arg_remove(a: Argument) -> Argument:
|
||||
return map_arg(a, lambda node: remove_env[node.name])
|
||||
|
||||
def load_arg_simple(a: Argument) -> Argument:
|
||||
return map_arg(a, lambda node: env[node.name])
|
||||
for node in self.quantized_graph.nodes:
|
||||
if node.op == 'output':
|
||||
act_post_process_removed_graph.output(
|
||||
map_arg(node.args[0], load_arg_simple))
|
||||
map_arg(node.args[0], load_arg_remove))
|
||||
continue
|
||||
if node.op == 'call_module' and \
|
||||
is_activation_post_process(self.modules[node.target]):
|
||||
# remove activation post process node
|
||||
env[node.name] = env[node.args[0].name]
|
||||
remove_env[node.name] = remove_env[node.args[0].name]
|
||||
else:
|
||||
env[node.name] = act_post_process_removed_graph.node_copy(
|
||||
node, load_arg_simple)
|
||||
remove_env[node.name] = act_post_process_removed_graph.node_copy(
|
||||
node, load_arg_remove)
|
||||
|
||||
# removes qconfig and activation_post_process modules
|
||||
if _remove_qconfig_flag:
|
||||
|
|
|
|||
|
|
@ -437,3 +437,11 @@ def node_bool_tensor_arg_indexes(node: Node) -> List[int]:
|
|||
if node.op == "call_method" and node.target == "masked_fill":
|
||||
return [1]
|
||||
return []
|
||||
|
||||
def is_get_tensor_info_node(node: Node) -> bool:
|
||||
""" Returns True if this node is a node that takes a Tensor as input and output some
|
||||
meta information about the Tensor, e.g. shape, size etc.
|
||||
"""
|
||||
result: bool = \
|
||||
node.op == "call_function" and node.target == getattr and node.args[1] == "shape" # type: ignore[assignment]
|
||||
return result
|
||||
|
|
|
|||
Loading…
Reference in a new issue