Add some missing docs to torch.rst, new unittest to enforce torch.rst no longer miss anything (#16039)

Summary:
This prevent people (reviewer, PR author) from forgetting adding things to `torch.rst`.

When something new is added to `_torch_doc.py` or `functional.py` but intentionally not in `torch.rst`, people should manually whitelist it in `test_docs_coverage.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16039

Differential Revision: D14070903

Pulled By: ezyang

fbshipit-source-id: 60f2a42eb5efe81be073ed64e54525d143eb643e
This commit is contained in:
Xiang Gao 2019-02-15 06:44:56 -08:00 committed by Facebook Github Bot
parent a771a6ba67
commit 07b5782ff7
3 changed files with 70 additions and 0 deletions

View file

@ -259,6 +259,8 @@ Other Operations
~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: bincount
.. autofunction:: broadcast_tensors
.. autofunction:: cartesian_prod
.. autofunction:: combinations
.. autofunction:: cross
.. autofunction:: diag
.. autofunction:: diag_embed
@ -275,7 +277,9 @@ Other Operations
.. autofunction:: tensordot
.. autofunction:: trace
.. autofunction:: tril
.. autofunction:: tril_indices
.. autofunction:: triu
.. autofunction:: triu_indices
BLAS and LAPACK Operations

View file

@ -28,6 +28,7 @@ TESTS = [
'dataloader',
'distributed',
'distributions',
'docs_coverage',
'expecttest',
'indexing',
'indexing_cuda',

View file

@ -0,0 +1,65 @@
import torch
import unittest
import os
import re
import ast
import _ast
path = os.path.dirname(os.path.realpath(__file__))
rstpath = os.path.join(path, '../docs/source/')
pypath = os.path.join(path, '../torch/_torch_docs.py')
r1 = re.compile(r'\.\. autofunction:: (\w*)')
class TestDocCoverage(unittest.TestCase):
def test_torch(self):
# get symbols documented in torch.rst
whitelist = [
'set_printoptions', 'get_rng_state', 'is_storage', 'initial_seed',
'set_default_tensor_type', 'load', 'save', 'set_default_dtype',
'is_tensor', 'compiled_with_cxx11_abi', 'set_rng_state',
'manual_seed'
]
everything = set()
filename = os.path.join(rstpath, 'torch.rst')
with open(filename, 'r') as f:
lines = f.readlines()
for l in lines:
l = l.strip()
name = r1.findall(l)
if name:
everything.add(name[0])
everything -= set(whitelist)
# get symbols in functional.py and _torch_docs.py
whitelist2 = ['product', 'inf', 'math', 'reduce', 'warnings', 'torch', 'annotate']
everything2 = set()
with open(pypath, 'r') as f:
body = ast.parse(f.read()).body
for i in body:
if not isinstance(i, _ast.Expr):
continue
i = i.value
if not isinstance(i, _ast.Call):
continue
if i.func.id != 'add_docstr':
continue
i = i.args[0]
if i.value.id != 'torch':
continue
i = i.attr
everything2.add(i)
for p in dir(torch.functional):
if not p.startswith('_') and p[0].islower():
everything2.add(p)
everything2 -= set(whitelist2)
# assert they are equal
for p in everything:
self.assertIn(p, everything2, 'in torch.rst but not in python')
for p in everything2:
self.assertIn(p, everything, 'in python but not in torch.rst')
if __name__ == '__main__':
unittest.main()