[ao] Added function to inform dynamic vs static appropriate

Summary: The _detect_dynamic_vs_static function was added to take in a
prepared fx graph model that already had ModelReportObservers built into
it and uses the collected information to determine whether input and
output are stationary or non-stationary and provides feedback on whether
to make linear modules static or dynamic based on this information.

This PR will be followed up soon with another PR that will more
rigoursly test the whole end to end performance of this system, which is
primarily how the function in this PR will be tested for functionality,
which is why this one only has 1 test.

Test Plan: python test/quantization/fx/test_model_report_fx.py TestModelReportDetectDynamicStatic

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79326

Approved by: https://github.com/HDCharles
This commit is contained in:
vspenubarthi 2022-06-14 14:39:02 -07:00 committed by PyTorch MergeBot
parent e10b762537
commit 38952d9350
3 changed files with 300 additions and 90 deletions

View file

@ -2,27 +2,23 @@
# Owner(s): ["oncall: quantization"]
import torch
import torch.ao.quantization.quantize_fx
from torch.ao.quantization import QConfig, QConfigMapping
from torch.ao.quantization.fx._model_report._detector import _detect_per_channel
import torch.ao.quantization.quantize_fx as quantize_fx
import torch.nn.functional as F
from torch.ao.quantization.fx._model_report.model_report_observer import (
ModelReportObserver,
)
from torch.ao.quantization.observer import (
default_per_channel_weight_observer,
HistogramObserver,
)
from torch.ao.quantization import QConfig, QConfigMapping
from torch.ao.quantization.fx._model_report._detector import _detect_dynamic_vs_static, _detect_per_channel
from torch.ao.quantization.fx._model_report.model_report_observer import ModelReportObserver
from torch.ao.quantization.observer import HistogramObserver, default_per_channel_weight_observer
from torch.nn.intrinsic.modules.fused import ConvReLU2d, LinearReLU
from torch.testing._internal.common_quantization import (
ConvModel,
QuantizationTestCase,
SingleLayerLinearModel,
TwoLayerLinearModel,
skipIfNoFBGEMM,
skipIfNoQNNPACK,
TwoLayerLinearModel,
)
"""
Partition of input domain:
@ -41,7 +37,9 @@ There are possible changes / suggestions, there are no changes / suggestions
"""
# Default output for string if no optimizations are possible
DEFAULT_NO_OPTIMS_ANSWER_STRING = "Further Optimizations for backend {}: \nNo further per_channel optimizations possible."
DEFAULT_NO_OPTIMS_ANSWER_STRING = (
"Further Optimizations for backend {}: \nNo further per_channel optimizations possible."
)
# Example Sequential Model with multiple Conv and Linear with nesting involved
NESTED_CONV_LINEAR_EXAMPLE = torch.nn.Sequential(
@ -71,14 +69,12 @@ FUSION_CONV_LINEAR_EXAMPLE = torch.nn.Sequential(
)
class TestModelReportFxDetector(QuantizationTestCase):
class TestFxModelReportDetector(QuantizationTestCase):
"""Prepares and callibrate the model"""
def _prepare_model_and_run_input(self, model, q_config_mapping, input):
model_prep = torch.ao.quantization.quantize_fx.prepare_fx(
model, q_config_mapping, input
) # prep model
model_prep = torch.ao.quantization.quantize_fx.prepare_fx(model, q_config_mapping, input) # prep model
model_prep(input).sum() # callibrate the model
return model_prep
@ -95,14 +91,10 @@ class TestModelReportFxDetector(QuantizationTestCase):
torch.backends.quantized.engine = "onednn"
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(
torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)
)
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
input = torch.randn(1, 3, 10, 10)
prepared_model = self._prepare_model_and_run_input(
ConvModel(), q_config_mapping, input
)
prepared_model = self._prepare_model_and_run_input(ConvModel(), q_config_mapping, input)
# run the detector
optims_str, per_channel_info = _detect_per_channel(prepared_model)
@ -119,9 +111,7 @@ class TestModelReportFxDetector(QuantizationTestCase):
per_channel_info["per_channel_status"]["conv"]["per_channel_supported"],
True,
)
self.assertEqual(
per_channel_info["per_channel_status"]["conv"]["per_channel_used"], True
)
self.assertEqual(per_channel_info["per_channel_status"]["conv"]["per_channel_used"], True)
"""Case includes:
Multiple conv or linear
@ -137,9 +127,7 @@ class TestModelReportFxDetector(QuantizationTestCase):
torch.backends.quantized.engine = "qnnpack"
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(
torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)
)
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
prepared_model = self._prepare_model_and_run_input(
TwoLayerLinearModel(),
@ -238,9 +226,7 @@ class TestModelReportFxDetector(QuantizationTestCase):
elif "conv" in key:
self.assertEqual(module_entry["per_channel_used"], True)
else:
raise ValueError(
"Should only contain conv and linear layers as key values"
)
raise ValueError("Should only contain conv and linear layers as key values")
"""Case includes:
Multiple conv or linear
@ -256,9 +242,7 @@ class TestModelReportFxDetector(QuantizationTestCase):
torch.backends.quantized.engine = "qnnpack"
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(
torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)
)
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
prepared_model = self._prepare_model_and_run_input(
NESTED_CONV_LINEAR_EXAMPLE,
@ -299,9 +283,7 @@ class TestModelReportFxDetector(QuantizationTestCase):
torch.backends.quantized.engine = "qnnpack"
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(
torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)
)
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
prepared_model = self._prepare_model_and_run_input(
LAZY_CONV_LINEAR_EXAMPLE,
@ -342,9 +324,7 @@ class TestModelReportFxDetector(QuantizationTestCase):
torch.backends.quantized.engine = "fbgemm"
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(
torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)
)
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
prepared_model = self._prepare_model_and_run_input(
FUSION_CONV_LINEAR_EXAMPLE,
@ -409,9 +389,7 @@ class TestModelReportFxDetector(QuantizationTestCase):
# model must be in eval mode for fusion
model_fp32.eval()
model_fp32_fused = torch.quantization.fuse_modules(
model_fp32, [["conv", "bn", "relu"]]
)
model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [["conv", "bn", "relu"]])
# model must be set to train mode for QAT logic to work
model_fp32_fused.train()
@ -454,7 +432,7 @@ Partition on Output
"""
class TestModelReportObserver(QuantizationTestCase):
class TestFxModelReportObserver(QuantizationTestCase):
class NestedModifiedSingleLayerLinear(torch.nn.Module):
def __init__(self):
super().__init__()
@ -488,12 +466,8 @@ class TestModelReportObserver(QuantizationTestCase):
getattr(model, "obs1").average_batch_activation_range,
torch.tensor(float(0)),
)
self.assertEqual(
getattr(model, "obs1").epoch_activation_min, torch.tensor(float("inf"))
)
self.assertEqual(
getattr(model, "obs1").epoch_activation_max, torch.tensor(float("-inf"))
)
self.assertEqual(getattr(model, "obs1").epoch_activation_min, torch.tensor(float("inf")))
self.assertEqual(getattr(model, "obs1").epoch_activation_max, torch.tensor(float("-inf")))
# loop through the batches and run through
for index, batch in enumerate(split_up_data):
@ -503,9 +477,7 @@ class TestModelReportObserver(QuantizationTestCase):
# get general info about the batch and the model to use later
batch_min, batch_max = torch.aminmax(batch)
current_average_range = getattr(
model, "obs1"
).average_batch_activation_range
current_average_range = getattr(model, "obs1").average_batch_activation_range
current_epoch_min = getattr(model, "obs1").epoch_activation_min
current_epoch_max = getattr(model, "obs1").epoch_activation_max
@ -513,9 +485,9 @@ class TestModelReportObserver(QuantizationTestCase):
model(ex_input)
# check that average batch activation range updated correctly
correct_updated_value = (
current_average_range * num_tracked_so_far + (batch_max - batch_min)
) / (num_tracked_so_far + 1)
correct_updated_value = (current_average_range * num_tracked_so_far + (batch_max - batch_min)) / (
num_tracked_so_far + 1
)
self.assertEqual(
getattr(model, "obs1").average_batch_activation_range,
correct_updated_value,
@ -659,9 +631,9 @@ class TestModelReportObserver(QuantizationTestCase):
x = self.relu(x)
return x
class ThreeOps(torch.nn.Module):
class ModifiedThreeOps(torch.nn.Module):
def __init__(self, batch_norm_dim):
super(ThreeOps, self).__init__()
super(ModifiedThreeOps, self).__init__()
self.obs1 = ModelReportObserver()
self.linear = torch.nn.Linear(7, 3, 2)
self.obs2 = ModelReportObserver()
@ -688,9 +660,9 @@ class TestModelReportObserver(QuantizationTestCase):
super(HighDimensionNet, self).__init__()
self.obs1 = ModelReportObserver()
self.fc1 = torch.nn.Linear(3, 7)
self.block1 = ThreeOps(3)
self.block1 = ModifiedThreeOps(3)
self.fc2 = torch.nn.Linear(3, 7)
self.block2 = ThreeOps(3)
self.block2 = ModifiedThreeOps(3)
self.fc3 = torch.nn.Linear(3, 7)
def forward(self, x):
@ -706,7 +678,12 @@ class TestModelReportObserver(QuantizationTestCase):
# the purpose of this test is to give the observers a variety of data examples
# initialize the model
models = [self.NestedModifiedSingleLayerLinear(), LargerIncludeNestModel(), ThreeOps(2), HighDimensionNet()]
models = [
self.NestedModifiedSingleLayerLinear(),
LargerIncludeNestModel(),
ModifiedThreeOps(2),
HighDimensionNet(),
]
# get some number of epochs and batches
num_epochs = 10
@ -723,3 +700,100 @@ class TestModelReportObserver(QuantizationTestCase):
# run it through the model and do general tests
for index, model in enumerate(models):
self.run_model_and_common_checks(model, inputs[index], num_epochs, num_batches)
"""
Partition on domain / things to test
There is only a single test case for now.
This will be more thoroughly tested with the implementation of the full end to end tool coming soon.
"""
class TestFxModelReportDetectDynamicStatic(QuantizationTestCase):
@skipIfNoFBGEMM
def test_nested_detection_case(self):
class SingleLinear(torch.nn.Module):
def __init__(self):
super(SingleLinear, self).__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
x = self.linear(x)
return x
class TwoBlockNet(torch.nn.Module):
def __init__(self):
super(TwoBlockNet, self).__init__()
self.block1 = SingleLinear()
self.block2 = SingleLinear()
def forward(self, x):
x = self.block1(x)
y = self.block2(x)
z = x + y
z = F.relu(z)
return z
# create model, example input, and qconfig mapping
torch.backends.quantized.engine = "fbgemm"
model = TwoBlockNet()
example_input = torch.randint(-10, 0, (1, 3, 3, 3))
example_input = example_input.to(torch.float)
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig("fbgemm"))
# prep model and select observer
model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input)
obs_ctr = ModelReportObserver
# find layer to attach to and store
linear_fqn = "block2.linear" # fqn of target linear
target_linear = None
for node in model_prep.graph.nodes:
if node.target == linear_fqn:
target_linear = node
break
# insert into both module and graph pre and post
# set up to insert before target_linear (pre_observer)
with model_prep.graph.inserting_before(target_linear):
obs_to_insert = obs_ctr()
pre_obs_fqn = linear_fqn + ".model_report_pre_observer"
model_prep.add_submodule(pre_obs_fqn, obs_to_insert)
model_prep.graph.create_node(op="call_module", target=pre_obs_fqn, args=target_linear.args)
# set up and insert after the target_linear (post_observer)
with model_prep.graph.inserting_after(target_linear):
obs_to_insert = obs_ctr()
post_obs_fqn = linear_fqn + ".model_report_post_observer"
model_prep.add_submodule(post_obs_fqn, obs_to_insert)
model_prep.graph.create_node(op="call_module", target=post_obs_fqn, args=(target_linear,))
# need to recompile module after submodule added and pass input through
model_prep.recompile()
num_iterations = 10
for i in range(num_iterations):
if i % 2 == 0:
example_input = torch.randint(-10, 0, (1, 3, 3, 3)).to(torch.float)
else:
example_input = torch.randint(0, 10, (1, 3, 3, 3)).to(torch.float)
model_prep(example_input)
# run it through the dynamic vs static detector
dynam_vs_stat_str, dynam_vs_stat_dict = _detect_dynamic_vs_static(model_prep, tolerance=0.5)
# one of the stats should be stationary, and the other non-stationary
# as a result, dynamic should be recommended
data_dist_info = [
dynam_vs_stat_dict[linear_fqn]["pre_observer_data_dist"],
dynam_vs_stat_dict[linear_fqn]["post_observer_data_dist"],
]
self.assertTrue("stationary" in data_dist_info)
self.assertTrue("non-stationary" in data_dist_info)
self.assertTrue(dynam_vs_stat_dict[linear_fqn]["dynamic_recommended"])

View file

@ -84,8 +84,9 @@ except ImportError:
# Test the model report module
try:
from quantization.fx.test_model_report_fx import TestModelReportFxDetector # noqa: F401
from quantization.fx.test_model_report_fx import TestModelReportObserver # noqa: F401
from quantization.fx.test_model_report_fx import TestFxModelReportDetector # noqa: F401
from quantization.fx.test_model_report_fx import TestFxModelReportObserver # noqa: F401
from quantization.fx.test_model_report_fx import TestFxModelReportDetectDynamicStatic # noqa: F401
except ImportError:
pass

View file

@ -3,11 +3,13 @@ from typing import Any, Dict, Set, Tuple
import torch
import torch.nn as nn
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.fx.graph_module import GraphModule
from torch.ao.quantization.observer import ObserverBase
from torch.ao.quantization.qconfig import QConfig
from torch.nn.qat.modules.conv import _ConvNd as QatConvNd
from torch.nn.qat.modules.linear import Linear as QatLinear
# Default map for representing supported per channel quantization modules for different backends
DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: Dict[str, Set[Any]] = {
"fbgemm": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, QatLinear, QatConvNd]),
@ -35,15 +37,9 @@ def _detect_per_channel(model: nn.Module) -> Tuple[str, Dict[str, Any]]:
backend_chosen = torch.backends.quantized.engine
supported_modules = set([])
if backend_chosen in DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES:
supported_modules = DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[
backend_chosen
]
supported_modules = DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[backend_chosen]
else:
raise ValueError(
"Not configured to work with {}. Try a different default backend".format(
backend_chosen
)
)
raise ValueError("Not configured to work with {}. Try a different default backend".format(backend_chosen))
# store information on submodules and if per_channel quantization is supported and used as well as qconfig information
per_channel_info = {"backend": backend_chosen, "per_channel_status": {}}
@ -66,11 +62,7 @@ def _detect_per_channel(model: nn.Module) -> Tuple[str, Dict[str, Any]]:
# asserts for MyPy
assert isinstance(fqn, str) and isinstance(per_channel_info["per_channel_status"], dict)
is_in_include_list = (
True
if sum(list(map(lambda x: isinstance(module, x), supported_modules))) > 0
else False
)
is_in_include_list = sum(list(map(lambda x: isinstance(module, x), supported_modules))) > 0
# check if the module per_channel is supported
# based on backend
@ -85,22 +77,15 @@ def _detect_per_channel(model: nn.Module) -> Tuple[str, Dict[str, Any]]:
# this object should either be fake quant or observer
q_or_s_obj = module.qconfig.weight.p.func()
assert isinstance(q_or_s_obj, FakeQuantize) or isinstance(
q_or_s_obj, ObserverBase
)
assert isinstance(q_or_s_obj, FakeQuantize) or isinstance(q_or_s_obj, ObserverBase)
per_channel_used = False # will be true if found in qconfig
if hasattr(
q_or_s_obj, "ch_axis"
): # then we know that per_channel quantization used
if hasattr(q_or_s_obj, "ch_axis"): # then we know that per_channel quantization used
# all fake quants have channel axis so need to check is_per_channel
if isinstance(q_or_s_obj, FakeQuantize):
if (
hasattr(q_or_s_obj, "is_per_channel")
and q_or_s_obj.is_per_channel
):
if hasattr(q_or_s_obj, "is_per_channel") and q_or_s_obj.is_per_channel:
per_channel_used = True
elif isinstance(q_or_s_obj, ObserverBase):
# should be an observer otherwise
@ -117,9 +102,7 @@ def _detect_per_channel(model: nn.Module) -> Tuple[str, Dict[str, Any]]:
_detect_per_channel_helper(model)
# String to let the user know of further optimizations
further_optims_str = "Further Optimizations for backend {}: \n".format(
backend_chosen
)
further_optims_str = "Further Optimizations for backend {}: \n".format(backend_chosen)
# assert for MyPy check
assert isinstance(per_channel_info["per_channel_status"], dict)
@ -134,9 +117,161 @@ def _detect_per_channel(model: nn.Module) -> Tuple[str, Dict[str, Any]]:
)
if optimizations_possible:
further_optims_str += "To use per_channel quantization, make sure the qconfig has a per_channel weight observer."
further_optims_str += (
"To use per_channel quantization, make sure the qconfig has a per_channel weight observer."
)
else:
further_optims_str += "No further per_channel optimizations possible."
# return the string and the dictionary form of same information
return (further_optims_str, per_channel_info)
# names for the pre and post observers that are inserted
DEFAULT_PRE_OBSERVER_NAME = "model_report_pre_observer"
DEFAULT_POST_OBSERVER_NAME = "model_report_post_observer"
# naming conventions for stationary vs non-stationary data
DEFAULT_STATIONARY = "stationary"
DEFAULT_NON_STATIONARY = "non-stationary"
# modules that are supported both dynamic and static for this report function
DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = set([nn.Linear])
def _detect_dynamic_vs_static(model: GraphModule, tolerance=0.5) -> Tuple[str, Dict[str, Any]]:
"""
Determines whether dynamic or static quantization is more appropriate for a given module.
Takes advantage of the ModelReportObserver that records range information.
Stationary distribution of data are strictly above tolerance level for the comparison statistic:
S = average_batch_activation_range/epoch_activation_range
Nonstationary distributions are below or at the tolerance level for this metric.
If the distribution of data right after the module is non-stationary, recommend dynamic quantization
Otherwise recommend static quantization
This will then generate suggestions for dynamic vs static quantization focused around Linear.
Args:
model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers around layers of interest
tolerance (float, optional): The threshold where S metric is stationary above and non-stationary otherwise. Default: 0.5
Returns a tuple with two elements:
String report of of whether dynamic or static quantization is recommended for certain modules
Dictionary mapping modules with ModelReportObservers around them to:
whether dynamic quantization is recommended
their S metric of input to module
whether input to module is stationary or non-stationary
their S metric of output of module
whether output of module is stationary or non-stationary
the tolerance level to decided whether input/output is stationary or non-stationary
"""
# store modules dynamic vs static information
module_dynamic_static_info = {}
# This for loop goes through the modules, and extracts all relavent information into module_dynamic_static_info
# This information primary includes whether the data distributions around a supported module is stationary or not
# Based on this, it is recorded whether dynamic or static quantization is recommended
# loop through all submodules included nested ones
for fqn, module in model.named_modules():
# check to see if module is of a supported type
is_supported_type = sum(list(map(lambda x: isinstance(module, x), DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED))) > 0
# if module is Linear has the ModelReportObserver attached to it
if is_supported_type and hasattr(module, DEFAULT_PRE_OBSERVER_NAME) and hasattr(module, DEFAULT_POST_OBSERVER_NAME):
# get pre and post observers for the module
pre_obs = getattr(module, DEFAULT_PRE_OBSERVER_NAME)
post_obs = getattr(module, DEFAULT_POST_OBSERVER_NAME)
# get the statistics for each module
pre_stat = pre_obs.get_batch_to_epoch_ratio()
post_stat = post_obs.get_batch_to_epoch_ratio()
# record module, pre and post stat, and whether to do dynamic or static based off it
# true if post observer data distribution is non-stationary, false if it's stationary
dynamic_recommended = post_stat <= tolerance
# specify the classifications for whether data distributions considered stationary or non-stationary
pre_obs_dist_classification = DEFAULT_STATIONARY if pre_stat > tolerance else DEFAULT_NON_STATIONARY
post_obs_dist_classification = DEFAULT_STATIONARY if post_stat > tolerance else DEFAULT_NON_STATIONARY
# store the set of important information for this module
module_info = {
"tolerance": tolerance,
"dynamic_recommended": dynamic_recommended,
"pre_observer_comp_stat": pre_stat,
"pre_observer_data_dist": pre_obs_dist_classification,
"post_observer_comp_stat": post_stat,
"post_observer_data_dist": post_obs_dist_classification,
}
module_dynamic_static_info[fqn] = module_info
dynamic_vs_static_string = "Dynamic vs. Static Quantization suggestions: \n"
# This for loop goes through the information collected in module_dynamic_static_info and:
# Populates the string based report with the information from module_dynamic_static_info
# Compiles the complete report by appending relavent formatted strings
for module_fqn in module_dynamic_static_info.keys():
module_info = module_dynamic_static_info[module_fqn]
suggestion_string_template = "For module {} it is suggested to use {} quantization because {}.\n"
# decide what string formatting values will be
quantization_type = ""
quantization_reasoning = "the distribution of data before {} is {} and the distribution after is {}."
dynamic_benefit = " You will get more accurate results if you use dynamic quantization"
static_benefit = " You can increase model efficiency if you use static quantization"
benefit_str = ""
# strings for if dynamic quantized per tensor is needed
recommend_per_tensor = " We recommend to add a {} before this module if it is static."
rec_lay_to_add = "dynamic quantize per tensor layer"
dynamic_per_tensor_string = recommend_per_tensor.format(rec_lay_to_add)
dynamic_per_tensor_reasoning_string = (
" This is because the input to this module has a non-stationary distribution."
)
# start composing explanation
if module_info["dynamic_recommended"]:
quantization_type = "dynamic"
benefit_str = dynamic_benefit
else:
quantization_type = "static"
benefit_str = static_benefit
# now set the quantization explanation string
quantization_reasoning = (
quantization_reasoning.format(
module_fqn, module_info["pre_observer_data_dist"], module_info["post_observer_data_dist"]
)
+ benefit_str
)
# if we have a non-stationary input -> linear -> stationary we suggested static
# however, we want to also recommend they add a dynamic quantize per tensor right if this change is made
if (
module_info["pre_observer_data_dist"] == DEFAULT_NON_STATIONARY
and module_info["post_observer_data_dist"] == DEFAULT_STATIONARY
):
quantization_reasoning = (
quantization_reasoning + dynamic_per_tensor_string + dynamic_per_tensor_reasoning_string
)
# format the overall suggestion string with the specific inputs
module_suggestion_string = suggestion_string_template.format(
module_fqn, quantization_type, quantization_reasoning
)
# append to overall suggestion
dynamic_vs_static_string += module_suggestion_string
# return the string as well as the dictionary of information
return (dynamic_vs_static_string, module_dynamic_static_info)