mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: This PR adds test harness for checking Python / C++ API parity for `torch.nn.Module` subclasses. Under the hood, we use JIT tracing to transfer `nn.Module` state from Python to C++, so that we can test initialization / forward / backward on Python / C++ modules with the same parameters and buffers. Pull Request resolved: https://github.com/pytorch/pytorch/pull/23852 Differential Revision: D16830204 Pulled By: yf225 fbshipit-source-id: 9b5298c0e8cd30e341a9f026e6f05604a82d6002
19 lines
429 B
Python
19 lines
429 B
Python
from collections import namedtuple
|
|
|
|
TorchNNTestParams = namedtuple(
|
|
'TorchNNTestParams',
|
|
[
|
|
'module_name',
|
|
'module_variant_name',
|
|
'python_constructor_args',
|
|
'cpp_constructor_args',
|
|
'example_inputs',
|
|
'has_parity',
|
|
'python_module_class',
|
|
'cpp_sources',
|
|
'num_attrs_recursive',
|
|
'device',
|
|
]
|
|
)
|
|
|
|
CppArg = namedtuple('CppArg', ['type', 'value'])
|