From fcf9dc3b118c5a7f89c34379465c27394a4c829e Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Mon, 6 Jan 2025 08:45:43 -0800 Subject: [PATCH] Migrate from Tuple -> tuple in benchmarks (#144259) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144259 Approved by: https://github.com/yanboliang --- benchmarks/dynamo/common.py | 3 +- .../microbenchmarks/operator_inp_utils.py | 4 +- benchmarks/fastrnns/cells.py | 16 ++++--- benchmarks/fastrnns/custom_lstms.py | 42 +++++++++---------- benchmarks/fastrnns/factory.py | 22 +++++----- .../torchaudio_models.py | 8 ++-- .../functional_autograd_benchmark/utils.py | 16 +++---- benchmarks/gpt_fast/generate.py | 4 +- .../instruction_counts/execution/runner.py | 8 ++-- benchmarks/instruction_counts/worker/main.py | 8 ++-- benchmarks/transformer/score_mod.py | 22 +++++----- benchmarks/transformer/sdpa.py | 4 +- 12 files changed, 77 insertions(+), 80 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index e0aae9cdaed..086cd68dc91 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -32,7 +32,6 @@ from typing import ( NamedTuple, Optional, Sequence, - Tuple, Type, TYPE_CHECKING, ) @@ -746,7 +745,7 @@ def timed( return (time_total, result) if return_result else time_total -def _normalize_bench_inputs(example_inputs) -> Tuple[Tuple[Any], Mapping[str, Any]]: +def _normalize_bench_inputs(example_inputs) -> tuple[tuple[Any], Mapping[str, Any]]: # NOTE(bowbao): For huggingface benchmark, example_inputs are formatted as dictionary, # and consumed like `model(**example_inputs)`. # For other benchmarks, example_inputs are formatted as tuple and consumed diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py b/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py index 602c3bc516f..af8bf7a7f8f 100644 --- a/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py +++ b/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py @@ -4,7 +4,7 @@ import math import os from collections import Counter, defaultdict from functools import partial -from typing import Any, Dict, Generator, Iterable, Tuple +from typing import Any, Dict, Generator, Iterable import torch from torch.testing import make_tensor @@ -263,7 +263,7 @@ class OperatorInputsLoader: def get_inputs_for_operator( self, operator, dtype=None, device="cuda" - ) -> Generator[Tuple[Iterable[Any], Dict[str, Any]], None, None]: + ) -> Generator[tuple[Iterable[Any], Dict[str, Any]], None, None]: assert ( str(operator) in self.operator_db ), f"Could not find {operator}, must provide overload" diff --git a/benchmarks/fastrnns/cells.py b/benchmarks/fastrnns/cells.py index 21e6149256f..ec55f444400 100644 --- a/benchmarks/fastrnns/cells.py +++ b/benchmarks/fastrnns/cells.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch from torch import Tensor @@ -27,12 +25,12 @@ def milstm_cell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias): def lstm_cell( input: Tensor, - hidden: Tuple[Tensor, Tensor], + hidden: tuple[Tensor, Tensor], w_ih: Tensor, w_hh: Tensor, b_ih: Tensor, b_hh: Tensor, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: hx, cx = hidden gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh @@ -57,7 +55,7 @@ def flat_lstm_cell( w_hh: Tensor, b_ih: Tensor, b_hh: Tensor, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) @@ -75,11 +73,11 @@ def flat_lstm_cell( def premul_lstm_cell( igates: Tensor, - hidden: Tuple[Tensor, Tensor], + hidden: tuple[Tensor, Tensor], w_hh: Tensor, b_ih: Tensor, b_hh: Tensor, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: hx, cx = hidden gates = igates + torch.mm(hx, w_hh.t()) + b_ih + b_hh @@ -97,8 +95,8 @@ def premul_lstm_cell( def premul_lstm_cell_no_bias( - igates: Tensor, hidden: Tuple[Tensor, Tensor], w_hh: Tensor, b_hh: Tensor -) -> Tuple[Tensor, Tensor]: + igates: Tensor, hidden: tuple[Tensor, Tensor], w_hh: Tensor, b_hh: Tensor +) -> tuple[Tensor, Tensor]: hx, cx = hidden gates = igates + torch.mm(hx, w_hh.t()) + b_hh diff --git a/benchmarks/fastrnns/custom_lstms.py b/benchmarks/fastrnns/custom_lstms.py index 0e5643bbeda..edc42206eff 100644 --- a/benchmarks/fastrnns/custom_lstms.py +++ b/benchmarks/fastrnns/custom_lstms.py @@ -1,7 +1,7 @@ import numbers import warnings from collections import namedtuple -from typing import List, Tuple +from typing import List import torch import torch.jit as jit @@ -131,8 +131,8 @@ class LSTMCell(jit.ScriptModule): @jit.script_method def forward( - self, input: Tensor, state: Tuple[Tensor, Tensor] - ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + self, input: Tensor, state: tuple[Tensor, Tensor] + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: hx, cx = state gates = ( torch.mm(input, self.weight_ih.t()) @@ -199,8 +199,8 @@ class LayerNormLSTMCell(jit.ScriptModule): @jit.script_method def forward( - self, input: Tensor, state: Tuple[Tensor, Tensor] - ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + self, input: Tensor, state: tuple[Tensor, Tensor] + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: hx, cx = state igates = self.layernorm_i(torch.mm(input, self.weight_ih.t())) hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t())) @@ -225,8 +225,8 @@ class LSTMLayer(jit.ScriptModule): @jit.script_method def forward( - self, input: Tensor, state: Tuple[Tensor, Tensor] - ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + self, input: Tensor, state: tuple[Tensor, Tensor] + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: inputs = input.unbind(0) outputs = torch.jit.annotate(List[Tensor], []) for i in range(len(inputs)): @@ -242,8 +242,8 @@ class ReverseLSTMLayer(jit.ScriptModule): @jit.script_method def forward( - self, input: Tensor, state: Tuple[Tensor, Tensor] - ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + self, input: Tensor, state: tuple[Tensor, Tensor] + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: inputs = reverse(input.unbind(0)) outputs = jit.annotate(List[Tensor], []) for i in range(len(inputs)): @@ -266,11 +266,11 @@ class BidirLSTMLayer(jit.ScriptModule): @jit.script_method def forward( - self, input: Tensor, states: List[Tuple[Tensor, Tensor]] - ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: + self, input: Tensor, states: List[tuple[Tensor, Tensor]] + ) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]: # List[LSTMState]: [forward LSTMState, backward LSTMState] outputs = jit.annotate(List[Tensor], []) - output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) + output_states = jit.annotate(List[tuple[Tensor, Tensor]], []) # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 i = 0 for direction in self.directions: @@ -300,10 +300,10 @@ class StackedLSTM(jit.ScriptModule): @jit.script_method def forward( - self, input: Tensor, states: List[Tuple[Tensor, Tensor]] - ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: + self, input: Tensor, states: List[tuple[Tensor, Tensor]] + ) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]: # List[LSTMState]: One state per layer - output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) + output_states = jit.annotate(List[tuple[Tensor, Tensor]], []) output = input # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 i = 0 @@ -330,11 +330,11 @@ class StackedLSTM2(jit.ScriptModule): @jit.script_method def forward( - self, input: Tensor, states: List[List[Tuple[Tensor, Tensor]]] - ) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]: + self, input: Tensor, states: List[List[tuple[Tensor, Tensor]]] + ) -> tuple[Tensor, List[List[tuple[Tensor, Tensor]]]]: # List[List[LSTMState]]: The outer list is for layers, # inner list is for directions. - output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], []) + output_states = jit.annotate(List[List[tuple[Tensor, Tensor]]], []) output = input # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 i = 0 @@ -370,10 +370,10 @@ class StackedLSTMWithDropout(jit.ScriptModule): @jit.script_method def forward( - self, input: Tensor, states: List[Tuple[Tensor, Tensor]] - ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: + self, input: Tensor, states: List[tuple[Tensor, Tensor]] + ) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]: # List[LSTMState]: One state per layer - output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) + output_states = jit.annotate(List[tuple[Tensor, Tensor]], []) output = input # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 i = 0 diff --git a/benchmarks/fastrnns/factory.py b/benchmarks/fastrnns/factory.py index 32bb3eec504..afa198ec985 100644 --- a/benchmarks/fastrnns/factory.py +++ b/benchmarks/fastrnns/factory.py @@ -1,5 +1,5 @@ from collections import namedtuple -from typing import List, Tuple +from typing import List import torch from torch import Tensor @@ -266,12 +266,12 @@ def varlen_pytorch_lstm_creator(**kwargs): def varlen_lstm_factory(cell, script): def dynamic_rnn( sequences: List[Tensor], - hiddens: Tuple[Tensor, Tensor], + hiddens: tuple[Tensor, Tensor], wih: Tensor, whh: Tensor, bih: Tensor, bhh: Tensor, - ) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]]: + ) -> tuple[List[Tensor], tuple[List[Tensor], List[Tensor]]]: hx, cx = hiddens hxs = hx.unbind(1) cxs = cx.unbind(1) @@ -406,12 +406,12 @@ def lstm_inputs( def lstm_factory(cell, script): def dynamic_rnn( input: Tensor, - hidden: Tuple[Tensor, Tensor], + hidden: tuple[Tensor, Tensor], wih: Tensor, whh: Tensor, bih: Tensor, bhh: Tensor, - ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: hx, cx = hidden outputs = [] inputs = input.unbind(0) @@ -432,12 +432,12 @@ def lstm_factory(cell, script): def lstm_factory_premul(premul_cell, script): def dynamic_rnn( input: Tensor, - hidden: Tuple[Tensor, Tensor], + hidden: tuple[Tensor, Tensor], wih: Tensor, whh: Tensor, bih: Tensor, bhh: Tensor, - ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: hx, cx = hidden outputs = [] inputs = torch.matmul(input, wih.t()).unbind(0) @@ -458,12 +458,12 @@ def lstm_factory_premul(premul_cell, script): def lstm_factory_premul_bias(premul_cell, script): def dynamic_rnn( input: Tensor, - hidden: Tuple[Tensor, Tensor], + hidden: tuple[Tensor, Tensor], wih: Tensor, whh: Tensor, bih: Tensor, bhh: Tensor, - ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: hx, cx = hidden outputs = [] inpSize = input.size() @@ -506,8 +506,8 @@ def lstm_factory_simple(cell, script): def lstm_factory_multilayer(cell, script): def dynamic_rnn( - input: Tensor, hidden: Tuple[Tensor, Tensor], params: List[Tensor] - ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + input: Tensor, hidden: tuple[Tensor, Tensor], params: List[Tensor] + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: params_stride = 4 # NB: this assumes that biases are there hx, cx = hidden hy, cy = hidden # for scoping... diff --git a/benchmarks/functional_autograd_benchmark/torchaudio_models.py b/benchmarks/functional_autograd_benchmark/torchaudio_models.py index 257d62ce58c..184762a198d 100644 --- a/benchmarks/functional_autograd_benchmark/torchaudio_models.py +++ b/benchmarks/functional_autograd_benchmark/torchaudio_models.py @@ -3,7 +3,7 @@ import math from collections import OrderedDict -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn.functional as F @@ -512,7 +512,7 @@ class MultiheadAttentionContainer(torch.nn.Module): attn_mask: Optional[torch.Tensor] = None, bias_k: Optional[torch.Tensor] = None, bias_v: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: r""" Args: query, key, value (Tensor): map a query and a set of key-value pairs to an output. @@ -589,7 +589,7 @@ class ScaledDotProduct(torch.nn.Module): attn_mask: Optional[torch.Tensor] = None, bias_k: Optional[torch.Tensor] = None, bias_v: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: r"""Uses a scaled dot product with the projected key-value pair to update the projected query. Args: @@ -686,7 +686,7 @@ class InProjContainer(torch.nn.Module): def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r"""Projects the input sequences using in-proj layers. Args: query, key, value (Tensors): sequence to be projected diff --git a/benchmarks/functional_autograd_benchmark/utils.py b/benchmarks/functional_autograd_benchmark/utils.py index e19570ffe3c..1aa6e696702 100644 --- a/benchmarks/functional_autograd_benchmark/utils.py +++ b/benchmarks/functional_autograd_benchmark/utils.py @@ -1,20 +1,20 @@ from collections import defaultdict -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Union import torch from torch import nn, Tensor # Type helpers -InputsType = Union[Tensor, Tuple[Tensor, ...]] +InputsType = Union[Tensor, tuple[Tensor, ...]] # A Getter takes in a device and returns a callable and the inputs to that callable -GetterReturnType = Tuple[Callable[..., Tensor], InputsType] +GetterReturnType = tuple[Callable[..., Tensor], InputsType] GetterType = Callable[[torch.device], GetterReturnType] # V here refers to the v in either vjp, jvp, vhp or hvp -VType = Union[None, Tensor, Tuple[Tensor, ...]] +VType = Union[None, Tensor, tuple[Tensor, ...]] # Type used to store timing results. The first key is the model name, the second key # is the task name, the result is a Tuple of: speedup, mean_before, var_before, mean_after, var_after. -TimingResultType = Dict[str, Dict[str, Tuple[float, ...]]] +TimingResultType = Dict[str, Dict[str, tuple[float, ...]]] # Utilities to make nn.Module "functional" @@ -44,7 +44,7 @@ def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None: _set_nested_attr(getattr(obj, names[0]), names[1:], value) -def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]: +def extract_weights(mod: nn.Module) -> tuple[tuple[Tensor, ...], List[str]]: """ This function removes all the Parameters from the model and return them as a tuple as well as their original attribute names. @@ -65,7 +65,7 @@ def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]: return params, names -def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...]) -> None: +def load_weights(mod: nn.Module, names: List[str], params: tuple[Tensor, ...]) -> None: """ Reload a set of weights so that `mod` can be used again to perform a forward pass. Note that the `params` are regular Tensors (that can have history) and so are left @@ -77,7 +77,7 @@ def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...]) - # Utilities to read/write markdown table-like content. def to_markdown_table( - res: TimingResultType, header: Optional[Tuple[str, ...]] = None + res: TimingResultType, header: Optional[tuple[str, ...]] = None ) -> str: if header is None: header = ("model", "task", "mean", "var") diff --git a/benchmarks/gpt_fast/generate.py b/benchmarks/gpt_fast/generate.py index 01f98609129..8b4e4a550b9 100644 --- a/benchmarks/gpt_fast/generate.py +++ b/benchmarks/gpt_fast/generate.py @@ -2,7 +2,7 @@ import dataclasses import itertools import platform import time -from typing import Optional, Tuple +from typing import Optional import torchao from common import Experiment, register_experiment @@ -89,7 +89,7 @@ def prefill( def decode_one_token( model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: # input_pos: [B, 1] assert input_pos.shape[-1] == 1 logits = model(x, input_pos) diff --git a/benchmarks/instruction_counts/execution/runner.py b/benchmarks/instruction_counts/execution/runner.py index a8660805903..43994b9d9fe 100644 --- a/benchmarks/instruction_counts/execution/runner.py +++ b/benchmarks/instruction_counts/execution/runner.py @@ -8,7 +8,7 @@ import subprocess import textwrap import threading import time -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Dict, List, Optional, Set, Union from worker.main import WorkerFailure, WorkerOutput @@ -55,7 +55,7 @@ class CorePool: True for _ in range(min_core_id, min_core_id + self._num_cores) ] - self._reservations: Dict[str, Tuple[int, ...]] = {} + self._reservations: Dict[str, tuple[int, ...]] = {} self._lock = threading.Lock() def reserve(self, n: int) -> Optional[str]: @@ -87,11 +87,11 @@ class CorePool: class Runner: def __init__( self, - work_items: Tuple[WorkOrder, ...], + work_items: tuple[WorkOrder, ...], core_pool: Optional[CorePool] = None, cadence: float = 1.0, ) -> None: - self._work_items: Tuple[WorkOrder, ...] = work_items + self._work_items: tuple[WorkOrder, ...] = work_items self._core_pool: CorePool = core_pool or CorePool(0, CPU_COUNT - 4) self._cadence: float = cadence diff --git a/benchmarks/instruction_counts/worker/main.py b/benchmarks/instruction_counts/worker/main.py index b8c277eb6dc..73cbe029878 100644 --- a/benchmarks/instruction_counts/worker/main.py +++ b/benchmarks/instruction_counts/worker/main.py @@ -24,7 +24,7 @@ import pickle import sys import timeit import traceback -from typing import Any, Tuple, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING, Union if TYPE_CHECKING: @@ -81,8 +81,8 @@ class WorkerTimerArgs: @dataclasses.dataclass(frozen=True) class WorkerOutput: # Only return values to reduce communication between main process and workers. - wall_times: Tuple[float, ...] - instructions: Tuple[int, ...] + wall_times: tuple[float, ...] + instructions: tuple[int, ...] @dataclasses.dataclass(frozen=True) @@ -145,7 +145,7 @@ def _run(timer_args: WorkerTimerArgs) -> WorkerOutput: m = timer.blocked_autorange(min_run_time=MIN_RUN_TIME) - stats: Tuple[CallgrindStats, ...] = timer.collect_callgrind( + stats: tuple[CallgrindStats, ...] = timer.collect_callgrind( number=CALLGRIND_NUMBER, collect_baseline=False, repeats=CALLGRIND_REPEATS, diff --git a/benchmarks/transformer/score_mod.py b/benchmarks/transformer/score_mod.py index e9ac9d1df79..fd845086cdc 100644 --- a/benchmarks/transformer/score_mod.py +++ b/benchmarks/transformer/score_mod.py @@ -6,7 +6,7 @@ from collections import defaultdict from contextlib import nullcontext from dataclasses import asdict, dataclass from functools import partial -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Union import numpy as np from tabulate import tabulate @@ -41,7 +41,7 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> @dataclass(frozen=True) class ExperimentConfig: - shape: Tuple[int] # [B, Hq, M, Hkv, N, D] + shape: tuple[int] # [B, Hq, M, Hkv, N, D] attn_type: str dtype: torch.dtype calculate_bwd_time: bool @@ -149,7 +149,7 @@ def generate_inputs( def generate_jagged_inputs( - shape: Tuple[int], + shape: tuple[int], query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -611,7 +611,7 @@ softcap_value = 50 dropout_p = 0.0 -def generate_score_mod(attn_type: str, shape: Tuple[int]) -> Callable | None: +def generate_score_mod(attn_type: str, shape: tuple[int]) -> Callable | None: B, Hq, M, Hkv, N, D = shape is_decoding = M == 1 from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap @@ -653,7 +653,7 @@ sliding_window_size = 512 prefix_length = 512 -def generate_block_mask(attn_type: str, shape: Tuple[int]): +def generate_block_mask(attn_type: str, shape: tuple[int]): B, Hq, M, Hkv, N, D = shape is_decoding = M == 1 @@ -728,7 +728,7 @@ def generate_block_mask(attn_type: str, shape: Tuple[int]): return block_mask, mask_mod_kwargs -def get_kernel_options(attn_type: str, shape: Tuple[int]): +def get_kernel_options(attn_type: str, shape: tuple[int]): B, Hq, M, Hkv, N, D = shape is_decoding = M == 1 kernel_opt_training_dict = { @@ -815,7 +815,7 @@ def get_backend_context(backend: str): def generate_FA_callable( - attn_type: str, shape: Tuple[int], dtype: torch.dtype, backend: str, **kwargs + attn_type: str, shape: tuple[int], dtype: torch.dtype, backend: str, **kwargs ) -> Callable | None: if dtype not in [torch.float16, torch.bfloat16]: return None @@ -882,7 +882,7 @@ def generate_FA_callable( def generate_FD_callable( - attn_type: str, shape: Tuple[int], dtype: torch.dtype + attn_type: str, shape: tuple[int], dtype: torch.dtype ) -> Callable | None: if dtype not in [torch.float16, torch.bfloat16]: return None @@ -929,7 +929,7 @@ def generate_FD_callable( def generate_attn_mask_linear_score_mod( - shape: Tuple[int], block_mask: BlockMask, score_mod: Callable, dtype: torch.dtype + shape: tuple[int], block_mask: BlockMask, score_mod: Callable, dtype: torch.dtype ): B, Hq, M, N = shape if block_mask is None and score_mod is None: @@ -954,7 +954,7 @@ def generate_attn_mask_linear_score_mod( def generate_eager_sdpa( attn_type: str, - shape: Tuple[int], + shape: tuple[int], dtype: torch.dtype, block_mask: BlockMask, score_mod: Callable | None = None, @@ -1025,7 +1025,7 @@ def generate_experiment_configs( calculate_bwd: bool, dtype: torch.dtype, batch_sizes: List[int], - num_heads: List[Tuple[int, int]], + num_heads: List[tuple[int, int]], seq_lens: List[int], head_dims: List[int], score_mods_str: List[str], diff --git a/benchmarks/transformer/sdpa.py b/benchmarks/transformer/sdpa.py index d45970213e0..83fd2a2925e 100644 --- a/benchmarks/transformer/sdpa.py +++ b/benchmarks/transformer/sdpa.py @@ -2,7 +2,7 @@ import itertools from collections import defaultdict from contextlib import nullcontext from dataclasses import asdict, dataclass -from typing import Callable, List, Tuple +from typing import Callable, List from tabulate import tabulate from tqdm import tqdm @@ -68,7 +68,7 @@ class Experiment: def get_input( config: ExperimentConfig, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: q = torch.randn( (config.batch_size, config.num_heads, config.q_seq_len, config.head_dim), dtype=config.dtype,