mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Enable C++ API parity tests on AArch64 (#145370)
Re-enables C++ API parity tests on AArch64 which now pass. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145370 Approved by: https://github.com/albanD
This commit is contained in:
parent
2f60f12f8b
commit
f85e4c1360
1 changed files with 42 additions and 47 deletions
|
|
@ -37,56 +37,51 @@ class TestCppApiParity(common.TestCase):
|
|||
|
||||
expected_test_params_dicts = []
|
||||
|
||||
if not common.IS_ARM64:
|
||||
for test_params_dicts, test_instance_class in [
|
||||
(sample_module.module_tests, common_nn.NewModuleTest),
|
||||
(sample_functional.functional_tests, common_nn.NewModuleTest),
|
||||
(common_nn.module_tests, common_nn.NewModuleTest),
|
||||
(common_nn.get_new_module_tests(), common_nn.NewModuleTest),
|
||||
(common_nn.criterion_tests, common_nn.CriterionTest),
|
||||
]:
|
||||
for test_params_dict in test_params_dicts:
|
||||
if test_params_dict.get("test_cpp_api_parity", True):
|
||||
if is_torch_nn_functional_test(test_params_dict):
|
||||
functional_impl_check.write_test_to_test_class(
|
||||
TestCppApiParity,
|
||||
test_params_dict,
|
||||
test_instance_class,
|
||||
parity_table,
|
||||
devices,
|
||||
)
|
||||
else:
|
||||
module_impl_check.write_test_to_test_class(
|
||||
TestCppApiParity,
|
||||
test_params_dict,
|
||||
test_instance_class,
|
||||
parity_table,
|
||||
devices,
|
||||
)
|
||||
expected_test_params_dicts.append(test_params_dict)
|
||||
for test_params_dicts, test_instance_class in [
|
||||
(sample_module.module_tests, common_nn.NewModuleTest),
|
||||
(sample_functional.functional_tests, common_nn.NewModuleTest),
|
||||
(common_nn.module_tests, common_nn.NewModuleTest),
|
||||
(common_nn.get_new_module_tests(), common_nn.NewModuleTest),
|
||||
(common_nn.criterion_tests, common_nn.CriterionTest),
|
||||
]:
|
||||
for test_params_dict in test_params_dicts:
|
||||
if test_params_dict.get("test_cpp_api_parity", True):
|
||||
if is_torch_nn_functional_test(test_params_dict):
|
||||
functional_impl_check.write_test_to_test_class(
|
||||
TestCppApiParity,
|
||||
test_params_dict,
|
||||
test_instance_class,
|
||||
parity_table,
|
||||
devices,
|
||||
)
|
||||
else:
|
||||
module_impl_check.write_test_to_test_class(
|
||||
TestCppApiParity,
|
||||
test_params_dict,
|
||||
test_instance_class,
|
||||
parity_table,
|
||||
devices,
|
||||
)
|
||||
expected_test_params_dicts.append(test_params_dict)
|
||||
|
||||
# Assert that all NN module/functional test dicts appear in the parity test
|
||||
assert len(
|
||||
[name for name in TestCppApiParity.__dict__ if "test_torch_nn_" in name]
|
||||
) == len(expected_test_params_dicts) * len(devices)
|
||||
# Assert that all NN module/functional test dicts appear in the parity test
|
||||
assert len(
|
||||
[name for name in TestCppApiParity.__dict__ if "test_torch_nn_" in name]
|
||||
) == len(expected_test_params_dicts) * len(devices)
|
||||
|
||||
# Assert that there exists auto-generated tests for `SampleModule` and `sample_functional`.
|
||||
# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
|
||||
assert (
|
||||
len([name for name in TestCppApiParity.__dict__ if "SampleModule" in name]) == 4
|
||||
)
|
||||
# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
|
||||
assert (
|
||||
len([name for name in TestCppApiParity.__dict__ if "sample_functional" in name])
|
||||
== 4
|
||||
)
|
||||
# Assert that there exists auto-generated tests for `SampleModule` and `sample_functional`.
|
||||
# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
|
||||
assert len([name for name in TestCppApiParity.__dict__ if "SampleModule" in name]) == 4
|
||||
# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
|
||||
assert (
|
||||
len([name for name in TestCppApiParity.__dict__ if "sample_functional" in name])
|
||||
== 4
|
||||
)
|
||||
|
||||
module_impl_check.build_cpp_tests(
|
||||
TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE
|
||||
)
|
||||
functional_impl_check.build_cpp_tests(
|
||||
TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE
|
||||
)
|
||||
module_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE)
|
||||
functional_impl_check.build_cpp_tests(
|
||||
TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
common.TestCase._default_dtype_check_enabled = True
|
||||
|
|
|
|||
Loading…
Reference in a new issue