mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
pre compute regex and match simple signature autograd codegen 15s -> 12s (#59852)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59852 This whole stack does not change anything to the codegened code Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D29063814 Pulled By: albanD fbshipit-source-id: a751047526f8d58f4760ee6f9ae906675bed5d75
This commit is contained in:
parent
30a18fe318
commit
d03ff1a17d
1 changed files with 15 additions and 11 deletions
|
|
@ -69,7 +69,7 @@ from typing import Dict, Optional, List, Tuple, Set, Sequence, Callable
|
|||
#
|
||||
|
||||
# These functions require manual Python bindings or are not exposed to Python
|
||||
SKIP_PYTHON_BINDINGS = [
|
||||
_SKIP_PYTHON_BINDINGS = [
|
||||
'alias', 'contiguous', 'is_cuda', 'is_sparse', 'is_sparse_csr', 'size', 'stride',
|
||||
'.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward',
|
||||
'.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*',
|
||||
|
|
@ -94,27 +94,31 @@ SKIP_PYTHON_BINDINGS = [
|
|||
'fake_quantize_per_channel_affine_cachemask',
|
||||
]
|
||||
|
||||
SKIP_PYTHON_BINDINGS = list(map(lambda pattern: re.compile(rf'^{pattern}$'), _SKIP_PYTHON_BINDINGS))
|
||||
|
||||
# These function signatures are not exposed to Python. Note that this signature
|
||||
# list does not support regex.
|
||||
SKIP_PYTHON_BINDINGS_SIGNATURES = [
|
||||
'add(Tensor, Scalar, Scalar)', 'add_(Tensor, Scalar, Scalar)',
|
||||
'sub(Tensor, Scalar, Scalar)', 'sub_(Tensor, Scalar, Scalar)',
|
||||
'mul(Tensor, Scalar)', 'mul_(Tensor, Scalar)',
|
||||
'div(Tensor, Scalar)', 'div_(Tensor, Scalar)',
|
||||
'add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor',
|
||||
'add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)',
|
||||
'sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor',
|
||||
'sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)',
|
||||
'mul.Scalar(Tensor self, Scalar other) -> Tensor',
|
||||
'mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)',
|
||||
'div.Scalar(Tensor self, Scalar other) -> Tensor',
|
||||
'div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)',
|
||||
]
|
||||
|
||||
@with_native_function
|
||||
def should_generate_py_binding(f: NativeFunction) -> bool:
|
||||
name = cpp.name(f.func)
|
||||
for pattern in SKIP_PYTHON_BINDINGS:
|
||||
if re.match('^' + pattern + '$', name):
|
||||
for skip_regex in SKIP_PYTHON_BINDINGS:
|
||||
if skip_regex.match(name):
|
||||
return False
|
||||
|
||||
args = ', '.join(argument_type_str(arg.type)
|
||||
for arg in signature(f).arguments())
|
||||
sig = f'{name}({args})'
|
||||
signature = str(f.func)
|
||||
for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
|
||||
if pattern == sig:
|
||||
if pattern == signature:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
|
|||
Loading…
Reference in a new issue