pytorch/test/cpp_api_parity/__init__.py
Will Feng 1bf1970fe2 Add Python/C++ torch.nn API parity test harness (#23852)
Summary:
This PR adds test harness for checking Python / C++ API parity for `torch.nn.Module` subclasses. Under the hood, we use JIT tracing to transfer `nn.Module` state from Python to C++, so that we can test initialization / forward / backward on Python / C++ modules with the same parameters and buffers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23852

Differential Revision: D16830204

Pulled By: yf225

fbshipit-source-id: 9b5298c0e8cd30e341a9f026e6f05604a82d6002
2019-08-26 08:02:25 -07:00

19 lines
429 B
Python

from collections import namedtuple
TorchNNTestParams = namedtuple(
'TorchNNTestParams',
[
'module_name',
'module_variant_name',
'python_constructor_args',
'cpp_constructor_args',
'example_inputs',
'has_parity',
'python_module_class',
'cpp_sources',
'num_attrs_recursive',
'device',
]
)
CppArg = namedtuple('CppArg', ['type', 'value'])