mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
avoid error string formatting aten codegen 28s -> 23s (#59848)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59848 This whole stack does not change anything to the codegened code Test Plan: Imported from OSS Reviewed By: ailzhang Differential Revision: D29063818 Pulled By: albanD fbshipit-source-id: c68734672eeacd212d7bd9bebe3d53aaa20c3c24
This commit is contained in:
parent
7143a6a189
commit
504ec30109
4 changed files with 6 additions and 5 deletions
|
|
@ -25,7 +25,7 @@ def native_function_manager(g: Union[NativeFunctionsGroup, NativeFunction]) -> I
|
|||
f = g.out
|
||||
else:
|
||||
f = g
|
||||
with context(f'in native_functions.yaml line {f.loc}:\n {f.func}'):
|
||||
with context(lambda: f'in native_functions.yaml line {f.loc}:\n {f.func}'):
|
||||
with local.parametrize(use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors):
|
||||
yield
|
||||
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ def parse_native_yaml(path: str) -> ParsedYaml:
|
|||
assert isinstance(e.get('__line__'), int), e
|
||||
loc = Location(path, e['__line__'])
|
||||
funcs = e.get('func')
|
||||
with context(f'in {loc}:\n {funcs}'):
|
||||
with context(lambda: f'in {loc}:\n {funcs}'):
|
||||
func, m = NativeFunction.from_yaml(e, loc)
|
||||
rs.append(func)
|
||||
BackendIndex.grow_index(bs, m)
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
|||
|
||||
backend_key: Optional[DispatchKey] = None
|
||||
if len(supported) > 0:
|
||||
with context(f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.'):
|
||||
with context(lambda: f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.'):
|
||||
backend_key = DispatchKey.parse(backend)
|
||||
|
||||
backend_idx = create_backend_index(supported, backend_key)
|
||||
|
|
@ -87,7 +87,7 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
|||
|
||||
autograd_key: Optional[DispatchKey] = None
|
||||
if len(supported_autograd) > 0:
|
||||
with context(f'The "autograd" key was specified, which indicates that you would like to override \
|
||||
with context(lambda: f'The "autograd" key was specified, which indicates that you would like to override \
|
||||
the behavior of autograd for some operators on your backend. However "Autograd{backend}" is not a valid DispatchKey.'):
|
||||
autograd_key = DispatchKey.parse(f'Autograd{backend}')
|
||||
|
||||
|
|
|
|||
|
|
@ -60,11 +60,12 @@ def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
|
|||
# easily say that an error occurred while processing a specific
|
||||
# context.
|
||||
@contextlib.contextmanager
|
||||
def context(msg: str) -> Iterator[None]:
|
||||
def context(msg_fn: Callable[[], str]) -> Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
# TODO: this does the wrong thing with KeyError
|
||||
msg = msg_fn()
|
||||
msg = textwrap.indent(msg, ' ')
|
||||
msg = f'{e.args[0]}\n{msg}' if e.args else msg
|
||||
e.args = (msg,) + e.args[1:]
|
||||
|
|
|
|||
Loading…
Reference in a new issue