pre compute regex and match simple signature autograd codegen 15s -> 12s (#59852)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59852

This whole stack does not change anything to the codegened code

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D29063814

Pulled By: albanD

fbshipit-source-id: a751047526f8d58f4760ee6f9ae906675bed5d75
This commit is contained in:
albanD 2021-06-12 06:55:44 -07:00 committed by Facebook GitHub Bot
parent 30a18fe318
commit d03ff1a17d

View file

@ -69,7 +69,7 @@ from typing import Dict, Optional, List, Tuple, Set, Sequence, Callable
#
# These functions require manual Python bindings or are not exposed to Python
SKIP_PYTHON_BINDINGS = [
_SKIP_PYTHON_BINDINGS = [
'alias', 'contiguous', 'is_cuda', 'is_sparse', 'is_sparse_csr', 'size', 'stride',
'.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward',
'.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*',
@ -94,27 +94,31 @@ SKIP_PYTHON_BINDINGS = [
'fake_quantize_per_channel_affine_cachemask',
]
SKIP_PYTHON_BINDINGS = list(map(lambda pattern: re.compile(rf'^{pattern}$'), _SKIP_PYTHON_BINDINGS))
# These function signatures are not exposed to Python. Note that this signature
# list does not support regex.
SKIP_PYTHON_BINDINGS_SIGNATURES = [
'add(Tensor, Scalar, Scalar)', 'add_(Tensor, Scalar, Scalar)',
'sub(Tensor, Scalar, Scalar)', 'sub_(Tensor, Scalar, Scalar)',
'mul(Tensor, Scalar)', 'mul_(Tensor, Scalar)',
'div(Tensor, Scalar)', 'div_(Tensor, Scalar)',
'add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor',
'add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)',
'sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor',
'sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)',
'mul.Scalar(Tensor self, Scalar other) -> Tensor',
'mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)',
'div.Scalar(Tensor self, Scalar other) -> Tensor',
'div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)',
]
@with_native_function
def should_generate_py_binding(f: NativeFunction) -> bool:
name = cpp.name(f.func)
for pattern in SKIP_PYTHON_BINDINGS:
if re.match('^' + pattern + '$', name):
for skip_regex in SKIP_PYTHON_BINDINGS:
if skip_regex.match(name):
return False
args = ', '.join(argument_type_str(arg.type)
for arg in signature(f).arguments())
sig = f'{name}({args})'
signature = str(f.func)
for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
if pattern == sig:
if pattern == signature:
return False
return True