mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
The original change was about 9.5% slower than then before #122074 . This improves it to be only about 1.4% slower. Also touched up some unrelated nits that the linter complained about. Fixes #126293 Ran torchbench 3 times on each change. Perf values before (stable), after (fix), and with #122074 backed out (backout): ``` ../inductor-tools/scripts/modelbench/inductor_single_run.sh single inference performance torchbench pyhpc_isoneutral_mixing amp first dynamic cpp stable: 43.948x 45.754x 44.906x fix: 47.505x 49.987x 47.493x backout: 48.243x 48.199x 48.192x ../inductor-tools/scripts/modelbench/inductor_single_run.sh single inference performance torchbench pyhpc_equation_of_state amp first static default stable: 15.224x 13.286x 15.354x fix: 16.402x 16.370x 16.183x backout: 16.554x 16.675x 16.787x ../inductor-tools/scripts/modelbench/inductor_single_run.sh single inference performance torchbench lennard_jones float32 first static default stable: 1.712x 1.651x 1.640x fix: 1.804x 1.798x 1.792x backout: 1.864x 1.824x 1.836x ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126996 Approved by: https://github.com/jansel
83 lines
3.2 KiB
Python
83 lines
3.2 KiB
Python
from enum import Enum
|
|
from typing import Optional, Tuple
|
|
|
|
from torch import Tensor
|
|
|
|
# Defined in torch/csrc/functorch/init.cpp
|
|
|
|
def _set_dynamic_layer_keys_included(included: bool) -> None: ...
|
|
def get_unwrapped(tensor: Tensor) -> Tensor: ...
|
|
def is_batchedtensor(tensor: Tensor) -> bool: ...
|
|
def is_functionaltensor(tensor: Tensor) -> bool: ...
|
|
def is_functorch_wrapped_tensor(tensor: Tensor) -> bool: ...
|
|
def is_gradtrackingtensor(tensor: Tensor) -> bool: ...
|
|
def is_legacy_batchedtensor(tensor: Tensor) -> bool: ...
|
|
def maybe_get_bdim(tensor: Tensor) -> int: ...
|
|
def maybe_get_level(tensor: Tensor) -> int: ...
|
|
def maybe_current_level() -> Optional[int]: ...
|
|
def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
|
|
def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
|
|
def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
|
|
def _unwrap_batched(tensor: Tensor, level: int) -> Tuple[Tensor, Optional[int]]: ...
|
|
def current_level() -> int: ...
|
|
def count_jvp_interpreters() -> int: ...
|
|
def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
|
|
def set_single_level_autograd_function_allowed(allowed: bool) -> None: ...
|
|
def get_single_level_autograd_function_allowed() -> bool: ...
|
|
def _unwrap_functional_tensor(tensor: Tensor, reapply_views: bool) -> Tensor: ...
|
|
def _wrap_functional_tensor(tensor: Tensor, level: int) -> Tensor: ...
|
|
def _vmap_increment_nesting(batch_size: int, randomness: str) -> int: ...
|
|
def _vmap_decrement_nesting() -> int: ...
|
|
def _grad_increment_nesting() -> int: ...
|
|
def _grad_decrement_nesting() -> int: ...
|
|
def _jvp_increment_nesting() -> int: ...
|
|
def _jvp_decrement_nesting() -> int: ...
|
|
|
|
# Defined in aten/src/ATen/functorch/Interpreter.h
|
|
class TransformType(Enum):
|
|
Torch: TransformType = ...
|
|
Vmap: TransformType = ...
|
|
Grad: TransformType = ...
|
|
Jvp: TransformType = ...
|
|
Functionalize: TransformType = ...
|
|
|
|
class RandomnessType(Enum):
|
|
Error: TransformType = ...
|
|
Same: TransformType = ...
|
|
Different: TransformType = ...
|
|
|
|
class CInterpreter:
|
|
def key(self) -> TransformType: ...
|
|
def level(self) -> int: ...
|
|
|
|
class CGradInterpreterPtr:
|
|
def __init__(self, interpreter: CInterpreter): ...
|
|
def lift(self, Tensor) -> Tensor: ...
|
|
def prevGradMode(self) -> bool: ...
|
|
|
|
class CJvpInterpreterPtr:
|
|
def __init__(self, interpreter: CInterpreter): ...
|
|
def lift(self, Tensor) -> Tensor: ...
|
|
def prevFwdGradMode(self) -> bool: ...
|
|
|
|
class CFunctionalizeInterpreterPtr:
|
|
def __init__(self, interpreter: CInterpreter): ...
|
|
def key(self) -> TransformType: ...
|
|
def level(self) -> int: ...
|
|
def functionalizeAddBackViews(self) -> bool: ...
|
|
|
|
class CVmapInterpreterPtr:
|
|
def __init__(self, interpreter: CInterpreter): ...
|
|
def key(self) -> TransformType: ...
|
|
def level(self) -> int: ...
|
|
def batchSize(self) -> int: ...
|
|
def randomness(self) -> RandomnessType: ...
|
|
|
|
class DynamicLayer: ...
|
|
|
|
def get_dynamic_layer_stack_depth() -> int: ...
|
|
def get_interpreter_stack() -> list[CInterpreter]: ...
|
|
def peek_interpreter_stack() -> CInterpreter: ...
|
|
def pop_dynamic_layer_stack() -> DynamicLayer: ...
|
|
def pop_dynamic_layer_stack_and_undo_to_depth(int) -> None: ...
|
|
def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ...
|