mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: This prevent people (reviewer, PR author) from forgetting adding things to `torch.rst`. When something new is added to `_torch_doc.py` or `functional.py` but intentionally not in `torch.rst`, people should manually whitelist it in `test_docs_coverage.py`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16039 Differential Revision: D14070903 Pulled By: ezyang fbshipit-source-id: 60f2a42eb5efe81be073ed64e54525d143eb643e
65 lines
2.2 KiB
Python
65 lines
2.2 KiB
Python
import torch
|
|
import unittest
|
|
import os
|
|
import re
|
|
import ast
|
|
import _ast
|
|
|
|
|
|
path = os.path.dirname(os.path.realpath(__file__))
|
|
rstpath = os.path.join(path, '../docs/source/')
|
|
pypath = os.path.join(path, '../torch/_torch_docs.py')
|
|
r1 = re.compile(r'\.\. autofunction:: (\w*)')
|
|
|
|
|
|
class TestDocCoverage(unittest.TestCase):
|
|
|
|
def test_torch(self):
|
|
# get symbols documented in torch.rst
|
|
whitelist = [
|
|
'set_printoptions', 'get_rng_state', 'is_storage', 'initial_seed',
|
|
'set_default_tensor_type', 'load', 'save', 'set_default_dtype',
|
|
'is_tensor', 'compiled_with_cxx11_abi', 'set_rng_state',
|
|
'manual_seed'
|
|
]
|
|
everything = set()
|
|
filename = os.path.join(rstpath, 'torch.rst')
|
|
with open(filename, 'r') as f:
|
|
lines = f.readlines()
|
|
for l in lines:
|
|
l = l.strip()
|
|
name = r1.findall(l)
|
|
if name:
|
|
everything.add(name[0])
|
|
everything -= set(whitelist)
|
|
# get symbols in functional.py and _torch_docs.py
|
|
whitelist2 = ['product', 'inf', 'math', 'reduce', 'warnings', 'torch', 'annotate']
|
|
everything2 = set()
|
|
with open(pypath, 'r') as f:
|
|
body = ast.parse(f.read()).body
|
|
for i in body:
|
|
if not isinstance(i, _ast.Expr):
|
|
continue
|
|
i = i.value
|
|
if not isinstance(i, _ast.Call):
|
|
continue
|
|
if i.func.id != 'add_docstr':
|
|
continue
|
|
i = i.args[0]
|
|
if i.value.id != 'torch':
|
|
continue
|
|
i = i.attr
|
|
everything2.add(i)
|
|
for p in dir(torch.functional):
|
|
if not p.startswith('_') and p[0].islower():
|
|
everything2.add(p)
|
|
everything2 -= set(whitelist2)
|
|
# assert they are equal
|
|
for p in everything:
|
|
self.assertIn(p, everything2, 'in torch.rst but not in python')
|
|
for p in everything2:
|
|
self.assertIn(p, everything, 'in python but not in torch.rst')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|