diff --git a/docs/source/torch.rst b/docs/source/torch.rst index f04b7fb21a6..79ac07fef24 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -259,6 +259,8 @@ Other Operations ~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: bincount .. autofunction:: broadcast_tensors +.. autofunction:: cartesian_prod +.. autofunction:: combinations .. autofunction:: cross .. autofunction:: diag .. autofunction:: diag_embed @@ -275,7 +277,9 @@ Other Operations .. autofunction:: tensordot .. autofunction:: trace .. autofunction:: tril +.. autofunction:: tril_indices .. autofunction:: triu +.. autofunction:: triu_indices BLAS and LAPACK Operations diff --git a/test/run_test.py b/test/run_test.py index 24c16417986..613cc948ec0 100644 --- a/test/run_test.py +++ b/test/run_test.py @@ -28,6 +28,7 @@ TESTS = [ 'dataloader', 'distributed', 'distributions', + 'docs_coverage', 'expecttest', 'indexing', 'indexing_cuda', diff --git a/test/test_docs_coverage.py b/test/test_docs_coverage.py new file mode 100644 index 00000000000..d5719003241 --- /dev/null +++ b/test/test_docs_coverage.py @@ -0,0 +1,65 @@ +import torch +import unittest +import os +import re +import ast +import _ast + + +path = os.path.dirname(os.path.realpath(__file__)) +rstpath = os.path.join(path, '../docs/source/') +pypath = os.path.join(path, '../torch/_torch_docs.py') +r1 = re.compile(r'\.\. autofunction:: (\w*)') + + +class TestDocCoverage(unittest.TestCase): + + def test_torch(self): + # get symbols documented in torch.rst + whitelist = [ + 'set_printoptions', 'get_rng_state', 'is_storage', 'initial_seed', + 'set_default_tensor_type', 'load', 'save', 'set_default_dtype', + 'is_tensor', 'compiled_with_cxx11_abi', 'set_rng_state', + 'manual_seed' + ] + everything = set() + filename = os.path.join(rstpath, 'torch.rst') + with open(filename, 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip() + name = r1.findall(l) + if name: + everything.add(name[0]) + everything -= set(whitelist) + # get symbols in functional.py and _torch_docs.py + whitelist2 = ['product', 'inf', 'math', 'reduce', 'warnings', 'torch', 'annotate'] + everything2 = set() + with open(pypath, 'r') as f: + body = ast.parse(f.read()).body + for i in body: + if not isinstance(i, _ast.Expr): + continue + i = i.value + if not isinstance(i, _ast.Call): + continue + if i.func.id != 'add_docstr': + continue + i = i.args[0] + if i.value.id != 'torch': + continue + i = i.attr + everything2.add(i) + for p in dir(torch.functional): + if not p.startswith('_') and p[0].islower(): + everything2.add(p) + everything2 -= set(whitelist2) + # assert they are equal + for p in everything: + self.assertIn(p, everything2, 'in torch.rst but not in python') + for p in everything2: + self.assertIn(p, everything, 'in python but not in torch.rst') + + +if __name__ == '__main__': + unittest.main()