pytorch/test/cpp_api_parity/__init__.py
Will Feng d7d3aedd2c Make various improvements to C++ API parity test harness (#25828)
Summary:
This PR makes the following improvements to C++ API parity test harness:
1. Remove `options_args` since we can get the list of options from the Python module constructor args.
2. Add test for mapping `int` or `tuple` in Python module constructor args to `ExpandingArray` in C++ module options.
3. Use regex to split up e.g. `(1, {2, 3}, 4)` into `['1', '{2, 3}', '4']` for `cpp_default_constructor_args`.
4. Add options arg accessor tests in `_test_torch_nn_module_ctor_args`.

We will be able to merge https://github.com/pytorch/pytorch/pull/24160 and https://github.com/pytorch/pytorch/pull/24860 after these improvements.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25828

Differential Revision: D17266197

Pulled By: yf225

fbshipit-source-id: 96d0d4a2fcc4b47cd1782d4df2c9bac107dec3f9
2019-09-09 15:43:55 -07:00

86 lines
2.4 KiB
Python

from collections import namedtuple
TorchNNTestParams = namedtuple(
'TorchNNTestParams',
[
'module_name',
'module_variant_name',
'python_constructor',
'python_constructor_args',
'cpp_constructor_args',
'example_inputs',
'has_parity',
'cpp_sources',
'num_attrs_recursive',
'device',
]
)
CppArg = namedtuple('CppArg', ['type', 'value'])
ParityStatus = namedtuple('ParityStatus', ['has_impl_parity', 'has_doc_parity'])
TorchNNModuleMetadata = namedtuple(
'TorchNNModuleMetadata',
[
'cpp_default_constructor_args',
'num_attrs_recursive',
'cpp_sources',
]
)
TorchNNModuleMetadata.__new__.__defaults__ = (None, None, '')
'''
This function expects the parity tracker Markdown file to have the following format:
```
## package1_name
API | Implementation Parity | Doc Parity
------------- | ------------- | -------------
API_Name|No|No
...
## package2_name
API | Implementation Parity | Doc Parity
------------- | ------------- | -------------
API_Name|No|No
...
```
The returned dict has the following format:
```
Dict[package_name]
-> Dict[api_name]
-> ParityStatus
```
'''
def parse_parity_tracker_table(file_path):
def parse_parity_choice(str):
if str in ['Yes', 'No']:
return str == 'Yes'
else:
raise RuntimeError(
'{} is not a supported parity choice. The valid choices are "Yes" and "No".'.format(str))
parity_tracker_dict = {}
with open(file_path, 'r') as f:
all_text = f.read()
packages = all_text.split('##')
for package in packages[1:]:
lines = [line.strip() for line in package.split('\n') if line.strip() != '']
package_name = lines[0]
if package_name in parity_tracker_dict:
raise RuntimeError("Duplicated package name `{}` found in {}".format(package_name, file_path))
else:
parity_tracker_dict[package_name] = {}
for api_status in lines[3:]:
api_name, has_impl_parity_str, has_doc_parity_str = [x.strip() for x in api_status.split('|')]
parity_tracker_dict[package_name][api_name] = ParityStatus(
has_impl_parity=parse_parity_choice(has_impl_parity_str),
has_doc_parity=parse_parity_choice(has_doc_parity_str))
return parity_tracker_dict