[BE][Easy][16/19] enforce style for empty lines in import segments in torch/_i*/ (#129768)

See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129768
Approved by: https://github.com/jansel
This commit is contained in:
Xuehai Pan 2024-07-20 18:35:26 +08:00 committed by PyTorch MergeBot
parent 8e478d4fb1
commit b6d477fd56
111 changed files with 134 additions and 92 deletions

View file

@ -58,7 +58,6 @@ ISORT_SKIPLIST = re.compile(
# torch/_[e-h]*/**
"torch/_[e-h]*/**",
# torch/_i*/**
"torch/_i*/**",
# torch/_[j-z]*/**
"torch/_[j-z]*/**",
# torch/[a-c]*/**

View file

@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch.fx
import torch.utils._pytree as pytree
__all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"]

View file

@ -7,6 +7,7 @@ from unittest import mock
import torch
import torch._export
from torch._inductor.utils import is_cpu_device
from .runtime.runtime_utils import cache_dir

View file

@ -31,15 +31,14 @@ from torch._inductor.compile_worker.subproc_pool import (
SubprocPool,
)
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
from torch._inductor.runtime.compile_tasks import (
_set_triton_ptxas_path,
_worker_compile_triton,
)
from torch.hub import _Faketqdm, tqdm
from torch.utils._triton import has_triton_package
if TYPE_CHECKING:
from torch._inductor.runtime.hints import HalideMeta

View file

@ -1,10 +1,8 @@
import json
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext,
AHMetadata,

View file

@ -1,7 +1,7 @@
import functools
from typing import Any, Callable, Dict, List, Tuple
Feedback = float
Choice = str
Value = Any

View file

@ -1,7 +1,6 @@
import importlib
import inspect
import pkgutil
from collections import defaultdict
from typing import Any, Dict, List, Optional

View file

@ -28,7 +28,6 @@ import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch import multiprocessing
from torch._dynamo.testing import rand_strided
from torch._inductor import ir
from torch._inductor.codecache import (
CppCodeCache,
@ -38,6 +37,7 @@ from torch._inductor.codecache import (
PyCodeCache,
)
if TYPE_CHECKING:
from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
@ -49,6 +49,7 @@ from . import config
from .runtime.runtime_utils import do_bench_cpu, do_bench_gpu
from .virtualized import V
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
EXIT_HANDLER_REGISTERED = False

View file

@ -8,6 +8,7 @@ from sympy import Expr
import torch
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
from .ir import InterpreterShim, LoopBody, LoopBodyBlock
from .utils import cache_on_self, dominated_nodes
from .virtualized import V

View file

@ -27,7 +27,7 @@ import threading
import warnings
from bisect import bisect_right
from copy import copy
from ctypes import c_void_p, cdll, CDLL
from ctypes import c_void_p, CDLL, cdll
from functools import partial
from pathlib import Path
from time import time, time_ns
@ -58,6 +58,7 @@ from torch._inductor.codegen.rocm.compile_command import (
rocm_compiler,
)
"""
codecache.py, cpp_builder.py and cpu_vec_isa.py import rule:
https://github.com/pytorch/pytorch/issues/124245#issuecomment-2197778902
@ -86,7 +87,6 @@ from torch._inductor.runtime.compile_tasks import (
)
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
from torch._inductor.utils import ALIGN_BYTES, clear_on_fresh_inductor_cache, is_linux
from torch._logging import trace_structured
from torch._subclasses.fake_tensor import (
extract_tensor_metadata,
@ -95,6 +95,7 @@ from torch._subclasses.fake_tensor import (
)
from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv
if TYPE_CHECKING:
from concurrent.futures import Future

View file

@ -2,9 +2,9 @@
import re
import torch
from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE
# It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like:
# "...
# from ..codecache import CudaKernelParamCache

View file

@ -1,5 +1,6 @@
import torch
# Provide aoti module launch hip/cuda drivers. This file is also used for unit testing purpose

View file

@ -21,8 +21,8 @@ from torch.utils import _pytree as pytree
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from ..._dynamo.utils import counters
from ..._dynamo.utils import counters
from .. import codecache, config, cpp_builder, cpu_vec_isa, ir, metrics
from ..codegen.wrapper import WrapperCodeGen
from ..optimize_indexing import range_expressable_in_32_bits
@ -46,7 +46,6 @@ from ..utils import (
sympy_product,
sympy_subs,
)
from ..virtualized import NullKernelHandler, ops, OpsValue, V
from .common import (
BackendFeature,
@ -63,7 +62,6 @@ from .common import (
OpOverrides,
OptimizationContext,
)
from .cpp_utils import (
cexpr,
cexpr_index,
@ -74,6 +72,7 @@ from .cpp_utils import (
value_to_cpp,
)
_IS_WINDOWS = sys.platform == "win32"
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")

View file

@ -6,19 +6,19 @@ from unittest.mock import patch
import torch
import torch.utils
from ..._dynamo.utils import counters
from .. import ir, lowering as L
from ..kernel.mm_common import mm_args
from ..select_algorithm import DataProcessorTemplateWrapper
from ..utils import cache_on_self, has_free_symbols, parallel_num_threads
from ..virtualized import ops, V
from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm, LayoutType
from .cpp_template import CppTemplate
from .cpp_template_kernel import CppTemplateKernel
from .cpp_utils import GemmBlocking, get_gemm_template_output_and_compute_dtype
GEMM_TEMPLATE = r"""
{{template.header().getvalue()}}

View file

@ -3,7 +3,6 @@ import ctypes
import functools
import itertools
import logging
import sys
from typing import Callable, List, Optional
from unittest.mock import patch
@ -17,6 +16,7 @@ from ..virtualized import V
from .common import KernelTemplate
from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel
log = logging.getLogger(__name__)

View file

@ -7,8 +7,8 @@ from sympy.parsing.sympy_parser import parse_expr
import torch
from torch.utils._sympy.symbol import SymT
from .. import config, cpp_builder, ir, lowering as L
from .. import config, cpp_builder, ir, lowering as L
from ..autotune_process import CppBenchmarkRequest
from ..select_algorithm import PartialRender
from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix

View file

@ -2,7 +2,6 @@
import contextlib
import copy
import math
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch
@ -11,12 +10,13 @@ import sympy
import torch
from torch.utils._sympy.symbol import symbol_is_type, SymT
from .. import ir
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
from ..virtualized import V
from .common import CSEVariable, ExprPrinter, Kernel, KernelArgs
DTYPE_TO_CPP = {
torch.float32: "float",
torch.float64: "double",

View file

@ -10,10 +10,10 @@ import sympy
from sympy import Expr
import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
import torch._ops
from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes
from .. import config, ir
from ..utils import _align, ALIGN_BYTES, cache_on_self, sympy_product
from ..virtualized import V

View file

@ -19,6 +19,7 @@ from .cpp_utils import DTYPE_TO_CPP
from .cpp_wrapper_cpu import CppWrapperCpu
from .wrapper import SymbolicCallArg
if TYPE_CHECKING:
from ..graph import GraphLowering

View file

@ -3,16 +3,15 @@ import logging
from typing import cast, Sequence
from ...._dynamo.utils import counters
from ... import config
from ...codecache import code_hash, get_path
from ...ir import CUDATemplateBuffer
from ...scheduler import BaseSchedulerNode, BaseScheduling, Scheduler, SchedulerNode
from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product
from ...virtualized import V
from ..common import IndentedBuffer
log = logging.getLogger(__name__)

View file

@ -6,6 +6,7 @@ import torch
from ... import config
log = logging.getLogger(__name__)

View file

@ -15,9 +15,9 @@ from ...ir import (
from ...utils import sympy_product
from ...virtualized import V
from ..common import IndentedBuffer, Kernel, OpOverrides
from ..cpp_utils import CppPrinter, DTYPE_TO_CPP
if TYPE_CHECKING:
from torch._inductor.codegen.cuda.cuda_template import CUDATemplate

View file

@ -8,14 +8,15 @@ from unittest.mock import patch
import sympy
import torch
from ...autotune_process import CUDABenchmarkRequest, TensorMeta
from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout
from ...utils import IndentedBuffer, unique
from ...virtualized import V
from ..common import KernelTemplate
from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel
log = logging.getLogger(__name__)

View file

@ -1,11 +1,12 @@
# mypy: allow-untyped-defs
from ..cutlass_utils import try_import_cutlass
if try_import_cutlass():
import enum
from cutlass_library.library import * # noqa: F401, F403
from cutlass_library.gemm_operation import * # noqa: F401, F403
from cutlass_library.library import * # noqa: F401, F403
# copied / modified from original at
# https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658

View file

@ -3,7 +3,6 @@ import functools
import logging
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, List, Optional
@ -11,12 +10,13 @@ from typing import Any, List, Optional
import sympy
import torch
from ... import config
from ...ir import Layout
from ...runtime.runtime_utils import cache_dir
from .cuda_env import get_cuda_arch, get_cuda_version
log = logging.getLogger(__name__)

View file

@ -17,11 +17,11 @@ from ...ir import (
ReinterpretView,
)
from ..common import IndentedBuffer
from . import cutlass_utils
from .cuda_kernel import CUDATemplateKernel
from .cuda_template import CUTLASSTemplate
log = logging.getLogger(__name__)
# Jinja template for GEMM Kernel, used by the CUTLASSGemmTemplate class below.

View file

@ -10,7 +10,6 @@ from ..scheduler import (
)
from .cuda.cuda_cpp_scheduling import CUDACPPScheduling
from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling
from .triton import TritonScheduling

View file

@ -25,6 +25,7 @@ import sympy
import torch
import torch._logging
from ..._prims_common import is_integer_dtype
from ...utils._sympy.functions import FloorDiv, ModularIndexing
from ...utils._sympy.symbol import symbol_is_type, SymT
@ -34,7 +35,6 @@ from ..codecache import HalideCodeCache
from ..ir import get_reduction_combine_fn
from ..metrics import is_metric_table_enabled, log_kernel_metadata
from ..ops_handler import AddParenHandler, MockHandler
from ..runtime.hints import HalideInputSpec, HalideMeta, ReductionHint
from ..utils import (
get_bounds_index_expr,
@ -58,6 +58,7 @@ from .cpp import DTYPE_TO_CPP
from .cpp_utils import cexpr
from .simd import constant_repr, SIMDKernel, SIMDScheduling
if TYPE_CHECKING:
from ..ops_handler import ReductionType, StoreMode

View file

@ -10,10 +10,10 @@ from typing import Any, Dict, Iterable, List, Optional, Protocol
import sympy
import torch
from .. import config, ir
from ..utils import _align, align, cache_on_self, CachedMethod, IndentedBuffer
from ..virtualized import V
from .wrapper import (
AllocateLine,
FreeIfNotReusedLine,

View file

@ -12,6 +12,7 @@ from ..utils import cache_on_self
from ..virtualized import V
from .common import TensorArg
log = logging.getLogger(__name__)

View file

@ -13,6 +13,7 @@ from torch._inductor.ir import Buffer, Layout
from ...utils import IndentedBuffer, try_import_ck_lib
_, gen_ops_library, gen_ops_preselected, CKGemmOperation = try_import_ck_lib()

View file

@ -6,6 +6,7 @@ from typing import List, Optional
from torch._inductor import config
from torch._inductor.utils import is_linux
log = logging.getLogger(__name__)

View file

@ -7,10 +7,10 @@ from ctypes import byref, c_size_t, c_void_p
from typing import Any, Callable, Iterable, List, Optional, Union
import torch
from torch._inductor.autotune_process import GPUDeviceBenchmarkRequest, TensorMeta
from torch._inductor.codecache import DLLWrapper, ROCmCodeCache
log = logging.getLogger(__name__)

View file

@ -8,9 +8,9 @@ from ...scheduler import BaseSchedulerNode, BaseScheduling, Scheduler, Scheduler
from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product
from ...virtualized import V
from ..common import IndentedBuffer
from .rocm_template_buffer import ROCmTemplateBuffer
log = logging.getLogger(__name__)

View file

@ -6,12 +6,11 @@ from ...ir import Buffer, ChoiceCaller, IRNode, Layout, PrimitiveInfoType, Tenso
from ...utils import sympy_product
from ...virtualized import V
from ..common import IndentedBuffer, Kernel, OpOverrides
from ..cpp_utils import CppPrinter
from .rocm_benchmark_request import ROCmBenchmarkRequest
from .rocm_template_buffer import ROCmTemplateBuffer
if TYPE_CHECKING:
from torch._inductor.codegen.rocm.rocm_template import ROCmTemplate

View file

@ -9,7 +9,6 @@ import sympy
from ...autotune_process import TensorMeta
from ...ir import Buffer, IRNode, Layout
from ...utils import IndentedBuffer, unique
from ...virtualized import V
from ..common import KernelTemplate
@ -17,6 +16,7 @@ from .rocm_benchmark_request import ROCmBenchmarkRequest
from .rocm_kernel import ROCmTemplateCaller, ROCmTemplateKernel
from .rocm_template_buffer import ROCmTemplateBuffer
log = logging.getLogger(__name__)

View file

@ -28,13 +28,12 @@ import sympy
import torch
import torch._logging
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from ..._dynamo.utils import counters
from .. import config, ir, scheduler
from ..codecache import code_hash
from ..dependencies import Dep, MemoryDep, StarDep, WeakDep
from ..ir import TritonTemplateBuffer
from ..optimize_indexing import indexing_dtype_strength_reduction

View file

@ -27,14 +27,13 @@ import sympy
import torch
import torch._logging
from torch._dynamo.utils import preserve_rng_state
from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
from torch._prims_common import is_integer_dtype
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
from torch.utils._triton import has_triton_package
from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT
from ...utils._sympy.value_ranges import ValueRanges
from .. import config, ir
from ..codecache import code_hash, get_path, PyCodeCache
from ..metrics import is_metric_table_enabled, log_kernel_metadata
@ -73,6 +72,7 @@ from .simd import (
)
from .triton_utils import config_of, signature_of, signature_to_meta
if TYPE_CHECKING:
from ..ir import IRNode

View file

@ -1,16 +1,12 @@
# mypy: allow-untyped-defs
import functools
from typing import Optional, Set
import torch._inductor.runtime.hints
from torch._inductor import config
from torch._inductor.codegen.simd import IterationRangesRoot
from torch._inductor.codegen.triton import triton_compute_type, TritonKernel
from torch._prims_common import prod
from torch.utils._sympy.functions import CeilDiv

View file

@ -10,7 +10,6 @@ import inspect
import logging
import operator
import re
import tempfile
from itertools import count
from typing import (
@ -41,7 +40,6 @@ from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._sympy.symbol import symbol_is_type, SymT
from .. import async_compile, config, ir
from ..codecache import output_code_log
from ..ir import ReinterpretView
from ..runtime import triton_heuristics
@ -58,6 +56,7 @@ from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .common import CodeGen, DeferredLine, IndentedBuffer, PythonPrinter
from .triton_utils import config_of, signature_to_meta
if TYPE_CHECKING:
import triton

View file

@ -5,8 +5,8 @@ from enum import IntEnum
import sympy
import torch
from . import ir
from . import ir
from .utils import get_dtype_size, sympy_product
from .virtualized import V

View file

@ -4,7 +4,6 @@ from __future__ import annotations
import heapq
import operator
import sys
from collections import defaultdict
from typing import Dict, List, Set, TYPE_CHECKING
@ -15,6 +14,7 @@ from . import config, ir
from .dependencies import WeakDep
from .utils import is_collective, is_wait
overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")
if TYPE_CHECKING:

View file

@ -8,15 +8,12 @@ import sys
import time
import warnings
from itertools import count
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from unittest import mock
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
import torch.fx
import torch.utils._pytree as pytree
from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo import (
compiled_autograd,
@ -77,6 +74,7 @@ from .utils import (
)
from .virtualized import V
if config.is_fbcode():
from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log
else:

View file

@ -9,6 +9,7 @@ from torch._inductor.compile_worker.subproc_pool import SubprocMain
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path
log = logging.getLogger(__name__)
_set_triton_ptxas_path()

View file

@ -18,6 +18,7 @@ from typing import Any, Callable, Dict
from torch._inductor import config
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
log = logging.getLogger(__name__)

View file

@ -1078,5 +1078,6 @@ if TYPE_CHECKING:
from torch.utils._config_module import install_config_module
# adds patch, save_config, etc
install_config_module(sys.modules[__name__])

View file

@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Optional, Tuple
import torch
import torch.utils._pytree as pytree
aten = torch.ops.aten
# We would like to split modules into two subgraphs for runtime weight updates to work correctly.

View file

@ -24,6 +24,7 @@ from torch._inductor import config, exc
from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA
from torch._inductor.runtime.runtime_utils import cache_dir
if config.is_fbcode():
from triton.fb import build_paths # noqa: F401

View file

@ -3,7 +3,6 @@ import dataclasses
import functools
import os
import platform
import re
import subprocess
import sys
@ -12,6 +11,7 @@ from typing import Any, Callable, Dict, List
import torch
from torch._inductor import config
_IS_WINDOWS = sys.platform == "win32"

View file

@ -49,7 +49,6 @@ import traceback
import warnings
import weakref
from collections import defaultdict
from enum import auto, Enum
from typing import (
Any,
@ -91,6 +90,7 @@ from torch.storage import UntypedStorage
from torch.utils import _pytree as pytree
from torch.utils.weak import TensorWeakRef
if TYPE_CHECKING:
from torch.types import _bool

View file

@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch._dynamo.utils import counters
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")

View file

@ -15,10 +15,8 @@ from typing import Any, Dict, List, Optional
from unittest.mock import patch
import torch
from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled
from torch import fx as fx
from torch._dynamo.repro.after_aot import save_graph_repro, wrap_compiler_debug
from torch._dynamo.utils import get_debug_dir
from torch.fx.graph_module import GraphModule
@ -36,6 +34,7 @@ from .scheduler import (
)
from .virtualized import V
log = logging.getLogger(__name__)
SchedulerNodeList = List[Any]

View file

@ -37,6 +37,7 @@ from .utils import (
use_scatter_fallback,
)
log = logging.getLogger(__name__)
aten = torch.ops.aten
prims = torch.ops.prims

View file

@ -25,6 +25,7 @@ from .utils import (
)
from .virtualized import OpsHandler, ReductionType, V
log = logging.getLogger(__name__)
is_indirect = re.compile(r"indirect|tmp").search

View file

@ -6,6 +6,7 @@ import tempfile
import textwrap
from functools import lru_cache
if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1":
@lru_cache(None)

View file

@ -3,7 +3,6 @@ from __future__ import annotations
import itertools
import logging
import weakref
from typing import Any, List, Optional, Tuple
@ -13,12 +12,12 @@ from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code
from torch._functorch.aot_autograd import MutationType
from torch._functorch.compile_utils import fx_graph_cse
from torch._inductor.constant_folding import constant_fold, replace_node_with_constant
from torch._inductor.fx_passes.freezing_patterns import freezing_passes
from torch._inductor.fx_passes.post_grad import view_to_reshape
from . import config
aten = torch.ops.aten
prims = torch.ops.prims

View file

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import torch
from ..._dynamo.utils import counters
from ..ir import FixedLayout
from ..pattern_matcher import (

View file

@ -3,11 +3,12 @@ import functools
import itertools
import torch
from ..._dynamo.utils import counters
from ..._dynamo.utils import counters
from ..pattern_matcher import Arg, CallFunction, KeywordArg
from .freezing_patterns import register_binary_folding_pattern
aten = torch.ops.aten
prims = torch.ops.prims

View file

@ -30,6 +30,7 @@ from .. import config
from ..fx_utils import get_fake_args_kwargs
from ..virtualized import V
aten = torch.ops.aten
logger: logging.Logger = logging.getLogger("comm_fusion")

View file

@ -7,10 +7,10 @@ from torch import Tensor
from torch._dynamo.utils import counters
from .. import config
from ..pattern_matcher import Arg, CallFunction, Match, register_graph_pattern
from .split_cat import construct_pattern_matcher_pass
aten = torch.ops.aten
log = logging.getLogger(__name__)

View file

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
import torch
import torch.nn as nn
from torch._dynamo.utils import counters
from torch._inductor import config as inductor_config
from torch.func import functional_call
@ -12,7 +11,6 @@ from ..pattern_matcher import (
Match,
register_graph_pattern,
)
from .pre_grad import efficient_conv_bn_eval_pass

View file

@ -3,8 +3,8 @@ import functools
import torch
from torch._inductor.compile_fx import fake_tensor_prop
from ..._dynamo.utils import counters
from ..._dynamo.utils import counters
from .. import config
from ..pattern_matcher import (
_return_true,
@ -20,6 +20,7 @@ from ..pattern_matcher import (
stable_topological_sort,
)
aten = torch.ops.aten
# First pass_patterns[0] are applied, then [1], then [2]

View file

@ -6,6 +6,7 @@ import math
import torch
from torch.nn.attention import sdpa_kernel, SDPBackend
from ..._dynamo.utils import counters
from ..pattern_matcher import (
filter_nodes,
@ -14,6 +15,7 @@ from ..pattern_matcher import (
joint_fwd_bwd,
)
log = logging.getLogger(__name__)
aten = torch.ops.aten

View file

@ -28,6 +28,7 @@ from ..pattern_matcher import (
stable_topological_sort,
)
try:
# importing this will register fbgemm lowerings for inductor
import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings # noqa: F401

View file

@ -27,6 +27,7 @@ from ..pattern_matcher import (
)
from .replace_random import replace_random_passes
log = logging.getLogger(__name__)
patterns = PatternMatcherPass()
aten = torch.ops.aten

View file

@ -4,8 +4,8 @@ from dataclasses import dataclass
from typing import cast, List, Set, Tuple, Union
import torch
from .. import inductor_prims
from .. import inductor_prims
from ..pattern_matcher import (
CallFunction,
Ignored,
@ -16,6 +16,7 @@ from ..pattern_matcher import (
register_graph_pattern,
)
aten = torch.ops.aten
patterns = PatternMatcherPass()

View file

@ -1,14 +1,14 @@
# mypy: allow-untyped-defs
import functools
from typing import Dict, Set, Tuple
import torch
from torch._dynamo.utils import counters
from torch._ops import OpOverload, OpOverloadPacket
from ..pattern_matcher import fwd_only, register_replacement
aten = torch.ops.aten

View file

@ -5,11 +5,9 @@ from functools import reduce
from typing import Any, Tuple
import torch
from torch.fx.experimental.symbolic_shapes import has_free_symbols
from .. import ir
from ..lowering import lowerings as L
from ..pattern_matcher import (
Arg,
@ -28,6 +26,7 @@ from .quantization import (
_register_woq_lowerings,
)
if torch._C._has_mkldnn:
aten = torch.ops.aten
mkldnn = torch.ops.mkldnn

View file

@ -12,6 +12,7 @@ import torch.optim as optim
from .. import config
logger: logging.Logger = logging.getLogger(__name__)
MAIN_RANDOM_SEED = 1337

View file

@ -23,7 +23,6 @@ from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._mode_utils import no_dispatch
from ...utils._triton import has_triton
from ..pattern_matcher import (
fwd_only,
gen_register_replacement,
@ -33,6 +32,7 @@ from ..pattern_matcher import (
SearchFn,
)
aten = torch.ops.aten

View file

@ -14,9 +14,7 @@ from torch._decomp import register_decomposition
from torch._dynamo.utils import counters, optimus_scuba_log
from torch._inductor import comms
from torch._inductor.virtualized import ops
from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
from torch._utils_internal import upload_graph
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
@ -24,7 +22,6 @@ from torch.fx.passes.graph_transform_observer import GraphTransformObserver
from .. import config, ir, pattern_matcher
from ..codegen.common import BackendFeature, has_backend_feature
from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage
from ..lowering import lowerings as L
from ..pattern_matcher import (
_return_true,
@ -54,6 +51,7 @@ from .pre_grad import is_same_dict, save_inductor_dict
from .reinplace import reinplace_inplaceable_ops
from .split_cat import POST_GRAD_PATTERNS
if TYPE_CHECKING:
from sympy import Expr

View file

@ -18,7 +18,6 @@ from torch.nn import functional as F
from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights
from .. import config
from ..fx_utils import matches_module_function_pattern
from ..pattern_matcher import (
init_once_fakemode,
@ -30,6 +29,7 @@ from .group_batch_fusion import group_batch_fusion_passes, PRE_GRAD_FUSIONS
from .misc_patterns import numpy_compat_normalization
from .split_cat import PRE_GRAD_PATTERNS
log = logging.getLogger(__name__)
efficient_conv_bn_eval_pass = PatternMatcherPass(

View file

@ -10,12 +10,14 @@ import torch
from torch._dynamo.utils import counters
from torch.fx.experimental.symbolic_shapes import has_free_symbols
from torch.fx.node import map_arg
from ..lowering import lowerings as L, require_channels_last
from ..pattern_matcher import Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match
from ..utils import pad_listlike
from .freezing_patterns import register_freezing_graph_pattern
from .post_grad import register_lowering_pattern
aten = torch.ops.aten
prims = torch.ops.prims
quantized_decomposed = torch.ops.quantized_decomposed

View file

@ -17,6 +17,7 @@ from torch.fx.immutable_collections import immutable_dict
from torch.fx.passes.reinplace import _is_view_op
from torch.utils import _pytree as pytree
aten = torch.ops.aten

View file

@ -5,6 +5,7 @@ import logging
import torch
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from .. import config, inductor_prims
from ..pattern_matcher import (
CallFunctionVarArgs,
@ -14,6 +15,7 @@ from ..pattern_matcher import (
)
from ..virtualized import V
log = logging.getLogger(__name__)
patterns = PatternMatcherPass()
aten = torch.ops.aten

View file

@ -28,6 +28,7 @@ from ..pattern_matcher import (
)
from .group_batch_fusion import is_node_meta_valid, POST_GRAD_FUSIONS, PRE_GRAD_FUSIONS
log = logging.getLogger(__name__)
_Arguments: TypeAlias = Tuple[torch.fx.node.Argument, ...]

View file

@ -15,6 +15,7 @@ from torch.fx.experimental.symbolic_shapes import (
)
from torch.utils import _pytree as pytree
from torch.utils._pytree import tree_map
from .virtualized import V

View file

@ -93,12 +93,14 @@ from .utils import (
)
from .virtualized import NullHandler, V
if TYPE_CHECKING:
from torch._higher_order_ops.effects import _EffectType
from .codegen.wrapper import WrapperCodeGen
from torch._inductor.codecache import output_code_log
log = logging.getLogger(__name__)
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")

View file

@ -2,6 +2,7 @@
import contextlib
from typing import Callable, List, TYPE_CHECKING
if TYPE_CHECKING:
import torch

View file

@ -31,8 +31,8 @@ import torch
from torch._prims_common import dtype_to_type, is_integer_dtype
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from .utils import generate_assert
from .utils import generate_assert
from .virtualized import V

View file

@ -7,6 +7,7 @@ from typing import Optional, Sequence
import torch
from torch import _prims, Tensor
log = logging.getLogger(__name__)

View file

@ -31,9 +31,7 @@ import sympy
from sympy import Expr, Integer
import torch._export.serde.schema as export_schema
import torch._logging
import torch.fx
import torch.utils._pytree as pytree
from torch._dynamo.device_interface import get_interface_for_device
@ -90,6 +88,7 @@ from .utils import (
)
from .virtualized import ops, V
if TYPE_CHECKING:
from .graph import GraphLowering

View file

@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union
import sympy
import torch
from .ir import Pointwise, TensorBox
from .lowering import fallback_handler, is_integer_type, register_lowering
from .virtualized import ops

View file

@ -16,11 +16,10 @@ from ..utils import (
use_triton_template,
)
from ..virtualized import V
from .mm import _is_static_problem
from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options
log = logging.getLogger(__name__)
aten = torch.ops.aten

View file

@ -8,7 +8,6 @@ from typing import cast, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypedDi
import torch
from .. import config, ir
from ..lowering import (
add_layout_constraint,
constrain_to_fx_strides,
@ -31,6 +30,7 @@ from ..utils import (
from ..virtualized import V
from .mm_common import filtered_configs
if TYPE_CHECKING:
from ..ir import TensorBox

View file

@ -9,6 +9,7 @@ import sympy
import torch
from torch._inductor.virtualized import V
from torch.utils._pytree import tree_map
from .. import config
from ..ir import (
ComputedBuffer,
@ -23,6 +24,7 @@ from ..ir import (
from ..lowering import empty, empty_strided, lowerings, register_lowering
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
log = logging.getLogger(__name__)
aten = torch.ops.aten

View file

@ -6,6 +6,7 @@ import sympy
import torch
from torch._inductor.virtualized import V
from ..ir import FixedLayout, FlexibleLayout
from ..lowering import empty_strided, lowerings
from ..runtime.runtime_utils import next_power_of_2

View file

@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional
import torch
from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate
from torch._inductor.virtualized import V
from .. import config as inductor_config
from ..codegen.common import BackendFeature
from ..codegen.cuda.gemm_template import CUTLASSGemmTemplate
@ -39,6 +40,7 @@ from .mm_common import (
triton_config,
)
log = logging.getLogger(__name__)
aten = torch.ops.aten

View file

@ -14,6 +14,7 @@ from .. import config as inductor_config
from ..runtime.runtime_utils import next_power_of_2
from ..utils import ceildiv as cdiv
log = logging.getLogger(__name__)

View file

@ -13,6 +13,7 @@ from ..utils import use_aten_gemm_kernels, use_triton_template
from ..virtualized import V
from .mm_common import mm_args, mm_grid, mm_options
aten = torch.ops.aten
aten_mm_plus_mm = ExternKernelChoice(

View file

@ -5,6 +5,7 @@ from typing import List, TYPE_CHECKING
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
from .mm_common import mm_args, mm_configs, mm_grid, mm_options
if TYPE_CHECKING:
from ..ir import ChoiceCaller

View file

@ -42,8 +42,8 @@ from torch.utils._sympy.functions import (
IntTrueDiv,
ModularIndexing,
)
from .._dynamo.utils import import_submodule
from .._dynamo.utils import import_submodule
from . import config, inductor_prims, ir, test_operators # NOQA: F401
from .decomposition import decompositions, get_decompositions
from .ir import (
@ -72,6 +72,7 @@ from .utils import (
)
from .virtualized import ops, V
log = logging.getLogger(__name__)
lowerings: Dict[torch._ops.OpOverload, Callable[..., Any]] = {}
layout_constraints: Dict[torch._ops.OpOverload, Callable[..., Any]] = {}
@ -6108,6 +6109,7 @@ def resize(x, size, *, memory_format=None):
from torch._higher_order_ops.auto_functionalize import auto_functionalized
make_fallback(auto_functionalized)
@ -6395,17 +6397,21 @@ except (AttributeError, ImportError):
# populate lowerings defined in kernel/*
from . import kernel
import_submodule(kernel)
from . import quantized_lowerings
quantized_lowerings.register_quantized_ops()
quantized_lowerings.register_woq_mm_ops()
from . import mkldnn_lowerings
mkldnn_lowerings.register_onednn_fusion_ops()
from . import jagged_lowerings
jagged_lowerings.register_jagged_ops()

View file

@ -8,12 +8,12 @@ import os
import re
from dataclasses import dataclass
from functools import lru_cache
from typing import Dict, List, Set, Tuple, TYPE_CHECKING, Union
from torch._inductor import config
from torch._inductor.utils import get_benchmark_name
# Prevent circular import
if TYPE_CHECKING:
from torch._inductor.scheduler import (

View file

@ -4,7 +4,6 @@ from typing import Any, List, Optional, Set
import sympy
import torch
from torch._prims_common import make_channels_last_strides_for
from .ir import (
@ -22,9 +21,7 @@ from .ir import (
NoneLayout,
TensorBox,
)
from .utils import convert_shape_to_inductor, pad_listlike
from .virtualized import V

View file

@ -5,6 +5,7 @@ from typing import List, Optional
import torch
import torch.utils._pytree as pytree
from torch._inductor.kernel.mm_common import mm_args
from . import ir, mkldnn_ir
from .codegen.cpp_gemm_template import CppPackedGemmTemplate
from .ir import TensorBox

View file

@ -18,8 +18,10 @@ import sympy
import torch
import torch.utils._pytree as pytree
from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str
T = TypeVar("T")
StoreMode = Optional[Literal["atomic_add"]]
ReductionType = Literal[

View file

@ -5,6 +5,7 @@ import sympy
import torch
from torch.utils._sympy.value_ranges import ValueRanges
from .ir import LoopBody
from .utils import dominated_nodes

View file

@ -14,6 +14,7 @@ import torch.utils._pytree as pytree
from torch._inductor import config, exc
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
from torch.export._tree_utils import reorder_kwargs
from .build_package import build_package_contents
from .pt2_archive_constants import AOTINDUCTOR_DIR, ARCHIVE_VERSION

View file

@ -38,7 +38,6 @@ successful match or a `FailedMatch` object for a failure to match.
from __future__ import annotations
import contextlib
import dataclasses
import functools
import importlib
@ -96,6 +95,7 @@ from . import config
from .decomposition import select_decomp_table
from .lowering import fallback_node_due_to_unsupported_type
log = logging.getLogger(__name__)
aten = torch.ops.aten
prims = torch.ops.prims

View file

@ -1,7 +1,9 @@
# mypy: allow-untyped-defs
import torch
from . import lowering
quantized = torch.ops.quantized
_quantized = torch.ops._quantized
aten = torch.ops.aten

Some files were not shown because too many files have changed in this diff Show more