mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[BE][Easy][16/19] enforce style for empty lines in import segments in torch/_i*/ (#129768)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129768 Approved by: https://github.com/jansel
This commit is contained in:
parent
8e478d4fb1
commit
b6d477fd56
111 changed files with 134 additions and 92 deletions
|
|
@ -58,7 +58,6 @@ ISORT_SKIPLIST = re.compile(
|
|||
# torch/_[e-h]*/**
|
||||
"torch/_[e-h]*/**",
|
||||
# torch/_i*/**
|
||||
"torch/_i*/**",
|
||||
# torch/_[j-z]*/**
|
||||
"torch/_[j-z]*/**",
|
||||
# torch/[a-c]*/**
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
import torch.fx
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
|
||||
__all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from unittest import mock
|
|||
import torch
|
||||
import torch._export
|
||||
from torch._inductor.utils import is_cpu_device
|
||||
|
||||
from .runtime.runtime_utils import cache_dir
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -31,15 +31,14 @@ from torch._inductor.compile_worker.subproc_pool import (
|
|||
SubprocPool,
|
||||
)
|
||||
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
|
||||
|
||||
from torch._inductor.runtime.compile_tasks import (
|
||||
_set_triton_ptxas_path,
|
||||
_worker_compile_triton,
|
||||
)
|
||||
|
||||
from torch.hub import _Faketqdm, tqdm
|
||||
from torch.utils._triton import has_triton_package
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.runtime.hints import HalideMeta
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
||||
AHContext,
|
||||
AHMetadata,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import functools
|
||||
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
|
||||
Feedback = float
|
||||
Choice = str
|
||||
Value = Any
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import importlib
|
||||
import inspect
|
||||
import pkgutil
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ import torch
|
|||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
from torch import multiprocessing
|
||||
from torch._dynamo.testing import rand_strided
|
||||
|
||||
from torch._inductor import ir
|
||||
from torch._inductor.codecache import (
|
||||
CppCodeCache,
|
||||
|
|
@ -38,6 +37,7 @@ from torch._inductor.codecache import (
|
|||
PyCodeCache,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from multiprocessing.process import BaseProcess
|
||||
from multiprocessing.queues import Queue
|
||||
|
|
@ -49,6 +49,7 @@ from . import config
|
|||
from .runtime.runtime_utils import do_bench_cpu, do_bench_gpu
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
|
||||
EXIT_HANDLER_REGISTERED = False
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from sympy import Expr
|
|||
|
||||
import torch
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
||||
|
||||
from .ir import InterpreterShim, LoopBody, LoopBodyBlock
|
||||
from .utils import cache_on_self, dominated_nodes
|
||||
from .virtualized import V
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ import threading
|
|||
import warnings
|
||||
from bisect import bisect_right
|
||||
from copy import copy
|
||||
from ctypes import c_void_p, cdll, CDLL
|
||||
from ctypes import c_void_p, CDLL, cdll
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from time import time, time_ns
|
||||
|
|
@ -58,6 +58,7 @@ from torch._inductor.codegen.rocm.compile_command import (
|
|||
rocm_compiler,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
codecache.py, cpp_builder.py and cpu_vec_isa.py import rule:
|
||||
https://github.com/pytorch/pytorch/issues/124245#issuecomment-2197778902
|
||||
|
|
@ -86,7 +87,6 @@ from torch._inductor.runtime.compile_tasks import (
|
|||
)
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
|
||||
from torch._inductor.utils import ALIGN_BYTES, clear_on_fresh_inductor_cache, is_linux
|
||||
|
||||
from torch._logging import trace_structured
|
||||
from torch._subclasses.fake_tensor import (
|
||||
extract_tensor_metadata,
|
||||
|
|
@ -95,6 +95,7 @@ from torch._subclasses.fake_tensor import (
|
|||
)
|
||||
from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from concurrent.futures import Future
|
||||
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE
|
||||
|
||||
|
||||
# It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like:
|
||||
# "...
|
||||
# from ..codecache import CudaKernelParamCache
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
|
||||
|
||||
# Provide aoti module launch hip/cuda drivers. This file is also used for unit testing purpose
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -21,8 +21,8 @@ from torch.utils import _pytree as pytree
|
|||
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
|
||||
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
|
||||
from ..._dynamo.utils import counters
|
||||
|
||||
from ..._dynamo.utils import counters
|
||||
from .. import codecache, config, cpp_builder, cpu_vec_isa, ir, metrics
|
||||
from ..codegen.wrapper import WrapperCodeGen
|
||||
from ..optimize_indexing import range_expressable_in_32_bits
|
||||
|
|
@ -46,7 +46,6 @@ from ..utils import (
|
|||
sympy_product,
|
||||
sympy_subs,
|
||||
)
|
||||
|
||||
from ..virtualized import NullKernelHandler, ops, OpsValue, V
|
||||
from .common import (
|
||||
BackendFeature,
|
||||
|
|
@ -63,7 +62,6 @@ from .common import (
|
|||
OpOverrides,
|
||||
OptimizationContext,
|
||||
)
|
||||
|
||||
from .cpp_utils import (
|
||||
cexpr,
|
||||
cexpr_index,
|
||||
|
|
@ -74,6 +72,7 @@ from .cpp_utils import (
|
|||
value_to_cpp,
|
||||
)
|
||||
|
||||
|
||||
_IS_WINDOWS = sys.platform == "win32"
|
||||
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,19 +6,19 @@ from unittest.mock import patch
|
|||
|
||||
import torch
|
||||
import torch.utils
|
||||
|
||||
from ..._dynamo.utils import counters
|
||||
from .. import ir, lowering as L
|
||||
|
||||
from ..kernel.mm_common import mm_args
|
||||
from ..select_algorithm import DataProcessorTemplateWrapper
|
||||
from ..utils import cache_on_self, has_free_symbols, parallel_num_threads
|
||||
from ..virtualized import ops, V
|
||||
from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm, LayoutType
|
||||
from .cpp_template import CppTemplate
|
||||
|
||||
from .cpp_template_kernel import CppTemplateKernel
|
||||
from .cpp_utils import GemmBlocking, get_gemm_template_output_and_compute_dtype
|
||||
|
||||
|
||||
GEMM_TEMPLATE = r"""
|
||||
{{template.header().getvalue()}}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import ctypes
|
|||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
|
||||
import sys
|
||||
from typing import Callable, List, Optional
|
||||
from unittest.mock import patch
|
||||
|
|
@ -17,6 +16,7 @@ from ..virtualized import V
|
|||
from .common import KernelTemplate
|
||||
from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from sympy.parsing.sympy_parser import parse_expr
|
|||
|
||||
import torch
|
||||
from torch.utils._sympy.symbol import SymT
|
||||
from .. import config, cpp_builder, ir, lowering as L
|
||||
|
||||
from .. import config, cpp_builder, ir, lowering as L
|
||||
from ..autotune_process import CppBenchmarkRequest
|
||||
from ..select_algorithm import PartialRender
|
||||
from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
import contextlib
|
||||
import copy
|
||||
import math
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from unittest.mock import patch
|
||||
|
|
@ -11,12 +10,13 @@ import sympy
|
|||
|
||||
import torch
|
||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
|
||||
from .. import ir
|
||||
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
|
||||
from ..virtualized import V
|
||||
|
||||
from .common import CSEVariable, ExprPrinter, Kernel, KernelArgs
|
||||
|
||||
|
||||
DTYPE_TO_CPP = {
|
||||
torch.float32: "float",
|
||||
torch.float64: "double",
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ import sympy
|
|||
from sympy import Expr
|
||||
|
||||
import torch
|
||||
|
||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
import torch._ops
|
||||
from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes
|
||||
|
||||
from .. import config, ir
|
||||
from ..utils import _align, ALIGN_BYTES, cache_on_self, sympy_product
|
||||
from ..virtualized import V
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from .cpp_utils import DTYPE_TO_CPP
|
|||
from .cpp_wrapper_cpu import CppWrapperCpu
|
||||
from .wrapper import SymbolicCallArg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..graph import GraphLowering
|
||||
|
||||
|
|
|
|||
|
|
@ -3,16 +3,15 @@ import logging
|
|||
from typing import cast, Sequence
|
||||
|
||||
from ...._dynamo.utils import counters
|
||||
|
||||
from ... import config
|
||||
from ...codecache import code_hash, get_path
|
||||
|
||||
from ...ir import CUDATemplateBuffer
|
||||
from ...scheduler import BaseSchedulerNode, BaseScheduling, Scheduler, SchedulerNode
|
||||
from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product
|
||||
from ...virtualized import V
|
||||
from ..common import IndentedBuffer
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import torch
|
|||
|
||||
from ... import config
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,9 +15,9 @@ from ...ir import (
|
|||
from ...utils import sympy_product
|
||||
from ...virtualized import V
|
||||
from ..common import IndentedBuffer, Kernel, OpOverrides
|
||||
|
||||
from ..cpp_utils import CppPrinter, DTYPE_TO_CPP
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.codegen.cuda.cuda_template import CUDATemplate
|
||||
|
||||
|
|
|
|||
|
|
@ -8,14 +8,15 @@ from unittest.mock import patch
|
|||
import sympy
|
||||
|
||||
import torch
|
||||
|
||||
from ...autotune_process import CUDABenchmarkRequest, TensorMeta
|
||||
from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout
|
||||
|
||||
from ...utils import IndentedBuffer, unique
|
||||
from ...virtualized import V
|
||||
from ..common import KernelTemplate
|
||||
from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from ..cutlass_utils import try_import_cutlass
|
||||
|
||||
|
||||
if try_import_cutlass():
|
||||
import enum
|
||||
|
||||
from cutlass_library.library import * # noqa: F401, F403
|
||||
from cutlass_library.gemm_operation import * # noqa: F401, F403
|
||||
from cutlass_library.library import * # noqa: F401, F403
|
||||
|
||||
# copied / modified from original at
|
||||
# https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import functools
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional
|
||||
|
|
@ -11,12 +10,13 @@ from typing import Any, List, Optional
|
|||
import sympy
|
||||
|
||||
import torch
|
||||
|
||||
from ... import config
|
||||
from ...ir import Layout
|
||||
|
||||
from ...runtime.runtime_utils import cache_dir
|
||||
from .cuda_env import get_cuda_arch, get_cuda_version
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,11 +17,11 @@ from ...ir import (
|
|||
ReinterpretView,
|
||||
)
|
||||
from ..common import IndentedBuffer
|
||||
|
||||
from . import cutlass_utils
|
||||
from .cuda_kernel import CUDATemplateKernel
|
||||
from .cuda_template import CUTLASSTemplate
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Jinja template for GEMM Kernel, used by the CUTLASSGemmTemplate class below.
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from ..scheduler import (
|
|||
)
|
||||
from .cuda.cuda_cpp_scheduling import CUDACPPScheduling
|
||||
from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling
|
||||
|
||||
from .triton import TritonScheduling
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ import sympy
|
|||
|
||||
import torch
|
||||
import torch._logging
|
||||
|
||||
from ..._prims_common import is_integer_dtype
|
||||
from ...utils._sympy.functions import FloorDiv, ModularIndexing
|
||||
from ...utils._sympy.symbol import symbol_is_type, SymT
|
||||
|
|
@ -34,7 +35,6 @@ from ..codecache import HalideCodeCache
|
|||
from ..ir import get_reduction_combine_fn
|
||||
from ..metrics import is_metric_table_enabled, log_kernel_metadata
|
||||
from ..ops_handler import AddParenHandler, MockHandler
|
||||
|
||||
from ..runtime.hints import HalideInputSpec, HalideMeta, ReductionHint
|
||||
from ..utils import (
|
||||
get_bounds_index_expr,
|
||||
|
|
@ -58,6 +58,7 @@ from .cpp import DTYPE_TO_CPP
|
|||
from .cpp_utils import cexpr
|
||||
from .simd import constant_repr, SIMDKernel, SIMDScheduling
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..ops_handler import ReductionType, StoreMode
|
||||
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ from typing import Any, Dict, Iterable, List, Optional, Protocol
|
|||
import sympy
|
||||
|
||||
import torch
|
||||
|
||||
from .. import config, ir
|
||||
from ..utils import _align, align, cache_on_self, CachedMethod, IndentedBuffer
|
||||
from ..virtualized import V
|
||||
|
||||
from .wrapper import (
|
||||
AllocateLine,
|
||||
FreeIfNotReusedLine,
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from ..utils import cache_on_self
|
|||
from ..virtualized import V
|
||||
from .common import TensorArg
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from torch._inductor.ir import Buffer, Layout
|
|||
|
||||
from ...utils import IndentedBuffer, try_import_ck_lib
|
||||
|
||||
|
||||
_, gen_ops_library, gen_ops_preselected, CKGemmOperation = try_import_ck_lib()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import List, Optional
|
|||
from torch._inductor import config
|
||||
from torch._inductor.utils import is_linux
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ from ctypes import byref, c_size_t, c_void_p
|
|||
from typing import Any, Callable, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from torch._inductor.autotune_process import GPUDeviceBenchmarkRequest, TensorMeta
|
||||
from torch._inductor.codecache import DLLWrapper, ROCmCodeCache
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@ from ...scheduler import BaseSchedulerNode, BaseScheduling, Scheduler, Scheduler
|
|||
from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product
|
||||
from ...virtualized import V
|
||||
from ..common import IndentedBuffer
|
||||
|
||||
from .rocm_template_buffer import ROCmTemplateBuffer
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,12 +6,11 @@ from ...ir import Buffer, ChoiceCaller, IRNode, Layout, PrimitiveInfoType, Tenso
|
|||
from ...utils import sympy_product
|
||||
from ...virtualized import V
|
||||
from ..common import IndentedBuffer, Kernel, OpOverrides
|
||||
|
||||
from ..cpp_utils import CppPrinter
|
||||
|
||||
from .rocm_benchmark_request import ROCmBenchmarkRequest
|
||||
from .rocm_template_buffer import ROCmTemplateBuffer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.codegen.rocm.rocm_template import ROCmTemplate
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import sympy
|
|||
|
||||
from ...autotune_process import TensorMeta
|
||||
from ...ir import Buffer, IRNode, Layout
|
||||
|
||||
from ...utils import IndentedBuffer, unique
|
||||
from ...virtualized import V
|
||||
from ..common import KernelTemplate
|
||||
|
|
@ -17,6 +16,7 @@ from .rocm_benchmark_request import ROCmBenchmarkRequest
|
|||
from .rocm_kernel import ROCmTemplateCaller, ROCmTemplateKernel
|
||||
from .rocm_template_buffer import ROCmTemplateBuffer
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -28,13 +28,12 @@ import sympy
|
|||
|
||||
import torch
|
||||
import torch._logging
|
||||
|
||||
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
|
||||
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
||||
|
||||
from ..._dynamo.utils import counters
|
||||
from .. import config, ir, scheduler
|
||||
from ..codecache import code_hash
|
||||
|
||||
from ..dependencies import Dep, MemoryDep, StarDep, WeakDep
|
||||
from ..ir import TritonTemplateBuffer
|
||||
from ..optimize_indexing import indexing_dtype_strength_reduction
|
||||
|
|
|
|||
|
|
@ -27,14 +27,13 @@ import sympy
|
|||
import torch
|
||||
import torch._logging
|
||||
from torch._dynamo.utils import preserve_rng_state
|
||||
|
||||
from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
|
||||
from torch._prims_common import is_integer_dtype
|
||||
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
|
||||
from torch.utils._triton import has_triton_package
|
||||
|
||||
from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT
|
||||
from ...utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
from .. import config, ir
|
||||
from ..codecache import code_hash, get_path, PyCodeCache
|
||||
from ..metrics import is_metric_table_enabled, log_kernel_metadata
|
||||
|
|
@ -73,6 +72,7 @@ from .simd import (
|
|||
)
|
||||
from .triton_utils import config_of, signature_of, signature_to_meta
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..ir import IRNode
|
||||
|
||||
|
|
|
|||
|
|
@ -1,16 +1,12 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
|
||||
from typing import Optional, Set
|
||||
|
||||
import torch._inductor.runtime.hints
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.simd import IterationRangesRoot
|
||||
|
||||
from torch._inductor.codegen.triton import triton_compute_type, TritonKernel
|
||||
|
||||
from torch._prims_common import prod
|
||||
|
||||
from torch.utils._sympy.functions import CeilDiv
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import inspect
|
|||
import logging
|
||||
import operator
|
||||
import re
|
||||
|
||||
import tempfile
|
||||
from itertools import count
|
||||
from typing import (
|
||||
|
|
@ -41,7 +40,6 @@ from torch.utils._sympy.singleton_int import SingletonInt
|
|||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
|
||||
from .. import async_compile, config, ir
|
||||
|
||||
from ..codecache import output_code_log
|
||||
from ..ir import ReinterpretView
|
||||
from ..runtime import triton_heuristics
|
||||
|
|
@ -58,6 +56,7 @@ from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
|||
from .common import CodeGen, DeferredLine, IndentedBuffer, PythonPrinter
|
||||
from .triton_utils import config_of, signature_to_meta
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import triton
|
||||
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ from enum import IntEnum
|
|||
import sympy
|
||||
|
||||
import torch
|
||||
from . import ir
|
||||
|
||||
from . import ir
|
||||
from .utils import get_dtype_size, sympy_product
|
||||
from .virtualized import V
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from __future__ import annotations
|
|||
|
||||
import heapq
|
||||
import operator
|
||||
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Set, TYPE_CHECKING
|
||||
|
|
@ -15,6 +14,7 @@ from . import config, ir
|
|||
from .dependencies import WeakDep
|
||||
from .utils import is_collective, is_wait
|
||||
|
||||
|
||||
overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
|
|||
|
|
@ -8,15 +8,12 @@ import sys
|
|||
import time
|
||||
import warnings
|
||||
from itertools import count
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from unittest import mock
|
||||
|
||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
|
||||
import torch.fx
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch._dynamo import (
|
||||
compiled_autograd,
|
||||
|
|
@ -77,6 +74,7 @@ from .utils import (
|
|||
)
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
if config.is_fbcode():
|
||||
from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from torch._inductor.compile_worker.subproc_pool import SubprocMain
|
|||
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
|
||||
from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_set_triton_ptxas_path()
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from typing import Any, Callable, Dict
|
|||
from torch._inductor import config
|
||||
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1078,5 +1078,6 @@ if TYPE_CHECKING:
|
|||
|
||||
from torch.utils._config_module import install_config_module
|
||||
|
||||
|
||||
# adds patch, save_config, etc
|
||||
install_config_module(sys.modules[__name__])
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Optional, Tuple
|
|||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
# We would like to split modules into two subgraphs for runtime weight updates to work correctly.
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from torch._inductor import config, exc
|
|||
from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
|
||||
|
||||
if config.is_fbcode():
|
||||
from triton.fb import build_paths # noqa: F401
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import dataclasses
|
|||
import functools
|
||||
import os
|
||||
import platform
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
|
@ -12,6 +11,7 @@ from typing import Any, Callable, Dict, List
|
|||
import torch
|
||||
from torch._inductor import config
|
||||
|
||||
|
||||
_IS_WINDOWS = sys.platform == "win32"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,6 @@ import traceback
|
|||
import warnings
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
|
||||
from enum import auto, Enum
|
||||
from typing import (
|
||||
Any,
|
||||
|
|
@ -91,6 +90,7 @@ from torch.storage import UntypedStorage
|
|||
from torch.utils import _pytree as pytree
|
||||
from torch.utils.weak import TensorWeakRef
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.types import _bool
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
|
||||
|
||||
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,10 +15,8 @@ from typing import Any, Dict, List, Optional
|
|||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled
|
||||
from torch import fx as fx
|
||||
|
||||
from torch._dynamo.repro.after_aot import save_graph_repro, wrap_compiler_debug
|
||||
from torch._dynamo.utils import get_debug_dir
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
|
@ -36,6 +34,7 @@ from .scheduler import (
|
|||
)
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
SchedulerNodeList = List[Any]
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from .utils import (
|
|||
use_scatter_fallback,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from .utils import (
|
|||
)
|
||||
from .virtualized import OpsHandler, ReductionType, V
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
is_indirect = re.compile(r"indirect|tmp").search
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import tempfile
|
|||
import textwrap
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1":
|
||||
|
||||
@lru_cache(None)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||
|
||||
import itertools
|
||||
import logging
|
||||
|
||||
import weakref
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
|
|
@ -13,12 +12,12 @@ from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code
|
|||
from torch._functorch.aot_autograd import MutationType
|
||||
from torch._functorch.compile_utils import fx_graph_cse
|
||||
from torch._inductor.constant_folding import constant_fold, replace_node_with_constant
|
||||
|
||||
from torch._inductor.fx_passes.freezing_patterns import freezing_passes
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
|
||||
from . import config
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from ..._dynamo.utils import counters
|
||||
from ..ir import FixedLayout
|
||||
from ..pattern_matcher import (
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@ import functools
|
|||
import itertools
|
||||
|
||||
import torch
|
||||
from ..._dynamo.utils import counters
|
||||
|
||||
from ..._dynamo.utils import counters
|
||||
from ..pattern_matcher import Arg, CallFunction, KeywordArg
|
||||
from .freezing_patterns import register_binary_folding_pattern
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from .. import config
|
|||
from ..fx_utils import get_fake_args_kwargs
|
||||
from ..virtualized import V
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
logger: logging.Logger = logging.getLogger("comm_fusion")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ from torch import Tensor
|
|||
from torch._dynamo.utils import counters
|
||||
|
||||
from .. import config
|
||||
|
||||
from ..pattern_matcher import Arg, CallFunction, Match, register_graph_pattern
|
||||
from .split_cat import construct_pattern_matcher_pass
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import config as inductor_config
|
||||
from torch.func import functional_call
|
||||
|
|
@ -12,7 +11,6 @@ from ..pattern_matcher import (
|
|||
Match,
|
||||
register_graph_pattern,
|
||||
)
|
||||
|
||||
from .pre_grad import efficient_conv_bn_eval_pass
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ import functools
|
|||
|
||||
import torch
|
||||
from torch._inductor.compile_fx import fake_tensor_prop
|
||||
from ..._dynamo.utils import counters
|
||||
|
||||
from ..._dynamo.utils import counters
|
||||
from .. import config
|
||||
from ..pattern_matcher import (
|
||||
_return_true,
|
||||
|
|
@ -20,6 +20,7 @@ from ..pattern_matcher import (
|
|||
stable_topological_sort,
|
||||
)
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
# First pass_patterns[0] are applied, then [1], then [2]
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import math
|
|||
|
||||
import torch
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
|
||||
from ..._dynamo.utils import counters
|
||||
from ..pattern_matcher import (
|
||||
filter_nodes,
|
||||
|
|
@ -14,6 +15,7 @@ from ..pattern_matcher import (
|
|||
joint_fwd_bwd,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from ..pattern_matcher import (
|
|||
stable_topological_sort,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# importing this will register fbgemm lowerings for inductor
|
||||
import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings # noqa: F401
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from ..pattern_matcher import (
|
|||
)
|
||||
from .replace_random import replace_random_passes
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
patterns = PatternMatcherPass()
|
||||
aten = torch.ops.aten
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ from dataclasses import dataclass
|
|||
from typing import cast, List, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from .. import inductor_prims
|
||||
|
||||
from .. import inductor_prims
|
||||
from ..pattern_matcher import (
|
||||
CallFunction,
|
||||
Ignored,
|
||||
|
|
@ -16,6 +16,7 @@ from ..pattern_matcher import (
|
|||
register_graph_pattern,
|
||||
)
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
patterns = PatternMatcherPass()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
|
||||
from typing import Dict, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
|
||||
from ..pattern_matcher import fwd_only, register_replacement
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,11 +5,9 @@ from functools import reduce
|
|||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import has_free_symbols
|
||||
|
||||
from .. import ir
|
||||
|
||||
from ..lowering import lowerings as L
|
||||
from ..pattern_matcher import (
|
||||
Arg,
|
||||
|
|
@ -28,6 +26,7 @@ from .quantization import (
|
|||
_register_woq_lowerings,
|
||||
)
|
||||
|
||||
|
||||
if torch._C._has_mkldnn:
|
||||
aten = torch.ops.aten
|
||||
mkldnn = torch.ops.mkldnn
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import torch.optim as optim
|
|||
|
||||
from .. import config
|
||||
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
MAIN_RANDOM_SEED = 1337
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ from torch._subclasses.fake_tensor import FakeTensor
|
|||
from torch.utils._mode_utils import no_dispatch
|
||||
|
||||
from ...utils._triton import has_triton
|
||||
|
||||
from ..pattern_matcher import (
|
||||
fwd_only,
|
||||
gen_register_replacement,
|
||||
|
|
@ -33,6 +32,7 @@ from ..pattern_matcher import (
|
|||
SearchFn,
|
||||
)
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,9 +14,7 @@ from torch._decomp import register_decomposition
|
|||
from torch._dynamo.utils import counters, optimus_scuba_log
|
||||
from torch._inductor import comms
|
||||
from torch._inductor.virtualized import ops
|
||||
|
||||
from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
|
||||
|
||||
from torch._utils_internal import upload_graph
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
|
||||
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
||||
|
|
@ -24,7 +22,6 @@ from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
|||
from .. import config, ir, pattern_matcher
|
||||
from ..codegen.common import BackendFeature, has_backend_feature
|
||||
from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage
|
||||
|
||||
from ..lowering import lowerings as L
|
||||
from ..pattern_matcher import (
|
||||
_return_true,
|
||||
|
|
@ -54,6 +51,7 @@ from .pre_grad import is_same_dict, save_inductor_dict
|
|||
from .reinplace import reinplace_inplaceable_ops
|
||||
from .split_cat import POST_GRAD_PATTERNS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sympy import Expr
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ from torch.nn import functional as F
|
|||
from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights
|
||||
|
||||
from .. import config
|
||||
|
||||
from ..fx_utils import matches_module_function_pattern
|
||||
from ..pattern_matcher import (
|
||||
init_once_fakemode,
|
||||
|
|
@ -30,6 +29,7 @@ from .group_batch_fusion import group_batch_fusion_passes, PRE_GRAD_FUSIONS
|
|||
from .misc_patterns import numpy_compat_normalization
|
||||
from .split_cat import PRE_GRAD_PATTERNS
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
efficient_conv_bn_eval_pass = PatternMatcherPass(
|
||||
|
|
|
|||
|
|
@ -10,12 +10,14 @@ import torch
|
|||
from torch._dynamo.utils import counters
|
||||
from torch.fx.experimental.symbolic_shapes import has_free_symbols
|
||||
from torch.fx.node import map_arg
|
||||
|
||||
from ..lowering import lowerings as L, require_channels_last
|
||||
from ..pattern_matcher import Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match
|
||||
from ..utils import pad_listlike
|
||||
from .freezing_patterns import register_freezing_graph_pattern
|
||||
from .post_grad import register_lowering_pattern
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from torch.fx.immutable_collections import immutable_dict
|
|||
from torch.fx.passes.reinplace import _is_view_op
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import logging
|
|||
import torch
|
||||
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
|
||||
from .. import config, inductor_prims
|
||||
from ..pattern_matcher import (
|
||||
CallFunctionVarArgs,
|
||||
|
|
@ -14,6 +15,7 @@ from ..pattern_matcher import (
|
|||
)
|
||||
from ..virtualized import V
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
patterns = PatternMatcherPass()
|
||||
aten = torch.ops.aten
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from ..pattern_matcher import (
|
|||
)
|
||||
from .group_batch_fusion import is_node_meta_valid, POST_GRAD_FUSIONS, PRE_GRAD_FUSIONS
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_Arguments: TypeAlias = Tuple[torch.fx.node.Argument, ...]
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||
)
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -93,12 +93,14 @@ from .utils import (
|
|||
)
|
||||
from .virtualized import NullHandler, V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._higher_order_ops.effects import _EffectType
|
||||
from .codegen.wrapper import WrapperCodeGen
|
||||
|
||||
from torch._inductor.codecache import output_code_log
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
import contextlib
|
||||
from typing import Callable, List, TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
|
|
|
|||
|
|
@ -31,8 +31,8 @@ import torch
|
|||
from torch._prims_common import dtype_to_type, is_integer_dtype
|
||||
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
|
||||
from .utils import generate_assert
|
||||
|
||||
from .utils import generate_assert
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import Optional, Sequence
|
|||
import torch
|
||||
from torch import _prims, Tensor
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -31,9 +31,7 @@ import sympy
|
|||
from sympy import Expr, Integer
|
||||
|
||||
import torch._export.serde.schema as export_schema
|
||||
|
||||
import torch._logging
|
||||
|
||||
import torch.fx
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
|
|
@ -90,6 +88,7 @@ from .utils import (
|
|||
)
|
||||
from .virtualized import ops, V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import GraphLowering
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union
|
|||
import sympy
|
||||
|
||||
import torch
|
||||
|
||||
from .ir import Pointwise, TensorBox
|
||||
from .lowering import fallback_handler, is_integer_type, register_lowering
|
||||
from .virtualized import ops
|
||||
|
|
|
|||
|
|
@ -16,11 +16,10 @@ from ..utils import (
|
|||
use_triton_template,
|
||||
)
|
||||
from ..virtualized import V
|
||||
|
||||
from .mm import _is_static_problem
|
||||
|
||||
from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from typing import cast, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypedDi
|
|||
import torch
|
||||
|
||||
from .. import config, ir
|
||||
|
||||
from ..lowering import (
|
||||
add_layout_constraint,
|
||||
constrain_to_fx_strides,
|
||||
|
|
@ -31,6 +30,7 @@ from ..utils import (
|
|||
from ..virtualized import V
|
||||
from .mm_common import filtered_configs
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..ir import TensorBox
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import sympy
|
|||
import torch
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from .. import config
|
||||
from ..ir import (
|
||||
ComputedBuffer,
|
||||
|
|
@ -23,6 +24,7 @@ from ..ir import (
|
|||
from ..lowering import empty, empty_strided, lowerings, register_lowering
|
||||
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import sympy
|
|||
|
||||
import torch
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
from ..ir import FixedLayout, FlexibleLayout
|
||||
from ..lowering import empty_strided, lowerings
|
||||
from ..runtime.runtime_utils import next_power_of_2
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional
|
|||
import torch
|
||||
from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
from .. import config as inductor_config
|
||||
from ..codegen.common import BackendFeature
|
||||
from ..codegen.cuda.gemm_template import CUTLASSGemmTemplate
|
||||
|
|
@ -39,6 +40,7 @@ from .mm_common import (
|
|||
triton_config,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from .. import config as inductor_config
|
|||
from ..runtime.runtime_utils import next_power_of_2
|
||||
from ..utils import ceildiv as cdiv
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from ..utils import use_aten_gemm_kernels, use_triton_template
|
|||
from ..virtualized import V
|
||||
from .mm_common import mm_args, mm_grid, mm_options
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
aten_mm_plus_mm = ExternKernelChoice(
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import List, TYPE_CHECKING
|
|||
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
||||
from .mm_common import mm_args, mm_configs, mm_grid, mm_options
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..ir import ChoiceCaller
|
||||
|
||||
|
|
|
|||
|
|
@ -42,8 +42,8 @@ from torch.utils._sympy.functions import (
|
|||
IntTrueDiv,
|
||||
ModularIndexing,
|
||||
)
|
||||
from .._dynamo.utils import import_submodule
|
||||
|
||||
from .._dynamo.utils import import_submodule
|
||||
from . import config, inductor_prims, ir, test_operators # NOQA: F401
|
||||
from .decomposition import decompositions, get_decompositions
|
||||
from .ir import (
|
||||
|
|
@ -72,6 +72,7 @@ from .utils import (
|
|||
)
|
||||
from .virtualized import ops, V
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
lowerings: Dict[torch._ops.OpOverload, Callable[..., Any]] = {}
|
||||
layout_constraints: Dict[torch._ops.OpOverload, Callable[..., Any]] = {}
|
||||
|
|
@ -6108,6 +6109,7 @@ def resize(x, size, *, memory_format=None):
|
|||
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
||||
|
||||
make_fallback(auto_functionalized)
|
||||
|
||||
|
||||
|
|
@ -6395,17 +6397,21 @@ except (AttributeError, ImportError):
|
|||
# populate lowerings defined in kernel/*
|
||||
from . import kernel
|
||||
|
||||
|
||||
import_submodule(kernel)
|
||||
|
||||
from . import quantized_lowerings
|
||||
|
||||
|
||||
quantized_lowerings.register_quantized_ops()
|
||||
quantized_lowerings.register_woq_mm_ops()
|
||||
|
||||
from . import mkldnn_lowerings
|
||||
|
||||
|
||||
mkldnn_lowerings.register_onednn_fusion_ops()
|
||||
|
||||
from . import jagged_lowerings
|
||||
|
||||
|
||||
jagged_lowerings.register_jagged_ops()
|
||||
|
|
|
|||
|
|
@ -8,12 +8,12 @@ import os
|
|||
import re
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
|
||||
from typing import Dict, List, Set, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from torch._inductor import config
|
||||
from torch._inductor.utils import get_benchmark_name
|
||||
|
||||
|
||||
# Prevent circular import
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.scheduler import (
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from typing import Any, List, Optional, Set
|
|||
import sympy
|
||||
|
||||
import torch
|
||||
|
||||
from torch._prims_common import make_channels_last_strides_for
|
||||
|
||||
from .ir import (
|
||||
|
|
@ -22,9 +21,7 @@ from .ir import (
|
|||
NoneLayout,
|
||||
TensorBox,
|
||||
)
|
||||
|
||||
from .utils import convert_shape_to_inductor, pad_listlike
|
||||
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import List, Optional
|
|||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._inductor.kernel.mm_common import mm_args
|
||||
|
||||
from . import ir, mkldnn_ir
|
||||
from .codegen.cpp_gemm_template import CppPackedGemmTemplate
|
||||
from .ir import TensorBox
|
||||
|
|
|
|||
|
|
@ -18,8 +18,10 @@ import sympy
|
|||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
StoreMode = Optional[Literal["atomic_add"]]
|
||||
ReductionType = Literal[
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import sympy
|
|||
|
||||
import torch
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
from .ir import LoopBody
|
||||
from .utils import dominated_nodes
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import torch.utils._pytree as pytree
|
|||
from torch._inductor import config, exc
|
||||
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
|
||||
from torch.export._tree_utils import reorder_kwargs
|
||||
|
||||
from .build_package import build_package_contents
|
||||
from .pt2_archive_constants import AOTINDUCTOR_DIR, ARCHIVE_VERSION
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,6 @@ successful match or a `FailedMatch` object for a failure to match.
|
|||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import importlib
|
||||
|
|
@ -96,6 +95,7 @@ from . import config
|
|||
from .decomposition import select_decomp_table
|
||||
from .lowering import fallback_node_due_to_unsupported_type
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from . import lowering
|
||||
|
||||
|
||||
quantized = torch.ops.quantized
|
||||
_quantized = torch.ops._quantized
|
||||
aten = torch.ops.aten
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue