mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add some missing docs to torch.rst, new unittest to enforce torch.rst no longer miss anything (#16039)
Summary: This prevent people (reviewer, PR author) from forgetting adding things to `torch.rst`. When something new is added to `_torch_doc.py` or `functional.py` but intentionally not in `torch.rst`, people should manually whitelist it in `test_docs_coverage.py`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16039 Differential Revision: D14070903 Pulled By: ezyang fbshipit-source-id: 60f2a42eb5efe81be073ed64e54525d143eb643e
This commit is contained in:
parent
a771a6ba67
commit
07b5782ff7
3 changed files with 70 additions and 0 deletions
|
|
@ -259,6 +259,8 @@ Other Operations
|
|||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: bincount
|
||||
.. autofunction:: broadcast_tensors
|
||||
.. autofunction:: cartesian_prod
|
||||
.. autofunction:: combinations
|
||||
.. autofunction:: cross
|
||||
.. autofunction:: diag
|
||||
.. autofunction:: diag_embed
|
||||
|
|
@ -275,7 +277,9 @@ Other Operations
|
|||
.. autofunction:: tensordot
|
||||
.. autofunction:: trace
|
||||
.. autofunction:: tril
|
||||
.. autofunction:: tril_indices
|
||||
.. autofunction:: triu
|
||||
.. autofunction:: triu_indices
|
||||
|
||||
|
||||
BLAS and LAPACK Operations
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ TESTS = [
|
|||
'dataloader',
|
||||
'distributed',
|
||||
'distributions',
|
||||
'docs_coverage',
|
||||
'expecttest',
|
||||
'indexing',
|
||||
'indexing_cuda',
|
||||
|
|
|
|||
65
test/test_docs_coverage.py
Normal file
65
test/test_docs_coverage.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
import torch
|
||||
import unittest
|
||||
import os
|
||||
import re
|
||||
import ast
|
||||
import _ast
|
||||
|
||||
|
||||
path = os.path.dirname(os.path.realpath(__file__))
|
||||
rstpath = os.path.join(path, '../docs/source/')
|
||||
pypath = os.path.join(path, '../torch/_torch_docs.py')
|
||||
r1 = re.compile(r'\.\. autofunction:: (\w*)')
|
||||
|
||||
|
||||
class TestDocCoverage(unittest.TestCase):
|
||||
|
||||
def test_torch(self):
|
||||
# get symbols documented in torch.rst
|
||||
whitelist = [
|
||||
'set_printoptions', 'get_rng_state', 'is_storage', 'initial_seed',
|
||||
'set_default_tensor_type', 'load', 'save', 'set_default_dtype',
|
||||
'is_tensor', 'compiled_with_cxx11_abi', 'set_rng_state',
|
||||
'manual_seed'
|
||||
]
|
||||
everything = set()
|
||||
filename = os.path.join(rstpath, 'torch.rst')
|
||||
with open(filename, 'r') as f:
|
||||
lines = f.readlines()
|
||||
for l in lines:
|
||||
l = l.strip()
|
||||
name = r1.findall(l)
|
||||
if name:
|
||||
everything.add(name[0])
|
||||
everything -= set(whitelist)
|
||||
# get symbols in functional.py and _torch_docs.py
|
||||
whitelist2 = ['product', 'inf', 'math', 'reduce', 'warnings', 'torch', 'annotate']
|
||||
everything2 = set()
|
||||
with open(pypath, 'r') as f:
|
||||
body = ast.parse(f.read()).body
|
||||
for i in body:
|
||||
if not isinstance(i, _ast.Expr):
|
||||
continue
|
||||
i = i.value
|
||||
if not isinstance(i, _ast.Call):
|
||||
continue
|
||||
if i.func.id != 'add_docstr':
|
||||
continue
|
||||
i = i.args[0]
|
||||
if i.value.id != 'torch':
|
||||
continue
|
||||
i = i.attr
|
||||
everything2.add(i)
|
||||
for p in dir(torch.functional):
|
||||
if not p.startswith('_') and p[0].islower():
|
||||
everything2.add(p)
|
||||
everything2 -= set(whitelist2)
|
||||
# assert they are equal
|
||||
for p in everything:
|
||||
self.assertIn(p, everything2, 'in torch.rst but not in python')
|
||||
for p in everything2:
|
||||
self.assertIn(p, everything, 'in python but not in torch.rst')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in a new issue