[quant][fx] Make scale, zero_point buffers in the model and use FQN (for quantized ops) (#51166)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51166

Currently scale and zero_point values are stored as constant values in the graph.
This prevents these values from being updated in the graph and also does not enable saving
these values to state_dict

After this PR we store scale/zero_point values for quantized ops as buffers in the root module
and createe get_attr nodes for them in the graph.

We also use the FQN of the module where the quantized ops are present to name these attributes so
that they can be uniquely  identified and mapped to quantized ops.

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_qparams_buffers

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D26092965

fbshipit-source-id: b549b2d3dccb45c5d38415ce95a09c26f5bd590b
This commit is contained in:
Supriya Rao 2021-01-28 08:29:57 -08:00 committed by Facebook GitHub Bot
parent 096adf4b8b
commit 4c3f59b70e
5 changed files with 118 additions and 13 deletions

View file

@ -1495,6 +1495,58 @@ class TestQuantizeFx(QuantizationTestCase):
str(context.exception) ==
'Per channel weight observer is not supported yet for ConvTranspose{n}d.')
@skipIfNoFBGEMM
def test_qparams_buffers(self):
class Linear(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.ones(5, 5)
self.b = torch.zeros(5)
def forward(self, x):
return torch.nn.functional.linear(x, self.w, self.b)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.mods1 = torch.nn.Sequential(
Linear(),
Linear()
)
self.mods2 = Linear()
def forward(self, x):
x = self.mods1(x)
x = self.mods2(x)
return x
model = M().eval()
qconfig_dict = {"": default_qconfig}
m = prepare_fx(model, qconfig_dict)
m(torch.rand(5, 5))
m = convert_fx(m)
keys = m.state_dict().keys()
scale_count = 0
zero_point_count = 0
for k in keys:
if 'scale' in k:
scale_count = scale_count + 1
elif 'zero_point' in k:
zero_point_count = zero_point_count + 1
# Expect each quantized linear op to have a scale and zero point
self.assertTrue(scale_count == 3, "Expect each quantized linear op to have a scale in state_dict")
self.assertTrue(zero_point_count == 3, "Expect each quantized linear op to have a zero_point in state_dict")
# ensure it runs
m(torch.rand(5, 5))
# ensure it is scriptable
scripted = torch.jit.script(m)
scripted_keys = scripted.state_dict().keys()
self.assertTrue(scripted_keys == keys, "Expected the scripted model to preserve the state_dict")
@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops

View file

@ -10,7 +10,8 @@ class ObservedGraphModule(GraphModule):
return ['_activation_post_process_map',
'_patterns',
'_qconfig_map',
'_prepare_custom_config_dict']
'_prepare_custom_config_dict',
'_node_name_to_scope']
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph):
preserved_attrs = dict()

View file

@ -32,6 +32,7 @@ from .utils import (
quantize_node,
get_per_tensor_qparams,
get_linear_prepack_op_for_dtype,
create_qparam_nodes,
)
from .quantization_types import QuantizerCls
@ -112,13 +113,17 @@ class Add(QuantizeHandler):
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
if self.relu_node is not None:
op = torch.ops.quantized.add_relu
else:
op = torch.ops.quantized.add
kwargs = {**self.add_node.kwargs, 'scale': scale, 'zero_point': zero_point}
return quantizer.quantized_graph.create_node(
'call_function', op, load_arg(quantized=True)(self.add_node.args), kwargs)
kwargs = {**self.add_node.kwargs}
add_args = (*load_arg(quantized=True)(self.add_node.args), scale_arg, zero_point_arg)
op = quantizer.quantized_graph.create_node(
'call_function', op, add_args, kwargs)
return op
# TODO: merge with Add
@register_quant_pattern(operator.mul)
@ -161,12 +166,16 @@ class Mul(QuantizeHandler):
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
if self.relu_node is not None:
op = torch.ops.quantized.mul_relu
else:
op = torch.ops.quantized.mul
kwargs = {**self.mul_node.kwargs, 'scale': scale, 'zero_point': zero_point}
return quantizer.quantized_graph.create_node('call_function', op, load_arg(quantized=True)(self.mul_node.args), kwargs)
kwargs = {**self.mul_node.kwargs}
args = (*load_arg(quantized=True)(self.mul_node.args), scale_arg, zero_point_arg)
return quantizer.quantized_graph.create_node('call_function', op, args, kwargs)
@register_quant_pattern(torch.cat)
class Cat(QuantizeHandler):
@ -179,7 +188,10 @@ class Cat(QuantizeHandler):
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
kwargs = {**load_arg(quantized=False)(node.kwargs), 'scale': scale, 'zero_point': zero_point}
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
kwargs = {**load_arg(quantized=False)(node.kwargs), 'scale': scale_arg, 'zero_point': zero_point_arg}
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.cat, load_arg(quantized=[0])(node.args), kwargs)
@ -311,7 +323,8 @@ class ConvRelu(QuantizeHandler):
act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name
activation_post_process = quantizer.activation_post_process_map[act_post_process_name]
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
qconv_args = (conv_input, packed_weight, scale, zero_point)
scale_node, zero_point_node = create_qparam_nodes(quantizer, self.conv_node.name, scale, zero_point)
qconv_args = (conv_input, packed_weight, scale_node, zero_point_node)
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
return quantizer.quantized_graph.create_node(
'call_function', qconv_op, qconv_args, kwargs)
@ -464,7 +477,10 @@ class LinearReLUQuantizeHandler(QuantizeHandler):
activation_post_process = \
quantizer.activation_post_process_map[act_post_process_name]
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
qlinear_args = (linear_input, packed_weight, scale, zero_point)
scale_node, zero_point_node = create_qparam_nodes(quantizer, self.linear_node.name, scale, zero_point)
qlinear_args = (linear_input, packed_weight, scale_node, zero_point_node)
return quantizer.quantized_graph.create_node(
"call_function", qlinear_op, qlinear_args, kwargs)
else:
@ -643,11 +659,13 @@ class DefaultNode(QuantizeHandler):
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
assert not isinstance(node.target, str), "Expecting node.target for "
"call_function to be a function instead of a string"
quantized_op = get_quantized_operator(node.target)
args = load_arg(quantized=[0])(node.args)
kwargs = {**load_arg(quantized=False)(node.kwargs), "output_scale": scale, "output_zero_point": zero_point}
kwargs = {**load_arg(quantized=False)(node.kwargs), "output_scale": scale_arg, "output_zero_point": zero_point_arg}
if quantized_op in ARGS_TO_SKIP:
args_to_skip = ARGS_TO_SKIP[quantized_op]
for arg in args_to_skip:
@ -666,9 +684,12 @@ class ELU(QuantizeHandler):
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
quantized_op = get_quantized_operator(node.target)
args = load_arg(quantized=[0])(node.args)
kwargs = {**load_arg(quantized=False)(node.kwargs), 'output_scale': scale, 'output_zero_point': zero_point}
kwargs = {**load_arg(quantized=False)(node.kwargs), 'output_scale': scale_arg, 'output_zero_point': zero_point_arg}
kwargs.pop('inplace')
return quantizer.quantized_graph.create_node(
'call_function', quantized_op, args, kwargs)

