From b6d477fd5688f43aa0ea20d7cbcd60e86e8a6553 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 20 Jul 2024 18:35:26 +0800 Subject: [PATCH] [BE][Easy][16/19] enforce style for empty lines in import segments in `torch/_i*/` (#129768) See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129768 Approved by: https://github.com/jansel --- tools/linter/adapters/ufmt_linter.py | 1 - torch/_inductor/__init__.py | 1 + torch/_inductor/aoti_eager.py | 1 + torch/_inductor/async_compile.py | 3 +-- torch/_inductor/autoheuristic/autoheuristic.py | 2 -- torch/_inductor/autoheuristic/autoheuristic_utils.py | 2 +- .../autoheuristic/learned_heuristic_controller.py | 1 - torch/_inductor/autotune_process.py | 3 ++- torch/_inductor/bounds.py | 1 + torch/_inductor/codecache.py | 5 +++-- torch/_inductor/codegen/aoti_hipify_utils.py | 2 +- torch/_inductor/codegen/codegen_device_driver.py | 1 + torch/_inductor/codegen/cpp.py | 5 ++--- torch/_inductor/codegen/cpp_gemm_template.py | 4 ++-- torch/_inductor/codegen/cpp_template.py | 2 +- torch/_inductor/codegen/cpp_template_kernel.py | 2 +- torch/_inductor/codegen/cpp_utils.py | 4 ++-- torch/_inductor/codegen/cpp_wrapper_cpu.py | 2 +- torch/_inductor/codegen/cpp_wrapper_cuda.py | 1 + torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py | 3 +-- torch/_inductor/codegen/cuda/cuda_env.py | 1 + torch/_inductor/codegen/cuda/cuda_kernel.py | 2 +- torch/_inductor/codegen/cuda/cuda_template.py | 3 ++- .../cutlass_lib_extensions/gemm_operation_extensions.py | 3 ++- torch/_inductor/codegen/cuda/cutlass_utils.py | 4 ++-- torch/_inductor/codegen/cuda/gemm_template.py | 2 +- torch/_inductor/codegen/cuda_combined_scheduling.py | 1 - torch/_inductor/codegen/halide.py | 3 ++- torch/_inductor/codegen/memory_planning.py | 2 +- torch/_inductor/codegen/multi_kernel.py | 1 + .../_inductor/codegen/rocm/ck_universal_gemm_template.py | 1 + torch/_inductor/codegen/rocm/compile_command.py | 1 + torch/_inductor/codegen/rocm/rocm_benchmark_request.py | 2 +- torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py | 2 +- torch/_inductor/codegen/rocm/rocm_kernel.py | 3 +-- torch/_inductor/codegen/rocm/rocm_template.py | 2 +- torch/_inductor/codegen/simd.py | 3 +-- torch/_inductor/codegen/triton.py | 4 ++-- torch/_inductor/codegen/triton_split_scan.py | 4 ---- torch/_inductor/codegen/wrapper.py | 3 +-- torch/_inductor/comm_analysis.py | 2 +- torch/_inductor/comms.py | 2 +- torch/_inductor/compile_fx.py | 4 +--- torch/_inductor/compile_worker/__main__.py | 1 + torch/_inductor/compile_worker/subproc_pool.py | 1 + torch/_inductor/config.py | 1 + torch/_inductor/constant_folding.py | 1 + torch/_inductor/cpp_builder.py | 1 + torch/_inductor/cpu_vec_isa.py | 2 +- torch/_inductor/cudagraph_trees.py | 2 +- torch/_inductor/cudagraph_utils.py | 1 + torch/_inductor/debug.py | 3 +-- torch/_inductor/decomposition.py | 1 + torch/_inductor/dependencies.py | 1 + torch/_inductor/exc.py | 1 + torch/_inductor/freezing.py | 3 +-- torch/_inductor/fx_passes/b2b_gemm.py | 1 + torch/_inductor/fx_passes/binary_folding.py | 3 ++- torch/_inductor/fx_passes/ddp_fusion.py | 1 + torch/_inductor/fx_passes/decompose_mem_bound_mm.py | 2 +- torch/_inductor/fx_passes/efficient_conv_bn_eval.py | 2 -- torch/_inductor/fx_passes/freezing_patterns.py | 3 ++- torch/_inductor/fx_passes/fuse_attention.py | 2 ++ torch/_inductor/fx_passes/group_batch_fusion.py | 1 + torch/_inductor/fx_passes/joint_graph.py | 1 + torch/_inductor/fx_passes/micro_pipeline_tp.py | 3 ++- torch/_inductor/fx_passes/misc_patterns.py | 4 ++-- torch/_inductor/fx_passes/mkldnn_fusion.py | 3 +-- torch/_inductor/fx_passes/numeric_utils.py | 1 + torch/_inductor/fx_passes/pad_mm.py | 2 +- torch/_inductor/fx_passes/post_grad.py | 4 +--- torch/_inductor/fx_passes/pre_grad.py | 2 +- torch/_inductor/fx_passes/quantization.py | 2 ++ torch/_inductor/fx_passes/reinplace.py | 1 + torch/_inductor/fx_passes/replace_random.py | 2 ++ torch/_inductor/fx_passes/split_cat.py | 1 + torch/_inductor/fx_utils.py | 1 + torch/_inductor/graph.py | 2 ++ torch/_inductor/hooks.py | 1 + torch/_inductor/index_propagation.py | 2 +- torch/_inductor/inductor_prims.py | 1 + torch/_inductor/ir.py | 3 +-- torch/_inductor/jagged_lowerings.py | 1 + torch/_inductor/kernel/bmm.py | 3 +-- torch/_inductor/kernel/conv.py | 2 +- torch/_inductor/kernel/flex_attention.py | 2 ++ torch/_inductor/kernel/flex_decoding.py | 1 + torch/_inductor/kernel/mm.py | 2 ++ torch/_inductor/kernel/mm_common.py | 1 + torch/_inductor/kernel/mm_plus_mm.py | 1 + torch/_inductor/kernel/unpack_mixed_mm.py | 1 + torch/_inductor/lowering.py | 8 +++++++- torch/_inductor/metrics.py | 2 +- torch/_inductor/mkldnn_ir.py | 3 --- torch/_inductor/mkldnn_lowerings.py | 1 + torch/_inductor/ops_handler.py | 2 ++ torch/_inductor/optimize_indexing.py | 1 + torch/_inductor/package/package.py | 1 + torch/_inductor/pattern_matcher.py | 2 +- torch/_inductor/quantized_lowerings.py | 2 ++ torch/_inductor/runtime/coordinate_descent_tuner.py | 2 +- torch/_inductor/runtime/triton_heuristics.py | 2 +- torch/_inductor/scheduler.py | 1 + torch/_inductor/select_algorithm.py | 5 +---- torch/_inductor/sizevars.py | 1 + torch/_inductor/subgraph_lowering.py | 1 + torch/_inductor/test_case.py | 1 - torch/_inductor/test_operators.py | 1 + torch/_inductor/utils.py | 2 ++ torch/_inductor/virtualized.py | 1 + torch/_inductor/wrapper_benchmark.py | 2 ++ 111 files changed, 134 insertions(+), 92 deletions(-) diff --git a/tools/linter/adapters/ufmt_linter.py b/tools/linter/adapters/ufmt_linter.py index b9d7fcea0a6..d92921fd415 100644 --- a/tools/linter/adapters/ufmt_linter.py +++ b/tools/linter/adapters/ufmt_linter.py @@ -58,7 +58,6 @@ ISORT_SKIPLIST = re.compile( # torch/_[e-h]*/** "torch/_[e-h]*/**", # torch/_i*/** - "torch/_i*/**", # torch/_[j-z]*/** "torch/_[j-z]*/**", # torch/[a-c]*/** diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index ad37a31ef73..72333d1139e 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple import torch.fx import torch.utils._pytree as pytree + __all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"] diff --git a/torch/_inductor/aoti_eager.py b/torch/_inductor/aoti_eager.py index d77c764a00e..5bd2027baf8 100644 --- a/torch/_inductor/aoti_eager.py +++ b/torch/_inductor/aoti_eager.py @@ -7,6 +7,7 @@ from unittest import mock import torch import torch._export from torch._inductor.utils import is_cpu_device + from .runtime.runtime_utils import cache_dir diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index f2c2f182dbd..1ba70135bd7 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -31,15 +31,14 @@ from torch._inductor.compile_worker.subproc_pool import ( SubprocPool, ) from torch._inductor.compile_worker.watchdog import _async_compile_initializer - from torch._inductor.runtime.compile_tasks import ( _set_triton_ptxas_path, _worker_compile_triton, ) - from torch.hub import _Faketqdm, tqdm from torch.utils._triton import has_triton_package + if TYPE_CHECKING: from torch._inductor.runtime.hints import HalideMeta diff --git a/torch/_inductor/autoheuristic/autoheuristic.py b/torch/_inductor/autoheuristic/autoheuristic.py index 199bb1a9b2b..2d1bcd03b0b 100644 --- a/torch/_inductor/autoheuristic/autoheuristic.py +++ b/torch/_inductor/autoheuristic/autoheuristic.py @@ -1,10 +1,8 @@ import json import os - from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch - from torch._inductor.autoheuristic.autoheuristic_utils import ( AHContext, AHMetadata, diff --git a/torch/_inductor/autoheuristic/autoheuristic_utils.py b/torch/_inductor/autoheuristic/autoheuristic_utils.py index 78d9bc1b5a2..96644df9a54 100644 --- a/torch/_inductor/autoheuristic/autoheuristic_utils.py +++ b/torch/_inductor/autoheuristic/autoheuristic_utils.py @@ -1,7 +1,7 @@ import functools - from typing import Any, Callable, Dict, List, Tuple + Feedback = float Choice = str Value = Any diff --git a/torch/_inductor/autoheuristic/learned_heuristic_controller.py b/torch/_inductor/autoheuristic/learned_heuristic_controller.py index d5ad218d550..3a7ccd6280c 100644 --- a/torch/_inductor/autoheuristic/learned_heuristic_controller.py +++ b/torch/_inductor/autoheuristic/learned_heuristic_controller.py @@ -1,7 +1,6 @@ import importlib import inspect import pkgutil - from collections import defaultdict from typing import Any, Dict, List, Optional diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 5e59054eadc..97931e684ff 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -28,7 +28,6 @@ import torch import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch import multiprocessing from torch._dynamo.testing import rand_strided - from torch._inductor import ir from torch._inductor.codecache import ( CppCodeCache, @@ -38,6 +37,7 @@ from torch._inductor.codecache import ( PyCodeCache, ) + if TYPE_CHECKING: from multiprocessing.process import BaseProcess from multiprocessing.queues import Queue @@ -49,6 +49,7 @@ from . import config from .runtime.runtime_utils import do_bench_cpu, do_bench_gpu from .virtualized import V + CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" EXIT_HANDLER_REGISTERED = False diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index b7bb37e5ee6..22678f8fa59 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -8,6 +8,7 @@ from sympy import Expr import torch from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges + from .ir import InterpreterShim, LoopBody, LoopBodyBlock from .utils import cache_on_self, dominated_nodes from .virtualized import V diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index a93e264c9e0..e331f56c87e 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -27,7 +27,7 @@ import threading import warnings from bisect import bisect_right from copy import copy -from ctypes import c_void_p, cdll, CDLL +from ctypes import c_void_p, CDLL, cdll from functools import partial from pathlib import Path from time import time, time_ns @@ -58,6 +58,7 @@ from torch._inductor.codegen.rocm.compile_command import ( rocm_compiler, ) + """ codecache.py, cpp_builder.py and cpu_vec_isa.py import rule: https://github.com/pytorch/pytorch/issues/124245#issuecomment-2197778902 @@ -86,7 +87,6 @@ from torch._inductor.runtime.compile_tasks import ( ) from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir from torch._inductor.utils import ALIGN_BYTES, clear_on_fresh_inductor_cache, is_linux - from torch._logging import trace_structured from torch._subclasses.fake_tensor import ( extract_tensor_metadata, @@ -95,6 +95,7 @@ from torch._subclasses.fake_tensor import ( ) from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv + if TYPE_CHECKING: from concurrent.futures import Future diff --git a/torch/_inductor/codegen/aoti_hipify_utils.py b/torch/_inductor/codegen/aoti_hipify_utils.py index 8a7aa833391..80085aa6d18 100644 --- a/torch/_inductor/codegen/aoti_hipify_utils.py +++ b/torch/_inductor/codegen/aoti_hipify_utils.py @@ -2,9 +2,9 @@ import re import torch - from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE + # It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like: # "... # from ..codecache import CudaKernelParamCache diff --git a/torch/_inductor/codegen/codegen_device_driver.py b/torch/_inductor/codegen/codegen_device_driver.py index f11188e1927..c31017fe647 100644 --- a/torch/_inductor/codegen/codegen_device_driver.py +++ b/torch/_inductor/codegen/codegen_device_driver.py @@ -1,5 +1,6 @@ import torch + # Provide aoti module launch hip/cuda drivers. This file is also used for unit testing purpose diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 0aefb4073f6..d0ea7252604 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -21,8 +21,8 @@ from torch.utils import _pytree as pytree from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges -from ..._dynamo.utils import counters +from ..._dynamo.utils import counters from .. import codecache, config, cpp_builder, cpu_vec_isa, ir, metrics from ..codegen.wrapper import WrapperCodeGen from ..optimize_indexing import range_expressable_in_32_bits @@ -46,7 +46,6 @@ from ..utils import ( sympy_product, sympy_subs, ) - from ..virtualized import NullKernelHandler, ops, OpsValue, V from .common import ( BackendFeature, @@ -63,7 +62,6 @@ from .common import ( OpOverrides, OptimizationContext, ) - from .cpp_utils import ( cexpr, cexpr_index, @@ -74,6 +72,7 @@ from .cpp_utils import ( value_to_cpp, ) + _IS_WINDOWS = sys.platform == "win32" schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 32e480df295..32c8f4c5e22 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -6,19 +6,19 @@ from unittest.mock import patch import torch import torch.utils + from ..._dynamo.utils import counters from .. import ir, lowering as L - from ..kernel.mm_common import mm_args from ..select_algorithm import DataProcessorTemplateWrapper from ..utils import cache_on_self, has_free_symbols, parallel_num_threads from ..virtualized import ops, V from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm, LayoutType from .cpp_template import CppTemplate - from .cpp_template_kernel import CppTemplateKernel from .cpp_utils import GemmBlocking, get_gemm_template_output_and_compute_dtype + GEMM_TEMPLATE = r""" {{template.header().getvalue()}} diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index f45889e1434..a0ac77ae15b 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -3,7 +3,6 @@ import ctypes import functools import itertools import logging - import sys from typing import Callable, List, Optional from unittest.mock import patch @@ -17,6 +16,7 @@ from ..virtualized import V from .common import KernelTemplate from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel + log = logging.getLogger(__name__) diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index f5392f8e5db..a0aceadb37d 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -7,8 +7,8 @@ from sympy.parsing.sympy_parser import parse_expr import torch from torch.utils._sympy.symbol import SymT -from .. import config, cpp_builder, ir, lowering as L +from .. import config, cpp_builder, ir, lowering as L from ..autotune_process import CppBenchmarkRequest from ..select_algorithm import PartialRender from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 04f951a740e..159010d82ee 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -2,7 +2,6 @@ import contextlib import copy import math - from collections import namedtuple from typing import Any, Callable, Dict, List, Optional, Tuple from unittest.mock import patch @@ -11,12 +10,13 @@ import sympy import torch from torch.utils._sympy.symbol import symbol_is_type, SymT + from .. import ir from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs from ..virtualized import V - from .common import CSEVariable, ExprPrinter, Kernel, KernelArgs + DTYPE_TO_CPP = { torch.float32: "float", torch.float64: "double", diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 02c4e609cad..e137fd00d36 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -10,10 +10,10 @@ import sympy from sympy import Expr import torch - import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools import torch._ops from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes + from .. import config, ir from ..utils import _align, ALIGN_BYTES, cache_on_self, sympy_product from ..virtualized import V diff --git a/torch/_inductor/codegen/cpp_wrapper_cuda.py b/torch/_inductor/codegen/cpp_wrapper_cuda.py index 9bb54e883ef..496f1126440 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cuda.py +++ b/torch/_inductor/codegen/cpp_wrapper_cuda.py @@ -19,6 +19,7 @@ from .cpp_utils import DTYPE_TO_CPP from .cpp_wrapper_cpu import CppWrapperCpu from .wrapper import SymbolicCallArg + if TYPE_CHECKING: from ..graph import GraphLowering diff --git a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py index 91559164895..1d7d2bbb88b 100644 --- a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +++ b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -3,16 +3,15 @@ import logging from typing import cast, Sequence from ...._dynamo.utils import counters - from ... import config from ...codecache import code_hash, get_path - from ...ir import CUDATemplateBuffer from ...scheduler import BaseSchedulerNode, BaseScheduling, Scheduler, SchedulerNode from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product from ...virtualized import V from ..common import IndentedBuffer + log = logging.getLogger(__name__) diff --git a/torch/_inductor/codegen/cuda/cuda_env.py b/torch/_inductor/codegen/cuda/cuda_env.py index 6171921173e..fa272314260 100644 --- a/torch/_inductor/codegen/cuda/cuda_env.py +++ b/torch/_inductor/codegen/cuda/cuda_env.py @@ -6,6 +6,7 @@ import torch from ... import config + log = logging.getLogger(__name__) diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index b6256c0ccd0..68fba0e77ba 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -15,9 +15,9 @@ from ...ir import ( from ...utils import sympy_product from ...virtualized import V from ..common import IndentedBuffer, Kernel, OpOverrides - from ..cpp_utils import CppPrinter, DTYPE_TO_CPP + if TYPE_CHECKING: from torch._inductor.codegen.cuda.cuda_template import CUDATemplate diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index 24a02efe380..8b22636d46e 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -8,14 +8,15 @@ from unittest.mock import patch import sympy import torch + from ...autotune_process import CUDABenchmarkRequest, TensorMeta from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout - from ...utils import IndentedBuffer, unique from ...virtualized import V from ..common import KernelTemplate from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel + log = logging.getLogger(__name__) diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py index 4ee8af3949a..a94f817a208 100644 --- a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py +++ b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py @@ -1,11 +1,12 @@ # mypy: allow-untyped-defs from ..cutlass_utils import try_import_cutlass + if try_import_cutlass(): import enum - from cutlass_library.library import * # noqa: F401, F403 from cutlass_library.gemm_operation import * # noqa: F401, F403 + from cutlass_library.library import * # noqa: F401, F403 # copied / modified from original at # https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658 diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py index 69aa7f3f88c..a149bd66a21 100644 --- a/torch/_inductor/codegen/cuda/cutlass_utils.py +++ b/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -3,7 +3,6 @@ import functools import logging import os import sys - from dataclasses import dataclass from pathlib import Path from typing import Any, List, Optional @@ -11,12 +10,13 @@ from typing import Any, List, Optional import sympy import torch + from ... import config from ...ir import Layout - from ...runtime.runtime_utils import cache_dir from .cuda_env import get_cuda_arch, get_cuda_version + log = logging.getLogger(__name__) diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index 764ee7709da..9b77336e2ef 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -17,11 +17,11 @@ from ...ir import ( ReinterpretView, ) from ..common import IndentedBuffer - from . import cutlass_utils from .cuda_kernel import CUDATemplateKernel from .cuda_template import CUTLASSTemplate + log = logging.getLogger(__name__) # Jinja template for GEMM Kernel, used by the CUTLASSGemmTemplate class below. diff --git a/torch/_inductor/codegen/cuda_combined_scheduling.py b/torch/_inductor/codegen/cuda_combined_scheduling.py index 1d85c4666cf..1307ee74420 100644 --- a/torch/_inductor/codegen/cuda_combined_scheduling.py +++ b/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -10,7 +10,6 @@ from ..scheduler import ( ) from .cuda.cuda_cpp_scheduling import CUDACPPScheduling from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling - from .triton import TritonScheduling diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index acf0ee0e957..259479ab045 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -25,6 +25,7 @@ import sympy import torch import torch._logging + from ..._prims_common import is_integer_dtype from ...utils._sympy.functions import FloorDiv, ModularIndexing from ...utils._sympy.symbol import symbol_is_type, SymT @@ -34,7 +35,6 @@ from ..codecache import HalideCodeCache from ..ir import get_reduction_combine_fn from ..metrics import is_metric_table_enabled, log_kernel_metadata from ..ops_handler import AddParenHandler, MockHandler - from ..runtime.hints import HalideInputSpec, HalideMeta, ReductionHint from ..utils import ( get_bounds_index_expr, @@ -58,6 +58,7 @@ from .cpp import DTYPE_TO_CPP from .cpp_utils import cexpr from .simd import constant_repr, SIMDKernel, SIMDScheduling + if TYPE_CHECKING: from ..ops_handler import ReductionType, StoreMode diff --git a/torch/_inductor/codegen/memory_planning.py b/torch/_inductor/codegen/memory_planning.py index 435bd2d895c..60360597ec1 100644 --- a/torch/_inductor/codegen/memory_planning.py +++ b/torch/_inductor/codegen/memory_planning.py @@ -10,10 +10,10 @@ from typing import Any, Dict, Iterable, List, Optional, Protocol import sympy import torch + from .. import config, ir from ..utils import _align, align, cache_on_self, CachedMethod, IndentedBuffer from ..virtualized import V - from .wrapper import ( AllocateLine, FreeIfNotReusedLine, diff --git a/torch/_inductor/codegen/multi_kernel.py b/torch/_inductor/codegen/multi_kernel.py index 7d53a267cc2..970df0a4684 100644 --- a/torch/_inductor/codegen/multi_kernel.py +++ b/torch/_inductor/codegen/multi_kernel.py @@ -12,6 +12,7 @@ from ..utils import cache_on_self from ..virtualized import V from .common import TensorArg + log = logging.getLogger(__name__) diff --git a/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py b/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py index eddd1a4069c..d9d6bf05429 100644 --- a/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py +++ b/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py @@ -13,6 +13,7 @@ from torch._inductor.ir import Buffer, Layout from ...utils import IndentedBuffer, try_import_ck_lib + _, gen_ops_library, gen_ops_preselected, CKGemmOperation = try_import_ck_lib() diff --git a/torch/_inductor/codegen/rocm/compile_command.py b/torch/_inductor/codegen/rocm/compile_command.py index ef3b8848323..dddb0c56d27 100644 --- a/torch/_inductor/codegen/rocm/compile_command.py +++ b/torch/_inductor/codegen/rocm/compile_command.py @@ -6,6 +6,7 @@ from typing import List, Optional from torch._inductor import config from torch._inductor.utils import is_linux + log = logging.getLogger(__name__) diff --git a/torch/_inductor/codegen/rocm/rocm_benchmark_request.py b/torch/_inductor/codegen/rocm/rocm_benchmark_request.py index da1c808aa17..4e1f3200365 100644 --- a/torch/_inductor/codegen/rocm/rocm_benchmark_request.py +++ b/torch/_inductor/codegen/rocm/rocm_benchmark_request.py @@ -7,10 +7,10 @@ from ctypes import byref, c_size_t, c_void_p from typing import Any, Callable, Iterable, List, Optional, Union import torch - from torch._inductor.autotune_process import GPUDeviceBenchmarkRequest, TensorMeta from torch._inductor.codecache import DLLWrapper, ROCmCodeCache + log = logging.getLogger(__name__) diff --git a/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py b/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py index 9cb6d3c60d4..2cd340dcfd4 100644 --- a/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py +++ b/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py @@ -8,9 +8,9 @@ from ...scheduler import BaseSchedulerNode, BaseScheduling, Scheduler, Scheduler from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product from ...virtualized import V from ..common import IndentedBuffer - from .rocm_template_buffer import ROCmTemplateBuffer + log = logging.getLogger(__name__) diff --git a/torch/_inductor/codegen/rocm/rocm_kernel.py b/torch/_inductor/codegen/rocm/rocm_kernel.py index 708031427e8..e9c897c1334 100644 --- a/torch/_inductor/codegen/rocm/rocm_kernel.py +++ b/torch/_inductor/codegen/rocm/rocm_kernel.py @@ -6,12 +6,11 @@ from ...ir import Buffer, ChoiceCaller, IRNode, Layout, PrimitiveInfoType, Tenso from ...utils import sympy_product from ...virtualized import V from ..common import IndentedBuffer, Kernel, OpOverrides - from ..cpp_utils import CppPrinter - from .rocm_benchmark_request import ROCmBenchmarkRequest from .rocm_template_buffer import ROCmTemplateBuffer + if TYPE_CHECKING: from torch._inductor.codegen.rocm.rocm_template import ROCmTemplate diff --git a/torch/_inductor/codegen/rocm/rocm_template.py b/torch/_inductor/codegen/rocm/rocm_template.py index 11f623020d3..88c89a6ba1e 100644 --- a/torch/_inductor/codegen/rocm/rocm_template.py +++ b/torch/_inductor/codegen/rocm/rocm_template.py @@ -9,7 +9,6 @@ import sympy from ...autotune_process import TensorMeta from ...ir import Buffer, IRNode, Layout - from ...utils import IndentedBuffer, unique from ...virtualized import V from ..common import KernelTemplate @@ -17,6 +16,7 @@ from .rocm_benchmark_request import ROCmBenchmarkRequest from .rocm_kernel import ROCmTemplateCaller, ROCmTemplateKernel from .rocm_template_buffer import ROCmTemplateBuffer + log = logging.getLogger(__name__) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index cf187edc7cb..fb4c862266c 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -28,13 +28,12 @@ import sympy import torch import torch._logging - from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT + from ..._dynamo.utils import counters from .. import config, ir, scheduler from ..codecache import code_hash - from ..dependencies import Dep, MemoryDep, StarDep, WeakDep from ..ir import TritonTemplateBuffer from ..optimize_indexing import indexing_dtype_strength_reduction diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index b523ccfb381..f24e855e33d 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -27,14 +27,13 @@ import sympy import torch import torch._logging from torch._dynamo.utils import preserve_rng_state - from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties from torch._prims_common import is_integer_dtype from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing from torch.utils._triton import has_triton_package + from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT from ...utils._sympy.value_ranges import ValueRanges - from .. import config, ir from ..codecache import code_hash, get_path, PyCodeCache from ..metrics import is_metric_table_enabled, log_kernel_metadata @@ -73,6 +72,7 @@ from .simd import ( ) from .triton_utils import config_of, signature_of, signature_to_meta + if TYPE_CHECKING: from ..ir import IRNode diff --git a/torch/_inductor/codegen/triton_split_scan.py b/torch/_inductor/codegen/triton_split_scan.py index 1e0475ffd0f..1f5f184ef95 100644 --- a/torch/_inductor/codegen/triton_split_scan.py +++ b/torch/_inductor/codegen/triton_split_scan.py @@ -1,16 +1,12 @@ # mypy: allow-untyped-defs import functools - from typing import Optional, Set import torch._inductor.runtime.hints from torch._inductor import config from torch._inductor.codegen.simd import IterationRangesRoot - from torch._inductor.codegen.triton import triton_compute_type, TritonKernel - from torch._prims_common import prod - from torch.utils._sympy.functions import CeilDiv diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 4be021daff3..779f8b69b72 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -10,7 +10,6 @@ import inspect import logging import operator import re - import tempfile from itertools import count from typing import ( @@ -41,7 +40,6 @@ from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.symbol import symbol_is_type, SymT from .. import async_compile, config, ir - from ..codecache import output_code_log from ..ir import ReinterpretView from ..runtime import triton_heuristics @@ -58,6 +56,7 @@ from .aoti_hipify_utils import maybe_hipify_code_wrapper from .common import CodeGen, DeferredLine, IndentedBuffer, PythonPrinter from .triton_utils import config_of, signature_to_meta + if TYPE_CHECKING: import triton diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index 71e8740a5fd..f8a233a3b9e 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -5,8 +5,8 @@ from enum import IntEnum import sympy import torch -from . import ir +from . import ir from .utils import get_dtype_size, sympy_product from .virtualized import V diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index a6222670860..886c1361a3c 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -4,7 +4,6 @@ from __future__ import annotations import heapq import operator - import sys from collections import defaultdict from typing import Dict, List, Set, TYPE_CHECKING @@ -15,6 +14,7 @@ from . import config, ir from .dependencies import WeakDep from .utils import is_collective, is_wait + overlap_log = torch._logging.getArtifactLogger(__name__, "overlap") if TYPE_CHECKING: diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 95522097024..b1c72565696 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -8,15 +8,12 @@ import sys import time import warnings from itertools import count - from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from unittest import mock import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools - import torch.fx import torch.utils._pytree as pytree - from functorch.compile import min_cut_rematerialization_partition from torch._dynamo import ( compiled_autograd, @@ -77,6 +74,7 @@ from .utils import ( ) from .virtualized import V + if config.is_fbcode(): from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log else: diff --git a/torch/_inductor/compile_worker/__main__.py b/torch/_inductor/compile_worker/__main__.py index 547ad5ffcf7..0f6503b7901 100644 --- a/torch/_inductor/compile_worker/__main__.py +++ b/torch/_inductor/compile_worker/__main__.py @@ -9,6 +9,7 @@ from torch._inductor.compile_worker.subproc_pool import SubprocMain from torch._inductor.compile_worker.watchdog import _async_compile_initializer from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path + log = logging.getLogger(__name__) _set_triton_ptxas_path() diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 72f9426a3ac..2a5ed34667b 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -18,6 +18,7 @@ from typing import Any, Callable, Dict from torch._inductor import config from torch._inductor.compile_worker.watchdog import _async_compile_initializer + log = logging.getLogger(__name__) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index f378242305d..1c39bf4e0d0 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1078,5 +1078,6 @@ if TYPE_CHECKING: from torch.utils._config_module import install_config_module + # adds patch, save_config, etc install_config_module(sys.modules[__name__]) diff --git a/torch/_inductor/constant_folding.py b/torch/_inductor/constant_folding.py index 77bb97fa695..ea0ac48cb39 100644 --- a/torch/_inductor/constant_folding.py +++ b/torch/_inductor/constant_folding.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Optional, Tuple import torch import torch.utils._pytree as pytree + aten = torch.ops.aten # We would like to split modules into two subgraphs for runtime weight updates to work correctly. diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 2dd5ff4f2c4..777999a43cc 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -24,6 +24,7 @@ from torch._inductor import config, exc from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA from torch._inductor.runtime.runtime_utils import cache_dir + if config.is_fbcode(): from triton.fb import build_paths # noqa: F401 diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index 5e99e376beb..7d74cc415d8 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -3,7 +3,6 @@ import dataclasses import functools import os import platform - import re import subprocess import sys @@ -12,6 +11,7 @@ from typing import Any, Callable, Dict, List import torch from torch._inductor import config + _IS_WINDOWS = sys.platform == "win32" diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 2020d05b62b..20343139b2d 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -49,7 +49,6 @@ import traceback import warnings import weakref from collections import defaultdict - from enum import auto, Enum from typing import ( Any, @@ -91,6 +90,7 @@ from torch.storage import UntypedStorage from torch.utils import _pytree as pytree from torch.utils.weak import TensorWeakRef + if TYPE_CHECKING: from torch.types import _bool diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index e2eb7c52390..b45b5663246 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch._dynamo.utils import counters + perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index b0ad369c431..808a089aedc 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -15,10 +15,8 @@ from typing import Any, Dict, List, Optional from unittest.mock import patch import torch - from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled from torch import fx as fx - from torch._dynamo.repro.after_aot import save_graph_repro, wrap_compiler_debug from torch._dynamo.utils import get_debug_dir from torch.fx.graph_module import GraphModule @@ -36,6 +34,7 @@ from .scheduler import ( ) from .virtualized import V + log = logging.getLogger(__name__) SchedulerNodeList = List[Any] diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 1fe95059167..0ae33145f3d 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -37,6 +37,7 @@ from .utils import ( use_scatter_fallback, ) + log = logging.getLogger(__name__) aten = torch.ops.aten prims = torch.ops.prims diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 3a68778c8a5..682dd2c0963 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -25,6 +25,7 @@ from .utils import ( ) from .virtualized import OpsHandler, ReductionType, V + log = logging.getLogger(__name__) is_indirect = re.compile(r"indirect|tmp").search diff --git a/torch/_inductor/exc.py b/torch/_inductor/exc.py index 8a172d8c29b..07c1eebf99b 100644 --- a/torch/_inductor/exc.py +++ b/torch/_inductor/exc.py @@ -6,6 +6,7 @@ import tempfile import textwrap from functools import lru_cache + if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1": @lru_cache(None) diff --git a/torch/_inductor/freezing.py b/torch/_inductor/freezing.py index 2fecce80acf..c62e7d4893e 100644 --- a/torch/_inductor/freezing.py +++ b/torch/_inductor/freezing.py @@ -3,7 +3,6 @@ from __future__ import annotations import itertools import logging - import weakref from typing import Any, List, Optional, Tuple @@ -13,12 +12,12 @@ from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code from torch._functorch.aot_autograd import MutationType from torch._functorch.compile_utils import fx_graph_cse from torch._inductor.constant_folding import constant_fold, replace_node_with_constant - from torch._inductor.fx_passes.freezing_patterns import freezing_passes from torch._inductor.fx_passes.post_grad import view_to_reshape from . import config + aten = torch.ops.aten prims = torch.ops.prims diff --git a/torch/_inductor/fx_passes/b2b_gemm.py b/torch/_inductor/fx_passes/b2b_gemm.py index cbcf59e6082..e67b275546b 100644 --- a/torch/_inductor/fx_passes/b2b_gemm.py +++ b/torch/_inductor/fx_passes/b2b_gemm.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import torch + from ..._dynamo.utils import counters from ..ir import FixedLayout from ..pattern_matcher import ( diff --git a/torch/_inductor/fx_passes/binary_folding.py b/torch/_inductor/fx_passes/binary_folding.py index 7453cde1ce9..93b148ac687 100644 --- a/torch/_inductor/fx_passes/binary_folding.py +++ b/torch/_inductor/fx_passes/binary_folding.py @@ -3,11 +3,12 @@ import functools import itertools import torch -from ..._dynamo.utils import counters +from ..._dynamo.utils import counters from ..pattern_matcher import Arg, CallFunction, KeywordArg from .freezing_patterns import register_binary_folding_pattern + aten = torch.ops.aten prims = torch.ops.prims diff --git a/torch/_inductor/fx_passes/ddp_fusion.py b/torch/_inductor/fx_passes/ddp_fusion.py index 6ef0f71a807..6c4e57a1d50 100644 --- a/torch/_inductor/fx_passes/ddp_fusion.py +++ b/torch/_inductor/fx_passes/ddp_fusion.py @@ -30,6 +30,7 @@ from .. import config from ..fx_utils import get_fake_args_kwargs from ..virtualized import V + aten = torch.ops.aten logger: logging.Logger = logging.getLogger("comm_fusion") diff --git a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py index dba2f62e7d6..4c2fab6d41f 100644 --- a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py +++ b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py @@ -7,10 +7,10 @@ from torch import Tensor from torch._dynamo.utils import counters from .. import config - from ..pattern_matcher import Arg, CallFunction, Match, register_graph_pattern from .split_cat import construct_pattern_matcher_pass + aten = torch.ops.aten log = logging.getLogger(__name__) diff --git a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py index c8165a1a392..6920629a4d3 100644 --- a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py +++ b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-defs import torch import torch.nn as nn - from torch._dynamo.utils import counters from torch._inductor import config as inductor_config from torch.func import functional_call @@ -12,7 +11,6 @@ from ..pattern_matcher import ( Match, register_graph_pattern, ) - from .pre_grad import efficient_conv_bn_eval_pass diff --git a/torch/_inductor/fx_passes/freezing_patterns.py b/torch/_inductor/fx_passes/freezing_patterns.py index 039fea2dcca..0d098515141 100644 --- a/torch/_inductor/fx_passes/freezing_patterns.py +++ b/torch/_inductor/fx_passes/freezing_patterns.py @@ -3,8 +3,8 @@ import functools import torch from torch._inductor.compile_fx import fake_tensor_prop -from ..._dynamo.utils import counters +from ..._dynamo.utils import counters from .. import config from ..pattern_matcher import ( _return_true, @@ -20,6 +20,7 @@ from ..pattern_matcher import ( stable_topological_sort, ) + aten = torch.ops.aten # First pass_patterns[0] are applied, then [1], then [2] diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index fad49d40482..1a3b681230e 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -6,6 +6,7 @@ import math import torch from torch.nn.attention import sdpa_kernel, SDPBackend + from ..._dynamo.utils import counters from ..pattern_matcher import ( filter_nodes, @@ -14,6 +15,7 @@ from ..pattern_matcher import ( joint_fwd_bwd, ) + log = logging.getLogger(__name__) aten = torch.ops.aten diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index a5fe1870349..b55c4a67a72 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -28,6 +28,7 @@ from ..pattern_matcher import ( stable_topological_sort, ) + try: # importing this will register fbgemm lowerings for inductor import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings # noqa: F401 diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 791f366e0ff..89165f99b6c 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -27,6 +27,7 @@ from ..pattern_matcher import ( ) from .replace_random import replace_random_passes + log = logging.getLogger(__name__) patterns = PatternMatcherPass() aten = torch.ops.aten diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py index 391548b041b..28ade543d6e 100644 --- a/torch/_inductor/fx_passes/micro_pipeline_tp.py +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -4,8 +4,8 @@ from dataclasses import dataclass from typing import cast, List, Set, Tuple, Union import torch -from .. import inductor_prims +from .. import inductor_prims from ..pattern_matcher import ( CallFunction, Ignored, @@ -16,6 +16,7 @@ from ..pattern_matcher import ( register_graph_pattern, ) + aten = torch.ops.aten patterns = PatternMatcherPass() diff --git a/torch/_inductor/fx_passes/misc_patterns.py b/torch/_inductor/fx_passes/misc_patterns.py index f2d943cab24..d7873fede3c 100644 --- a/torch/_inductor/fx_passes/misc_patterns.py +++ b/torch/_inductor/fx_passes/misc_patterns.py @@ -1,14 +1,14 @@ # mypy: allow-untyped-defs import functools - from typing import Dict, Set, Tuple import torch from torch._dynamo.utils import counters - from torch._ops import OpOverload, OpOverloadPacket + from ..pattern_matcher import fwd_only, register_replacement + aten = torch.ops.aten diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index a7dc42107e1..34ddbf90b7f 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -5,11 +5,9 @@ from functools import reduce from typing import Any, Tuple import torch - from torch.fx.experimental.symbolic_shapes import has_free_symbols from .. import ir - from ..lowering import lowerings as L from ..pattern_matcher import ( Arg, @@ -28,6 +26,7 @@ from .quantization import ( _register_woq_lowerings, ) + if torch._C._has_mkldnn: aten = torch.ops.aten mkldnn = torch.ops.mkldnn diff --git a/torch/_inductor/fx_passes/numeric_utils.py b/torch/_inductor/fx_passes/numeric_utils.py index 5bad4ed9489..7069b100a38 100644 --- a/torch/_inductor/fx_passes/numeric_utils.py +++ b/torch/_inductor/fx_passes/numeric_utils.py @@ -12,6 +12,7 @@ import torch.optim as optim from .. import config + logger: logging.Logger = logging.getLogger(__name__) MAIN_RANDOM_SEED = 1337 diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index f8e5f4f550d..0f161573352 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -23,7 +23,6 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.utils._mode_utils import no_dispatch from ...utils._triton import has_triton - from ..pattern_matcher import ( fwd_only, gen_register_replacement, @@ -33,6 +32,7 @@ from ..pattern_matcher import ( SearchFn, ) + aten = torch.ops.aten diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index b49958de217..df942842639 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -14,9 +14,7 @@ from torch._decomp import register_decomposition from torch._dynamo.utils import counters, optimus_scuba_log from torch._inductor import comms from torch._inductor.virtualized import ops - from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype - from torch._utils_internal import upload_graph from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq from torch.fx.passes.graph_transform_observer import GraphTransformObserver @@ -24,7 +22,6 @@ from torch.fx.passes.graph_transform_observer import GraphTransformObserver from .. import config, ir, pattern_matcher from ..codegen.common import BackendFeature, has_backend_feature from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage - from ..lowering import lowerings as L from ..pattern_matcher import ( _return_true, @@ -54,6 +51,7 @@ from .pre_grad import is_same_dict, save_inductor_dict from .reinplace import reinplace_inplaceable_ops from .split_cat import POST_GRAD_PATTERNS + if TYPE_CHECKING: from sympy import Expr diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index 90b8e709ca2..6c243a92c1e 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -18,7 +18,6 @@ from torch.nn import functional as F from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights from .. import config - from ..fx_utils import matches_module_function_pattern from ..pattern_matcher import ( init_once_fakemode, @@ -30,6 +29,7 @@ from .group_batch_fusion import group_batch_fusion_passes, PRE_GRAD_FUSIONS from .misc_patterns import numpy_compat_normalization from .split_cat import PRE_GRAD_PATTERNS + log = logging.getLogger(__name__) efficient_conv_bn_eval_pass = PatternMatcherPass( diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 0ea5df8cf99..f40d5dc8c30 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -10,12 +10,14 @@ import torch from torch._dynamo.utils import counters from torch.fx.experimental.symbolic_shapes import has_free_symbols from torch.fx.node import map_arg + from ..lowering import lowerings as L, require_channels_last from ..pattern_matcher import Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match from ..utils import pad_listlike from .freezing_patterns import register_freezing_graph_pattern from .post_grad import register_lowering_pattern + aten = torch.ops.aten prims = torch.ops.prims quantized_decomposed = torch.ops.quantized_decomposed diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index bca338e23a9..33d42bf1456 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -17,6 +17,7 @@ from torch.fx.immutable_collections import immutable_dict from torch.fx.passes.reinplace import _is_view_op from torch.utils import _pytree as pytree + aten = torch.ops.aten diff --git a/torch/_inductor/fx_passes/replace_random.py b/torch/_inductor/fx_passes/replace_random.py index c028eb35379..f56a90b86dd 100644 --- a/torch/_inductor/fx_passes/replace_random.py +++ b/torch/_inductor/fx_passes/replace_random.py @@ -5,6 +5,7 @@ import logging import torch from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.fx.passes.shape_prop import _extract_tensor_metadata + from .. import config, inductor_prims from ..pattern_matcher import ( CallFunctionVarArgs, @@ -14,6 +15,7 @@ from ..pattern_matcher import ( ) from ..virtualized import V + log = logging.getLogger(__name__) patterns = PatternMatcherPass() aten = torch.ops.aten diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 913536c4914..11a8e439902 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -28,6 +28,7 @@ from ..pattern_matcher import ( ) from .group_batch_fusion import is_node_meta_valid, POST_GRAD_FUSIONS, PRE_GRAD_FUSIONS + log = logging.getLogger(__name__) _Arguments: TypeAlias = Tuple[torch.fx.node.Argument, ...] diff --git a/torch/_inductor/fx_utils.py b/torch/_inductor/fx_utils.py index 8f3ed2e9177..6b791e79940 100644 --- a/torch/_inductor/fx_utils.py +++ b/torch/_inductor/fx_utils.py @@ -15,6 +15,7 @@ from torch.fx.experimental.symbolic_shapes import ( ) from torch.utils import _pytree as pytree from torch.utils._pytree import tree_map + from .virtualized import V diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index c513d6c1bb7..cdbf863f177 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -93,12 +93,14 @@ from .utils import ( ) from .virtualized import NullHandler, V + if TYPE_CHECKING: from torch._higher_order_ops.effects import _EffectType from .codegen.wrapper import WrapperCodeGen from torch._inductor.codecache import output_code_log + log = logging.getLogger(__name__) perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") diff --git a/torch/_inductor/hooks.py b/torch/_inductor/hooks.py index bf4a8bb090a..9d8aeecd283 100644 --- a/torch/_inductor/hooks.py +++ b/torch/_inductor/hooks.py @@ -2,6 +2,7 @@ import contextlib from typing import Callable, List, TYPE_CHECKING + if TYPE_CHECKING: import torch diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py index 894f6836efa..fbfd4cf2896 100644 --- a/torch/_inductor/index_propagation.py +++ b/torch/_inductor/index_propagation.py @@ -31,8 +31,8 @@ import torch from torch._prims_common import dtype_to_type, is_integer_dtype from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges -from .utils import generate_assert +from .utils import generate_assert from .virtualized import V diff --git a/torch/_inductor/inductor_prims.py b/torch/_inductor/inductor_prims.py index 4a50129470f..1f12d729b14 100644 --- a/torch/_inductor/inductor_prims.py +++ b/torch/_inductor/inductor_prims.py @@ -7,6 +7,7 @@ from typing import Optional, Sequence import torch from torch import _prims, Tensor + log = logging.getLogger(__name__) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index f8a4de9363e..dd49e1ecce9 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -31,9 +31,7 @@ import sympy from sympy import Expr, Integer import torch._export.serde.schema as export_schema - import torch._logging - import torch.fx import torch.utils._pytree as pytree from torch._dynamo.device_interface import get_interface_for_device @@ -90,6 +88,7 @@ from .utils import ( ) from .virtualized import ops, V + if TYPE_CHECKING: from .graph import GraphLowering diff --git a/torch/_inductor/jagged_lowerings.py b/torch/_inductor/jagged_lowerings.py index a53140f13e5..7a6f1f4632b 100644 --- a/torch/_inductor/jagged_lowerings.py +++ b/torch/_inductor/jagged_lowerings.py @@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union import sympy import torch + from .ir import Pointwise, TensorBox from .lowering import fallback_handler, is_integer_type, register_lowering from .virtualized import ops diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py index 7d1fbc0b35e..f6af2e1bb6e 100644 --- a/torch/_inductor/kernel/bmm.py +++ b/torch/_inductor/kernel/bmm.py @@ -16,11 +16,10 @@ from ..utils import ( use_triton_template, ) from ..virtualized import V - from .mm import _is_static_problem - from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options + log = logging.getLogger(__name__) aten = torch.ops.aten diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index 15c750f0b66..36e12777efb 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -8,7 +8,6 @@ from typing import cast, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypedDi import torch from .. import config, ir - from ..lowering import ( add_layout_constraint, constrain_to_fx_strides, @@ -31,6 +30,7 @@ from ..utils import ( from ..virtualized import V from .mm_common import filtered_configs + if TYPE_CHECKING: from ..ir import TensorBox diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 6f71384d8ae..8dfe6d39fa4 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -9,6 +9,7 @@ import sympy import torch from torch._inductor.virtualized import V from torch.utils._pytree import tree_map + from .. import config from ..ir import ( ComputedBuffer, @@ -23,6 +24,7 @@ from ..ir import ( from ..lowering import empty, empty_strided, lowerings, register_lowering from ..select_algorithm import autotune_select_algorithm, TritonTemplate + log = logging.getLogger(__name__) aten = torch.ops.aten diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py index 78c4156f43e..562dd70d8c8 100644 --- a/torch/_inductor/kernel/flex_decoding.py +++ b/torch/_inductor/kernel/flex_decoding.py @@ -6,6 +6,7 @@ import sympy import torch from torch._inductor.virtualized import V + from ..ir import FixedLayout, FlexibleLayout from ..lowering import empty_strided, lowerings from ..runtime.runtime_utils import next_power_of_2 diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index ad4630dc6f1..0283055c9e7 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional import torch from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate from torch._inductor.virtualized import V + from .. import config as inductor_config from ..codegen.common import BackendFeature from ..codegen.cuda.gemm_template import CUTLASSGemmTemplate @@ -39,6 +40,7 @@ from .mm_common import ( triton_config, ) + log = logging.getLogger(__name__) aten = torch.ops.aten diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 9ffaba040e7..5153d4d9861 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -14,6 +14,7 @@ from .. import config as inductor_config from ..runtime.runtime_utils import next_power_of_2 from ..utils import ceildiv as cdiv + log = logging.getLogger(__name__) diff --git a/torch/_inductor/kernel/mm_plus_mm.py b/torch/_inductor/kernel/mm_plus_mm.py index f2f810d1fe0..acc8d58de10 100644 --- a/torch/_inductor/kernel/mm_plus_mm.py +++ b/torch/_inductor/kernel/mm_plus_mm.py @@ -13,6 +13,7 @@ from ..utils import use_aten_gemm_kernels, use_triton_template from ..virtualized import V from .mm_common import mm_args, mm_grid, mm_options + aten = torch.ops.aten aten_mm_plus_mm = ExternKernelChoice( diff --git a/torch/_inductor/kernel/unpack_mixed_mm.py b/torch/_inductor/kernel/unpack_mixed_mm.py index c483dbff2b8..674da97c165 100644 --- a/torch/_inductor/kernel/unpack_mixed_mm.py +++ b/torch/_inductor/kernel/unpack_mixed_mm.py @@ -5,6 +5,7 @@ from typing import List, TYPE_CHECKING from ..select_algorithm import autotune_select_algorithm, TritonTemplate from .mm_common import mm_args, mm_configs, mm_grid, mm_options + if TYPE_CHECKING: from ..ir import ChoiceCaller diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 24a4bec5368..b959f513fa3 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -42,8 +42,8 @@ from torch.utils._sympy.functions import ( IntTrueDiv, ModularIndexing, ) -from .._dynamo.utils import import_submodule +from .._dynamo.utils import import_submodule from . import config, inductor_prims, ir, test_operators # NOQA: F401 from .decomposition import decompositions, get_decompositions from .ir import ( @@ -72,6 +72,7 @@ from .utils import ( ) from .virtualized import ops, V + log = logging.getLogger(__name__) lowerings: Dict[torch._ops.OpOverload, Callable[..., Any]] = {} layout_constraints: Dict[torch._ops.OpOverload, Callable[..., Any]] = {} @@ -6108,6 +6109,7 @@ def resize(x, size, *, memory_format=None): from torch._higher_order_ops.auto_functionalize import auto_functionalized + make_fallback(auto_functionalized) @@ -6395,17 +6397,21 @@ except (AttributeError, ImportError): # populate lowerings defined in kernel/* from . import kernel + import_submodule(kernel) from . import quantized_lowerings + quantized_lowerings.register_quantized_ops() quantized_lowerings.register_woq_mm_ops() from . import mkldnn_lowerings + mkldnn_lowerings.register_onednn_fusion_ops() from . import jagged_lowerings + jagged_lowerings.register_jagged_ops() diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index 9fd3a60a4d2..07ac6702e32 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -8,12 +8,12 @@ import os import re from dataclasses import dataclass from functools import lru_cache - from typing import Dict, List, Set, Tuple, TYPE_CHECKING, Union from torch._inductor import config from torch._inductor.utils import get_benchmark_name + # Prevent circular import if TYPE_CHECKING: from torch._inductor.scheduler import ( diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index c7cf37151bd..5d82c87f0c9 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -4,7 +4,6 @@ from typing import Any, List, Optional, Set import sympy import torch - from torch._prims_common import make_channels_last_strides_for from .ir import ( @@ -22,9 +21,7 @@ from .ir import ( NoneLayout, TensorBox, ) - from .utils import convert_shape_to_inductor, pad_listlike - from .virtualized import V diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index b743af79681..84ed0583fd1 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -5,6 +5,7 @@ from typing import List, Optional import torch import torch.utils._pytree as pytree from torch._inductor.kernel.mm_common import mm_args + from . import ir, mkldnn_ir from .codegen.cpp_gemm_template import CppPackedGemmTemplate from .ir import TensorBox diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 041dd3464ee..b901de2abe2 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -18,8 +18,10 @@ import sympy import torch import torch.utils._pytree as pytree + from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str + T = TypeVar("T") StoreMode = Optional[Literal["atomic_add"]] ReductionType = Literal[ diff --git a/torch/_inductor/optimize_indexing.py b/torch/_inductor/optimize_indexing.py index 63887b34736..c333bd88294 100644 --- a/torch/_inductor/optimize_indexing.py +++ b/torch/_inductor/optimize_indexing.py @@ -5,6 +5,7 @@ import sympy import torch from torch.utils._sympy.value_ranges import ValueRanges + from .ir import LoopBody from .utils import dominated_nodes diff --git a/torch/_inductor/package/package.py b/torch/_inductor/package/package.py index c1fafe8389e..34e0c910b26 100644 --- a/torch/_inductor/package/package.py +++ b/torch/_inductor/package/package.py @@ -14,6 +14,7 @@ import torch.utils._pytree as pytree from torch._inductor import config, exc from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder from torch.export._tree_utils import reorder_kwargs + from .build_package import build_package_contents from .pt2_archive_constants import AOTINDUCTOR_DIR, ARCHIVE_VERSION diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 58c595ec9f2..c6e403c0372 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -38,7 +38,6 @@ successful match or a `FailedMatch` object for a failure to match. from __future__ import annotations import contextlib - import dataclasses import functools import importlib @@ -96,6 +95,7 @@ from . import config from .decomposition import select_decomp_table from .lowering import fallback_node_due_to_unsupported_type + log = logging.getLogger(__name__) aten = torch.ops.aten prims = torch.ops.prims diff --git a/torch/_inductor/quantized_lowerings.py b/torch/_inductor/quantized_lowerings.py index 954a85abe52..852586dd8ad 100644 --- a/torch/_inductor/quantized_lowerings.py +++ b/torch/_inductor/quantized_lowerings.py @@ -1,7 +1,9 @@ # mypy: allow-untyped-defs import torch + from . import lowering + quantized = torch.ops.quantized _quantized = torch.ops._quantized aten = torch.ops.aten diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 31ff9477461..133a3010166 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -5,9 +5,9 @@ import logging from typing import Callable, Optional from .hints import TRITON_MAX_BLOCK - from .runtime_utils import red_text, triton_config_to_hashable + try: import triton except ImportError: diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index e516153fa68..85af8650210 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -19,7 +19,6 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple import torch from .coordinate_descent_tuner import CoordescTuner - from .hints import ( _NUM_THREADS_PER_WARP, AutotuneHint, @@ -43,6 +42,7 @@ from .runtime_utils import ( triton_config_to_hashable, ) + try: import triton except ImportError: diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 58619c1d630..f80df472039 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3076,6 +3076,7 @@ def debug_triton_code(node: Union[SchedulerNode, FusedSchedulerNode]) -> List[st from torch._inductor.codegen.cuda_combined_scheduling import ( CUDACombinedScheduling, ) + from .codegen.simd import SIMDScheduling snodes = (node,) if isinstance(node, SchedulerNode) else node.snodes diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index cf9195f090c..b5dec6362d3 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -6,7 +6,6 @@ import inspect import itertools import json import logging - import math import operator import os @@ -16,7 +15,6 @@ import time from collections import namedtuple from concurrent.futures import as_completed, ThreadPoolExecutor from io import StringIO - from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch @@ -32,7 +30,6 @@ from . import config, ir from .autotune_process import TensorMeta, TritonBenchmarkRequest from .codecache import code_hash, PersistentCache, PyCodeCache from .codegen.common import IndentedBuffer, KernelTemplate - from .codegen.triton import ( gen_common_triton_imports, texpr, @@ -40,7 +37,6 @@ from .codegen.triton import ( TritonPrinter, TritonScheduling, ) - from .codegen.triton_utils import config_of, signature_to_meta from .exc import CUDACompileError from .ir import ChoiceCaller, PrimitiveInfoType @@ -58,6 +54,7 @@ from .utils import ( ) from .virtualized import V + log = logging.getLogger(__name__) # correctness checks struggle with fp16/tf32 diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 7d97f5b2752..e013b728bad 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -32,6 +32,7 @@ from .utils import ( ) from .virtualized import V + log = logging.getLogger(__name__) diff --git a/torch/_inductor/subgraph_lowering.py b/torch/_inductor/subgraph_lowering.py index 4f7eec8ff50..58cac06be85 100644 --- a/torch/_inductor/subgraph_lowering.py +++ b/torch/_inductor/subgraph_lowering.py @@ -15,6 +15,7 @@ from .exc import SubgraphLoweringException from .ops_handler import SimpleCSEHandler from .virtualized import ops, V, WrapperHandler + T = TypeVar("T") diff --git a/torch/_inductor/test_case.py b/torch/_inductor/test_case.py index 3acc68ff22a..53a791685c6 100644 --- a/torch/_inductor/test_case.py +++ b/torch/_inductor/test_case.py @@ -6,7 +6,6 @@ from torch._dynamo.test_case import ( run_tests as dynamo_run_tests, TestCase as DynamoTestCase, ) - from torch._inductor import config from torch._inductor.utils import fresh_inductor_cache diff --git a/torch/_inductor/test_operators.py b/torch/_inductor/test_operators.py index 3c105ba7db2..a5c1d401f2d 100644 --- a/torch/_inductor/test_operators.py +++ b/torch/_inductor/test_operators.py @@ -3,6 +3,7 @@ import torch.library from torch import Tensor from torch.autograd import Function + if not torch._running_with_deploy(): _test_lib_def = torch.library.Library("_inductor_test", "DEF") _test_lib_def.define( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index e3016cfe8a1..18ae52d4fbf 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -58,9 +58,11 @@ from torch.utils._sympy.functions import ( ) from torch.utils._sympy.symbol import make_symbol, SymT from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges + from . import config from .runtime.runtime_utils import ceildiv as runtime_ceildiv + log = logging.getLogger(__name__) _T = TypeVar("_T") diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 51ff55a00b7..23f374cbc1e 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -70,6 +70,7 @@ from .ops_handler import ( # noqa: F401 WrapperHandler, ) + if TYPE_CHECKING: import torch from torch._inductor.codegen.cpp_utils import LocalBufferContext diff --git a/torch/_inductor/wrapper_benchmark.py b/torch/_inductor/wrapper_benchmark.py index 976d0c7458e..46bf9d0fe1c 100644 --- a/torch/_inductor/wrapper_benchmark.py +++ b/torch/_inductor/wrapper_benchmark.py @@ -5,12 +5,14 @@ from collections import defaultdict import torch from torch.autograd import DeviceType + from .runtime.runtime_utils import ( create_bandwidth_info_str, do_bench_gpu, get_num_bytes, ) + _kernel_category_choices = [ "foreach", "persistent_reduction",