[quant] added dq->op->q quantization patterns for GELU and softmax ops (#56004)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56004

added reference pattern support for GELU, softmax and bmm for int dtypes. For GELU and Softmax, this consisted of adding reference patterns to the default node handler for int dtypes. Note GELU and softmax patterns are not registered since they do not have a proper quantized kernel which means they would either add unnecessary dequant and quant ops to the network, or they would simply error. This can be circumvented with custom qconfig usage as in test_gelu_reference

bmm was added within binary ops along with some significant changes to how that code is structured. Theoretically the reference pattern used for bmm could be applied to other dtypes. This was not enabled because of issues relating to Line 1323 in quantize.py. In essence, the prepare step does not know whether an op will use a reference pattern or not, so for ops that are supported with one dtype in reference and one dtype normally, this has the potential to cause issues. This is difficult to get aorund with the is_reference flag being available in the prepare step or discussed changes around separating

Test Plan:
python test/test_quantization.py TestQuantizeFxOps.test_gelu_reference
python test/test_quantization.py TestQuantizeFxOps.ttest_gelu_normal
python test/test_quantization.py TestQuantizeFxOps.test_softmax_reference
python test/test_quantization.py TestQuantizeFxOps.test_softmax_normal
python test/test_quantization.py TestQuantizeFxOps.test_silu_reference
python test/test_quantization.py TestQuantizeFxOps.test_bmm_int_reference
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestFuseFx
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxModels

Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D27818340

fbshipit-source-id: de65be0797035463cd2d1b0e4677d1a87f69143c
This commit is contained in:
Charles David Hernandez 2021-04-20 13:24:56 -07:00 committed by Facebook GitHub Bot
parent ea4af1511c
commit 6e1fc5cef8
3 changed files with 325 additions and 105 deletions

View file

@ -14,6 +14,8 @@ from torch.quantization.quantize_fx import (
prepare_qat_fx,
)
from torch.quantization.fx.quantization_patterns import DefaultNodeQuantizeHandler
from torch.quantization.fx.pattern_utils import (
is_match,
MatchAllNode,
@ -3125,28 +3127,151 @@ class TestQuantizeFxOps(QuantizationTestCase):
quantized_module, torch.ops.quantized.instance_norm,
skip_op_arg_for_functional=True)
def test_silu(self):
def _test_default_node_quant_handler_ops(
self, module, functional, qconfig, is_reference=True, node_list=None, additional_quant_pattern_dict=None
):
class M(torch.nn.Module):
def __init__(self, mod, func):
super().__init__()
self.module = mod()
self.functional = func
def forward(self, x):
x = self.module(x)
x = self.functional(x)
return x
if node_list is None:
node_list = []
if additional_quant_pattern_dict is None:
additional_quant_pattern_dict = {}
data = torch.randn((2, 2, 2, 2))
quant_type = QuantType.STATIC
prepare_custom_qconfig_dict = {"additional_quant_pattern": additional_quant_pattern_dict}
qconfig_dict = {"": qconfig}
m = M(module, functional).eval()
m_prep = torch.quantization.quantize_fx.prepare_fx(m, qconfig_dict, prepare_custom_qconfig_dict)
m_prep(data)
m_quant = torch.quantization.quantize_fx.convert_fx(m_prep, is_reference=is_reference)
m_quant(data)
self.checkGraphModuleNodes(m_quant, expected_node_list=node_list)
def test_gelu_normal(self):
module = torch.nn.GELU
functional = torch.nn.functional.gelu
qconfig = torch.quantization.get_default_qconfig("fbgemm")
is_reference = False
node_list = [
ns.call_module(module),
ns.call_function(functional),
]
self._test_default_node_quant_handler_ops(
module, functional, qconfig, is_reference, node_list)
def test_softmax_normal(self):
module = torch.nn.Softmax
functional = torch.nn.functional.softmax
qconfig = torch.quantization.get_default_qconfig("fbgemm")
is_reference = False
node_list = [
ns.call_module(module),
ns.call_function(functional),
]
self._test_default_node_quant_handler_ops(
module, functional, qconfig, is_reference, node_list)
def test_gelu_reference(self):
module = torch.nn.GELU
functional = torch.nn.functional.gelu
qconfig = torch.quantization.get_default_qconfig("fbgemm")
is_reference = True
node_list = [
ns.call_function(torch.quantize_per_tensor),
ns.call_method("dequantize"),
ns.call_module(module),
ns.call_function(torch.quantize_per_tensor),
ns.call_method('dequantize'),
ns.call_function(functional),
ns.call_function(torch.quantize_per_tensor),
ns.call_method('dequantize')
]
additional_patterns = {torch.nn.GELU: DefaultNodeQuantizeHandler,
torch.nn.functional.gelu: DefaultNodeQuantizeHandler}
self._test_default_node_quant_handler_ops(
module, functional, qconfig, is_reference, node_list, additional_patterns)
def test_softmax_reference(self):
module = torch.nn.Softmax
functional = torch.nn.functional.softmax
qconfig = torch.quantization.get_default_qconfig("fbgemm")
is_reference = True
node_list = [
ns.call_function(torch.quantize_per_tensor),
ns.call_method("dequantize"),
ns.call_module(module),
ns.call_function(torch.quantize_per_tensor),
ns.call_method('dequantize'),
ns.call_function(functional),
ns.call_function(torch.quantize_per_tensor),
ns.call_method('dequantize')
]
additional_patterns = {torch.nn.Softmax: DefaultNodeQuantizeHandler,
torch.nn.functional.softmax: DefaultNodeQuantizeHandler}
self._test_default_node_quant_handler_ops(
module, functional, qconfig, is_reference, node_list, additional_patterns)
def test_silu_reference(self):
module = torch.nn.SiLU
functional = torch.nn.functional.silu
qconfig = float16_static_qconfig
is_reference = True
node_list = [
ns.call_method("to"),
ns.call_method("dequantize"),
ns.call_module(module),
ns.call_method("to"),
ns.call_method('dequantize'),
ns.call_function(functional),
ns.call_method("to"),
ns.call_method('dequantize')
]
self._test_default_node_quant_handler_ops(
module, functional, qconfig, is_reference, node_list)
def test_bmm_int_reference(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.silu = torch.nn.SiLU()
self.bmm = torch.bmm
def forward(self, x):
x = self.silu(x)
x = torch.nn.functional.silu(x)
return x
def forward(self, x, y):
out = self.bmm(x, y)
return out
data = (torch.randn((2, 2, 2, 2), dtype=torch.float),)
quant_type = QuantType.STATIC
qconfig_dict = {
"": float16_static_qconfig
}
node_occurrence = {
ns.call_method("to"): 3
}
m = self.checkGraphModeFxOp(
M(), data, quant_type, custom_qconfig_dict=qconfig_dict,
expected_node_occurrence=node_occurrence)
data_x = torch.randn((2, 2, 2,))
data_y = torch.randn((2, 2, 2,))
qconfig_dict = {"": torch.quantization.get_default_qconfig("fbgemm")}
is_reference = True
node_list = [
ns.call_function(torch.quantize_per_tensor),
ns.call_function(torch.quantize_per_tensor),
ns.call_method('dequantize'),
ns.call_method('dequantize'),
ns.call_function(torch.bmm),
ns.call_function(torch.quantize_per_tensor),
ns.call_method('dequantize'),
]
m = M().eval()
m_prep = torch.quantization.quantize_fx.prepare_fx(m, qconfig_dict)
m_prep(data_x, data_y)
m_quant = torch.quantization.quantize_fx.convert_fx(m_prep, is_reference=is_reference)
m_quant(data_x, data_y)
self.checkGraphModuleNodes(m_quant, expected_node_list=node_list)
@skipIfNoFBGEMM
def test_clamp(self):

View file

@ -91,6 +91,9 @@ binary_op_all_dtypes = [
binary_op_float16_dtypes = [
(torch.float16, torch.float16, None)
]
binary_op_int8_dtypes = [
(torch.quint8, torch.qint8, None),
]
binary_op_supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype, torch.dtype, None]]] = {
operator.add: binary_op_all_dtypes,
torch.add: binary_op_all_dtypes,
@ -103,6 +106,9 @@ binary_op_supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype,
operator.truediv: binary_op_float16_dtypes,
torch.sum: binary_op_float16_dtypes
}
binary_reference_op_supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype, torch.dtype, None]]] = {
torch.bmm: binary_op_int8_dtypes,
}
@register_quant_pattern(operator.add)
@ -170,60 +176,93 @@ class BinaryOpQuantizeHandler(QuantizeHandler):
qconfig = quantizer.qconfig_map[node.name]
dtypes = get_qconfig_dtypes(qconfig)
# leave the op unquantized if the dtype combination is not supported
if dtypes not in binary_op_supported_dtypes[self.binary_op]:
warnings.warn(
"dtype combination: {} is not "
"supported by {} "
"supported dtype combinations are: {}".format(dtypes, self.binary_op, binary_op_supported_dtypes[self.binary_op]))
if self.relu_node:
op_out = quantizer.quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=False))
relu_args = [op_out]
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
return quantizer.quantized_graph.create_node(
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
else:
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
if dtypes in [(torch.quint8, torch.qint8, None)]:
assert self.quantized_binary_op is not None
if self.num_tensor_args == 1:
# add/mul scalar
first_arg = self.binary_op_node.args[0]
cache_for_no_tensor_check: Dict[Node, bool] = dict()
if isinstance(first_arg, Node) and (
not all_node_args_have_no_tensors(
first_arg, quantizer.modules, cache_for_no_tensor_check)):
quantized_index = 0
else:
quantized_index = 1
return quantizer.quantized_graph.create_node(
'call_function', self.quantized_binary_op,
load_arg(quantized=[quantized_index])(self.binary_op_node.args), self.binary_op_node.kwargs)
else:
if is_reference and self.binary_op in binary_reference_op_supported_dtypes and \
dtypes in binary_reference_op_supported_dtypes[self.binary_op]:
if dtypes in binary_op_int8_dtypes:
args = load_arg(quantized=[0, 1])(node.args)
args = load_arg(quantized=False)(node.args)
kwargs = load_arg(quantized=False)(node.kwargs)
op_out = quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
cur_idx = quantizer.activation_post_process_indexes[node.name]
activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
quantizer.activation_post_process_indexes[node.name] += 1
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
return quantize_node(
quantizer, op_out, activation_post_process,
node, is_input=False)
else:
warnings.warn(
"No implementation found for dtype combination: {}"
"for op {} with is_reference={} despite it being listed as supported"
"this should not happen".format(dtypes, self.binary_op, is_reference))
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
elif not is_reference and self.binary_op in binary_op_supported_dtypes and \
dtypes in binary_op_supported_dtypes[self.binary_op]:
if dtypes in [(torch.quint8, torch.qint8, None)]:
assert self.quantized_binary_op is not None
if self.num_tensor_args == 1:
# add/mul scalar
first_arg = self.binary_op_node.args[0]
cache_for_no_tensor_check: Dict[Node, bool] = dict()
if isinstance(first_arg, Node) and (
not all_node_args_have_no_tensors(
first_arg, quantizer.modules, cache_for_no_tensor_check)):
quantized_index = 0
else:
quantized_index = 1
if self.relu_node is not None:
op = torch.ops.quantized.add_relu
return quantizer.quantized_graph.create_node(
'call_function', self.quantized_binary_op,
load_arg(quantized=[quantized_index])(self.binary_op_node.args), self.binary_op_node.kwargs)
else:
op = torch.ops.quantized.add
kwargs = {**self.binary_op_node.kwargs}
add_args = (*load_arg(quantized=True)(self.binary_op_node.args), scale_arg, zero_point_arg)
op = quantizer.quantized_graph.create_node(
'call_function', self.quantized_binary_op, add_args, kwargs)
return op
cur_idx = quantizer.activation_post_process_indexes[node.name]
activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
quantizer.activation_post_process_indexes[node.name] += 1
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
if self.relu_node is not None:
op = torch.ops.quantized.add_relu
else:
op = torch.ops.quantized.add
kwargs = {**self.binary_op_node.kwargs}
add_args = (*load_arg(quantized=True)(self.binary_op_node.args), scale_arg, zero_point_arg)
op = quantizer.quantized_graph.create_node(
'call_function', self.quantized_binary_op, add_args, kwargs)
return op
else:
assert dtypes == (torch.float16, torch.float16, None)
# TODO (refactor) this is duplicated, maybe have a helper function
if self.relu_node:
op_out = quantizer.quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=False))
relu_args = [op_out]
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
return quantizer.quantized_graph.create_node(
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
else:
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
else:
assert dtypes == (torch.float16, torch.float16, None)
# TODO (refactor) this is duplicated, maybe have a helper function
# leave the op unquantized if the dtype,reference combination is not supported
warnings.warn(
"dtype combination: {} is not "
"supported by {} for is_reference={}. "
"Supported non-reference dtype combinations are: {} "
"Supported reference dtype combinations are: {}"
"".format(dtypes,
self.binary_op,
is_reference,
binary_op_supported_dtypes[self.binary_op],
(
[] if self.binary_op not in binary_reference_op_supported_dtypes.keys()
else binary_reference_op_supported_dtypes[self.binary_op]
)
)
)
if self.relu_node:
op_out = quantizer.quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=False))
relu_args = [op_out]
@ -234,6 +273,7 @@ class BinaryOpQuantizeHandler(QuantizeHandler):
else:
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
@register_quant_pattern(torch.cat)
class CatQuantizeHandler(QuantizeHandler):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
@ -775,14 +815,23 @@ ARGS_TO_SKIP = {
@register_quant_pattern(torch.nn.InstanceNorm3d)
@register_quant_pattern(torch.nn.LayerNorm)
@register_quant_pattern(torch.nn.SiLU)
# we currently only support reference patterns for these ops so they have been removed
# until they receive a proper fp16 kernel. To use the reference pattern, use a custom qconfig
# @register_quant_pattern(torch.nn.GELU)
# @register_quant_pattern(torch.nn.Softmax)
@register_quant_pattern(torch.nn.functional.hardswish)
@register_quant_pattern(torch.nn.functional.instance_norm)
@register_quant_pattern(torch.nn.functional.layer_norm)
@register_quant_pattern(torch.nn.functional.leaky_relu)
@register_quant_pattern(torch.nn.functional.silu)
# we currently only support reference patterns for these ops so they have been removed
# until they receive a proper fp16 kernel. To use the reference pattern, use a custom qconfig
# @register_quant_pattern(torch.nn.functional.gelu)
# @register_quant_pattern(torch.nn.functional.softmax)
class DefaultNodeQuantizeHandler(QuantizeHandler):
''' Common quantized op, first input and first output will be quantized
'''
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
if node.op == "call_function" or node.op == "call_method":
@ -822,11 +871,15 @@ class DefaultNodeQuantizeHandler(QuantizeHandler):
torch.nn.InstanceNorm3d: int8_dtypes,
torch.nn.LayerNorm: all_dtypes,
torch.nn.SiLU: fp16_dtypes,
torch.nn.GELU: int8_dtypes,
torch.nn.Softmax: int8_dtypes,
torch.nn.functional.hardswish: int8_dtypes,
torch.nn.functional.instance_norm: int8_dtypes,
torch.nn.functional.layer_norm: all_dtypes,
torch.nn.functional.leaky_relu: int8_dtypes,
torch.nn.functional.silu: fp16_dtypes,
torch.nn.functional.gelu: int8_dtypes,
torch.nn.functional.softmax: int8_dtypes,
}
qconfig = quantizer.qconfig_map[node.name]
dtypes = get_qconfig_dtypes(qconfig)
@ -836,50 +889,73 @@ class DefaultNodeQuantizeHandler(QuantizeHandler):
"supported by {} "
"supported dtype combinations are: {}".format(dtypes, self.op, supported_dtypes[self.op]))
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
# TODO: make helper functions for (torch.quint8, torch.qint8, None)
if dtypes in [(torch.quint8, torch.qint8, None)]:
cur_idx = quantizer.activation_post_process_indexes[node.name]
activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
quantizer.activation_post_process_indexes[node.name] += 1
if node.op == 'call_module':
module = quantizer.modules[node.target]
module.activation_post_process = activation_post_process
quantized_module_cls = get_static_quant_module_class(
type(module), additional_static_quant_mapping)
quantized_module = quantized_module_cls.from_float(module)
parent_name, name = _parent_name(node.target)
setattr(quantizer.modules[parent_name], name, quantized_module)
return quantizer.quantized_graph.create_node(
'call_module',
node.target,
load_arg(quantized=[0])(node.args),
load_arg(quantized=False)(node.kwargs))
if not is_reference:
if dtypes in [(torch.quint8, torch.qint8, None)]:
cur_idx = quantizer.activation_post_process_indexes[node.name]
activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
quantizer.activation_post_process_indexes[node.name] += 1
if node.op == 'call_module':
module = quantizer.modules[node.target]
module.activation_post_process = activation_post_process
quantized_module_cls = get_static_quant_module_class(
type(module), additional_static_quant_mapping)
quantized_module = quantized_module_cls.from_float(module)
parent_name, name = _parent_name(node.target)
setattr(quantizer.modules[parent_name], name, quantized_module)
return quantizer.quantized_graph.create_node(
'call_module',
node.target,
load_arg(quantized=[0])(node.args),
load_arg(quantized=False)(node.kwargs))
else:
assert node.op == "call_function"
# call_function
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
assert not isinstance(node.target, str), "Expecting node.target for "
"call_function to be a function instead of a string"
quantized_op = get_quantized_operator(node.target)
args = load_arg(quantized=[0])(node.args)
kwargs = {**load_arg(quantized=False)(node.kwargs), "output_scale": scale_arg,
"output_zero_point": zero_point_arg}
if quantized_op in ARGS_TO_SKIP:
args_to_skip = ARGS_TO_SKIP[quantized_op]
for arg in args_to_skip:
if arg in kwargs:
kwargs.pop(arg)
return quantizer.quantized_graph.create_node(
"call_function", quantized_op, args, kwargs)
else:
assert node.op == "call_function"
# call_function
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
assert not isinstance(node.target, str), "Expecting node.target for "
"call_function to be a function instead of a string"
quantized_op = get_quantized_operator(node.target)
args = load_arg(quantized=[0])(node.args)
kwargs = {**load_arg(quantized=False)(node.kwargs), "output_scale": scale_arg, "output_zero_point": zero_point_arg}
if quantized_op in ARGS_TO_SKIP:
args_to_skip = ARGS_TO_SKIP[quantized_op]
for arg in args_to_skip:
if arg in kwargs:
kwargs.pop(arg)
return quantizer.quantized_graph.create_node(
"call_function", quantized_op, args, kwargs)
assert dtypes in [(torch.float16, torch.float16, None)]
# Generally fp16 kernels don't exist for fp16 ops
warnings.warn(
"Only reference patterns are currently supported for {dtype} dtype with {op} op"
"".format(dtype=dtypes, op=self.op))
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
else:
assert dtypes == (torch.float16, torch.float16, None)
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
assert is_reference
if dtypes in [(torch.quint8, torch.qint8, None)]:
load_arg(quantized=[0])(node.args)
args = load_arg(quantized=False)(node.args)
kwargs = load_arg(quantized=False)(node.kwargs)
op_out = quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
cur_idx = quantizer.activation_post_process_indexes[node.name]
activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
quantizer.activation_post_process_indexes[node.name] += 1
return quantize_node(
quantizer, op_out, activation_post_process,
node, is_input=False)
else:
assert dtypes in [(torch.float16, torch.float16, None)]
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
# TODO: elu is using scale/zero_point instead of output_scale, output_zero_point
@register_quant_pattern(torch.nn.functional.elu)

