diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index bf622b4e421..874e54d74a6 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -69,7 +69,7 @@ from typing import Dict, Optional, List, Tuple, Set, Sequence, Callable # # These functions require manual Python bindings or are not exposed to Python -SKIP_PYTHON_BINDINGS = [ +_SKIP_PYTHON_BINDINGS = [ 'alias', 'contiguous', 'is_cuda', 'is_sparse', 'is_sparse_csr', 'size', 'stride', '.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward', '.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*', @@ -94,27 +94,31 @@ SKIP_PYTHON_BINDINGS = [ 'fake_quantize_per_channel_affine_cachemask', ] +SKIP_PYTHON_BINDINGS = list(map(lambda pattern: re.compile(rf'^{pattern}$'), _SKIP_PYTHON_BINDINGS)) + # These function signatures are not exposed to Python. Note that this signature # list does not support regex. SKIP_PYTHON_BINDINGS_SIGNATURES = [ - 'add(Tensor, Scalar, Scalar)', 'add_(Tensor, Scalar, Scalar)', - 'sub(Tensor, Scalar, Scalar)', 'sub_(Tensor, Scalar, Scalar)', - 'mul(Tensor, Scalar)', 'mul_(Tensor, Scalar)', - 'div(Tensor, Scalar)', 'div_(Tensor, Scalar)', + 'add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor', + 'add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)', + 'sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor', + 'sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)', + 'mul.Scalar(Tensor self, Scalar other) -> Tensor', + 'mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)', + 'div.Scalar(Tensor self, Scalar other) -> Tensor', + 'div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)', ] @with_native_function def should_generate_py_binding(f: NativeFunction) -> bool: name = cpp.name(f.func) - for pattern in SKIP_PYTHON_BINDINGS: - if re.match('^' + pattern + '$', name): + for skip_regex in SKIP_PYTHON_BINDINGS: + if skip_regex.match(name): return False - args = ', '.join(argument_type_str(arg.type) - for arg in signature(f).arguments()) - sig = f'{name}({args})' + signature = str(f.func) for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES: - if pattern == sig: + if pattern == signature: return False return True