mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
ns for fx: move linear activation test case to new API (#53777)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53777 Moves linear activation test case to new NS API Test Plan: ``` python test/test_quantization.py TestFXNumericSuiteCoreAPIsModels.test_compare_activations_linear ``` Imported from OSS Reviewed By: hx89 Differential Revision: D26967107 fbshipit-source-id: 83c4401b2bf79d15227b7fb3e59c54276ec5626b
This commit is contained in:
parent
57bf13409a
commit
2912ad1324
1 changed files with 18 additions and 52 deletions
|
|
@ -216,57 +216,6 @@ class TestGraphModeNumericSuite(QuantizationTestCase):
|
|||
v["float"][i][1].shape == v["quantized"][i][1].shape
|
||||
)
|
||||
|
||||
@override_qengines
|
||||
def test_compare_model_outputs_linear_static_fx(self):
|
||||
r"""Compare the output of linear layer in static quantized model and corresponding
|
||||
output of linear layer in float model
|
||||
"""
|
||||
|
||||
qengine = torch.backends.quantized.engine
|
||||
qconfig = get_default_qconfig(qengine)
|
||||
qconfig_dict = {"": qconfig}
|
||||
|
||||
float_model = SingleLayerLinearModel()
|
||||
float_model.eval()
|
||||
|
||||
prepared_model = prepare_fx(float_model, qconfig_dict)
|
||||
|
||||
prepared_float_model = copy.deepcopy(prepared_model)
|
||||
|
||||
# Run calibration
|
||||
test_only_eval_fn(prepared_model, self.calib_data)
|
||||
q_model = convert_fx(prepared_model)
|
||||
|
||||
linear_data = self.calib_data[0][0]
|
||||
|
||||
expected_act_compare_dict_keys = {"x.stats", "fc1.stats"}
|
||||
self.compare_and_validate_model_outputs_results_fx(
|
||||
prepared_float_model, q_model, expected_act_compare_dict_keys, linear_data
|
||||
)
|
||||
|
||||
@override_qengines
|
||||
def test_compare_model_outputs_linear_dynamic_fx(self):
|
||||
r"""Compare the output of linear layer in dynamic quantized model and corresponding
|
||||
output of linear layer in float model
|
||||
"""
|
||||
|
||||
qconfig_dict = {"object_type": [(nn.Linear, default_dynamic_qconfig)]}
|
||||
|
||||
float_model = SingleLayerLinearDynamicModel()
|
||||
float_model.eval()
|
||||
|
||||
prepared_model = prepare_fx(float_model, qconfig_dict)
|
||||
prepared_float_model = copy.deepcopy(prepared_model)
|
||||
|
||||
q_model = convert_fx(prepared_model)
|
||||
|
||||
linear_data = self.calib_data[0][0]
|
||||
|
||||
expected_act_compare_dict_keys = {"x.stats", "fc1.stats"}
|
||||
self.compare_and_validate_model_outputs_results_fx(
|
||||
prepared_float_model, q_model, expected_act_compare_dict_keys, linear_data
|
||||
)
|
||||
|
||||
@override_qengines
|
||||
def test_compare_model_outputs_lstm_dynamic_fx(self):
|
||||
r"""Compare the output of LSTM layer in dynamic quantized model and corresponding
|
||||
|
|
@ -573,8 +522,11 @@ class FXNumericSuiteQuantizationTestCase(QuantizationTestCase):
|
|||
def _test_match_activations(
|
||||
self, m, data, prepared_expected_node_occurrence=None, results_len=0,
|
||||
should_log_inputs=False,
|
||||
qconfig_dict=None,
|
||||
):
|
||||
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
|
||||
if qconfig_dict is None:
|
||||
qconfig_dict = {'': torch.quantization.default_qconfig}
|
||||
mp = prepare_fx(m, qconfig_dict)
|
||||
mp(*data)
|
||||
# TODO(future PR): prevent the need for copying here, we can copy the
|
||||
# modules but should reuse the underlying tensors
|
||||
|
|
@ -836,6 +788,20 @@ class TestFXNumericSuiteCoreAPIsModels(FXNumericSuiteQuantizationTestCase):
|
|||
res = self._test_match_activations(
|
||||
m, (torch.randn(1, 3, 4, 4),), results_len=1)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_compare_activations_linear(self):
|
||||
test_cases = (
|
||||
(SingleLayerLinearModel(), None),
|
||||
(
|
||||
SingleLayerLinearDynamicModel(),
|
||||
{"object_type": [(nn.Linear, default_dynamic_qconfig)]},
|
||||
),
|
||||
)
|
||||
for m, qconfig_dict in test_cases:
|
||||
m.eval()
|
||||
res = self._test_match_activations(
|
||||
m, (torch.randn(5, 5),), results_len=1, qconfig_dict=qconfig_dict)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_sparsenn_compare_activations(self):
|
||||
for should_log_inputs in (True, False):
|
||||
|
|
|
|||
Loading…
Reference in a new issue