mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[quant][fx] Make scale, zero_point buffers in the model and use FQN (for quantized ops) (#51166)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51166 Currently scale and zero_point values are stored as constant values in the graph. This prevents these values from being updated in the graph and also does not enable saving these values to state_dict After this PR we store scale/zero_point values for quantized ops as buffers in the root module and createe get_attr nodes for them in the graph. We also use the FQN of the module where the quantized ops are present to name these attributes so that they can be uniquely identified and mapped to quantized ops. Test Plan: python test/test_quantization.py TestQuantizeFx.test_qparams_buffers Imported from OSS Reviewed By: jerryzh168 Differential Revision: D26092965 fbshipit-source-id: b549b2d3dccb45c5d38415ce95a09c26f5bd590b
This commit is contained in:
parent
096adf4b8b
commit
4c3f59b70e
5 changed files with 118 additions and 13 deletions
|
|
@ -1495,6 +1495,58 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
str(context.exception) ==
|
||||
'Per channel weight observer is not supported yet for ConvTranspose{n}d.')
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_qparams_buffers(self):
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.w = torch.ones(5, 5)
|
||||
self.b = torch.zeros(5)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.linear(x, self.w, self.b)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mods1 = torch.nn.Sequential(
|
||||
Linear(),
|
||||
Linear()
|
||||
)
|
||||
self.mods2 = Linear()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.mods1(x)
|
||||
x = self.mods2(x)
|
||||
return x
|
||||
|
||||
model = M().eval()
|
||||
qconfig_dict = {"": default_qconfig}
|
||||
m = prepare_fx(model, qconfig_dict)
|
||||
m(torch.rand(5, 5))
|
||||
|
||||
m = convert_fx(m)
|
||||
keys = m.state_dict().keys()
|
||||
|
||||
scale_count = 0
|
||||
zero_point_count = 0
|
||||
for k in keys:
|
||||
if 'scale' in k:
|
||||
scale_count = scale_count + 1
|
||||
elif 'zero_point' in k:
|
||||
zero_point_count = zero_point_count + 1
|
||||
|
||||
# Expect each quantized linear op to have a scale and zero point
|
||||
self.assertTrue(scale_count == 3, "Expect each quantized linear op to have a scale in state_dict")
|
||||
self.assertTrue(zero_point_count == 3, "Expect each quantized linear op to have a zero_point in state_dict")
|
||||
# ensure it runs
|
||||
m(torch.rand(5, 5))
|
||||
# ensure it is scriptable
|
||||
scripted = torch.jit.script(m)
|
||||
scripted_keys = scripted.state_dict().keys()
|
||||
self.assertTrue(scripted_keys == keys, "Expected the scripted model to preserve the state_dict")
|
||||
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
class TestQuantizeFxOps(QuantizationTestCase):
|
||||
"""Unit tests for individual ops
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ class ObservedGraphModule(GraphModule):
|
|||
return ['_activation_post_process_map',
|
||||
'_patterns',
|
||||
'_qconfig_map',
|
||||
'_prepare_custom_config_dict']
|
||||
'_prepare_custom_config_dict',
|
||||
'_node_name_to_scope']
|
||||
|
||||
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph):
|
||||
preserved_attrs = dict()
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from .utils import (
|
|||
quantize_node,
|
||||
get_per_tensor_qparams,
|
||||
get_linear_prepack_op_for_dtype,
|
||||
create_qparam_nodes,
|
||||
)
|
||||
|
||||
from .quantization_types import QuantizerCls
|
||||
|
|
@ -112,13 +113,17 @@ class Add(QuantizeHandler):
|
|||
scale, zero_point = activation_post_process.calculate_qparams()
|
||||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
|
||||
|
||||
if self.relu_node is not None:
|
||||
op = torch.ops.quantized.add_relu
|
||||
else:
|
||||
op = torch.ops.quantized.add
|
||||
kwargs = {**self.add_node.kwargs, 'scale': scale, 'zero_point': zero_point}
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_function', op, load_arg(quantized=True)(self.add_node.args), kwargs)
|
||||
kwargs = {**self.add_node.kwargs}
|
||||
add_args = (*load_arg(quantized=True)(self.add_node.args), scale_arg, zero_point_arg)
|
||||
op = quantizer.quantized_graph.create_node(
|
||||
'call_function', op, add_args, kwargs)
|
||||
return op
|
||||
|
||||
# TODO: merge with Add
|
||||
@register_quant_pattern(operator.mul)
|
||||
|
|
@ -161,12 +166,16 @@ class Mul(QuantizeHandler):
|
|||
scale, zero_point = activation_post_process.calculate_qparams()
|
||||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
|
||||
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
|
||||
|
||||
if self.relu_node is not None:
|
||||
op = torch.ops.quantized.mul_relu
|
||||
else:
|
||||
op = torch.ops.quantized.mul
|
||||
kwargs = {**self.mul_node.kwargs, 'scale': scale, 'zero_point': zero_point}
|
||||
return quantizer.quantized_graph.create_node('call_function', op, load_arg(quantized=True)(self.mul_node.args), kwargs)
|
||||
kwargs = {**self.mul_node.kwargs}
|
||||
args = (*load_arg(quantized=True)(self.mul_node.args), scale_arg, zero_point_arg)
|
||||
return quantizer.quantized_graph.create_node('call_function', op, args, kwargs)
|
||||
|
||||
@register_quant_pattern(torch.cat)
|
||||
class Cat(QuantizeHandler):
|
||||
|
|
@ -179,7 +188,10 @@ class Cat(QuantizeHandler):
|
|||
scale, zero_point = activation_post_process.calculate_qparams()
|
||||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
kwargs = {**load_arg(quantized=False)(node.kwargs), 'scale': scale, 'zero_point': zero_point}
|
||||
|
||||
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
|
||||
|
||||
kwargs = {**load_arg(quantized=False)(node.kwargs), 'scale': scale_arg, 'zero_point': zero_point_arg}
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_function', torch.ops.quantized.cat, load_arg(quantized=[0])(node.args), kwargs)
|
||||
|
||||
|
|
@ -311,7 +323,8 @@ class ConvRelu(QuantizeHandler):
|
|||
act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name
|
||||
activation_post_process = quantizer.activation_post_process_map[act_post_process_name]
|
||||
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
|
||||
qconv_args = (conv_input, packed_weight, scale, zero_point)
|
||||
scale_node, zero_point_node = create_qparam_nodes(quantizer, self.conv_node.name, scale, zero_point)
|
||||
qconv_args = (conv_input, packed_weight, scale_node, zero_point_node)
|
||||
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_function', qconv_op, qconv_args, kwargs)
|
||||
|
|
@ -464,7 +477,10 @@ class LinearReLUQuantizeHandler(QuantizeHandler):
|
|||
activation_post_process = \
|
||||
quantizer.activation_post_process_map[act_post_process_name]
|
||||
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
|
||||
qlinear_args = (linear_input, packed_weight, scale, zero_point)
|
||||
|
||||
scale_node, zero_point_node = create_qparam_nodes(quantizer, self.linear_node.name, scale, zero_point)
|
||||
|
||||
qlinear_args = (linear_input, packed_weight, scale_node, zero_point_node)
|
||||
return quantizer.quantized_graph.create_node(
|
||||
"call_function", qlinear_op, qlinear_args, kwargs)
|
||||
else:
|
||||
|
|
@ -643,11 +659,13 @@ class DefaultNode(QuantizeHandler):
|
|||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
|
||||
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
|
||||
|
||||
assert not isinstance(node.target, str), "Expecting node.target for "
|
||||
"call_function to be a function instead of a string"
|
||||
quantized_op = get_quantized_operator(node.target)
|
||||
args = load_arg(quantized=[0])(node.args)
|
||||
kwargs = {**load_arg(quantized=False)(node.kwargs), "output_scale": scale, "output_zero_point": zero_point}
|
||||
kwargs = {**load_arg(quantized=False)(node.kwargs), "output_scale": scale_arg, "output_zero_point": zero_point_arg}
|
||||
if quantized_op in ARGS_TO_SKIP:
|
||||
args_to_skip = ARGS_TO_SKIP[quantized_op]
|
||||
for arg in args_to_skip:
|
||||
|
|
@ -666,9 +684,12 @@ class ELU(QuantizeHandler):
|
|||
scale, zero_point = activation_post_process.calculate_qparams()
|
||||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
|
||||
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
|
||||
|
||||
quantized_op = get_quantized_operator(node.target)
|
||||
args = load_arg(quantized=[0])(node.args)
|
||||
kwargs = {**load_arg(quantized=False)(node.kwargs), 'output_scale': scale, 'output_zero_point': zero_point}
|
||||
kwargs = {**load_arg(quantized=False)(node.kwargs), 'output_scale': scale_arg, 'output_zero_point': zero_point_arg}
|
||||
kwargs.pop('inplace')
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_function', quantized_op, args, kwargs)
|
||||
|
|
|
|||
|
|
@ -332,6 +332,9 @@ class Quantizer:
|
|||
self.patterns: Optional[Dict[Pattern, QuantizeHandler]] = None
|
||||
self.prepare_custom_config_dict: Dict[str, Any] = {}
|
||||
|
||||
# mapping from node name to the scope of the module which contains the node.
|
||||
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
|
||||
|
||||
|
||||
def _qat_swap_modules(
|
||||
self, root: torch.nn.Module,
|
||||
|
|
@ -347,7 +350,7 @@ class Quantizer:
|
|||
qconfig_dict: Any,
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]]) -> None:
|
||||
global_qconfig = qconfig_dict.get("", None)
|
||||
|
||||
self.node_name_to_scope = node_name_to_scope
|
||||
self.qconfig_map = dict()
|
||||
for node in input_graph.nodes:
|
||||
if node.op == "get_attr":
|
||||
|
|
@ -567,6 +570,7 @@ class Quantizer:
|
|||
observed._qconfig_map = self.qconfig_map # type: ignore
|
||||
observed._prepare_custom_config_dict = \
|
||||
self.prepare_custom_config_dict # type: ignore
|
||||
observed._node_name_to_scope = self.node_name_to_scope # type: ignore
|
||||
|
||||
def restore_state(self, observed: GraphModule) -> None:
|
||||
assert is_observed_module(observed), \
|
||||
|
|
@ -577,6 +581,7 @@ class Quantizer:
|
|||
self.qconfig_map = observed._qconfig_map # type: ignore
|
||||
self.prepare_custom_config_dict = \
|
||||
observed._prepare_custom_config_dict # type: ignore
|
||||
self.node_name_to_scope = observed._node_name_to_scope # type: ignore
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ from torch.fx.graph import (
|
|||
Node,
|
||||
)
|
||||
|
||||
from typing import Callable, Optional, List, Dict, Any, Set
|
||||
from typing import Callable, Optional, List, Dict, Any, Set, Tuple
|
||||
from .quantization_types import QuantizerCls
|
||||
|
||||
# turn foo.bar -> ['foo', 'bar']
|
||||
def _parent_name(target):
|
||||
|
|
@ -185,6 +186,8 @@ def get_linear_prepack_op_for_dtype(dtype):
|
|||
# >> new_name = get_new_observer_name(module)
|
||||
# new_name will be an unused attribute name on module, e.g. `_observer_1`
|
||||
def get_new_attr_name_with_prefix(prefix: str) -> Callable:
|
||||
prefix = prefix.replace(".", "_")
|
||||
|
||||
def get_new_attr_name(module: torch.nn.Module):
|
||||
def get_attr_name(i: int):
|
||||
return prefix + str(i)
|
||||
|
|
@ -260,3 +263,26 @@ def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
|
|||
)
|
||||
device = next(iter(devices)) if len(devices) > 0 else None
|
||||
return device
|
||||
|
||||
def create_getattr_from_value(module: GraphModule, graph: Graph, prefix: str, value: Any) -> Node:
|
||||
"""
|
||||
Given a value of any type, creates a getattr node corresponding to the value and
|
||||
registers the value as a buffer to the module.
|
||||
"""
|
||||
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
|
||||
attr_name = get_new_attr_name(module)
|
||||
module.register_buffer(attr_name, torch.tensor(value))
|
||||
# Create get_attr with value
|
||||
attr_node = graph.create_node("get_attr", attr_name)
|
||||
return attr_node
|
||||
|
||||
def create_qparam_nodes(quantizer: QuantizerCls, node_name: str, scale: Any, zero_point: Any) -> Tuple[Node, Node]:
|
||||
"""
|
||||
Create getattr nodes in the quantizer graph for scale and zero point values.
|
||||
The nodes are registered with the root_module of the model.
|
||||
"""
|
||||
root_module = quantizer.modules['']
|
||||
module_path, _ = quantizer.node_name_to_scope[node_name]
|
||||
scale_node = create_getattr_from_value(root_module, quantizer.quantized_graph, (module_path + "_scale_"), scale)
|
||||
zero_point_node = create_getattr_from_value(root_module, quantizer.quantized_graph, (module_path + "_zero_point_"), zero_point)
|
||||
return (scale_node, zero_point_node)
|
||||
|
|
|
|||
Loading…
Reference in a new issue