mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[quant] added dq->op->q quantization patterns for GELU and softmax ops (#56004)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56004 added reference pattern support for GELU, softmax and bmm for int dtypes. For GELU and Softmax, this consisted of adding reference patterns to the default node handler for int dtypes. Note GELU and softmax patterns are not registered since they do not have a proper quantized kernel which means they would either add unnecessary dequant and quant ops to the network, or they would simply error. This can be circumvented with custom qconfig usage as in test_gelu_reference bmm was added within binary ops along with some significant changes to how that code is structured. Theoretically the reference pattern used for bmm could be applied to other dtypes. This was not enabled because of issues relating to Line 1323 in quantize.py. In essence, the prepare step does not know whether an op will use a reference pattern or not, so for ops that are supported with one dtype in reference and one dtype normally, this has the potential to cause issues. This is difficult to get aorund with the is_reference flag being available in the prepare step or discussed changes around separating Test Plan: python test/test_quantization.py TestQuantizeFxOps.test_gelu_reference python test/test_quantization.py TestQuantizeFxOps.ttest_gelu_normal python test/test_quantization.py TestQuantizeFxOps.test_softmax_reference python test/test_quantization.py TestQuantizeFxOps.test_softmax_normal python test/test_quantization.py TestQuantizeFxOps.test_silu_reference python test/test_quantization.py TestQuantizeFxOps.test_bmm_int_reference python test/test_quantization.py TestQuantizeFxOps python test/test_quantization.py TestFuseFx python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxModels Imported from OSS Reviewed By: raghuramank100 Differential Revision: D27818340 fbshipit-source-id: de65be0797035463cd2d1b0e4677d1a87f69143c
This commit is contained in:
parent
ea4af1511c
commit
6e1fc5cef8
3 changed files with 325 additions and 105 deletions
|
|
@ -14,6 +14,8 @@ from torch.quantization.quantize_fx import (
|
|||
prepare_qat_fx,
|
||||
)
|
||||
|
||||
from torch.quantization.fx.quantization_patterns import DefaultNodeQuantizeHandler
|
||||
|
||||
from torch.quantization.fx.pattern_utils import (
|
||||
is_match,
|
||||
MatchAllNode,
|
||||
|
|
@ -3125,28 +3127,151 @@ class TestQuantizeFxOps(QuantizationTestCase):
|
|||
quantized_module, torch.ops.quantized.instance_norm,
|
||||
skip_op_arg_for_functional=True)
|
||||
|
||||
def test_silu(self):
|
||||
def _test_default_node_quant_handler_ops(
|
||||
self, module, functional, qconfig, is_reference=True, node_list=None, additional_quant_pattern_dict=None
|
||||
):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, mod, func):
|
||||
super().__init__()
|
||||
self.module = mod()
|
||||
self.functional = func
|
||||
|
||||
def forward(self, x):
|
||||
x = self.module(x)
|
||||
x = self.functional(x)
|
||||
return x
|
||||
|
||||
if node_list is None:
|
||||
node_list = []
|
||||
if additional_quant_pattern_dict is None:
|
||||
additional_quant_pattern_dict = {}
|
||||
|
||||
data = torch.randn((2, 2, 2, 2))
|
||||
quant_type = QuantType.STATIC
|
||||
prepare_custom_qconfig_dict = {"additional_quant_pattern": additional_quant_pattern_dict}
|
||||
qconfig_dict = {"": qconfig}
|
||||
|
||||
m = M(module, functional).eval()
|
||||
m_prep = torch.quantization.quantize_fx.prepare_fx(m, qconfig_dict, prepare_custom_qconfig_dict)
|
||||
m_prep(data)
|
||||
m_quant = torch.quantization.quantize_fx.convert_fx(m_prep, is_reference=is_reference)
|
||||
m_quant(data)
|
||||
|
||||
self.checkGraphModuleNodes(m_quant, expected_node_list=node_list)
|
||||
|
||||
def test_gelu_normal(self):
|
||||
module = torch.nn.GELU
|
||||
functional = torch.nn.functional.gelu
|
||||
qconfig = torch.quantization.get_default_qconfig("fbgemm")
|
||||
is_reference = False
|
||||
node_list = [
|
||||
ns.call_module(module),
|
||||
ns.call_function(functional),
|
||||
]
|
||||
self._test_default_node_quant_handler_ops(
|
||||
module, functional, qconfig, is_reference, node_list)
|
||||
|
||||
def test_softmax_normal(self):
|
||||
module = torch.nn.Softmax
|
||||
functional = torch.nn.functional.softmax
|
||||
qconfig = torch.quantization.get_default_qconfig("fbgemm")
|
||||
is_reference = False
|
||||
node_list = [
|
||||
ns.call_module(module),
|
||||
ns.call_function(functional),
|
||||
]
|
||||
self._test_default_node_quant_handler_ops(
|
||||
module, functional, qconfig, is_reference, node_list)
|
||||
|
||||
def test_gelu_reference(self):
|
||||
module = torch.nn.GELU
|
||||
functional = torch.nn.functional.gelu
|
||||
qconfig = torch.quantization.get_default_qconfig("fbgemm")
|
||||
is_reference = True
|
||||
node_list = [
|
||||
ns.call_function(torch.quantize_per_tensor),
|
||||
ns.call_method("dequantize"),
|
||||
ns.call_module(module),
|
||||
ns.call_function(torch.quantize_per_tensor),
|
||||
ns.call_method('dequantize'),
|
||||
ns.call_function(functional),
|
||||
ns.call_function(torch.quantize_per_tensor),
|
||||
ns.call_method('dequantize')
|
||||
]
|
||||
additional_patterns = {torch.nn.GELU: DefaultNodeQuantizeHandler,
|
||||
torch.nn.functional.gelu: DefaultNodeQuantizeHandler}
|
||||
self._test_default_node_quant_handler_ops(
|
||||
module, functional, qconfig, is_reference, node_list, additional_patterns)
|
||||
|
||||
def test_softmax_reference(self):
|
||||
module = torch.nn.Softmax
|
||||
functional = torch.nn.functional.softmax
|
||||
qconfig = torch.quantization.get_default_qconfig("fbgemm")
|
||||
is_reference = True
|
||||
node_list = [
|
||||
ns.call_function(torch.quantize_per_tensor),
|
||||
ns.call_method("dequantize"),
|
||||
ns.call_module(module),
|
||||
ns.call_function(torch.quantize_per_tensor),
|
||||
ns.call_method('dequantize'),
|
||||
ns.call_function(functional),
|
||||
ns.call_function(torch.quantize_per_tensor),
|
||||
ns.call_method('dequantize')
|
||||
]
|
||||
additional_patterns = {torch.nn.Softmax: DefaultNodeQuantizeHandler,
|
||||
torch.nn.functional.softmax: DefaultNodeQuantizeHandler}
|
||||
self._test_default_node_quant_handler_ops(
|
||||
module, functional, qconfig, is_reference, node_list, additional_patterns)
|
||||
|
||||
def test_silu_reference(self):
|
||||
module = torch.nn.SiLU
|
||||
functional = torch.nn.functional.silu
|
||||
qconfig = float16_static_qconfig
|
||||
is_reference = True
|
||||
node_list = [
|
||||
ns.call_method("to"),
|
||||
ns.call_method("dequantize"),
|
||||
ns.call_module(module),
|
||||
ns.call_method("to"),
|
||||
ns.call_method('dequantize'),
|
||||
ns.call_function(functional),
|
||||
ns.call_method("to"),
|
||||
ns.call_method('dequantize')
|
||||
]
|
||||
self._test_default_node_quant_handler_ops(
|
||||
module, functional, qconfig, is_reference, node_list)
|
||||
|
||||
def test_bmm_int_reference(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.silu = torch.nn.SiLU()
|
||||
self.bmm = torch.bmm
|
||||
|
||||
def forward(self, x):
|
||||
x = self.silu(x)
|
||||
x = torch.nn.functional.silu(x)
|
||||
return x
|
||||
def forward(self, x, y):
|
||||
out = self.bmm(x, y)
|
||||
return out
|
||||
|
||||
data = (torch.randn((2, 2, 2, 2), dtype=torch.float),)
|
||||
quant_type = QuantType.STATIC
|
||||
qconfig_dict = {
|
||||
"": float16_static_qconfig
|
||||
}
|
||||
node_occurrence = {
|
||||
ns.call_method("to"): 3
|
||||
}
|
||||
m = self.checkGraphModeFxOp(
|
||||
M(), data, quant_type, custom_qconfig_dict=qconfig_dict,
|
||||
expected_node_occurrence=node_occurrence)
|
||||
data_x = torch.randn((2, 2, 2,))
|
||||
data_y = torch.randn((2, 2, 2,))
|
||||
qconfig_dict = {"": torch.quantization.get_default_qconfig("fbgemm")}
|
||||
is_reference = True
|
||||
node_list = [
|
||||
ns.call_function(torch.quantize_per_tensor),
|
||||
ns.call_function(torch.quantize_per_tensor),
|
||||
ns.call_method('dequantize'),
|
||||
ns.call_method('dequantize'),
|
||||
ns.call_function(torch.bmm),
|
||||
ns.call_function(torch.quantize_per_tensor),
|
||||
ns.call_method('dequantize'),
|
||||
]
|
||||
|
||||
m = M().eval()
|
||||
m_prep = torch.quantization.quantize_fx.prepare_fx(m, qconfig_dict)
|
||||
m_prep(data_x, data_y)
|
||||
m_quant = torch.quantization.quantize_fx.convert_fx(m_prep, is_reference=is_reference)
|
||||
m_quant(data_x, data_y)
|
||||
|
||||
self.checkGraphModuleNodes(m_quant, expected_node_list=node_list)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_clamp(self):
|
||||
|
|
|
|||
|
|
@ -91,6 +91,9 @@ binary_op_all_dtypes = [
|
|||
binary_op_float16_dtypes = [
|
||||
(torch.float16, torch.float16, None)
|
||||
]
|
||||
binary_op_int8_dtypes = [
|
||||
(torch.quint8, torch.qint8, None),
|
||||
]
|
||||
binary_op_supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype, torch.dtype, None]]] = {
|
||||
operator.add: binary_op_all_dtypes,
|
||||
torch.add: binary_op_all_dtypes,
|
||||
|
|
@ -103,6 +106,9 @@ binary_op_supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype,
|
|||
operator.truediv: binary_op_float16_dtypes,
|
||||
torch.sum: binary_op_float16_dtypes
|
||||
}
|
||||
binary_reference_op_supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype, torch.dtype, None]]] = {
|
||||
torch.bmm: binary_op_int8_dtypes,
|
||||
}
|
||||
|
||||
|
||||
@register_quant_pattern(operator.add)
|
||||
|
|
@ -170,60 +176,93 @@ class BinaryOpQuantizeHandler(QuantizeHandler):
|
|||
|
||||
qconfig = quantizer.qconfig_map[node.name]
|
||||
dtypes = get_qconfig_dtypes(qconfig)
|
||||
# leave the op unquantized if the dtype combination is not supported
|
||||
if dtypes not in binary_op_supported_dtypes[self.binary_op]:
|
||||
warnings.warn(
|
||||
"dtype combination: {} is not "
|
||||
"supported by {} "
|
||||
"supported dtype combinations are: {}".format(dtypes, self.binary_op, binary_op_supported_dtypes[self.binary_op]))
|
||||
if self.relu_node:
|
||||
op_out = quantizer.quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=False))
|
||||
relu_args = [op_out]
|
||||
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
||||
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
||||
return quantizer.quantized_graph.create_node(
|
||||
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
|
||||
else:
|
||||
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
||||
|
||||
if dtypes in [(torch.quint8, torch.qint8, None)]:
|
||||
assert self.quantized_binary_op is not None
|
||||
if self.num_tensor_args == 1:
|
||||
# add/mul scalar
|
||||
first_arg = self.binary_op_node.args[0]
|
||||
cache_for_no_tensor_check: Dict[Node, bool] = dict()
|
||||
if isinstance(first_arg, Node) and (
|
||||
not all_node_args_have_no_tensors(
|
||||
first_arg, quantizer.modules, cache_for_no_tensor_check)):
|
||||
quantized_index = 0
|
||||
else:
|
||||
quantized_index = 1
|
||||
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_function', self.quantized_binary_op,
|
||||
load_arg(quantized=[quantized_index])(self.binary_op_node.args), self.binary_op_node.kwargs)
|
||||
else:
|
||||
if is_reference and self.binary_op in binary_reference_op_supported_dtypes and \
|
||||
dtypes in binary_reference_op_supported_dtypes[self.binary_op]:
|
||||
if dtypes in binary_op_int8_dtypes:
|
||||
args = load_arg(quantized=[0, 1])(node.args)
|
||||
args = load_arg(quantized=False)(node.args)
|
||||
kwargs = load_arg(quantized=False)(node.kwargs)
|
||||
op_out = quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
||||
cur_idx = quantizer.activation_post_process_indexes[node.name]
|
||||
activation_post_process = \
|
||||
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
|
||||
quantizer.activation_post_process_indexes[node.name] += 1
|
||||
scale, zero_point = activation_post_process.calculate_qparams()
|
||||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
|
||||
return quantize_node(
|
||||
quantizer, op_out, activation_post_process,
|
||||
node, is_input=False)
|
||||
else:
|
||||
warnings.warn(
|
||||
"No implementation found for dtype combination: {}"
|
||||
"for op {} with is_reference={} despite it being listed as supported"
|
||||
"this should not happen".format(dtypes, self.binary_op, is_reference))
|
||||
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
||||
elif not is_reference and self.binary_op in binary_op_supported_dtypes and \
|
||||
dtypes in binary_op_supported_dtypes[self.binary_op]:
|
||||
if dtypes in [(torch.quint8, torch.qint8, None)]:
|
||||
assert self.quantized_binary_op is not None
|
||||
if self.num_tensor_args == 1:
|
||||
# add/mul scalar
|
||||
first_arg = self.binary_op_node.args[0]
|
||||
cache_for_no_tensor_check: Dict[Node, bool] = dict()
|
||||
if isinstance(first_arg, Node) and (
|
||||
not all_node_args_have_no_tensors(
|
||||
first_arg, quantizer.modules, cache_for_no_tensor_check)):
|
||||
quantized_index = 0
|
||||
else:
|
||||
quantized_index = 1
|
||||
|
||||
if self.relu_node is not None:
|
||||
op = torch.ops.quantized.add_relu
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_function', self.quantized_binary_op,
|
||||
load_arg(quantized=[quantized_index])(self.binary_op_node.args), self.binary_op_node.kwargs)
|
||||
else:
|
||||
op = torch.ops.quantized.add
|
||||
kwargs = {**self.binary_op_node.kwargs}
|
||||
add_args = (*load_arg(quantized=True)(self.binary_op_node.args), scale_arg, zero_point_arg)
|
||||
op = quantizer.quantized_graph.create_node(
|
||||
'call_function', self.quantized_binary_op, add_args, kwargs)
|
||||
return op
|
||||
cur_idx = quantizer.activation_post_process_indexes[node.name]
|
||||
activation_post_process = \
|
||||
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
|
||||
quantizer.activation_post_process_indexes[node.name] += 1
|
||||
scale, zero_point = activation_post_process.calculate_qparams()
|
||||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
|
||||
|
||||
if self.relu_node is not None:
|
||||
op = torch.ops.quantized.add_relu
|
||||
else:
|
||||
op = torch.ops.quantized.add
|
||||
kwargs = {**self.binary_op_node.kwargs}
|
||||
add_args = (*load_arg(quantized=True)(self.binary_op_node.args), scale_arg, zero_point_arg)
|
||||
op = quantizer.quantized_graph.create_node(
|
||||
'call_function', self.quantized_binary_op, add_args, kwargs)
|
||||
return op
|
||||
else:
|
||||
assert dtypes == (torch.float16, torch.float16, None)
|
||||
# TODO (refactor) this is duplicated, maybe have a helper function
|
||||
if self.relu_node:
|
||||
op_out = quantizer.quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=False))
|
||||
relu_args = [op_out]
|
||||
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
||||
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
||||
return quantizer.quantized_graph.create_node(
|
||||
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
|
||||
else:
|
||||
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
||||
else:
|
||||
assert dtypes == (torch.float16, torch.float16, None)
|
||||
# TODO (refactor) this is duplicated, maybe have a helper function
|
||||
# leave the op unquantized if the dtype,reference combination is not supported
|
||||
warnings.warn(
|
||||
"dtype combination: {} is not "
|
||||
"supported by {} for is_reference={}. "
|
||||
"Supported non-reference dtype combinations are: {} "
|
||||
"Supported reference dtype combinations are: {}"
|
||||
"".format(dtypes,
|
||||
self.binary_op,
|
||||
is_reference,
|
||||
binary_op_supported_dtypes[self.binary_op],
|
||||
(
|
||||
[] if self.binary_op not in binary_reference_op_supported_dtypes.keys()
|
||||
else binary_reference_op_supported_dtypes[self.binary_op]
|
||||
)
|
||||
)
|
||||
)
|
||||
if self.relu_node:
|
||||
op_out = quantizer.quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=False))
|
||||
relu_args = [op_out]
|
||||
|
|
@ -234,6 +273,7 @@ class BinaryOpQuantizeHandler(QuantizeHandler):
|
|||
else:
|
||||
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
||||
|
||||
|
||||
@register_quant_pattern(torch.cat)
|
||||
class CatQuantizeHandler(QuantizeHandler):
|
||||
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
||||
|
|
@ -775,14 +815,23 @@ ARGS_TO_SKIP = {
|
|||
@register_quant_pattern(torch.nn.InstanceNorm3d)
|
||||
@register_quant_pattern(torch.nn.LayerNorm)
|
||||
@register_quant_pattern(torch.nn.SiLU)
|
||||
# we currently only support reference patterns for these ops so they have been removed
|
||||
# until they receive a proper fp16 kernel. To use the reference pattern, use a custom qconfig
|
||||
# @register_quant_pattern(torch.nn.GELU)
|
||||
# @register_quant_pattern(torch.nn.Softmax)
|
||||
@register_quant_pattern(torch.nn.functional.hardswish)
|
||||
@register_quant_pattern(torch.nn.functional.instance_norm)
|
||||
@register_quant_pattern(torch.nn.functional.layer_norm)
|
||||
@register_quant_pattern(torch.nn.functional.leaky_relu)
|
||||
@register_quant_pattern(torch.nn.functional.silu)
|
||||
# we currently only support reference patterns for these ops so they have been removed
|
||||
# until they receive a proper fp16 kernel. To use the reference pattern, use a custom qconfig
|
||||
# @register_quant_pattern(torch.nn.functional.gelu)
|
||||
# @register_quant_pattern(torch.nn.functional.softmax)
|
||||
class DefaultNodeQuantizeHandler(QuantizeHandler):
|
||||
''' Common quantized op, first input and first output will be quantized
|
||||
'''
|
||||
|
||||
def __init__(self, quantizer: QuantizerCls, node: Node):
|
||||
super().__init__(quantizer, node)
|
||||
if node.op == "call_function" or node.op == "call_method":
|
||||
|
|
@ -822,11 +871,15 @@ class DefaultNodeQuantizeHandler(QuantizeHandler):
|
|||
torch.nn.InstanceNorm3d: int8_dtypes,
|
||||
torch.nn.LayerNorm: all_dtypes,
|
||||
torch.nn.SiLU: fp16_dtypes,
|
||||
torch.nn.GELU: int8_dtypes,
|
||||
torch.nn.Softmax: int8_dtypes,
|
||||
torch.nn.functional.hardswish: int8_dtypes,
|
||||
torch.nn.functional.instance_norm: int8_dtypes,
|
||||
torch.nn.functional.layer_norm: all_dtypes,
|
||||
torch.nn.functional.leaky_relu: int8_dtypes,
|
||||
torch.nn.functional.silu: fp16_dtypes,
|
||||
torch.nn.functional.gelu: int8_dtypes,
|
||||
torch.nn.functional.softmax: int8_dtypes,
|
||||
}
|
||||
qconfig = quantizer.qconfig_map[node.name]
|
||||
dtypes = get_qconfig_dtypes(qconfig)
|
||||
|
|
@ -836,50 +889,73 @@ class DefaultNodeQuantizeHandler(QuantizeHandler):
|
|||
"supported by {} "
|
||||
"supported dtype combinations are: {}".format(dtypes, self.op, supported_dtypes[self.op]))
|
||||
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
||||
|
||||
# TODO: make helper functions for (torch.quint8, torch.qint8, None)
|
||||
if dtypes in [(torch.quint8, torch.qint8, None)]:
|
||||
cur_idx = quantizer.activation_post_process_indexes[node.name]
|
||||
activation_post_process = \
|
||||
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
|
||||
quantizer.activation_post_process_indexes[node.name] += 1
|
||||
if node.op == 'call_module':
|
||||
module = quantizer.modules[node.target]
|
||||
module.activation_post_process = activation_post_process
|
||||
quantized_module_cls = get_static_quant_module_class(
|
||||
type(module), additional_static_quant_mapping)
|
||||
quantized_module = quantized_module_cls.from_float(module)
|
||||
parent_name, name = _parent_name(node.target)
|
||||
setattr(quantizer.modules[parent_name], name, quantized_module)
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_module',
|
||||
node.target,
|
||||
load_arg(quantized=[0])(node.args),
|
||||
load_arg(quantized=False)(node.kwargs))
|
||||
if not is_reference:
|
||||
if dtypes in [(torch.quint8, torch.qint8, None)]:
|
||||
cur_idx = quantizer.activation_post_process_indexes[node.name]
|
||||
activation_post_process = \
|
||||
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
|
||||
quantizer.activation_post_process_indexes[node.name] += 1
|
||||
if node.op == 'call_module':
|
||||
module = quantizer.modules[node.target]
|
||||
module.activation_post_process = activation_post_process
|
||||
quantized_module_cls = get_static_quant_module_class(
|
||||
type(module), additional_static_quant_mapping)
|
||||
quantized_module = quantized_module_cls.from_float(module)
|
||||
parent_name, name = _parent_name(node.target)
|
||||
setattr(quantizer.modules[parent_name], name, quantized_module)
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_module',
|
||||
node.target,
|
||||
load_arg(quantized=[0])(node.args),
|
||||
load_arg(quantized=False)(node.kwargs))
|
||||
else:
|
||||
assert node.op == "call_function"
|
||||
# call_function
|
||||
scale, zero_point = activation_post_process.calculate_qparams()
|
||||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
|
||||
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
|
||||
|
||||
assert not isinstance(node.target, str), "Expecting node.target for "
|
||||
"call_function to be a function instead of a string"
|
||||
quantized_op = get_quantized_operator(node.target)
|
||||
args = load_arg(quantized=[0])(node.args)
|
||||
kwargs = {**load_arg(quantized=False)(node.kwargs), "output_scale": scale_arg,
|
||||
"output_zero_point": zero_point_arg}
|
||||
if quantized_op in ARGS_TO_SKIP:
|
||||
args_to_skip = ARGS_TO_SKIP[quantized_op]
|
||||
for arg in args_to_skip:
|
||||
if arg in kwargs:
|
||||
kwargs.pop(arg)
|
||||
return quantizer.quantized_graph.create_node(
|
||||
"call_function", quantized_op, args, kwargs)
|
||||
else:
|
||||
assert node.op == "call_function"
|
||||
# call_function
|
||||
scale, zero_point = activation_post_process.calculate_qparams()
|
||||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
|
||||
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
|
||||
|
||||
assert not isinstance(node.target, str), "Expecting node.target for "
|
||||
"call_function to be a function instead of a string"
|
||||
quantized_op = get_quantized_operator(node.target)
|
||||
args = load_arg(quantized=[0])(node.args)
|
||||
kwargs = {**load_arg(quantized=False)(node.kwargs), "output_scale": scale_arg, "output_zero_point": zero_point_arg}
|
||||
if quantized_op in ARGS_TO_SKIP:
|
||||
args_to_skip = ARGS_TO_SKIP[quantized_op]
|
||||
for arg in args_to_skip:
|
||||
if arg in kwargs:
|
||||
kwargs.pop(arg)
|
||||
return quantizer.quantized_graph.create_node(
|
||||
"call_function", quantized_op, args, kwargs)
|
||||
assert dtypes in [(torch.float16, torch.float16, None)]
|
||||
# Generally fp16 kernels don't exist for fp16 ops
|
||||
warnings.warn(
|
||||
"Only reference patterns are currently supported for {dtype} dtype with {op} op"
|
||||
"".format(dtype=dtypes, op=self.op))
|
||||
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
||||
else:
|
||||
assert dtypes == (torch.float16, torch.float16, None)
|
||||
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
||||
assert is_reference
|
||||
if dtypes in [(torch.quint8, torch.qint8, None)]:
|
||||
load_arg(quantized=[0])(node.args)
|
||||
args = load_arg(quantized=False)(node.args)
|
||||
kwargs = load_arg(quantized=False)(node.kwargs)
|
||||
op_out = quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
||||
cur_idx = quantizer.activation_post_process_indexes[node.name]
|
||||
activation_post_process = \
|
||||
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
|
||||
quantizer.activation_post_process_indexes[node.name] += 1
|
||||
return quantize_node(
|
||||
quantizer, op_out, activation_post_process,
|
||||
node, is_input=False)
|
||||
else:
|
||||
assert dtypes in [(torch.float16, torch.float16, None)]
|
||||
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
||||
|
||||
|
||||
# TODO: elu is using scale/zero_point instead of output_scale, output_zero_point
|
||||
@register_quant_pattern(torch.nn.functional.elu)
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ from .graph_module import (
|
|||
|
||||
from .quantization_patterns import (
|
||||
binary_op_supported_dtypes,
|
||||
binary_reference_op_supported_dtypes,
|
||||
BinaryOpQuantizeHandler,
|
||||
CatQuantizeHandler,
|
||||
CopyNodeQuantizeHandler,
|
||||
|
|
@ -1429,11 +1430,29 @@ class Quantizer:
|
|||
dtypes = get_qconfig_dtypes(this_node_qconfig)
|
||||
# TODO(future PR): update the pattern to quantize
|
||||
# handler logic to take this into account.
|
||||
skip_this_match = (
|
||||
|
||||
|
||||
# This needs to handle 3 cases
|
||||
# 1) op and dtype is in either [is_ref or non-ref] list -> don't skip
|
||||
# 2) op is not in either list (i.e. relu) -> don't skip
|
||||
# 3) op is in non-ref list, but not for dtype, and op+dtype not in is_ref list -> skip
|
||||
|
||||
# note: the value of is_reference is unknown at prepare, so we have to cover both cases
|
||||
# handle is_reference = False
|
||||
skip_match_not_is_reference = (
|
||||
(base_node.target in binary_op_supported_dtypes) and
|
||||
(dtypes not in binary_op_supported_dtypes[base_node.target])
|
||||
)
|
||||
|
||||
# handle is_reference = True
|
||||
supported_is_reference = (
|
||||
(base_node.target in binary_reference_op_supported_dtypes) and
|
||||
(dtypes in binary_reference_op_supported_dtypes[base_node.target])
|
||||
)
|
||||
|
||||
# only skip if not reference says skip and is_reference doesn't support
|
||||
skip_this_match = skip_match_not_is_reference and not supported_is_reference
|
||||
|
||||
if not skip_this_match:
|
||||
matched: List[Any] = []
|
||||
record_match(pattern, node, matched)
|
||||
|
|
|
|||
Loading…
Reference in a new issue