View file

@ -332,6 +332,9 @@ class Quantizer:
self.patterns: Optional[Dict[Pattern, QuantizeHandler]] = None
self.prepare_custom_config_dict: Dict[str, Any] = {}
# mapping from node name to the scope of the module which contains the node.
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
def _qat_swap_modules(
self, root: torch.nn.Module,
@ -347,7 +350,7 @@ class Quantizer:
qconfig_dict: Any,
node_name_to_scope: Dict[str, Tuple[str, type]]) -> None:
global_qconfig = qconfig_dict.get("", None)
self.node_name_to_scope = node_name_to_scope
self.qconfig_map = dict()
for node in input_graph.nodes:
if node.op == "get_attr":
@ -567,6 +570,7 @@ class Quantizer:
observed._qconfig_map = self.qconfig_map # type: ignore
observed._prepare_custom_config_dict = \
self.prepare_custom_config_dict # type: ignore
observed._node_name_to_scope = self.node_name_to_scope # type: ignore
def restore_state(self, observed: GraphModule) -> None:
assert is_observed_module(observed), \
@ -577,6 +581,7 @@ class Quantizer:
self.qconfig_map = observed._qconfig_map # type: ignore
self.prepare_custom_config_dict = \
observed._prepare_custom_config_dict # type: ignore
self.node_name_to_scope = observed._node_name_to_scope # type: ignore
def prepare(
self,

View file

@ -9,7 +9,8 @@ from torch.fx.graph import (
Node,
)
from typing import Callable, Optional, List, Dict, Any, Set
from typing import Callable, Optional, List, Dict, Any, Set, Tuple
from .quantization_types import QuantizerCls
# turn foo.bar -> ['foo', 'bar']
def _parent_name(target):
@ -185,6 +186,8 @@ def get_linear_prepack_op_for_dtype(dtype):
# >> new_name = get_new_observer_name(module)
# new_name will be an unused attribute name on module, e.g. `_observer_1`
def get_new_attr_name_with_prefix(prefix: str) -> Callable:
prefix = prefix.replace(".", "_")
def get_new_attr_name(module: torch.nn.Module):
def get_attr_name(i: int):
return prefix + str(i)
@ -260,3 +263,26 @@ def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
)
device = next(iter(devices)) if len(devices) > 0 else None
return device
def create_getattr_from_value(module: GraphModule, graph: Graph, prefix: str, value: Any) -> Node:
"""
Given a value of any type, creates a getattr node corresponding to the value and
registers the value as a buffer to the module.
"""
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
attr_name = get_new_attr_name(module)
module.register_buffer(attr_name, torch.tensor(value))
# Create get_attr with value
attr_node = graph.create_node("get_attr", attr_name)
return attr_node
def create_qparam_nodes(quantizer: QuantizerCls, node_name: str, scale: Any, zero_point: Any) -> Tuple[Node, Node]:
"""
Create getattr nodes in the quantizer graph for scale and zero point values.
The nodes are registered with the root_module of the model.
"""
root_module = quantizer.modules['']
module_path, _ = quantizer.node_name_to_scope[node_name]
scale_node = create_getattr_from_value(root_module, quantizer.quantized_graph, (module_path + "_scale_"), scale)
zero_point_node = create_getattr_from_value(root_module, quantizer.quantized_graph, (module_path + "_zero_point_"), zero_point)
return (scale_node, zero_point_node)