View file

@ -55,6 +55,7 @@ from .graph_module import (
from .quantization_patterns import (
binary_op_supported_dtypes,
binary_reference_op_supported_dtypes,
BinaryOpQuantizeHandler,
CatQuantizeHandler,
CopyNodeQuantizeHandler,
@ -1429,11 +1430,29 @@ class Quantizer:
dtypes = get_qconfig_dtypes(this_node_qconfig)
# TODO(future PR): update the pattern to quantize
# handler logic to take this into account.
skip_this_match = (
# This needs to handle 3 cases
# 1) op and dtype is in either [is_ref or non-ref] list -> don't skip
# 2) op is not in either list (i.e. relu) -> don't skip
# 3) op is in non-ref list, but not for dtype, and op+dtype not in is_ref list -> skip
# note: the value of is_reference is unknown at prepare, so we have to cover both cases
# handle is_reference = False
skip_match_not_is_reference = (
(base_node.target in binary_op_supported_dtypes) and
(dtypes not in binary_op_supported_dtypes[base_node.target])
)
# handle is_reference = True
supported_is_reference = (
(base_node.target in binary_reference_op_supported_dtypes) and
(dtypes in binary_reference_op_supported_dtypes[base_node.target])
)
# only skip if not reference says skip and is_reference doesn't support
skip_this_match = skip_match_not_is_reference and not supported_is_reference
if not skip_this_match:
matched: List[Any] = []
record_match(pattern, node, matched)