diff --git a/test/dynamo/test_compiler_bisector.py b/test/dynamo/test_compiler_bisector.py new file mode 100644 index 00000000000..64d70789699 --- /dev/null +++ b/test/dynamo/test_compiler_bisector.py @@ -0,0 +1,112 @@ +# Owner(s): ["module: dynamo"] + +import unittest +from contextlib import contextmanager +from importlib import import_module + +import torch +import torch._prims_common as utils +from torch._dynamo.test_case import TestCase +from torch._inductor import config +from torch._inductor.bisect_helper import BisectionManager +from torch.testing._internal.inductor_utils import HAS_CUDA + + +aten = torch.ops.aten + +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") + +f32 = torch.float32 +i64 = torch.int64 +i32 = torch.int32 + + +@requires_cuda +class TestCompilerBisector(TestCase): + def test_bad_decomp(self): + mod = import_module("torch._inductor.compile_fx") + + def bad_exp_decomp(self, rate=1, generator=None): + assert generator is None + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_integer_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"Exponential distribution is a continuous probability distribution. \ + dtype must be a floating point but you specified {self.dtype}", + ) + torch._check( + rate > 0.0, + lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}", + ) + return torch.rand_like(self) * float("nan") + + @contextmanager + def patch_exp_decomp(): + from torch._inductor.compile_fx import select_decomp_table as old_decomp + + def get_decomp(): + out = old_decomp() + out = out.copy() + out[aten.exponential.default] = bad_exp_decomp + return out + + torch._inductor.compile_fx.select_decomp_table = get_decomp + try: + yield + + finally: + torch._inductor.compile_fx.select_decomp_table = old_decomp + + def vq(x): + return (x + 3).exponential_() * 10.5 + + def test_fn(): + torch._dynamo.reset() + with patch_exp_decomp(): + vq_compiled = torch.compile(vq) + x = torch.randn(4, 400, 256).cuda() + with torch._dynamo.utils.preserve_rng_state(): + out = vq(x) + out_compiled = vq_compiled(x) + + return not out_compiled.isnan().any() + + out = BisectionManager.do_bisect(test_fn) + self.assertEqual(out.backend, "aot_eager_decomp_partition") + self.assertEqual(out.subsystem, "decomposition") + self.assertEqual(out.bisect_number, 1) + self.assertTrue("aten.exponential" in out.debug_info) + + def test_bad_lowering(self): + def test_fn(): + torch._dynamo.reset() + with config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy"): + + def my_func(x): + return ((x * -1) - 0.01).relu() + + inp = torch.rand([100], device="cuda") + + return torch.allclose(torch.compile(my_func)(inp), my_func(inp)) + + out = BisectionManager.do_bisect(test_fn) + self.assertEqual(out.backend, "inductor") + self.assertEqual(out.subsystem, "lowerings") + self.assertEqual(out.bisect_number, 2) + self.assertTrue("relu" in out.debug_info) + + def test_eager_backend(self): + # should indicate problem with first backend + def test_fn(): + return False + + out = BisectionManager.do_bisect(test_fn) + self.assertEqual(out.backend, "eager") + self.assertEqual(out.subsystem, None) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/run_test.py b/test/run_test.py index b2b04cfdad5..846990f6153 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -220,6 +220,7 @@ RUN_PARALLEL_BLOCKLIST = [ "test_cuda_nvml_based_avail", # temporarily sets a global config "test_autograd_fallback", + "inductor/test_compiler_bisector", ] + FSDP_TEST # Test files that should always be run serially with other test files, diff --git a/torch/__init__.py b/torch/__init__.py index 1ab317925ec..fa30c53bf96 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2446,6 +2446,12 @@ def compile( ) if mode is None and options is None: mode = "default" + + from torch._inductor.bisect_helper import BisectionManager + + if bisect_backend := BisectionManager.get_backend(): + backend = bisect_backend + if backend == "inductor": backend = _TorchCompileInductorWrapper(mode, options, dynamic) else: diff --git a/torch/_inductor/bisect_helper.py b/torch/_inductor/bisect_helper.py new file mode 100644 index 00000000000..1b9284a8340 --- /dev/null +++ b/torch/_inductor/bisect_helper.py @@ -0,0 +1,470 @@ +import collections +import dataclasses +import functools +import os +import shutil +import sys +from typing import Callable, Dict, List, Optional, Tuple + +from torch._inductor.runtime.cache_dir_utils import cache_dir + + +# Set the subdirectory name +SUBDIR_NAME = "bisect" + +# Dictionary of backend -> subsystems +BACKENDS: Dict[str, List[str]] = { + # run dynamo without aot_autograd + "eager": [], + # run dynamo with aot_autograd, but no partitioner or decomps + "aot_eager": [], + # run dynamo with aot autograd, decompositions and partitioner + "aot_eager_decomp_partition": [ + "decomposition" # number of decompositions we apply in tracing + ], # TODO - add cse ? + "inductor": [ + "post_grad_passes", # passes applied individually on forward, and backward in inductor + "lowerings", # lowering aten operators to inductor + ], # TODO - add more - fusions, amp numeric mode ? +} + +subsystem_call_counter: Dict[str, int] = collections.Counter() +call_counter_debug_info: Dict[int, str] = {} + + +def reset_counters() -> None: + subsystem_call_counter.clear() + call_counter_debug_info.clear() + + +@functools.lru_cache(None) +def get_env_val(env_str: str) -> Optional[str]: + return os.environ.get(env_str, None) + + +@dataclasses.dataclass +class BisectionResult: + """ + backend: torch.compile backend responsible for failure + subsystem: optional, registered component identified for failure + bisect_number: optional, number of times the subsystem needed to be applied to trigger failure + debug_info: associated info of the triggering bisect application of subsystem + """ + + backend: str + subsystem: Optional[str] = None + bisect_number: Optional[int] = None + debug_info: Optional[str] = None + + +class BisectionManager: + bisection_enabled: bool = False + + @classmethod + def get_dir(cls) -> str: + return f"{cache_dir()}/{SUBDIR_NAME}" + + @classmethod + def write_lines_to_file(cls, file_path: str, lines: List[str]) -> None: + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "w") as file: + file.writelines(lines) + + @classmethod + def read_lines_from_file(cls, file_path: str) -> List[str]: + if os.path.exists(file_path): + with open(file_path) as file: + return file.readlines() + return [] + + @classmethod + def update_run_state( + cls, backend_name: str, subsystem_name: str, run_state: str + ) -> None: + file_path = os.path.join( + cls.get_dir(), backend_name, f"{subsystem_name}_run_state.txt" + ) + cls.write_lines_to_file(file_path, [run_state]) + + @classmethod + def update_bisect_status(cls, backend_name: str, subsystem_name: str) -> None: + file_path = os.path.join(cls.get_dir(), "bisect_status.txt") + lines = [f"backend={backend_name}\n", f"subsystem={subsystem_name}\n"] + cls.write_lines_to_file(file_path, lines) + + @classmethod + def update_bisect_range( + cls, backend_name: str, subsystem_name: str, low: int, high: int + ) -> None: + file_path = os.path.join( + cls.get_dir(), backend_name, f"{subsystem_name}_bisect_range.txt" + ) + lines = [f"low={low}\n", f"high={high}\n"] + cls.write_lines_to_file(file_path, lines) + + @classmethod + def get_backend(cls) -> Optional[str]: + """ + Returns the active backend, if any + """ + if val := get_env_val("TORCH_BISECT_BACKEND"): + return val + + file_path = os.path.join(cls.get_dir(), "bisect_status.txt") + lines = cls.read_lines_from_file(file_path) + for line in lines: + if line.startswith("backend="): + return line.strip().split("=")[1] + return None + + @classmethod + def get_subsystem(cls) -> Optional[str]: + """ + Returns the active subsystem, if any + """ + + if val := get_env_val("TORCH_BISECT_SUBSYSTEM"): + return val + + file_path = os.path.join(cls.get_dir(), "bisect_status.txt") + lines = cls.read_lines_from_file(file_path) + for line in lines: + if line.startswith("subsystem="): + return line.strip().split("=")[1] + return None + + @classmethod + def get_run_state(cls, backend_name: str, subsystem_name: str) -> Optional[str]: + """ + Returns the current stage of bisecting, if Any + """ + + file_path = os.path.join( + cls.get_dir(), backend_name, f"{subsystem_name}_run_state.txt" + ) + lines = cls.read_lines_from_file(file_path) + if lines: + out = lines[0].strip() + assert out in ("test_disable", "find_max_bounds", "bisect") + return out + return None + + @classmethod + def get_bisect_range( + cls, backend_name: str, subsystem_name: str + ) -> Tuple[int, int]: + file_path = os.path.join( + cls.get_dir(), backend_name, f"{subsystem_name}_bisect_range.txt" + ) + lines = cls.read_lines_from_file(file_path) + low = None + high = None + for line in reversed(lines): + if line.startswith("low="): + low = int(line.strip().split("=")[1]) + elif line.startswith("high="): + high = int(line.strip().split("=")[1]) + + if low is not None and high is not None: + break + + if low is None or high is None: + raise RuntimeError( + f"Trying to get bisect range when it is not set: subsystem {subsystem_name}" + ) + + return low, high + + @classmethod + def delete_bisect_status(cls) -> None: + if os.path.exists(cls.get_dir()): + shutil.rmtree(cls.get_dir()) + print("Bisection status deleted.") + else: + print("No bisection status found.") + + @classmethod + def get_system_counter(cls, name: str, increment: bool = True) -> int: + global subsystem_call_counter + curr = subsystem_call_counter[name] + if increment: + subsystem_call_counter[name] += 1 + return curr + + @classmethod + def disable_subsystem( + cls, + backend: str, + subsystem: str, + debug_info: Optional[Callable[[], str]] = None, + ) -> bool: + if not cls.bisection_enabled: + return False + + if cls.get_backend() != backend: + return False + + if cls.get_subsystem() != subsystem: + return False + + if val := get_env_val("TORCH_BISECT_MAX"): + counter = cls.get_system_counter(subsystem, increment=True) + return counter > int(val) + + run_state = cls.get_run_state(backend, subsystem) + if run_state == "test_disable": + # First run, disable completely + return True + elif run_state == "find_max_bounds": + # Second run, update bisection range and return True to enable the subsystem + cls.update_bisect_range( + backend, + subsystem, + 0, + cls.get_system_counter(subsystem, increment=True), + ) + return False + else: + assert run_state == "bisect" + # If the environment variable is not set, use the bisection range midpoint + low, high = cls.get_bisect_range(backend, subsystem) + # if high - low <= 2: + midpoint = (low + high) // 2 + call_counter = cls.get_system_counter(subsystem) + + if ( + call_counter >= low + and call_counter <= high + and (low - high) <= 2 + and debug_info is not None + ): + call_counter_debug_info[call_counter] = debug_info() + + return call_counter > midpoint + + @classmethod + def advance_subsystem(cls, curr_backend: str, curr_subsystem: str) -> Optional[str]: + """ + Tries to move to the next subsystem within the current system. + """ + print(f"Disabling {curr_subsystem} did not fix the issue.") + + current_subsystems = BACKENDS[curr_backend] + current_subsystem_index = current_subsystems.index(curr_subsystem) + + if current_subsystem_index < len(current_subsystems) - 1: + curr_subsystem = current_subsystems[current_subsystem_index + 1] + cls.update_bisect_status(curr_backend, curr_subsystem) + cls.update_run_state(curr_backend, curr_subsystem, "test_disable") + print(f"Moving to the next subsystem: {curr_backend} - {curr_subsystem}") + return curr_subsystem + else: + print( + f"All subsystems in {curr_backend} have been checked. The issue is not in this system." + ) + return None + + @classmethod + def advance_backend(cls, curr_backend: str) -> Optional[str]: + """ + Tries Move to the next backend. + """ + current_system_index = list(BACKENDS.keys()).index(curr_backend) + + if current_system_index < len(BACKENDS) - 1: + curr_backend = list(BACKENDS.keys())[current_system_index + 1] + cls.update_bisect_status(curr_backend, "") + print(f"Moving to the next system: {curr_backend}") + return curr_backend + else: + return None + + @classmethod + def perform_bisection( + cls, + curr_backend: str, + curr_subsystem: str, + fn: Callable[[], bool], + cli_interface: bool = True, + ) -> bool: + """ + Perform the bisection process for the current system and subsystem. Returns True if the issue is found, False otherwise. + """ + while True: + run_state = cls.get_run_state(curr_backend, curr_subsystem) + reset_counters() + if run_state == "test_disable": + if not fn(): + next_subsystem = cls.advance_subsystem(curr_backend, curr_subsystem) + if not next_subsystem: + return False + curr_subsystem = next_subsystem + else: + # breakpoint() + print( + f"Disabling {curr_subsystem} fixed the issue. Starting bisect by getting upper bound." + ) + cls.update_run_state( + curr_backend, curr_subsystem, "find_max_bounds" + ) + elif run_state == "find_max_bounds": + if fn(): + raise RuntimeError( + f"Function succeeded with 'find_max_bounds' status for {curr_backend} - {curr_subsystem}." + ) + else: + _, high = cls.get_bisect_range(curr_backend, curr_subsystem) + print(f"Upper bound of {high} found for {curr_backend}.") + cls.update_run_state(curr_backend, curr_subsystem, "bisect") + elif run_state == "bisect": + low, high = cls.get_bisect_range(curr_backend, curr_subsystem) + midpoint = (low + high) // 2 + print( + f"Bisecting {curr_backend} - {curr_subsystem} (Range: [{low}, {high}], Midpoint: {midpoint})" + ) + if fn(): + cls.update_bisect_range( + curr_backend, curr_subsystem, midpoint + 1, high + ) + else: + cls.update_bisect_range(curr_backend, curr_subsystem, low, midpoint) + low, high = cls.get_bisect_range(curr_backend, curr_subsystem) + if low == high: + print( + f"Binary search completed for {curr_backend} - {curr_subsystem}. The bisect number is {low}. " + f"Debug info: {call_counter_debug_info.get(low, 'not found')}" + ) + return True + else: + raise RuntimeError(f"Unexpected run_state {run_state}") + + if cli_interface: + sys.exit(0) + + @classmethod + def initialize_system(cls) -> None: + curr_backend = next(iter(BACKENDS.keys())) + curr_subsystem = "" + cls.update_bisect_status(curr_backend, curr_subsystem) + print(f"Starting bisection process with system: {curr_backend}") + + @classmethod + def do_bisect( + cls, fn: Callable[[], bool], cli_interface: bool = False + ) -> Optional[BisectionResult]: + if not cli_interface: + bisection_enabled_orig = cls.bisection_enabled + cls.delete_bisect_status() + cls.bisection_enabled = True + + # TODO - cli interface, and in-process different directories + class DisableBisect: + def __del__(self) -> None: + cls.bisection_enabled = bisection_enabled_orig + cls.delete_bisect_status() + + cleanup = DisableBisect() + + curr_backend = cls.get_backend() + curr_subsystem = cls.get_subsystem() + + if not curr_backend: + cls.initialize_system() + curr_backend = cls.get_backend() + curr_subsystem = cls.get_subsystem() + + while True: + assert curr_backend is not None + reset_counters() + if curr_subsystem: + result = cls.perform_bisection( + curr_backend, curr_subsystem, fn, cli_interface=cli_interface + ) + if result: + curr_subsystem = cls.get_subsystem() + assert curr_subsystem is not None + low, _ = cls.get_bisect_range(curr_backend, curr_subsystem) + return BisectionResult( + curr_backend, + curr_subsystem, + low, + call_counter_debug_info.get(low, None), + ) + + next_subsystem = cls.advance_subsystem(curr_backend, curr_subsystem) + if not next_subsystem: + print( + f"The issue is in the {curr_backend} system, but could not identify subsystem." + ) + assert curr_backend is not None + return BisectionResult(curr_backend) + + curr_subsystem = next_subsystem + else: + if fn(): + next_backend = cls.advance_backend(curr_backend) + if not next_backend: + print("All systems have been checked.") + return None + + curr_backend = next_backend + else: + current_subsystems = BACKENDS[curr_backend] + if current_subsystems: + curr_subsystem = current_subsystems[0] + cls.update_bisect_status(curr_backend, curr_subsystem) + cls.update_run_state( + curr_backend, curr_subsystem, "test_disable" + ) + print( + f"The issue is in the {curr_backend} system. Moving to the first subsystem: {curr_subsystem}" + ) + else: + print(f"The issue is in the {curr_backend} system.") + return BisectionResult(curr_backend) + + if cli_interface: + sys.exit(0) + + +def command_line_usage() -> None: + if len(sys.argv) < 2: + print("Usage: python bisect_update.py ") + sys.exit(1) + + bisection_manager = BisectionManager() + command = sys.argv[1] + + if command == "end": + bisection_manager.delete_bisect_status() + sys.exit(0) + + if command == "start": + bisection_manager.delete_bisect_status() + bisection_manager.initialize_system() + sys.exit(0) + + if command not in ["good", "bad"]: + print("Invalid command. Must be 'good', 'bad', 'start', or 'end'.") + sys.exit(1) + + def test_function() -> bool: + return command == "good" + + if not bisection_manager.get_backend(): + raise ValueError("Must call start prior to good or bad") + + bisection_manager.do_bisect(test_function, cli_interface=True) + + +def get_is_bisection_enabled() -> bool: + return ( + BisectionManager.get_subsystem() is not None + or BisectionManager.get_backend() is not None + ) + + +BisectionManager.bisection_enabled = get_is_bisection_enabled() + +if __name__ == "__main__": + command_line_usage() diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 4907a76f95c..e2ab200f71c 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1285,6 +1285,12 @@ class FxGraphCache: "Freezing may introduce constants that aren't static across runs" ) + from torch._inductor.bisect_helper import BisectionManager + + if BisectionManager.bisection_enabled: + log.debug("dont cache graph when bisect enabled") + raise BypassFxGraphCache + # The treatment of guards in the caching implementation requires that # we have a shape env. if FxGraphCache._get_shape_env() is None: diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 89b9f8785f7..2490a0bb6f4 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -4,7 +4,7 @@ import itertools import logging import operator from collections import Counter, defaultdict -from typing import Any, Dict, List, Optional, Set +from typing import Any, Callable, Dict, List, Optional, Set import torch import torch._inductor as inductor @@ -65,6 +65,19 @@ pass_patterns = [ ] +def apply_pass(pass_fn: Callable[[], object], name: Optional[str] = None) -> None: + # TODO - we should just make this part of GraphTransformObserver + from torch._inductor.bisect_helper import BisectionManager + + debug_info: Optional[Callable[[], str]] = None + if name is not None: + debug_info = lambda: name # noqa: E731 + + if BisectionManager.disable_subsystem("inductor", "post_grad_passes", debug_info): + return + pass_fn() + + def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): """ Passes that run on after grad. This is called once on the forwards @@ -80,23 +93,28 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): gm.graph.eliminate_dead_code() if is_inference and config.reorder_for_locality: - reorder_for_locality(gm.graph) + apply_pass(lambda: reorder_for_locality(gm.graph), "reorder_for_locality") fake_tensor_updater = FakeTensorUpdater(gm.graph) - if config.post_grad_custom_pre_pass is not None: + if post_grad_custom_pre_pass := config.post_grad_custom_pre_pass: with GraphTransformObserver( gm, "post_grad_custom_pre_pass", config.trace.log_url_for_graph_xform ): - config.post_grad_custom_pre_pass(gm.graph) + apply_pass( + lambda: post_grad_custom_pre_pass(gm.graph), "post_grad_custom_pre_pass" + ) if config.pattern_matcher: lazy_init() optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph) - group_batch_fusion_passes(gm.graph, pre_grad=False) - remove_noop_ops(gm.graph) - for patterns in pass_patterns: - patterns.apply(gm.graph) # type: ignore[arg-type] + apply_pass( + lambda: group_batch_fusion_passes(gm.graph, pre_grad=False), + "group_batch_fusion_passes", + ) + apply_pass(lambda: remove_noop_ops(gm.graph), "remove_noop_ops") + for i, patterns in enumerate(pass_patterns): + apply_pass(lambda: patterns.apply(gm.graph), f"pass_pattern_{i}") # type: ignore[arg-type] for pass_name in config.post_grad_fusion_options: # skip all patterns for group batch fusions if pass_name in POST_GRAD_FUSIONS: @@ -105,7 +123,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): inductor_before_change = save_inductor_dict( [pattern_matcher_pass.pass_name] ) - pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] + apply_pass(lambda: pattern_matcher_pass.apply(gm.graph), pass_name) # type: ignore[arg-type] if not is_same_dict(counters["inductor"], inductor_before_change): optimus_scuba_log[ f"{pattern_matcher_pass.pass_name}_post_grad" @@ -117,30 +135,40 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): micro_pipeline_tp_pass(gm.graph) if config._fuse_ddp_communication: - fuse_ddp_communication( - gm.graph, - config._fuse_ddp_communication_passes, - config._fuse_ddp_bucket_size, + apply_pass( + lambda: fuse_ddp_communication( + gm.graph, + config._fuse_ddp_communication_passes, + config._fuse_ddp_bucket_size, + ), + "fuse_ddp_communication", ) - if config.post_grad_custom_post_pass is not None: + if post_grad_custom_post_pass := config.post_grad_custom_post_pass: with GraphTransformObserver( gm, "post_grad_custom_post_pass", config.trace.log_url_for_graph_xform ): - config.post_grad_custom_post_pass(gm.graph) + apply_pass( + lambda: post_grad_custom_post_pass(gm.graph), + "post_grad_custom_post_pass", + ) - stable_topological_sort(gm.graph) + apply_pass(lambda: stable_topological_sort(gm.graph), "stable_sort") - move_constructors_to_gpu(gm.graph) + apply_pass(lambda: move_constructors_to_gpu(gm.graph), "move_constructors_to_cuda") fake_tensor_updater.incremental_update() # Keep these last, since they introduces mutation. Look at # ./fx_passes/README.md for a discussion of mutation invariants. - reinplace_inplaceable_ops(gm.graph) - decompose_auto_functionalized(gm.graph) + apply_pass(lambda: reinplace_inplaceable_ops(gm.graph), "reinplace_inplaceable_ops") + apply_pass( + lambda: decompose_auto_functionalized(gm.graph), "decompose_auto_functionalized" + ) - comms.reinplace_fsdp_all_gather(gm.graph) + apply_pass( + lambda: comms.reinplace_fsdp_all_gather(gm.graph), "reinplace_fsdp_all_gather" + ) gm.recompile() optimus_scuba_log["after_recompile_post_grad"] = upload_graph(gm.graph) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 7859e8126cc..cebd8089a35 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1304,6 +1304,8 @@ class GraphLowering(torch.fx.Interpreter): def debug(msg: str) -> None: log.debug("lowering %s %s", LazyString(n.format_node), msg) + from torch._inductor.bisect_helper import BisectionManager + buffer_watermark = len(self.buffers) operation_watermark = len(self.operations) @@ -1320,7 +1322,12 @@ class GraphLowering(torch.fx.Interpreter): if ( n.op == "call_function" and n.target is not operator.getitem - and fallback_node_due_to_unsupported_type(n) + and ( + fallback_node_due_to_unsupported_type(n) + or BisectionManager.disable_subsystem( + "inductor", "lowerings", lambda: repr(n) + ) + ) ): debug("fallback_handler") result = fallback_handler(n.target, add_to_fallback_set=False)( diff --git a/torch/_inductor/runtime/cache_dir_utils.py b/torch/_inductor/runtime/cache_dir_utils.py new file mode 100644 index 00000000000..1a2aabc572c --- /dev/null +++ b/torch/_inductor/runtime/cache_dir_utils.py @@ -0,0 +1,23 @@ +import getpass +import os +import re +import tempfile + + +# Factoring out to file without torch dependencies + + +def cache_dir() -> str: + cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") + if cache_dir is None: + os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = default_cache_dir() + os.makedirs(cache_dir, exist_ok=True) + return cache_dir + + +def default_cache_dir() -> str: + sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser()) + return os.path.join( + tempfile.gettempdir(), + "torchinductor_" + sanitized_username, + ) diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index 446dbc71c61..e7e25876632 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -3,13 +3,13 @@ from __future__ import annotations import contextlib import functools -import getpass import operator -import os -import re -import tempfile import torch +from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401 + cache_dir, + default_cache_dir, +) def conditional_product(*args): @@ -86,22 +86,6 @@ def get_max_y_grid(): return 65535 -def cache_dir() -> str: - cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") - if cache_dir is None: - os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = default_cache_dir() - os.makedirs(cache_dir, exist_ok=True) - return cache_dir - - -def default_cache_dir(): - sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser()) - return os.path.join( - tempfile.gettempdir(), - "torchinductor_" + sanitized_username, - ) - - try: import colorama diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 5bf9cfae810..4f7895fdca5 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -2202,7 +2202,14 @@ def maybe_handle_decomp( args: Tuple[object, ...], kwargs: Dict[str, object], ) -> object: + from torch._inductor.bisect_helper import BisectionManager + if op in CURRENT_DECOMPOSITION_TABLE: + if BisectionManager.disable_subsystem( + "aot_eager_decomp_partition", "decomposition", lambda: repr(op) + ): + return NotImplemented + with proxy_mode: proxy_mode.decomp_layers += 1 out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)