diff --git a/test/test_cpp_api_parity.py b/test/test_cpp_api_parity.py index 0c27051e8e5..2193243b751 100644 --- a/test/test_cpp_api_parity.py +++ b/test/test_cpp_api_parity.py @@ -37,56 +37,51 @@ class TestCppApiParity(common.TestCase): expected_test_params_dicts = [] -if not common.IS_ARM64: - for test_params_dicts, test_instance_class in [ - (sample_module.module_tests, common_nn.NewModuleTest), - (sample_functional.functional_tests, common_nn.NewModuleTest), - (common_nn.module_tests, common_nn.NewModuleTest), - (common_nn.get_new_module_tests(), common_nn.NewModuleTest), - (common_nn.criterion_tests, common_nn.CriterionTest), - ]: - for test_params_dict in test_params_dicts: - if test_params_dict.get("test_cpp_api_parity", True): - if is_torch_nn_functional_test(test_params_dict): - functional_impl_check.write_test_to_test_class( - TestCppApiParity, - test_params_dict, - test_instance_class, - parity_table, - devices, - ) - else: - module_impl_check.write_test_to_test_class( - TestCppApiParity, - test_params_dict, - test_instance_class, - parity_table, - devices, - ) - expected_test_params_dicts.append(test_params_dict) +for test_params_dicts, test_instance_class in [ + (sample_module.module_tests, common_nn.NewModuleTest), + (sample_functional.functional_tests, common_nn.NewModuleTest), + (common_nn.module_tests, common_nn.NewModuleTest), + (common_nn.get_new_module_tests(), common_nn.NewModuleTest), + (common_nn.criterion_tests, common_nn.CriterionTest), +]: + for test_params_dict in test_params_dicts: + if test_params_dict.get("test_cpp_api_parity", True): + if is_torch_nn_functional_test(test_params_dict): + functional_impl_check.write_test_to_test_class( + TestCppApiParity, + test_params_dict, + test_instance_class, + parity_table, + devices, + ) + else: + module_impl_check.write_test_to_test_class( + TestCppApiParity, + test_params_dict, + test_instance_class, + parity_table, + devices, + ) + expected_test_params_dicts.append(test_params_dict) - # Assert that all NN module/functional test dicts appear in the parity test - assert len( - [name for name in TestCppApiParity.__dict__ if "test_torch_nn_" in name] - ) == len(expected_test_params_dicts) * len(devices) +# Assert that all NN module/functional test dicts appear in the parity test +assert len( + [name for name in TestCppApiParity.__dict__ if "test_torch_nn_" in name] +) == len(expected_test_params_dicts) * len(devices) - # Assert that there exists auto-generated tests for `SampleModule` and `sample_functional`. - # 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices) - assert ( - len([name for name in TestCppApiParity.__dict__ if "SampleModule" in name]) == 4 - ) - # 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices) - assert ( - len([name for name in TestCppApiParity.__dict__ if "sample_functional" in name]) - == 4 - ) +# Assert that there exists auto-generated tests for `SampleModule` and `sample_functional`. +# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices) +assert len([name for name in TestCppApiParity.__dict__ if "SampleModule" in name]) == 4 +# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices) +assert ( + len([name for name in TestCppApiParity.__dict__ if "sample_functional" in name]) + == 4 +) - module_impl_check.build_cpp_tests( - TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE - ) - functional_impl_check.build_cpp_tests( - TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE - ) +module_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE) +functional_impl_check.build_cpp_tests( + TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE +) if __name__ == "__main__": common.TestCase._default_dtype_check_enabled = True