diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index b99e9d0c94d..2fe8f5dd2e3 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -758,11 +758,10 @@ const auto sinc_string = jiterator_stringify( T sinc(T a) { if (a == T(0)) { return T(1); - } else { - constexpr T pi = T(3.14159265358979323846L); - T product = pi * a; - return std::sin(product) / product; } + constexpr T pi = T(3.14159265358979323846L); + T product = pi * a; + return std::sin(product) / product; } ); // sinc_string diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh b/benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh old mode 100644 new mode 100755 diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 388b8d1a5f6..c80c46dc1e7 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -6,15 +6,15 @@ add_loop_eager_dynamic,compile_time_instruction_count,5703000000,0.025 -add_loop_inductor,compile_time_instruction_count,30150000000,0.015 +add_loop_inductor,compile_time_instruction_count,29630000000,0.015 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44440000000,0.025 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43980000000,0.025 -add_loop_inductor_gpu,compile_time_instruction_count,26740000000,0.015 +add_loop_inductor_gpu,compile_time_instruction_count,26240000000,0.015 @@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18980000000, -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17250000000,0.015 +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17150000000,0.015 @@ -62,4 +62,4 @@ aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3863000000, -aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10340000000,0.015 +aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10390000000,0.015 diff --git a/buckbuild.bzl b/buckbuild.bzl index 17153b5df77..65141ac9b5a 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -997,6 +997,7 @@ def define_buck_targets( "Config.h": ":generate_aten_config[Config.h]", }, labels = labels, + visibility = ["PUBLIC"], ) fb_xplat_cxx_library( diff --git a/c10/build.bzl b/c10/build.bzl index d4192a46852..6ecae511223 100644 --- a/c10/build.bzl +++ b/c10/build.bzl @@ -22,3 +22,15 @@ def define_targets(rules): [], ), ) + + rules.cc_library( + name = "c10_headers", + deps = [ + "//c10/core:base_headers", + "//c10/macros", + "//c10/util:base_headers", + "//c10/util:bit_cast", + "//c10/util:ssize", + ], + visibility = ["//visibility:public"], + ) diff --git a/c10/core/build.bzl b/c10/core/build.bzl index 45fc5ea3390..fe9a31a2da4 100644 --- a/c10/core/build.bzl +++ b/c10/core/build.bzl @@ -90,6 +90,22 @@ def define_targets(rules): alwayslink = True, ) + rules.cc_library( + name = "base_headers", + srcs = [], + hdrs = rules.glob( + [ + "*.h", + "impl/*.h", + ], + exclude = [ + "CPUAllocator.h", + "impl/alloc_cpu.h", + ], + ), + visibility = ["//visibility:public"], + ) + rules.filegroup( name = "headers", srcs = rules.glob( @@ -101,5 +117,5 @@ def define_targets(rules): "alignment.h", ], ), - visibility = ["//c10:__pkg__"], + visibility = ["//visibility:public"], ) diff --git a/c10/metal/special_math.h b/c10/metal/special_math.h index 8bcb1f7a53e..04fd7eee18f 100644 --- a/c10/metal/special_math.h +++ b/c10/metal/special_math.h @@ -477,5 +477,31 @@ inline float2 sinc(float2 inp) { return float2(re, im) / a2; } +template +inline T spherical_bessel_j0(T x) { + if (::metal::isinf(x)) + return T(0.0); + T x2 = x * x; + T k1 = static_cast(-1.0); + T k2 = static_cast(1.0); + + if (::metal::abs(x) < T(0.5)) { + return T(1.0) + + x2 * + (k1 / T(6.0) + + x2 * + (k2 / T(120.0) + + x2 * + (k1 / T(5040.0) + + x2 * + (k2 / T(362880.0) + + x2 * + (k1 / T(39916800.0) + + x2 * (k2 / T(6227020800.0))))))); + } + + return ::metal::sin(x) / x; +} + } // namespace metal } // namespace c10 diff --git a/c10/util/build.bzl b/c10/util/build.bzl index a6f95ae7516..5e1dc6fbfbf 100644 --- a/c10/util/build.bzl +++ b/c10/util/build.bzl @@ -80,6 +80,18 @@ def define_targets(rules): ], ) + rules.cc_library( + name = "base_headers", + hdrs = rules.glob( + ["*.h"], + exclude = [ + "bit_cast.h", + "ssize.h", + ], + ), + visibility = ["//visibility:public"], + ) + rules.filegroup( name = "headers", srcs = rules.glob( diff --git a/test/distributed/tensor/test_convolution_ops.py b/test/distributed/tensor/test_convolution_ops.py index 867e89e778b..5d40a18f067 100644 --- a/test/distributed/tensor/test_convolution_ops.py +++ b/test/distributed/tensor/test_convolution_ops.py @@ -5,13 +5,15 @@ import copy import torch import torch.nn as nn -from torch.distributed._tensor import ( - DeviceMesh, +from torch.distributed import DeviceMesh, init_device_mesh +from torch.distributed.tensor import ( distribute_module, distribute_tensor, + DTensor, Replicate, Shard, ) +from torch.nn import functional as F from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -181,6 +183,28 @@ class DistConvolutionOpsTest(DTensorTestBase): f"Too large relative mse for bias tensor, expected less equal 1e-6, got {bias_mse_rel}", ) + @with_comms + @skip_if_lt_x_gpu(2) + def test_conv_backward_none_grad_inp(self): + device_mesh = init_device_mesh( + device_type="cuda", mesh_shape=(self.world_size,) + ) + conv = nn.Conv2d(64, 64, 3, padding=1).train() + x = torch.randn(1, 64, 32, 32) + x_dt = DTensor.from_local(x, device_mesh, [Replicate()]) + w = conv.weight + w_dt = torch.nn.Parameter(DTensor.from_local(w, device_mesh, [Replicate()])) + + b = conv.bias + b_dt = torch.nn.Parameter(DTensor.from_local(b, device_mesh, [Replicate()])) + + res = F.conv2d(x_dt, w_dt, b_dt, padding=1) + dres = torch.rand_like(res) + res.backward(dres) + self.assertTrue(w_dt.grad is not None) + self.assertTrue(b_dt.grad is not None) + self.assertTrue(x_dt.grad is None) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 2ab444a4b68..522b6815ada 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2007,7 +2007,7 @@ class DistributedDataParallelTest( replica_devices = [dev0] # Tells _test_grad_layout to construct ConvNet with all layers on this process's first assigned device. layer_devs = dev0 - local_batch_size = 8 + local_batch_size = 16 self._test_grad_layout(replica_devices, layer_devs, local_batch_size) @requires_nccl() @@ -2021,7 +2021,7 @@ class DistributedDataParallelTest( replica_devices = None # Tells _test_grad_layout to constructs this process's ConvNet on 2 devices, with 2 layers on each device. layer_devs = [dev0] * 2 + [dev1] * 2 - local_batch_size = 8 + local_batch_size = 16 self._test_grad_layout(replica_devices, layer_devs, local_batch_size) @requires_nccl() diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 12258b956bc..c874a578b4e 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -1744,10 +1744,11 @@ class GraphModule(torch.nn.Module): class ContextlibContextManagerTests(torch._dynamo.test_case.TestCase): def setUp(self): + self._prev = torch._dynamo.config.enable_trace_contextlib torch._dynamo.config.enable_trace_contextlib = True def tearDown(self): - torch._dynamo.config.enable_trace_contextlib = False + torch._dynamo.config.enable_trace_contextlib = self._prev def test_ctx_basic0(self): @contextlib.contextmanager @@ -2236,7 +2237,7 @@ class GraphModule(torch.nn.Module): eager = EagerAndRecordGraphs() out = torch.compile(backend=eager, fullgraph=False)(fn)(x) self.assertEqual(expected, out) - self.assertEqual(len(eager.graphs), 1) + self.assertEqual(len(eager.graphs), 0) def test_graph_break_before_and_after___enter__(self): @contextlib.contextmanager @@ -2262,7 +2263,7 @@ class GraphModule(torch.nn.Module): eager = EagerAndRecordGraphs() out = torch.compile(backend=eager, fullgraph=False)(fn)(x) self.assertEqual(expected, out) - self.assertEqual(len(eager.graphs), 1) + self.assertEqual(len(eager.graphs), 0) def test_graph_break_before___enter___and_disable___exit__(self): @contextlib.contextmanager @@ -2292,7 +2293,7 @@ class GraphModule(torch.nn.Module): eager = EagerAndRecordGraphs() out = torch.compile(backend=eager, fullgraph=False)(fn)(x) self.assertEqual(expected, out) - self.assertEqual(len(eager.graphs), 1) + self.assertEqual(len(eager.graphs), 0) def test_disable___enter__(self): def h(x): @@ -2573,7 +2574,7 @@ class GraphModule(torch.nn.Module): eager = EagerAndRecordGraphs() out = torch.compile(backend=eager, fullgraph=False)(fn)(x) self.assertEqual(expected, out) - self.assertEqual(len(eager.graphs), 1) + self.assertEqual(len(eager.graphs), 0) def test_dynamo_disable_ctx(self): @contextlib.contextmanager @@ -2623,7 +2624,7 @@ class GraphModule(torch.nn.Module): eager = EagerAndRecordGraphs() out = torch.compile(backend=eager, fullgraph=False, dynamic=False)(f)(x) self.assertEqual(expected, out) - self.assertEqual(len(eager.graphs), 3) + self.assertEqual(len(eager.graphs), 2) @parametrize("name", ("suppress", "stdout", "stderr")) def test_contextlib_suppress(self, name): diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index a72b87d8f6d..94d722fc059 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -404,6 +404,21 @@ class ExceptionTests(torch._dynamo.test_case.TestCase): self.assertEqual(ref[0], res[0]) self.assertEqual(ref[1], res[1]) + def test_raise_GeneratorExit(self): + # GeneratorExit does not inherit from Exception + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + try: + raise GeneratorExit + except Exception: + return t.sin() + except BaseException: + return t.cos() + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t.cos()) + def test_speculation_exception(self): log = SpeculationLog() log.next("fake", 555, "fake", Instruction(1, "fake", 1, 1)) diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py new file mode 100644 index 00000000000..764db540575 --- /dev/null +++ b/test/dynamo/test_generator.py @@ -0,0 +1,1809 @@ +# Owner(s): ["module: dynamo"] +import itertools +import sys +import unittest +from collections import OrderedDict + +import torch +import torch._dynamo.test_case +import torch._dynamo.testing +from torch._dynamo.exc import InternalTorchDynamoError, Unsupported +from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm +from torch._dynamo.utils import counters +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) + + +class GeneratorTestsBase(torch._dynamo.test_case.TestCase): + def setUp(self): + super().setUp() + self._old = torch._dynamo.config.enable_faithful_generator_behavior + torch._dynamo.config.enable_faithful_generator_behavior = True + + def tearDown(self): + super().tearDown() + torch._dynamo.config.enable_faithful_generator_behavior = self._old + + def _compile_check(self, fn, args=None, fullgraph=True): + eager = EagerAndRecordGraphs() + if args is None: + args = (torch.randn(2),) + r = torch.compile(fn, backend=eager, fullgraph=fullgraph)(*args) + self.assertGreater(len(eager.graphs), 0) + return r + + +class GeneratorTests(GeneratorTestsBase): + def test_generator_simple(self): + def whoo(): + yield 1 + yield 2 + yield 3 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo() + t = t + next(gen) + t = t + next(gen) + t = t + next(gen) + return t + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t + 6) + + def test_infinite_generator(self): + def whoo(): + i = 0 + while True: + yield i + i += 1 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo() + t = t + next(gen) + t = t + next(gen) + t = t + next(gen) + return t + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t + 3) + + def test_infinite_generator_2(self): + def whoo(t): + i = 0 + while True: + yield t + i + i += 1 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + return list(zip(range(3), whoo(t))) + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, list(zip(range(3), whoo(t)))) + + def test_infinite_generator_3(self): + def whoo(i): + while True: + yield i + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + return list(zip(range(3), whoo(1))), t.sin() + + t = torch.randn(2) + y, _ = fn(t) + self.assertEqual(y, list(zip(range(3), whoo(1)))) + + def test_graph_break_in_generator(self): + def whoo(): + yield 1 + torch._dynamo.graph_break() + yield 2 + + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=False) + def fn(t): + gen = whoo() + s = next(gen) + s += next(gen) + return t + s + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t + 3) + self.assertEqual(len(eager.graphs), 0) + + def test_graph_break_in_generator_2(self): + def whoo(x): + yield x.sin() + torch._dynamo.graph_break() + yield x.cos() + + def call_whoo(x): + gen = whoo(x) + sin = next(gen) + cos = next(gen) + return sin, cos + + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=False) + def fn(t): + sin, cos = call_whoo(t) + return sin + cos + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t.sin() + t.cos()) + self.assertEqual(len(eager.graphs), 1) + self.assertExpectedInline( + normalize_gm(eager.graphs[0].print_readable(False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_stack0_0_: "f32[2]", L_stack0_1_: "f32[2]"): + l_stack0_0_ = L_stack0_0_ + l_stack0_1_ = L_stack0_1_ + + add: "f32[2]" = l_stack0_0_ + l_stack0_1_; l_stack0_0_ = l_stack0_1_ = None + return (add,) +""", + ) + + def test_reconstruct_generator_with_local_var_mutation(self): + def whoo(t): + x = 0 + yield t.sin() + x + x += 1 + yield t.cos() + x + x += 1 + yield t.tan() + x + + @torch.compile(backend="eager", fullgraph=False) + def fn(t): + gen = whoo(t) + next(gen) + return t.sin(), gen + + t = torch.randn(2) + y, g = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(list(g), [t.cos() + 1, t.tan() + 2]) + + def test_reconstruct_generator_with_dict_mutation(self): + counters.clear() + + def whoo(t, d): + d[2] = t + yield t.sin() + yield t.cos() + d[3] = t + 1 + yield t.tan() + + @torch.compile(backend="eager", fullgraph=False) + def fn(t, d): + gen = whoo(t, d) + next(gen) + return t.sin(), whoo(t, d) + + t = torch.randn(2) + d = {1: t} + fn(t, d) + self.assertEqual(len(counters["unimplemented"]), 1) + self.assertEqual( + dict(counters["unimplemented"]), + { + "Cannot reconstruct a generator with variable mutations. " + "Dynamo needs to fully exhaust the generator, which may cause " + "unintended variable modifications.": 1 + }, + ) + + def test_reconstruct_generator_with_dict_mutation_before(self): + def whoo(t, d): + d[2] = t + yield t.sin() + yield t.cos() + yield t.tan() + + @torch.compile(backend="eager", fullgraph=False) + def fn(t, d): + gen = whoo(t, d) + next(gen) + return t.sin(), gen + + t = torch.randn(2) + d = {1: t} + y, g = fn(t, d) + self.assertEqual(y, t.sin()) + self.assertEqual(list(g), [t.cos(), t.tan()]) + self.assertEqual(d, {1: t, 2: t}) + + def test_reconstruct_generator_with_object_mutation(self): + class Counter: + def __init__(self): + self.x = 0 + + def incr(self): + self.x += 1 + + def whoo(t, c): + c.incr() + yield t.sin() + yield t.cos() + c.incr() + yield t.tan() + + @torch.compile(backend="eager", fullgraph=False) + def fn(t, c): + gen = whoo(t, c) + next(gen) + return t.sin(), gen + + t = torch.randn(2) + c = Counter() + fn(t, c) + self.assertEqual(len(counters["unimplemented"]), 1) + self.assertEqual( + dict(counters["unimplemented"]), + { + "Cannot reconstruct a generator with variable mutations. " + "Dynamo needs to fully exhaust the generator, which may cause " + "unintended variable modifications.": 1 + }, + ) + + def test_reconstruct_generator_with_object_mutation_before(self): + class Counter: + def __init__(self): + self.x = 0 + + def incr(self): + self.x += 1 + + def whoo(t, c): + c.incr() + yield t.sin() + yield t.cos() + yield t.tan() + + @torch.compile(backend="eager", fullgraph=False) + def fn(t, c): + gen = whoo(t, c) + next(gen) + # We should be able to reconstruct the generator as there's no object + # mutation after the first yield + return t.sin(), gen + + t = torch.randn(2) + c = Counter() + y, g = fn(t, c) + self.assertEqual(c.x, 1) + self.assertEqual(y, t.sin()) + self.assertEqual(list(g), [t.cos(), t.tan()]) + + def test_graph_break_and_reconstruct_generator(self): + def whoo(t): + yield t.sin() + yield t.cos() + yield t.tan() + + def g(t): + torch._dynamo.graph_break() + + @torch.compile(backend="eager", fullgraph=False) + def fn(t): + gen = whoo(t) + next(gen) + g(t) + return t.sin(), list(gen) + + t = torch.randn(2) + y, gen = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(list(gen), [t.cos(), t.tan()]) + + def test_graph_break_in_generator_while_reconstructing(self): + counters.clear() + + def whoo(): + yield 1 + torch._dynamo.graph_break() + yield 2 + + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=False) + def fn(t): + gen = whoo() + s = next(gen) + torch._dynamo.graph_break() + s += next(gen) + return t + s + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t + 3) + self.assertEqual(len(eager.graphs), 0) + + def test_generator_as_argument(self): + # The inline tracer needs to be kept in sync if an already advanced generator + # is given to a compiled function. + def whoo(): + yield 1 + yield 2 + yield 3 + + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=True) + def fn(t, ctx): + return t + next(ctx) + + t = torch.randn(2) + ctx = whoo() + next(ctx) + with self.assertRaisesRegex( + Unsupported, "Generator as graph argument is not supported" + ): + fn(t, ctx) + + def test_generator_as_argument_2(self): + def whoo(x): + yield x.sin() + yield x.cos() + + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=True) + def fn(t, ctx): + return t + next(ctx) + + t = torch.randn(2) + ctx = whoo(t) + next(ctx) + with self.assertRaisesRegex( + Unsupported, "Generator as graph argument is not supported" + ): + fn(t, ctx) + + def test_generator_as_argument_3(self): + # The inline tracer needs to be kept in sync if an already advanced generator + # is given to a compiled function. + def whoo(): + yield 1 + yield 2 + yield 3 + + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=True) + def fn(t, ctx): + return t + next(ctx) + + t = torch.randn(2) + ctx = whoo() + with self.assertRaisesRegex( + Unsupported, "Generator as graph argument is not supported" + ): + fn(t, ctx) + + def test_generator_as_argument_4(self): + def whoo(x): + yield x.sin() + yield x.cos() + + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=True) + def fn(t, ctx): + return t + next(ctx) + + t = torch.randn(2) + ctx = whoo(t) + with self.assertRaisesRegex( + Unsupported, "Generator as graph argument is not supported" + ): + fn(t, ctx) + + def test_islice_chain(self): + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=True) + def fn(t): + tmp1 = [t + 1, t + 2] + tmp2 = [t + 3, t + 4] + return list(itertools.chain(tmp1, tmp2)) + + t = torch.tensor([1.0]) + y = fn(t) + self.assertEqual(y, [t + 1, t + 2, t + 3, t + 4]) + + def test_zip_generator(self): + def whoo(t): + yield t + 1 + yield t + 2 + yield t + 3 + + def fn(t): + return zip(range(3), whoo(t)), t.sin() + + t = torch.randn(2) + z, _ = self._compile_check(fn, args=(t,)) + self.assertEqual(list(z), list(zip(range(3), whoo(t)))) + + @unittest.expectedFailure + def test_zip_generator_2(self): + def bar(t, i): + return t + i + + def whoo(t): + yield bar(t, 1) + yield bar(t, 2) + yield bar(t, 3) + + def fn(t): + return zip(range(3), whoo(t)) + + t = torch.randn(3) + y = self._compile_check(fn, args=(t,), fullgraph=False) + expected = list(zip(range(3), whoo(t))) + self.assertEqual(expected, list(y)) + + def test_zip_subgenerator(self): + def subgen(t): + yield t + 1 + yield t + 2 + + def whoo(t): + yield from subgen(t) + yield t + 3 + + def fn(t): + return zip(range(3), whoo(t)), t.sin() + + t = torch.randn(2) + z, _ = self._compile_check(fn, args=(t,)) + self.assertEqual(list(z), list(zip(range(3), whoo(t)))) + + def test_list_zip_generator(self): + def whoo(t): + yield t + 1 + yield t + 2 + yield t + 3 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + return list(zip(range(3), whoo(t))) + + t = torch.randn(3) + y = fn(t) + expected = list(zip(range(3), whoo(t))) + self.assertEqual(expected, y) + + def test_zip_infinite_generator(self): + def whoo(t): + i = 0 + while True: + yield t + i + i += 1 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + return list(zip(range(3), whoo(t))) + + t = torch.randn(3) + y = fn(t) + expected = list(zip(range(3), whoo(t))) + self.assertEqual(expected, y) + + @parametrize("container", [list, tuple, dict, OrderedDict]) + def test_dict_tuple_list_generator(self, container): + def whoo(t): + yield 1, t + 1 + yield 2, t + 2 + yield 3, t + 3 + + def fn(t): + gen = whoo(t) + return container(gen) + + t = torch.randn(2) + expected = fn(t) + got = torch.compile(backend="eager", fullgraph=True)(fn)(t) + self.assertEqual(expected, got) + + def test_return_generator(self): + def whoo(t): + yield t + 1 + yield t + 2 + yield t + 3 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + return gen + + t = torch.tensor([1.0]) + gen = fn(t) + self.assertEqual(list(gen), [t + 1, t + 2, t + 3]) + + def test_return_tuple_generator(self): + def whoo(t): + yield t.sin() + yield t.cos() + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + g1, g2 = whoo(t), whoo(t + 1) + return (g1, g2), t.sin() + + t = torch.randn(2) + (g1, g2), _ = fn(t) + self.assertEqual(list(g1), [t.sin(), t.cos()]) + self.assertEqual(list(g2), [(t + 1).sin(), (t + 1).cos()]) + + def test_return_advanced_generator(self): + def whoo(t): + yield t + 1 + yield t + 2 + yield t + 3 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + next(gen) + return gen + + t = torch.tensor([1.0]) + gen = fn(t) + self.assertEqual(list(gen), [t + 2, t + 3]) + + def test_return_exhaust_generator(self): + def whoo(t): + yield t + 1 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + next(gen) + return gen + + t = torch.tensor([1.0]) + gen = fn(t) + with self.assertRaises(StopIteration): + next(gen) + + @unittest.expectedFailure + def test_reconstruct_generator_tensor_mutation(self): + def whoo(t): + yield t.sin_() + yield t.cos_() + + def fn(t): + gen = whoo(t) + return gen + + with self.assertRaisesRegex( + Unsupported, + "Cannot reconstruct a generator with variable mutations", + ): + self._compile_check(fn) + + def test_subgenerator(self): + def subgen(t): + yield t + 1 + yield t + 2 + + def main_gen(t): + yield from subgen(t) + yield t + 3 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = main_gen(t) + return list(gen) + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, [t + 1, t + 2, t + 3]) + + def test_return_subgenerator(self): + def subgen(t): + yield t + 1 + yield t + 2 + + def main_gen(t): + yield from subgen(t) + yield t + 3 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = main_gen(t) + next(gen) + return gen + + t = torch.randn(2) + gen = fn(t) + self.assertEqual(list(gen), [t + 2, t + 3]) + + def test_dynamo_disable_generator(self): + @torch._dynamo.disable + def main_gen(t): + yield t + 1 + yield t + 2 + yield t + 3 + + @torch.compile(backend="eager", fullgraph=False) + def fn(t): + gen = main_gen(t) + return list(gen) + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, [t + 1, t + 2, t + 3]) + + def test_dynamo_disable_sub_generator(self): + @torch._dynamo.disable + def subgen(t): + yield t + 2 + yield t + 3 + + def main_gen(t): + yield t + 1 + yield from subgen(t) + + @torch.compile(backend="eager", fullgraph=False) + def fn(t): + gen = main_gen(t) + return list(gen) + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, [t + 1, t + 2, t + 3]) + + def test_graph_break_outside_generator(self): + def whoo(t): + yield t + 1 + yield t + 2 + + @torch.compile(backend="eager", fullgraph=False) + def fn(t): + gen = whoo(t) + x = next(gen) + torch._dynamo.graph_break() + y = next(gen) + return x + y + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, (t + 1) + (t + 2)) + + def test_graph_break_before_calling_generator(self): + def whoo(t): + for perm in itertools.product(itertools.permutations((0, 1, 2)), repeat=1): + yield sum(perm[0]) + + def fn(t): + s = 0 + for b, p in itertools.product(whoo(t), itertools.permutations((4, 5))): + s += b + return s + + t = torch.randn(2) + expected = fn(t) + got = torch.compile(backend="eager", fullgraph=False)(fn)(t) + self.assertEqual(expected, got) + + def test_generator_with_side_effects(self): + counters.clear() + i = 0 + + def whoo(t): + nonlocal i + for j in range(5): + i += 1 + yield t + j + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + return whoo(t), t.sin() + + t = torch.randn(2) + fn(t) + self.assertEqual(len(counters["unimplemented"]), 1) + self.assertEqual( + dict(counters["unimplemented"]), + { + "Cannot reconstruct a generator with variable mutations. " + "Dynamo needs to fully exhaust the generator, which may cause " + "unintended variable modifications.": 1 + }, + ) + + def test_subgenerator_with_side_effects(self): + i = 0 + + def subgen(t): + nonlocal i + i += 1 + yield t + i += 1 + yield t + 1 + + def whoo(t): + nonlocal i + yield from subgen(t) + i += 1 + yield t + 2 + i += 1 + yield t + 3 + i += 1 + yield t + 4 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + return whoo(t), t.sin() + + t = torch.randn(2) + gen, y = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(len(list(gen)), 5) + self.assertTrue( + "Cannot reconstruct a generator with variable mutations. " + "Dynamo needs to fully exhaust the generator, which may cause " + "unintended variable modifications." in dict(counters["unimplemented"]) + ) + + def test_generator_with_side_effects_graph_break(self): + i = 0 + + def whoo(t): + nonlocal i + for j in range(5): + i += 1 + yield t + j + + @torch.compile(backend="eager", fullgraph=False) + def fn(t): + gen = whoo(t) + torch._dynamo.graph_break() + next(gen) + return gen, t.sin() + + t = torch.randn(2) + gen, y = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(len(list(gen)), 4) + self.assertTrue( + "Cannot reconstruct a generator with variable mutations. " + "Dynamo needs to fully exhaust the generator, which may cause " + "unintended variable modifications." in dict(counters["unimplemented"]) + ) + + def test_generator_with_side_effects_graph_break_2(self): + i = 0 + + def whoo(t): + nonlocal i + for j in range(5): + i += 1 + yield t + j + torch._dynamo.graph_break() + + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=False) + def fn(t): + gen = whoo(t) + return list(zip(range(3), gen)) + + t = torch.randn(2) + fn(t) + self.assertEqual(len(eager.graphs), 0) + + @unittest.skipIf(sys.version_info < (3, 12), "Test CLEANUP_THROW") + @unittest.expectedFailure + def test_cleanup_throw(self): + def nested_generator(): + try: + yield 1 + yield 2 + except StopIteration: + return 123 # noqa: B901 + + def outer_generator(): + yield from nested_generator() + yield 3 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = outer_generator() + next(gen) # Start the outer generator and enter the nested generato + + i = 0 + try: + # Force an exception while the generator is running + i = gen.throw(StopIteration("stop")) + except RuntimeError: + pass + return (i, t.sin()) + + t = torch.randn(2) + i, y = self._compile_check(fn, args=(t,)) + self.assertEqual(i, 3) + self.assertEqual(y, t.sin()) + + def test_iter(self): + def whoo(): + i = 0 + while True: + yield i + i += 1 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + s = 0 + for i in whoo(): + if i > 5: + break + s += i + return t + s + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t + sum(range(6))) + + +class TestGeneratorSend(GeneratorTestsBase): + def test_send(self): + def double(): + x = yield + yield x * 2 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = double() + next(gen) + return gen.send(t) + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t * 2) + + @parametrize("fullgraph", [True, False]) + def test_send_stop_iteration(self, fullgraph): + def double(): + x = yield + yield x * 2 + + @torch.compile(backend="eager", fullgraph=fullgraph) + def fn(t): + gen = double() + next(gen) + a = gen.send(t) + b = gen.send(t) # should result in StopIteration + return a + b + + t = torch.randn(2) + if fullgraph: + with self.assertRaisesRegex(Unsupported, "Observed exception"): + fn(t) + else: + with self.assertRaises(StopIteration): + fn(t) + + +class TestGeneratorClose(GeneratorTestsBase): + def test_close(self): + def whoo(t): + yield t.sin() + yield t.cos() + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + i = next(gen) + gen.close() + return i + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t.sin()) + + def test_close_subgen(self): + z = 0 + + def subgen(t): + nonlocal z + z = 1 + yield t.sin() + z = 3 + yield t.cos() + + def whoo(t): + yield from subgen(t) + yield t.tan() + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + i = next(gen) + gen.close() + return i + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(z, 1) + + def test_close_with_side_effects(self): + L = [] + z = 0 + + def whoo(t): + nonlocal z + try: + L.append(1) + yield t.sin() + L.append(2) + yield t.cos() + finally: + L.append(z) + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + nonlocal z + gen = whoo(t) + i = next(gen) + z = -123 + gen.close() + L.append(len(L)) + return i + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(L, [1, -123, 2]) + + def test_close_capture_GeneratorExit_return(self): + z = 0 + + def whoo(t): + nonlocal z + try: + z += 1 + yield t.sin() + yield t.cos() + except GeneratorExit: + z += 10 + return t.tan() # noqa: B901 + finally: + z += 100 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + nonlocal z + gen = whoo(t) + i = next(gen) + y = gen.close() + return (i, y) + + t = torch.randn(2) + (i, y) = fn(t) + self.assertEqual(i, t.sin()) + self.assertEqual(y, t.tan()) + self.assertEqual(z, 111) + + @parametrize("fullgraph", [True, False]) + def test_close_capture_GeneratorExit(self, fullgraph): + z = 0 + + def whoo(t): + nonlocal z + try: + yield t.sin() + yield t.cos() + except GeneratorExit: + yield t.tan() + finally: + z = 1 + + @torch.compile(backend="eager", fullgraph=fullgraph) + def fn(t): + nonlocal z + gen = whoo(t) + i = next(gen) + gen.close() + return i + + t = torch.randn(2) + if fullgraph: + # This should actually be RuntimeError("generator ignored GeneratorExit") + # but Dynamo swallow the exception and raises Unsupported instead + with self.assertRaisesRegex(Unsupported, "Observed exception"): + fn(t) + else: + with self.assertRaisesRegex( + RuntimeError, "generator ignored GeneratorExit" + ): + fn(t) + + def test_close_capture_and_reraise_GeneratorExit(self): + L = [] + z = 0 + + def whoo(t): + nonlocal z + try: + L.append(1) + yield t.sin() + yield t.cos() + except GeneratorExit: + L.append(z) + z = -1 + raise + finally: + L.append(z) + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + nonlocal z + gen = whoo(t) + i = next(gen) + z = -123 + gen.close() + L.append(456) + return i + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(L, [1, -123, -1, 456]) + + @parametrize("exc", [RuntimeError, AttributeError]) + def test_close_capture_and_reraise_exc(self, exc): + def whoo(t): + try: + yield t.sin() + yield t.cos() + except GeneratorExit as e: + raise exc from e + finally: + pass + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + i = next(gen) + gen.close() + return i + + t = torch.randn(2) + with self.assertRaises(exc): + fn(t) + + def test_close_with_subgen(self): + L = [] + z = 0 + + def subgen(t): + yield t.sin() + yield t.cos() + + def whoo(t): + nonlocal z + L.append(10) + yield from subgen(t) + L.append(20) + try: + L.append(1) + z = 4 + yield t.tan() + finally: + L.append(z) + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + nonlocal z + gen = whoo(t) + i = next(gen) + z = -123 + gen.close() + L.append(456) + return i, t.sin() + + t = torch.randn(2) + y, _ = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(L, [10, 456]) + self.assertEqual(z, -123) + + def test_close_after_close(self): + z = 0 + + def whoo(t): + nonlocal z + try: + z += 1 + yield t.sin() + yield t.cos() + finally: + # finally should only be executed once + z += 1 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + i = next(gen) + gen.close() + return (i, gen.close()) + + t = torch.randn(2) + (i, y) = fn(t) + self.assertEqual(i, t.sin()) + self.assertEqual(y, None) + self.assertEqual(z, 2) + + @parametrize("fullgraph", [True, False]) + def test_next_after_close(self, fullgraph): + def whoo(t): + yield t.sin() + yield t.cos() + + @torch.compile(backend="eager", fullgraph=fullgraph) + def fn(t): + gen = whoo(t) + gen.close() + a = next(gen) + return [t.sin(), a] + + t = torch.randn(3) + if fullgraph: + with self.assertRaises(Unsupported): + fn(t) + else: + with self.assertRaises(StopIteration): + fn(t) + + def test_close_after_exception(self): + def whoo(t): + raise ValueError("foo") + yield t.cos() + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + try: + next(gen) + except ValueError: + pass + b = gen.close() + return [t.sin(), b] + + t = torch.randn(2) + y, b = fn(t) + self.assertEqual(y, t.sin()) + self.assertIsNone(b) + + def test_close_handling_finally(self): + z = 0 + + def whoo(t): + nonlocal z + try: + yield t.sin() + yield t.cos() + except GeneratorExit: + z += 1 + return t.tan() # noqa: B901 + finally: + z += 1 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + next(gen) + b = gen.close() + return t.sin(), b + + t = torch.randn(2) + y, b = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(b, t.tan()) + self.assertEqual(z, 2) + + +class TestGeneratorThrow(GeneratorTestsBase): + def test_throw(self): + def whoo(t): + try: + yield t.sin() + except RuntimeError: + yield t.cos() + + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(RuntimeError) + return a + b + + t = torch.randn(2) + y = self._compile_check(fn, (t,)) + self.assertEqual(y, t.sin() + t.cos()) + + @unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE") + def test_throw_with_finally(self): + z = 0 + + def whoo(): + nonlocal z + z = 0 + try: + try: + yield 1 + except ValueError: + yield 2 + finally: + z += 2 + except ValueError: + z += 33 + yield 4 + finally: + z += 1 + z += 10 + + def f(x): + gen = whoo() + next(gen) + gen.throw(ValueError) + return x.sin() + + self._compile_check(f) + self.assertEqual(z, 3) + + def test_throw_without_finally(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + z += 10 + except RuntimeError: + z += 100 + yield t.cos() + z += 1_000 + z += 10_000 + + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(RuntimeError) + return a + b + + t = torch.randn(2) + y = self._compile_check(fn, (t,)) + self.assertEqual(y, t.sin() + t.cos()) + self.assertEqual(z, 101) + + def test_throw_three_arguments(self): + def whoo(t): + try: + yield t.sin() + except ValueError: + yield t.cos() + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(ValueError, "Error", None) + return a + b + + t = torch.randn(2) + with self.assertRaises(InternalTorchDynamoError): + fn(t) + + def test_throw_no_yield_after_throw(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + except ValueError: + z += 10 + finally: + z += 100 + + def fn(t): + gen = whoo(t) + a = next(gen) + try: + gen.throw(ValueError) + except StopIteration: + return a + + t = torch.randn(2) + y = self._compile_check(fn, (t,)) + self.assertEqual(z, 111) + self.assertEqual(y, t.sin()) + + def test_throw_not_catch(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + except ValueError: + z += 10 + yield t.cos() + finally: + z += 100 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(RuntimeError) + return a + b + + t = torch.randn(2) + with self.assertRaises(RuntimeError): + fn(t) + + def test_throw_raise_difference_exc(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + except ValueError as e: + z += 10 + raise RuntimeError from e + finally: + z += 100 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(ValueError) + return a + b + + t = torch.randn(2) + with self.assertRaises(RuntimeError): + fn(t) + + def test_throw_yield_finally(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + except RuntimeError: + z += 10 + yield t.cos() + finally: + z += 100 + yield t.tan() # RuntimeError: generator ignored GeneratorExit + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(RuntimeError) + return a + b + + t = torch.randn(2) + with self.assertRaises(Unsupported): + fn(t) + + @unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE") + def test_throw_try_except_finally(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + except ValueError: + z += 10 + yield t.cos() + except RuntimeError: + z += 100 + yield t.tan() + finally: + z += 1000 + z += 10_000 + + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(RuntimeError) + return a + b + + t = torch.randn(2) + y = self._compile_check(fn, (t,)) + self.assertEqual(y, t.sin() + t.tan()) + self.assertEqual(z, 1 + 100 + 1000) + + def test_exception_context_with_yield(self): + def f(): + yield + + def fn(t): + gen = f() + gen.send(None) + try: + gen.throw(ValueError) + except ValueError: + z = 1 + except Exception as e: + raise AssertionError from e + assert z == 1 + return t.sin() + + self._compile_check(fn) + + +class GeneratorCloseCPythonTests(GeneratorTestsBase): + # Taken from commit + # https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10 + # changed the tests a little bit to run them inside dynamo + # + replaced all self.assert* calls to plain assert statements + + def test_close_no_return_value(self): + def f(): + yield + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = f() + gen.send(None) + assert gen.close() is None + return t.sin() + + t = torch.randn(2) + fn(t) + + def test_close_return_value(self): + def f(): + try: + yield + # close() raises GeneratorExit here, which is caught + except GeneratorExit: + return 0 # noqa: B901 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = f() + gen.send(None) + assert gen.close() == 0 + return t.sin() + + t = torch.randn(2) + fn(t) + + def test_close_not_catching_exit(self): + def f(): + yield + # close() raises GeneratorExit here, which isn't caught and + # therefore propagates -- no return value + return 0 # noqa: B901 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = f() + gen.send(None) + assert gen.close() is None + return t.sin() + + t = torch.randn(2) + fn(t) + + def test_close_not_started(self): + def f(): + try: + yield + except GeneratorExit: + return 0 # noqa: B901 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = f() + assert gen.close() is None + return t.sin() + + t = torch.randn(2) + fn(t) + + def test_close_exhausted(self): + def f(): + try: + yield + except GeneratorExit: + return 0 # noqa: B901 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = f() + next(gen) + z = 0 + try: + next(gen) # -> StopIteration + except StopIteration: + z = 1 + except Exception as e: + # anything other than StopIteration should fail + raise AssertionError from e + assert z == 1 + assert gen.close() is None + return t.sin() + + t = torch.randn(2) + fn(t) + + def test_close_closed(self): + def f(): + try: + yield + except GeneratorExit: + return 0 # noqa: B901 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = f() + gen.send(None) + assert gen.close() == 0 + assert gen.close() is None + return t.sin() + + t = torch.randn(2) + fn(t) + + def test_close_raises(self): + def f(): + try: + yield + except GeneratorExit: + pass + raise RuntimeError + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = f() + gen.send(None) + z = 0 + try: + gen.close() # -> RuntimeError + except RuntimeError: + z = 1 + except Exception as e: + raise AssertionError from e + assert z == 1 + return t.sin() + + t = torch.randn(2) + fn(t) + + +class GeneratorThrowCpythonTests(GeneratorTestsBase): + # Taken from commit + # https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10 + # changed the tests a little bit to run them inside dynamo + # + replaced all self.assert* calls to plain assert statements + + @unittest.expectedFailure + def test_exception_context_with_yield(self): + def f(): + try: + raise KeyError("a") + except Exception: + yield + + def fn(t): + gen = f() + gen.send(None) + try: + gen.throw(ValueError) + except ValueError as e: + context = e.__context__ + assert (type(context), context.args) == (KeyError, ("a",)) + except Exception as e: + raise AssertionError from e + return t.sin() + + self._compile_check(fn) + + @unittest.expectedFailure + def test_exception_context_with_yield_inside_generator(self): + # Check that the context is also available from inside the generator + # with yield, as opposed to outside. + def f(): + z = 0 + try: + raise KeyError("a") + except Exception: + try: + yield + except Exception as exc: + z = 1 + assert type(exc) == ValueError + context = exc.__context__ + assert (type(context), context.args) == (KeyError, ("a",)) + yield "b" + finally: + assert z == 1 + + def fn(t): + gen = f() + gen.send(None) + actual = gen.throw(ValueError) + # This ensures that the assertions inside were executed. + assert actual == "b" + return t.sin() + + self._compile_check(fn) + + @unittest.expectedFailure + def test_exception_context_with_yield_from(self): + def f(): + yield + + def g(): + try: + raise KeyError("a") + except Exception: + yield from f() + + def fn(t): + gen = g() + gen.send(None) + try: + gen.throw(ValueError) + except ValueError as e: + context = e.__context__ + assert (type(context), context.args) == (KeyError, ("a",)) + except Exception as e: + raise AssertionError from e + return t.sin() + + self._compile_check(fn) + + def test_exception_context_with_yield_from_with_context_cycle(self): + # Check trying to create an exception context cycle: + # https://bugs.python.org/issue40696 + has_cycle = None + + def f(): + yield + + def g(exc): + nonlocal has_cycle + try: + raise exc + except Exception: + try: + yield from f() + except Exception as exc: + has_cycle = exc is exc.__context__ + yield + + def fn(t): + exc = KeyError("a") + gen = g(exc) + gen.send(None) + gen.throw(exc) + # This also distinguishes from the initial has_cycle=None. + assert has_cycle is False + return t.sin() + + self._compile_check(fn) + + def test_throw_after_none_exc_type(self): + def g(): + try: + raise KeyError + except KeyError: + pass + + try: + yield + except Exception: + raise RuntimeError # noqa: B904 + + def fn(t): + gen = g() + gen.send(None) + z = 0 + try: + gen.throw(ValueError) + except RuntimeError: + z += 1 + except Exception: + raise AssertionError # noqa: B904 + assert z == 1 + return t.sin() + + self._compile_check(fn) + + +class GeneratorCPythonTests(GeneratorTestsBase): + # Taken from commit + # https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10 + # changed the tests a little bit to run them inside dynamo + # + replaced all self.assert* calls to plain assert statements + + def test_send_non_none_to_new_gen(self): + def f(): + yield 1 + + def fn(t): + g = f() + z = 0 + try: + g.send(0) + except TypeError: + z += 1 + except Exception as e: + raise AssertionError from e + assert z == 1 + assert next(g) == 1 + return t.sin() + + self._compile_check(fn) + + def test_issue103488(self): + def gen_raises(): + yield 1 + raise ValueError + + def loop(): + try: + for _ in gen_raises(): + if True is False: # noqa: PLR0133 + return + except ValueError: + pass + + def fn(t): + # This should not raise + loop() + return t.sin() + + self._compile_check(fn) + + +instantiate_parametrized_tests(GeneratorTests) +instantiate_parametrized_tests(TestGeneratorSend) +instantiate_parametrized_tests(TestGeneratorClose) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 3704d9e5c53..6960698382d 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -7000,6 +7000,12 @@ class TestHigherOrderOpsOpInfo(torch._dynamo.test_case.TestCase): ) def test_hops_compile(self, device, dtype, op, backend): # Ensure HOPs can be compiled + + if backend == "aot_eager" and op.name == "invoke_quant": + raise unittest.SkipTest( + "TODO: partitioner fails. migrate canonicalization to aot eager backend" + ) + sample_inputs_itr = op.sample_inputs( device, dtype, requires_grad=op.supports_autograd ) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 746f25a2e8a..239412b8370 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9579,21 +9579,6 @@ def ___make_guard_fn(): ): compiled_fn(x) - # FIXME(XuehaiPan): do not inline infinite generator if it does not raise errors in eager mode - def fn(x): - def gen(): - while True: - yield x - - return list(zip(range(10), gen())) - - x = torch.randn([0, 1, 2, 3, 4, 5]) - compiled_fn = torch.compile(fn, backend="eager", fullgraph=True) - with self.assertRaisesRegex( - torch._dynamo.exc.Unsupported, "infinite generator" - ): - compiled_fn(x) - def test_itertools_islice(self): counters.clear() diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index eae6ec46477..eb1a8d2d6ca 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1418,9 +1418,9 @@ class ReproTests(torch._dynamo.test_case.TestCase): self.assertTrue(same(opt_model(a, b, c, d), correct)) if torch._dynamo.config.assume_static_by_default: - self.assertExpectedInline(cnt.frame_count, """4""") + self.assertExpectedInline(cnt.frame_count, """2""") else: - self.assertExpectedInline(cnt.frame_count, """5""") + self.assertExpectedInline(cnt.frame_count, """3""") def test_hf_model_output(self): ex = ModelOutput(a=torch.randn(10), b=torch.randn(10), c=torch.randn(10)) @@ -6510,6 +6510,27 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): ).sum() self.assertEqual(actual, expected) + def test_incompatible_configs(self): + with torch._dynamo.config.patch( + suppress_errors=False, fail_on_recompile_limit_hit=False + ): + torch.compile(lambda: None) + + with torch._dynamo.config.patch( + suppress_errors=True, fail_on_recompile_limit_hit=False + ): + torch.compile(lambda: None) + + with torch._dynamo.config.patch( + suppress_errors=False, fail_on_recompile_limit_hit=True + ): + torch.compile(lambda: None) + + with torch._dynamo.config.patch( + suppress_errors=True, fail_on_recompile_limit_hit=True + ), self.assertRaises(AssertionError): + torch.compile(lambda: None) + class ReproTestsDevice(torch._dynamo.test_case.TestCase): def test_sub_alpha_scalar_repro(self, device): diff --git a/test/higher_order_ops/test_invoke_quant.py b/test/higher_order_ops/test_invoke_quant.py new file mode 100644 index 00000000000..96addfe1aae --- /dev/null +++ b/test/higher_order_ops/test_invoke_quant.py @@ -0,0 +1,183 @@ +# Owner(s): ["module: higher order operators"] +# flake8: noqa: B950 + +import contextlib +import logging +import unittest + +import torch +import torch._dynamo +import torch._functorch +import torch._inductor +import torch._inductor.decomposition +from torch._higher_order_ops import InvokeQuant +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + Ignored, + Match, + PatternMatcherPass, + register_graph_pattern, +) +from torch.testing import FileCheck +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase + + +invoke_quant_tracer = InvokeQuant() + + +@skipIfTorchDynamo("Not a torch._dynamo test") +class TestInvokeQuant(TestCase): + backend = "" + + def test_simple(self): + def gn(x, y): + return (torch.mul(x, y) + y,) + + def fn(x, y): + return invoke_quant_tracer( + gn, (x, y), scheme="nf4", quant_options=invoke_quant_tracer + )[0] + + x = torch.randn(8, requires_grad=False) + y = torch.randn(8, requires_grad=False) + ref = gn(x, y)[0] + + x_clone = x.clone().detach().requires_grad_(False) + y_clone = y.clone().detach().requires_grad_(False) + res = torch.compile(fn, backend=self.backend)(x_clone, y_clone) + self.assertEqual(ref, res) + + def test_construct_inline(self): + def gn(x, y): + return (torch.mul(x, y) + y,) + + def fn(x, y): + return InvokeQuant(codegen_low_precision=False)(gn, (x, y), scheme="nf4")[0] + + x = torch.randn(8, requires_grad=False) + y = torch.randn(8, requires_grad=False) + ref = gn(x, y)[0] + + x_clone = x.clone().detach().requires_grad_(False) + y_clone = y.clone().detach().requires_grad_(False) + res = torch.compile(fn, backend=self.backend)(x_clone, y_clone) + self.assertEqual(ref, res) + + def test_inline(self): + def gn(x, y): + return (torch.mul(x, y) + y,) + + def fn(x, y): + return InvokeQuant()(gn, (x, y), scheme="nf4")[0] + + x = torch.randn(8, requires_grad=False) + y = torch.randn(8, requires_grad=False) + ref = gn(x, y)[0] + + x_clone = x.clone().detach().requires_grad_(False) + y_clone = y.clone().detach().requires_grad_(False) + res = torch.compile(fn, backend=self.backend)(x_clone, y_clone) + self.assertEqual(ref, res) + + def test_multiple(self): + torch._logging.set_logs(post_grad_graphs=True) + + def gn(x, y): + return torch.mul(x, y) + y + + def fn(x, y, z): + o1 = invoke_quant_tracer(gn, (x, y), scheme="nf4") + o2 = invoke_quant_tracer(gn, (y, z), scheme="nf4") + return o1 + o2 + + x = torch.randn(8, requires_grad=False) + y = torch.randn(8, requires_grad=False) + z = torch.randn(8, requires_grad=False) + ref = fn(x, y, z) + + log_context = ( + contextlib.nullcontext() + if self.backend != "inductor" + else self.assertLogs(logger="torch._inductor", level=logging.DEBUG) + ) + + with log_context as log: + res = torch.compile(fn, backend=self.backend)(x, y, z) + + self.assertEqual(ref, res) + + if self.backend == "inductor": + logs = "\n".join(r.getMessage() for r in log.records) + f = FileCheck() + f.check("AFTER POST GRAD") + f.check("subgraph0").check("subgraph1") + for _ in range(2): + f.check("torch.ops.higher_order.invoke_quant(").check_same("nf4") + f.run(logs) + + +class TestInvokeQuantEager(TestInvokeQuant): + backend = "eager" + + +class TestInvokeQuantAotEager(TestInvokeQuant): + backend = "aot_eager" + + +class TestInvokeQuantInductor(TestInvokeQuant): + backend = "inductor" + + def test_pattern_matching(self): + counter = 0 + + test_pass = PatternMatcherPass() + + def my_pass(g): + return test_pass.apply(g) + + def gn(x, y): + return torch.mul(x, y) + y + + def fn(x, y, z): + return invoke_quant_tracer(gn, (x, y), scheme="nf4") @ z + + def fn_no_match(x, y, z): + return invoke_quant_tracer(gn, (x, y)) @ z + + x = torch.randn(64, 64, requires_grad=False) + y = torch.randn(64, 64, requires_grad=False) + z = torch.randn(64, 64, requires_grad=False) + + @register_graph_pattern( + CallFunction( + torch.ops.aten.mm, + CallFunction( + torch.ops.higher_order.invoke_quant, + Ignored(), + Ignored(), + Ignored(), + scheme="nf4", + ), + Arg(), + ), + pass_dict=test_pass, + ) + def quant_matching(match: Match, *args, **kwargs): + nonlocal counter + counter += 1 + + with unittest.mock.patch( + "torch._inductor.config.post_grad_custom_pre_pass", my_pass + ): + torch.compile(fn)(x, y, z) + self.assertTrue(counter == 1) + + torch.compile(fn_no_match)(x, y, z) + self.assertTrue(counter == 1) + + +del TestInvokeQuant + +if __name__ == "__main__": + run_tests() diff --git a/test/inductor/test_block_analysis.py b/test/inductor/test_block_analysis.py new file mode 100644 index 00000000000..5cf932d52e8 --- /dev/null +++ b/test/inductor/test_block_analysis.py @@ -0,0 +1,102 @@ +# Owner(s): ["module: inductor"] + +import sympy + +import torch +from torch._inductor.codegen.block_analysis import BlockPatternMatcher +from torch._inductor.virtualized import V +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, + TestCase, +) +from torch.testing._internal.inductor_utils import dummy_graph +from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing + + +# Some useful symbols +x, y = sympy.symbols("x y") + + +@instantiate_parametrized_tests +class BlockAnalysisTest(TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + + # Create a GraphLowering, so we can access V.graph. + cls.graph = dummy_graph() + + @parametrize( + "stride,symbol,expr", + [ + (5, x, Identity(5 * x)), + (4, y, 4 * Identity(y)), + (3, x, Identity(3) * x), + ], + ) + def test_affine_identity(self, stride: int, symbol: sympy.Symbol, expr: sympy.Expr): + # Test that we can handle an identity expression in affine indexing. + matched_stride = BlockPatternMatcher.match_affine_block_expr(expr, symbol) + self.assertEqual(matched_stride, stride) + + @parametrize( + "dims,strides,symbol,expr", + [ + ( + (2, 4), + (4, 1), + x, + 4 * FloorDiv(Identity(x), 4) + ModularIndexing(x, 1, 4), + ), + ( + (3, 9), + (5, 2), + x, + 5 * FloorDiv(x, 9) + 2 * ModularIndexing(Identity(x), 1, 9), + ), + ((2, 7), (1, 1), x, Identity(FloorDiv(x, 7) + ModularIndexing(x, 1, 7))), + ], + ) + def test_mod_div_identity( + self, + dims: tuple[int], + strides: tuple[int], + symbol: sympy.Symbol, + expr: sympy.Expr, + ): + # Test that we can handle an identity expression in modular indexing. + numel = int(torch.prod(torch.Tensor(dims))) + num_dims = len(dims) + with V.set_graph_handler(self.graph): + match_result = BlockPatternMatcher.match_mod_div_block_expr( + expr, symbol, numel, num_dims + ) + + # Check the matched block dimensions. + self.assertNotEqual(match_result, None) + matched_dims, matched_strides, matched_block_index_exprs = match_result + self.assertEqual(matched_dims, dims) + self.assertEqual(matched_strides, strides) + + @parametrize( + "symbol,expr,subexpr", + [ + (x, Identity(x), x), + (x, Identity(x + 5), x), + (y, Identity(x + 2 * y) + 5, 2 * y), + ], + ) + def test_subexpr_identity( + self, + symbol: sympy.Symbol, + expr: sympy.Expr, + subexpr: sympy.Expr, + ): + matched_subexpr = BlockPatternMatcher.get_subexpr_involving_symbol(expr, symbol) + self.assertEqual(matched_subexpr, subexpr) + + +if __name__ == "__main__": + run_tests() diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 8b4382061b0..99440593c2b 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -2510,9 +2510,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @supported_platform def test_strided_backwards(self): shape = (1, 2, 4096, 64) - Q = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16) - K = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16) - V = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16) + Q = torch.randn(shape, requires_grad=True, device="cuda") + K = torch.randn(shape, requires_grad=True, device="cuda") + V = torch.randn(shape, requires_grad=True, device="cuda") func = torch.compile(flex_attention, dynamic=True, fullgraph=True) K_sliced = K[:, :, :-128] diff --git a/test/inductor/test_op_completeness.py b/test/inductor/test_op_completeness.py index 04fac4870fd..23d59a78941 100644 --- a/test/inductor/test_op_completeness.py +++ b/test/inductor/test_op_completeness.py @@ -5,19 +5,23 @@ from torch._inductor.codegen.cpp import CppOverrides, CppVecOverrides from torch._inductor.codegen.halide import HalideOverrides from torch._inductor.codegen.mps import MetalOverrides from torch._inductor.codegen.triton import TritonKernelOverrides -from torch._inductor.ops_handler import list_ops, OP_NAMES +from torch._inductor.ops_handler import list_ops, OP_NAMES, OpsHandler from torch._inductor.test_case import TestCase class TestOpCompleteness(TestCase): def verify_ops_handler_completeness(self, handler): - op_names = list_ops(handler) - if OP_NAMES == op_names: - return - print(f"Missing ops: {OP_NAMES - op_names}") - print(f"Extra ops: {op_names - OP_NAMES}") - self.assertEqual(", ".join(OP_NAMES - op_names), "") - self.assertEqual(", ".join(op_names - OP_NAMES), "") + for op in OP_NAMES: + self.assertIsNot( + getattr(handler, op), + getattr(OpsHandler, op), + msg=f"{handler} must implement {op}", + ) + extra_ops = list_ops(handler) - OP_NAMES + if extra_ops: + raise AssertionError( + f"{handler} has an extra ops: {extra_ops}, add them to OpHandler class or prefix with `_`" + ) def test_triton_overrides(self): self.verify_ops_handler_completeness(TritonKernelOverrides) diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm1d_eval_mode_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm1d_eval_mode_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm1d_train_mode_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm1d_train_mode_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm2d_eval_mode_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm2d_eval_mode_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm2d_train_mode_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm2d_train_mode_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm3d_eval_mode_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm3d_eval_mode_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm3d_train_mode_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm3d_train_mode_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Bilinear_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Bilinear_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Conv1d_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Conv1d_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Conv2d_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Conv2d_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Conv3d_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Conv3d_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_ConvTranspose1d_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_ConvTranspose1d_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_ConvTranspose2d_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_ConvTranspose2d_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_ConvTranspose3d_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_ConvTranspose3d_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_GRUCell_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_GRUCell_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_GroupNorm_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_GroupNorm_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_LSTMCell_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_LSTMCell_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_LayerNorm_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_LayerNorm_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Linear_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Linear_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_RNNCell_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_RNNCell_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm1d_eval_mode_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm1d_eval_mode_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm1d_train_mode_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm1d_train_mode_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm2d_eval_mode_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm2d_eval_mode_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm2d_train_mode_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm2d_train_mode_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm3d_eval_mode_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm3d_eval_mode_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm3d_train_mode_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm3d_train_mode_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Bilinear_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Bilinear_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Conv1d_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Conv1d_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Conv2d_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Conv2d_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Conv3d_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Conv3d_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_ConvTranspose1d_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_ConvTranspose1d_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_ConvTranspose2d_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_ConvTranspose2d_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_ConvTranspose3d_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_ConvTranspose3d_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_GRUCell_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_GRUCell_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_GroupNorm_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_GroupNorm_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_LSTMCell_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_LSTMCell_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_LayerNorm_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_LayerNorm_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Linear_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Linear_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_RNNCell_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_RNNCell_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index 8c1f1b12f36..daa39964374 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -50,7 +50,7 @@ class TestSortAndSelect(TestCase): return ((b != b) | (a <= b)).all().item() else: - error( # noqa: F821 + raise ValueError( f'unknown order "{order}", must be "ascending" or "descending"' ) diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index 5cd93027417..e804e289c1c 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -22,6 +22,7 @@ from torch.testing._internal.common_utils import ( ) from torch.utils._sympy.functions import ( FloorDiv, + Identity, OpaqueUnaryFn_cos, simple_floordiv_gcd, ) @@ -34,7 +35,8 @@ from torch.utils._sympy.reference import ( ) from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve -from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges +from torch.utils._sympy.value_ranges import ValueRanges +from torch._inductor.bounds import ValueRangeAnalysis UNARY_OPS = [ @@ -954,6 +956,17 @@ class TestSingletonInt(TestCase): self.assertEqual(j1.free_symbols, set()) +class TestIdentity(TestCase): + def test_expand_identity(self): + """ + Test removing an identity via expansion. + """ + x = sympy.Symbol("x") + arg = x + sympy.S.One + expr = Identity(arg) + expanded = expr.expand(identity=True) + self.assertEqual(expanded.count(Identity), 0) + self.assertEqual(expanded, arg) instantiate_parametrized_tests(TestValueRanges) instantiate_parametrized_tests(TestSympyInterp) diff --git a/test/test_transformers.py b/test/test_transformers.py index eab1cb8a605..af711a6fb67 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -324,7 +324,7 @@ class TestTransformers(NNTestCase): encoder(test, src_key_padding_mask=pad_mask.to(torch.uint8)) except AssertionError: continue - self.assertFalse(e, "Failed to catch unsupported uint8 type exception") # noqa: F821 + self.assertFalse(e, "Failed to catch unsupported uint8 type exception") test_train_bool = encoder(test, src_key_padding_mask=pad_mask) encoder.eval() @@ -335,7 +335,7 @@ class TestTransformers(NNTestCase): encoder(test, src_key_padding_mask=pad_mask.to(torch.int64)) except AssertionError as e: continue - self.assertFalse(e, "Failed to catch unsupported Long type exception") # noqa: F821 + self.assertFalse(e, "Failed to catch unsupported Long type exception") test_eval_bool = encoder(test, src_key_padding_mask=pad_mask) l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item() diff --git a/tools/build_with_debinfo.py b/tools/build_with_debinfo.py index 17b7d06bb0f..73c9dba0090 100755 --- a/tools/build_with_debinfo.py +++ b/tools/build_with_debinfo.py @@ -78,6 +78,9 @@ def create_build_plan() -> list[tuple[str, str]]: if line.startswith(": &&") and line.endswith("&& :"): line = line[4:-4] line = line.replace("-O2", "-g").replace("-O3", "-g") + # Build Metal shaders with debug infomation + if "xcrun metal " in line and "-frecord-sources" not in line: + line += " -frecord-sources -gline-tables-only" try: name = line.split("-o ", 1)[1].split(" ")[0] rc.append((name, line)) diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 893e7056281..2aed88f713b 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -26,7 +26,10 @@ from .exc import IncorrectUsage, unimplemented from .source import AttrSource, Source from .utils import is_safe_constant, rot_n_helper from .variables.base import ValueMutationExisting, VariableTracker -from .variables.functions import FunctionDecoratedByContextlibContextManagerVariable +from .variables.functions import ( + ContextlibContextManagerLocalGeneratorObjectVariable, + LocalGeneratorObjectVariable, +) from .variables.nn_module import NNModuleVariable from .variables.tensor import ( NumpyNdarrayVariable, @@ -162,14 +165,20 @@ class PyCodegen: return if value.is_realized() and isinstance( - value, FunctionDecoratedByContextlibContextManagerVariable + value, ContextlibContextManagerLocalGeneratorObjectVariable ): raise IncorrectUsage( "NYI: Returning a @contextmanager object from a torch.compile function" ) # Dynamo normally prefers codegen from source to account for aliasing. - if value.source is not None and allow_cache: + if ( + value.source is not None + and allow_cache + and not ( + value.is_realized() and isinstance(value, LocalGeneratorObjectVariable) + ) + ): # There's a corner case for export: for instance, if the computation # graph is just identity on an input tensor, Dynamo would just emit # a `LOAD_FAST` from the input source, rather than generating an diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index dcc4de1e7f0..84ed3be5a70 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -52,6 +52,7 @@ skip_code_recursive_on_recompile_limit_hit = True # raise a hard error if cache limit is hit. If you are on a model where you # know you've sized the cache correctly, this can help detect problems when # you regress guards/specialization. This works best when recompile_limit = 1. +# This flag is incompatible with: suppress_errors. # [@compile_ignored: runtime_behaviour] fail_on_recompile_limit_hit = False @@ -164,6 +165,7 @@ traceable_tensor_subclasses: set[type[Any]] = set() # This is a good way to get your model to work one way or another, but you may # lose optimization opportunities this way. Devs, if your benchmark model is failing # this way, you should figure out why instead of suppressing it. +# This flag is incompatible with: fail_on_recompile_limit_hit. suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False)) # Record and write an execution record of the current frame to a file @@ -417,6 +419,10 @@ enable_cpp_symbolic_shape_guards = False # Enable tracing through contextlib.contextmanager enable_trace_contextlib = True +# Enable tracing generator functions lazily. If False, Dynamo will exhaust +# generators upon first execution. And if True, the generator will be accessed lazily +enable_faithful_generator_behavior = True + # Inline inbuilt nn modules inline_inbuilt_nn_modules = Config( # type: ignore[var-annotated] default=True, diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 045cd350b60..789ed41d3a2 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -833,6 +833,13 @@ def is_inductor_supported(): return False +def check_for_incompatible_configs(): + # Some of the configs should be mutually exclusive + assert not ( + config.suppress_errors and config.fail_on_recompile_limit_hit + ), "Dynamo configs suppress_error and fail_on_recompile_limit_hit can not both be active at the same time." + + def optimize(*args, **kwargs): def rebuild_ctx(): ca_kwargs_override = config.compiled_autograd_kwargs_override @@ -885,6 +892,7 @@ def _optimize( ... """ check_if_dynamo_supported() + check_for_incompatible_configs() # Note: The hooks object could be global instead of passed around, *however* that would make # for a confusing API usage and plumbing story wherein we nest multiple .optimize calls. # There is some prior art around this, w/r/t nesting backend calls are enforced to be the same diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index eb62d4d30db..7a524d5017f 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -162,6 +162,17 @@ class AttributeMutationError(Unsupported): super().__init__(msg) +class InfiniteGeneratorError(Unsupported): + # Raised when the number of yielded values is greater than MAX_ITERATOR_LIMIT + def __init__(self, msg: str) -> None: + super().__init__(msg) + + +class SideEffectsError(Unsupported): + def __init__(self, msg: str) -> None: + super().__init__(msg) + + class CondOpArgsMismatchError(ArgsMismatchError): """ Internal error from cond() due to arguments mismatch. @@ -267,12 +278,17 @@ class ObservedKeyError(ObservedLookupError): pass +class ObservedGeneratorExit(ObservedException): + pass + + class ObservedAttributeError(ObservedException): # An AttributeError exception to be raised from inside Dynamo tracing. This can happen on user defined object __getattr__ pass class ObservedRuntimeError(ObservedException): + # A RuntimeError exception to be raised from inside Dynamo tracing. This can happen on generator.throw(..) method pass @@ -280,17 +296,32 @@ class ObservedNotImplementedError(ObservedException): pass +class ObservedTypeError(ObservedException): + # A TypeError exception to be raised from inside Dynamo tracing. This can happen on generator.send(..) method + pass + + observed_exception_map = { StopIteration: ObservedUserStopIteration, LookupError: ObservedLookupError, IndexError: ObservedIndexError, + GeneratorExit: ObservedGeneratorExit, KeyError: ObservedKeyError, AttributeError: ObservedAttributeError, RuntimeError: ObservedRuntimeError, NotImplementedError: ObservedNotImplementedError, + TypeError: ObservedTypeError, } +def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedException]: + if exc_type not in observed_exception_map: + observed_exception_map[exc_type] = type( + f"Observed{exc_type.__name__}Error", (ObservedException,), {} + ) + return observed_exception_map[exc_type] + + def raise_observed_exception( exc_type: type[Exception], tx: InstructionTranslatorBase, diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 35219afb9f0..818ef871988 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1936,6 +1936,9 @@ class SubgraphTracer(fx.Tracer): # backward recomputation of the checkpoint region doesn't affect its correctness. self.allow_side_effects_under_checkpoint = False + # True if this tracer is currently tracing (reconstructing) into a Python generator + self.is_reconstructing_generator = False + self.debug_level: int = parent.debug_level + 1 if parent is not None else 0 self._cur_code = None diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index c24f4d821c3..2cdf6a0cc41 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -7,7 +7,7 @@ import warnings import weakref from collections.abc import MutableMapping from types import CellType -from typing import Any, Optional +from typing import Any, Optional, TYPE_CHECKING import torch.nn @@ -19,7 +19,7 @@ from .bytecode_transformation import ( create_instruction, ) from .codegen import PyCodegen -from .exc import unimplemented +from .exc import SideEffectsError, unimplemented from .source import GlobalSource, LocalCellSource, LocalSource, Source from .utils import dict_new, is_frozen_dataclass, nn_module_new, object_new, tuple_new from .variables.base import ( @@ -34,6 +34,10 @@ from .variables.base import ( from .variables.user_defined import FrozenDataClassVariable +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + def _manual_dict_setitem(dict_from, dict_to, mro_index): # Carefully calls the dict or OrderedDict `clear` or `__setitem__`. We have # to be careful because we don't want to trigger the user defined object @@ -134,6 +138,14 @@ class SideEffects: and output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint ) + def is_reconstructing_generator(self): + output_graph = self.output_graph_weakref() + + return ( + output_graph + and output_graph.current_tx.output.current_tracer.is_reconstructing_generator + ) + def check_allowed_side_effect(self, item): from torch._dynamo.variables.misc import AutogradFunctionContextVariable @@ -143,6 +155,14 @@ class SideEffects: return True if self.should_allow_side_effects_under_checkpoint(): return True + if self.is_reconstructing_generator(): + # This is missing the case where one mutates a tensor. See + # test_generator.py::test_reconstruct_generator_tensor_mutation + raise SideEffectsError( + "Cannot reconstruct a generator with variable mutations. " + "Dynamo needs to fully exhaust the generator, which may cause " + "unintended variable modifications." + ) if not is_side_effect_safe(item.mutation_type): unimplemented( "HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)" @@ -842,7 +862,7 @@ class SideEffects: @contextlib.contextmanager -def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): # type: ignore[name-defined] # noqa: F821 +def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): assert tx.output.current_tracer.under_activation_checkpoint orig_val = tx.output.current_tracer.allow_side_effects_under_checkpoint try: @@ -850,3 +870,13 @@ def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): # type: i yield finally: tx.output.current_tracer.allow_side_effects_under_checkpoint = orig_val + + +@contextlib.contextmanager +def disallow_side_effects_in_generator(tx: "InstructionTranslator"): + orig_val = tx.output.current_tracer.is_reconstructing_generator + try: + tx.output.current_tracer.is_reconstructing_generator = True + yield + finally: + tx.output.current_tracer.is_reconstructing_generator = orig_val diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 8e2b1bfa61c..9a8edf94478 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -85,7 +85,8 @@ from .variables.ctx_manager import ( from .variables.dicts import ConstDictVariable, SetVariable from .variables.functions import ( BaseUserFunctionVariable, - FunctionDecoratedByContextlibContextManagerVariable, + LocalGeneratorFunctionVariable, + LocalGeneratorObjectVariable, NestedUserFunctionVariable, SkipFunctionVariable, UserFunctionVariable, @@ -290,6 +291,34 @@ def _step_logger(): return torchdynamo_logging.get_step_logger(log) +@contextlib.contextmanager +def save_and_restart_speculation_log(tx: "InstructionTranslatorBase"): + # When reconstructing a generator after a graph break, we advance it until + # it is fully exhausted. This process adds new entries to the speculation + # log that were not previously observed. Without temporarily clearing the + # speculation log, this could lead to a divergence error. + + entries = tx.speculation_log.entries + index = tx.speculation_log.index + try: + tx.speculation_log.entries = [] + tx.speculation_log.index = 0 + yield + finally: + tx.speculation_log.entries = entries + tx.speculation_log.index = index + + +@contextlib.contextmanager +def temporarely_allow_writes_to_output_graph(tx: "InstructionTranslatorBase"): + try: + tmp = tx.output.should_exit + tx.output.should_exit = False + yield + finally: + tx.output.should_exit = tmp + + @dataclasses.dataclass class BlockStackEntry: # Current instruction that pushes something to block_stack @@ -922,11 +951,22 @@ class InstructionTranslatorBase( raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}") self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] + def inline_generator_function(self, fn, args, kwargs): + """ + Redirect the call to the generator "call_function" + """ + if not isinstance(fn, LocalGeneratorFunctionVariable): + fn = LocalGeneratorFunctionVariable(fn) + return fn.call_function(self, args, kwargs) + def inline_user_function_return(self, fn, args, kwargs): """ A call to some user defined function by inlining it. """ - return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + if config.enable_faithful_generator_behavior and is_generator(fn.get_code()): + return self.inline_generator_function(fn, args, kwargs) + else: + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) def get_line_of_code_header(self, lineno=None): if lineno is None: @@ -1499,6 +1539,15 @@ class InstructionTranslatorBase( self._raise_exception_variable(inst) unimplemented("raise ... from ...") + def CLEANUP_THROW(self, inst): + # https://github.com/python/cpython/pull/96010 + tos = self.stack[-1] + assert isinstance(tos, ExceptionVariable) + if tos.exc_type is StopIteration: + unimplemented("CLEANUP_THROW with StopIteration") + else: + self.RERAISE(inst) + def RERAISE(self, inst): if sys.version_info >= (3, 11): # RERAISE is currently supported in a narrow case of `raise ... from None` @@ -3083,7 +3132,20 @@ class InstructionTranslator(InstructionTranslatorBase): return True return False + def replace_tos_if_return_is_generator(self): + if ( + len(self.stack) + and (tos := self.stack[-1]) + and isinstance(tos, LocalGeneratorObjectVariable) + ): + self.stack[-1] = ListIteratorVariable( + tos.force_unpack_var_sequence(self), + mutation_type=ValueMutationNew(), + ) + def _return(self, inst): + self.replace_tos_if_return_is_generator() + if ( not config.allow_empty_graphs and self.output.count_calls() == 0 @@ -3093,6 +3155,7 @@ class InstructionTranslator(InstructionTranslatorBase): and not self.one_graph ): raise exc.SkipFrame("because no content in function call") + self.instruction_pointer = None _step_logger()( logging.INFO, @@ -3179,8 +3242,6 @@ class InliningInstructionTranslator(InstructionTranslatorBase): func: VariableTracker, args: list[VariableTracker], kwargs, - *, - stop_generator_on_yield: bool = False, ): if isinstance(func, SkipFunctionVariable): unimplemented("inline with functions in skip files") @@ -3189,7 +3250,8 @@ class InliningInstructionTranslator(InstructionTranslatorBase): ( UserFunctionVariable, NestedUserFunctionVariable, - FunctionDecoratedByContextlibContextManagerVariable, + LocalGeneratorFunctionVariable, + LocalGeneratorObjectVariable, ), ) result = InliningInstructionTranslator.check_inlineable(func) @@ -3254,9 +3316,10 @@ class InliningInstructionTranslator(InstructionTranslatorBase): parent.symbolic_globals, parent.symbolic_torch_function_state, func, - stop_generator_on_yield=stop_generator_on_yield, ) else: + # need the line below to make MyPy happy + assert not isinstance(func, LocalGeneratorObjectVariable) tracer = InliningInstructionTranslator( parent, code, @@ -3302,24 +3365,30 @@ class InliningInstructionTranslator(InstructionTranslatorBase): log.debug("DONE INLINING %s", code) - if is_generator(code): - assert isinstance(self, InliningGeneratorInstructionTranslator) - # The first flag tells us if we consume generators lazily or not - # and the second is if the generator is exhausted. - # In the future, generators should be lazily consumed and the first - # flag (stop_generator_on_yield) will not be needed. - if self.stop_generator_on_yield and self.generator_exhausted: + if config.enable_faithful_generator_behavior or ( + isinstance(self, InliningGeneratorInstructionTranslator) + and self.is_generator_from_ctx_manager + ): + if ( + is_generator(code) + and isinstance(self, InliningGeneratorInstructionTranslator) + and self.generator_exhausted + ): + assert isinstance(self, InliningGeneratorInstructionTranslator) # When the generator returns None, we raise StopIteration - r = self.symbolic_result - assert r.as_python_constant() is None exc.raise_observed_exception(StopIteration, self) else: + return self.symbolic_result + else: + if is_generator(code): + assert isinstance(self, InliningGeneratorInstructionTranslator) + assert self.symbolic_result.as_python_constant() is None return ListIteratorVariable( self.generated_items, mutation_type=ValueMutationNew(), ) - else: - return self.symbolic_result + else: + return self.symbolic_result def __init__( self, @@ -3438,27 +3507,26 @@ class InliningInstructionTranslator(InstructionTranslatorBase): class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): generated_items: list[VariableTracker] # Flag wether or not the InlineGenerator should consume the entire iterator - stop_generator_on_yield: bool - def __init__(self, *args, stop_generator_on_yield: bool = False, **kwargs) -> None: + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.generated_items = [] - # In the future, generators should run lazily (i.e. when next(...) is called) - # TODO: Set this to True by default, so that dynamo follows CPython more - # closely - self.stop_generator_on_yield = stop_generator_on_yield self.generator_exhausted = False + self.is_generator_from_ctx_manager = False def YIELD_VALUE(self, inst: Instruction): top = self.pop() self.generated_items.append(top) if len(self.generated_items) > MAX_ITERATOR_LIMIT: - unimplemented( + raise exc.InfiniteGeneratorError( "Too many yield values in generator. Maybe you are inlining an infinite generator. " f"If not, please report a bug at {PT2_ISSUE_TRACKER_URL}", ) self.push(ConstantVariable.create(None)) - if self.stop_generator_on_yield: + if ( + config.enable_faithful_generator_behavior + or self.is_generator_from_ctx_manager + ): self.symbolic_result = top # Stop tracing raise YieldValueOp @@ -3500,10 +3568,6 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): self.pop() self.push(ConstantVariable.create(ex.value)) else: - self.push(val) - # Add the value to yield into generated_items and replace the top of the stack with None - self.YIELD_VALUE(inst) - # Repeat the YIELD_FROM instruction in the next eval loop assert ( isinstance(self.instruction_pointer, int) @@ -3511,11 +3575,15 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): ) self.instruction_pointer -= 1 + self.push(val) + # Add the value to yield into generated_items and replace the top of the stack with None + self.YIELD_VALUE(inst) + def SEND(self, inst): assert len(self.stack) >= 2 val = self.pop() tos = self.stack[-1] - if isinstance(tos, ListIteratorVariable) or ( + if isinstance(tos, (ListIteratorVariable, LocalGeneratorObjectVariable)) or ( isinstance(tos, UserDefinedObjectVariable) and isinstance(tos.value, collections.abc.Iterator) ): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 7b459ffcbb9..d06ffccfc4c 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -32,8 +32,9 @@ from .utils import getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper from .variables import ( BuiltinVariable, FunctionalCallVariable, - FunctionDecoratedByContextlibContextManagerVariable, FunctorchHigherOrderVariable, + LocalGeneratorFunctionVariable, + LocalGeneratorObjectVariable, NestedUserFunctionVariable, PolyfilledFunctionVariable, SkipFunctionVariable, @@ -3209,6 +3210,7 @@ LEGACY_MOD_INLINELIST = { "torch._higher_order_ops.while_loop", "torch._higher_order_ops.associative_scan", "torch._higher_order_ops.scan", + "torch._higher_order_ops._invoke_quant", "torch._higher_order_ops.utils", "torch.nn.attention.flex_attention", "torch.ao.quantization.pt2e.export_utils", @@ -3619,7 +3621,8 @@ def check_verbose(obj, is_inlined_call=False): UserFunctionVariable, UserMethodVariable, NestedUserFunctionVariable, - FunctionDecoratedByContextlibContextManagerVariable, + LocalGeneratorFunctionVariable, + LocalGeneratorObjectVariable, ), ): try: @@ -3639,7 +3642,14 @@ def check_verbose(obj, is_inlined_call=False): # Consulte the central trace rules defined in torch._dynamo.trace_rules. reasons: set[str] = set() rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons) - if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)): + if issubclass( + rule, + ( + UserFunctionVariable, + LocalGeneratorFunctionVariable, + PolyfilledFunctionVariable, + ), + ): return SkipResult( False, f"inlined according trace_rules.lookup {reasons.pop()}", diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 9fc28fe50a6..ba7a10267e2 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -39,6 +39,8 @@ from .functions import ( FunctionDecoratedByContextlibContextManagerVariable, FunctoolsPartialVariable, FunctoolsWrapsVariable, + LocalGeneratorFunctionVariable, + LocalGeneratorObjectVariable, NestedUserFunctionVariable, PolyfilledFunctionVariable, SkipFunctionVariable, diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 6e411dfe50d..709c76f74d7 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -773,7 +773,13 @@ class BuiltinVariable(VariableTracker): tx, [v.realize() for v in args], kwargs ) - if inspect.isclass(fn) and issubclass(fn, Exception): + if inspect.isclass(fn) and ( + issubclass(fn, Exception) + # GeneratorExit doens't inherit from Exception + # >>> issubclass(GeneratorExit, Exception) + # False + or fn is GeneratorExit + ): def create_exception_class_object( tx: "InstructionTranslator", args, kwargs @@ -1425,6 +1431,13 @@ class BuiltinVariable(VariableTracker): mutation_type=ValueMutationNew(), ) + def _call_iter_tuple_generator(self, tx, obj, *args, **kwargs): + cls = variables.BaseListVariable.cls_for(self.fn) + return cls( + list(obj.force_unpack_var_sequence(tx)), # exhaust generator + mutation_type=ValueMutationNew(), + ) + def _call_tuple_list(self, tx, obj=None, *args, **kwargs): if isinstance(obj, variables.IteratorVariable): cls = variables.BaseListVariable.cls_for(self.fn) @@ -1432,6 +1445,8 @@ class BuiltinVariable(VariableTracker): list(obj.force_unpack_var_sequence(tx)), mutation_type=ValueMutationNew(), ) + elif isinstance(obj, variables.LocalGeneratorObjectVariable): + return self._call_iter_tuple_generator(tx, obj, *args, **kwargs) else: return self._call_iter_tuple_list(tx, obj, *args, **kwargs) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 29a3cb18abd..45e04bbb80c 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -4,17 +4,30 @@ import builtins import functools import inspect import itertools +import sys import types from collections.abc import Sequence -from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, TypeVar from typing_extensions import Never from unittest.mock import patch import torch from .. import polyfills, variables -from ..bytecode_transformation import create_call_function, create_rot_n -from ..exc import raise_observed_exception, unimplemented, Unsupported +from ..bytecode_transformation import create_call_function, create_rot_n, is_generator +from ..exc import ( + get_dynamo_observed_exception, + handle_observed_exception, + IncorrectUsage, + InfiniteGeneratorError, + ObservedException, + ObservedGeneratorExit, + ObservedUserStopIteration, + raise_observed_exception, + SkipFrame, + unimplemented, + Unsupported, +) from ..guards import GuardBuilder, install_guard from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource from ..utils import ( @@ -378,61 +391,426 @@ class BuiltinMethodVariable(BaseUserFunctionVariable): return obj_vt.call_method(tx, name, args, kwargs) -class FunctionDecoratedByContextlibContextManagerVariable(BaseUserFunctionVariable): - # TODO(guilherme): replace this with a generic GeneratorFunctionVariable +class LocalGeneratorObjectVariable(VariableTracker): + def __init__( + self, + code: types.CodeType, + f_globals, + inline_tracer: Optional["InstructionTranslator"], + **kwargs, + ): + super().__init__(**kwargs) + self.code = code + self.f_globals = f_globals + self.inline_tracer = inline_tracer + def get_code(self): + return self.code + + def get_filename(self): + return self.get_code().co_filename + + def get_name(self): + return self.get_code().co_name + + def get_function(self): + raise NotImplementedError + + def has_self(self): + return False + + def __name__(self): + return self.get_name() + + def __str__(self): + return f"{self.__class__.__name__}({self.get_name()})" + + __repr__ = __str__ + + def reconstruct(self, codegen): + from torch._dynamo.side_effects import disallow_side_effects_in_generator + from torch._dynamo.symbolic_convert import ( + InstructionTranslator, + save_and_restart_speculation_log, + temporarely_allow_writes_to_output_graph, + ) + + tx = InstructionTranslator.current_tx() + save = save_and_restart_speculation_log(tx) + disallow = disallow_side_effects_in_generator(tx) + temp = temporarely_allow_writes_to_output_graph(tx) + + with save, disallow, temp: + tracer = self._get_inline_tracer(tx) + if not tracer.generator_exhausted: + self.remaining_items = self.force_unpack_var_sequence(tx) + variables.ListIteratorVariable(self.remaining_items).reconstruct(codegen) + + def bind_args(self, tx, args, kwargs): + return self.fn.bind_args(tx, args, kwargs) + + def get_globals(self): + return self.f_globals + + def python_type(self): + return types.GeneratorType + + def _get_inline_tracer(self, tx): + from torch._dynamo.symbolic_convert import InliningInstructionTranslator + + if self.inline_tracer is None: + self.inline_tracer = InliningInstructionTranslator.build_inline_tracer( + tx, self, [], {} + ) + return self.inline_tracer + + def next_variable(self, tx): + tracer = self._get_inline_tracer(tx) + + if self._is_generator_exhausted(): + raise_observed_exception(StopIteration, tx) + + try: + # Hierarchically, tx can be seen as the parent of the inline tracer + # created on call_function. Any exception needs to be propagated to tx + # for Dynamo to behave correctly + with patch.dict(counters, {"unimplemented": counters["inline_call"]}): + return tracer.inline_call_() + except ObservedException as e: + tx.exn_vt_stack.extend(tracer.exn_vt_stack) + raise e + except InfiniteGeneratorError: + # test/dynamo/test_misc.py::test_iterator_limit + raise + except Unsupported as e: + torch._C._dynamo.eval_frame.skip_code(self.get_code()) + raise SkipFrame from e + finally: + counters["unimplemented"] |= counters["inline_call"] + + def has_unpack_var_sequence(self, tx): + return False + + def has_force_unpack_var_sequence(self, tx) -> builtins.bool: + return True + + def force_unpack_var_sequence(self, tx) -> List[VariableTracker]: + result = [] + while True: + try: + result.append(self.next_variable(tx)) + except ObservedUserStopIteration: + handle_observed_exception(tx) + break + return result + + def _setup_exception(self, tx, exc): + tracer = self._get_inline_tracer(tx) + tracer.push(exc) + try: + tracer._raise_exception_variable(None) + except ObservedException as e: + # if no handler is available (i.e. user code doesn't catch it), the + # exception is raised again. + tracer.exception_handler(e) + + def _is_generator_just_started(self): + return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0 + + def _is_generator_exhausted(self): + return getattr(self.inline_tracer, "generator_exhausted", False) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__next__": + return self.next_variable(tx) + elif name == "__iter__": + # iter(gen) returns itself + return self + elif name == "send": + # Sends a value into the generator function. Returns the next value + # yielded by the generator, or raises StopIteration if the generator + # exits without yielding another value + if self._is_generator_just_started() and len(args): + # can't send non-None value to a just-started generator + # Test: GeneratorCPythonTests.test_send_non_none_to_new_gen + if not all( + isinstance(arg, ConstantVariable) and arg.value is None + for arg in args + ): + raise_observed_exception(TypeError, tx) + tracer = self._get_inline_tracer(tx) + tracer.push_many(args) + return self.next_variable(tx) + elif name == "close": + # * Raises a GeneratorExit at the point where the generator function was paused. + # * If the generator function catches the exception and returns a + # value, this value is returned from close() - Python 3.13+ + # * If the generator function is already closed, or raises GeneratorExit + # (by not catching the exception), close() returns None. + # * If the generator yields a value, a RuntimeError is raised. + # * If the generator raises any other exception, it is propagated to the caller. + # * If the generator has already exited due to an exception or normal + # exit, close() returns None and has no other effect. + + # Return None if close is called on a just-started generator + # See test GeneratorCloseCpythonTests::test_close_not_started + + tracer = self._get_inline_tracer(tx) + if self._is_generator_just_started() or self._is_generator_exhausted(): + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + + # Raise GeneratorExit to see if user code catches it. Any other exception + # is propagated to the parent frame. + try: + self._setup_exception( + tx, variables.ExceptionVariable(GeneratorExit, ()) + ) + # There's an extra block on Python 3.12+ to handle StopIteration + # see: https://github.com/python/cpython/blob/8f93dd8a8f237b277abad20d566df90c5cbd7f1e/Objects/genobject.c#L394-L397 + # + # 1 0 RETURN_GENERATOR + # 2 POP_TOP + # 4 RESUME 0 + + # 2 6 LOAD_CONST 1 (1) + # 8 YIELD_VALUE 1 + # 10 RESUME 1 + # 12 POP_TOP + # 14 RETURN_CONST 0 (None) + # >> 16 CALL_INTRINSIC_1 3 (INTRINSIC_STOPITERATION_ERROR) + # 18 RERAISE 1 + # ExceptionTable: + # 4 to 14 -> 16 [0] lasti + if ( + sys.version_info >= (3, 12) + and tracer.next_instruction.opname == "CALL_INTRINSIC_1" + ): + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + except ObservedGeneratorExit: + # If it doesn't catch, we just return None, as per the text above + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + + try: + # Raise RuntimeError if the generator yields any other value + if self.next_variable(tx): + raise_observed_exception(RuntimeError, tx) + except ObservedGeneratorExit: + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + except ObservedUserStopIteration: + # In Python 3.13+, one can capture GeneratorExit and return a value + # See test_generator.py::test_close_capture_GeneratorExit_return + # https://discuss.python.org/t/let-generator-close-return-stopiteration-value/24786/26 + # https://github.com/python/cpython/pull/104771 + assert tracer.symbolic_result is not None + return tracer.symbolic_result + elif name == "throw": + # * Raises an exception at the point where the generator was paused, and + # returns the next value yielded by the generator. + # * If the generator exits without yielding, raise StopIteration + # * If the generator function does not catch the passed-in exception, + # or raises a different exception, then that exception propagates to the caller. + + if len(args) > 1: + raise IncorrectUsage( + "the (type, exc, tb) signature of throw() is deprecated, " + "use the single-arg signature instead." + ) + + # Setup the exception table and jump target in case of try...finally + tracer = self._get_inline_tracer(tx) + try: + self._setup_exception(tx, args[0]) + except ObservedException: + # propagate the exception back to the parent caller + tx.exn_vt_stack.extend(tracer.exn_vt_stack) + raise + + retval = self.next_variable(tx) + + # The exception raised before is still active. We need to check the exception + # table one more time to find the next target. But why? Let’s walk + # through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M + # + # z = 0 + # def whoo(): + # global z + # z = 0 + # try: + # yield 1 + # except ValueError: + # yield 2 + # finally: + # z += 1 + # z += 10 + # + # gen = whoo() + # next(gen) + # gen.throw(ValueError) + # print('z', z) -> z = 1 + # + # ... + # >> 58 PUSH_EXC_INFO + # + # 8 60 LOAD_GLOBAL 2 (ValueError) + # 70 CHECK_EXC_MATCH + # 72 POP_JUMP_IF_FALSE 7 (to 88) + # 74 POP_TOP + # + # 9 76 LOAD_CONST 3 (2) + # 78 YIELD_VALUE 3 <------ ValueError is still active here + # 80 RESUME 1 + # 82 POP_TOP + # 84 POP_EXCEPT + # 86 jump_backward 34 (to 20) + # ... + # + # ExceptionTable: + # 4 to 8 -> 124 [0] lasti + # 12 to 18 -> 58 [0] + # 20 to 56 -> 124 [0] lasti + # 58 to 82 -> 90 [1] lasti <------ move to 90 + # 84 to 86 -> 96 [0] + # 88 to 88 -> 90 [1] lasti + # 90 to 94 -> 96 [0] + # 96 to 116 -> 118 [1] lasti + # 118 to 122 -> 124 [0] lasti + # + # In this scenario, a generator can yield after `throw()` is called. Even + # after the exception is raised a few lines above, it remains active + # within the `78 YIELD_VALUE` instruction. When the generator resumes + # after the second yield on instruction `80 RESUME`, we cannot simply + # return the control flow to the next instruction. Instead, one must + # check the exception table (or equivalent) to find the next target + # In this case, it says the instruction pointer must be moved to 90. + # + # Without this step, if we let the trace proceed to the next + # instruction, it would follow the control flow where the exception + # raised by `throw()` was handled and swallowed, potentially leading + # to incorrect behavior. + exc_type = type("__InternalThrowException", (Exception,), {}) + + try: + self._setup_exception(tx, variables.ExceptionVariable(exc_type, ())) + self.next_variable(tx) + except get_dynamo_observed_exception(exc_type): + # We should get back the exception raised before. + pass + except ObservedException: + # Propagate anything else back to the parent caller + tx.exn_vt_stack.extend(tracer.exn_vt_stack) + else: + raise_observed_exception(RuntimeError, tracer) + return retval + + super().call_method(tx, name, args, kwargs) + + +class ContextlibContextManagerLocalGeneratorObjectVariable( + LocalGeneratorObjectVariable +): + """ + .. note:: + + This is only used when the function is annotated with @contextlib.contextmanager + + It is a special case of a generator function as we do not allow return a context manager + from a torch.compile function. + """ + + +class LocalGeneratorFunctionVariable(BaseUserFunctionVariable): """functions that behaves like iterators .. note:: - This is only used when the function is annotated with @contextlib.contextmanager + This is a wrapper around (Nested)UserFunctionVariable """ - def __init__(self, vt: VariableTracker, **kwargs): + def __init__( + self, + vt: VariableTracker, + *, + generator_cls=LocalGeneratorObjectVariable, + **kwargs, + ): super().__init__(**kwargs) self.vt = vt - self.inline_tracer = None + self.generator_cls = generator_cls def __getattr__(self, name): if name in self.__class__.__dict__.keys(): return getattr(self, name) return getattr(self.vt, name) + def _build_inline_tracer(self, tx, args, kwargs): + from torch._dynamo.symbolic_convert import InliningInstructionTranslator + + return InliningInstructionTranslator.build_inline_tracer( + tx, + self, + args, + kwargs, + ) + def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": - from torch._dynamo.bytecode_transformation import is_generator + assert is_generator(self.vt.get_code()) - assert is_generator(self.get_code()) - from torch._dynamo.symbolic_convert import InliningInstructionTranslator + inline_tracer = self._build_inline_tracer(tx, args, kwargs) + code = self.vt.get_code() + f_globals = self.vt.get_globals() - self.inline_tracer = InliningInstructionTranslator.build_inline_tracer( - tx, - self, - [*self.self_args(), *args], - kwargs, - stop_generator_on_yield=True, + # calling a generator returns a generator object + return self.generator_cls( + code, + f_globals, + inline_tracer, + source=self.source, ) - return self - def next_variable(self, tx): - from torch._dynamo import exc +class FunctionDecoratedByContextlibContextManagerVariable( + LocalGeneratorFunctionVariable +): + """ + .. note:: - tracer = self.inline_tracer + This is only used when the function is annotated with @contextlib.contextmanager + """ - try: - # Hierarchically, tx can be seen as the parent of the inline tracer - # created on call_function. Any exception needs to be propagated to tx - # for Dynamo to behave correctly - with patch.dict(counters, {"unimplemented": counters["inline_call"]}): - return tracer.inline_call_().next_variable(tx) - except exc.ObservedException as e: - tx.exn_vt_stack.extend(tracer.exn_vt_stack) - raise e + def __init__(self, vt, **kwargs): + super().__init__( + vt, + generator_cls=ContextlibContextManagerLocalGeneratorObjectVariable, + **kwargs, + ) + + def _build_inline_tracer(self, tx, args, kwargs): + # NOTE: This only exists to not break support for context manager when + # config.enable_faithful_generator_behavior = False and + # config.enable_trace_contextlib = True. In case the former is false, + # Dynamo should still be able to trace through @contextmanager functions + tracer = super()._build_inline_tracer(tx, args, kwargs) + assert isinstance( + tracer, + torch._dynamo.symbolic_convert.InliningGeneratorInstructionTranslator, + ) + tracer.is_generator_from_ctx_manager = True + return tracer class UserMethodVariable(UserFunctionVariable): diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 3c55c1b2afc..bcadb5941ff 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -234,6 +234,7 @@ class SuperVariable(VariableTracker): class ExceptionVariable(VariableTracker): + # The ExceptionVariable corresponds to the BaseException class in Python def __init__(self, exc_type, args, **kwargs) -> None: super().__init__(**kwargs) self.exc_type = exc_type diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index a9dd87d3046..fe646d35b05 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -493,7 +493,7 @@ class UserDefinedClassVariable(UserDefinedVariable): ): if not torch._dynamo.config.enable_trace_contextlib: unimplemented("contextlib.contextmanager") - # Replace UserFunctionVariable by FunctionDecoratedBycontextlibContextManagerVariable + # Wrap UserFunctionVariable in FunctionDecoratedByContextlibContextManagerVariable # if the function is annotated with @contextlib.contextmanager # This shouldn't be necessary once generator functions are fully # supported in dynamo @@ -805,6 +805,11 @@ class UserDefinedObjectVariable(UserDefinedVariable): # of the cmp_eq polyfill function. return ConstantVariable.create(self.value is other.value) + if torch._dynamo.config.enable_faithful_generator_behavior and isinstance( + self.value, types.GeneratorType + ): + unimplemented("Generator as graph argument is not supported") + # check for methods implemented in C++ if isinstance(method, types.FunctionType): source = ( diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index dc9e5af16da..10f1767167a 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -1822,7 +1822,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa metadata: ViewAndMutationMeta = fw_metadata # type: ignore[assignment] maybe_subclass_metadata: Optional[SubclassMeta] = maybe_subclass_meta num_symints_saved_for_bw = num_symints_saved_for_bw_ - _compiled_autograd_should_lift = False _aot_id = aot_config.aot_id _lazy_backward_info = lazy_backward_info @@ -1989,7 +1988,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa # https://github.com/pytorch/pytorch/pull/92348/files#r1072962107 class CompiledFunctionBackward(torch.autograd.Function): # CompiledFunctionBackward is not yet supported in dynamo skipfiles - _compiled_autograd_should_lift = False _aot_id = aot_config.aot_id @staticmethod diff --git a/torch/_higher_order_ops/__init__.py b/torch/_higher_order_ops/__init__.py index 47dbbd941f2..c8a9da5e78d 100644 --- a/torch/_higher_order_ops/__init__.py +++ b/torch/_higher_order_ops/__init__.py @@ -1,3 +1,8 @@ +from torch._higher_order_ops._invoke_quant import ( + invoke_quant, + invoke_quant_packed, + InvokeQuant, +) from torch._higher_order_ops.aoti_call_delegate import aoti_call_delegate from torch._higher_order_ops.associative_scan import associative_scan from torch._higher_order_ops.auto_functionalize import ( @@ -51,6 +56,9 @@ __all__ = [ "executorch_call_delegate", "call_torchbind", "run_const_graph", + "InvokeQuant", + "invoke_quant", + "invoke_quant_packed", "wrap_with_set_grad_enabled", "wrap_with_autocast", "wrap_activation_checkpoint", diff --git a/torch/_higher_order_ops/_invoke_quant.py b/torch/_higher_order_ops/_invoke_quant.py new file mode 100644 index 00000000000..cfbb7b1cc55 --- /dev/null +++ b/torch/_higher_order_ops/_invoke_quant.py @@ -0,0 +1,72 @@ +# mypy: allow-untyped-defs +# need to fix prim_hop_base type annotations first + +import dataclasses +from typing import Optional + +import torch +from torch._higher_order_ops.prim_hop_base import FunctionWithNoFreeVars, PrimHOPBase + + +class InvokeQuantTracer(PrimHOPBase): + def __init__(self) -> None: + super().__init__("invoke_quant_packed") + + def __call__(self, subgraph, operands, *, scheme=None, quant_options=None): + subgraph = FunctionWithNoFreeVars(subgraph) + return super().__call__( + subgraph, operands, scheme=scheme, quant_options=quant_options + ) + + +invoke_quant_packed = InvokeQuantTracer() + + +class InvokeQuantUnpacked(PrimHOPBase): + def __init__(self) -> None: + super().__init__("invoke_quant") + + def __call__(self, subgraph, *operands, scheme=None): + return super().__call__(subgraph, operands, scheme=scheme) + + def _call_FakeTensorMode( + self, mode, subgraph, operands, scheme: Optional[str] = None, **kwargs + ): + # TODO: this should probably route through FakeTensorMode to reuse caching + with mode: + return subgraph(*operands[0], **kwargs) + + +invoke_quant = InvokeQuantUnpacked() + + +@dataclasses.dataclass(frozen=True, repr=True) +class InvokeQuant: + """ + Invoke a quantization function that will be preserved as a single operator. Preservation + as a single operator aids in pattern matching and custom lowerings. + + The operation appears as: + torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=scheme) + + Args: + codegen_low_precision: Use observed subgraph dtypes for codegen instead of + upcasting to fp32. Can improve performance for prologue fusion but + requires careful testing of numerics. + """ + + codegen_low_precision: bool = True + + def __call__( + self, + *args, + scheme: Optional[str] = None, + **kwargs, + ): + if not torch._utils.is_compiling(): + return args[0](*args[1], **kwargs) + + if scheme is not None: + kwargs["scheme"] = scheme + + return invoke_quant_packed(*args, **kwargs, quant_options=self) # type: ignore[call-arg] diff --git a/torch/_inductor/analyze_preserves_zero_mask.py b/torch/_inductor/analyze_preserves_zero_mask.py index a03439c2bae..974960b9589 100644 --- a/torch/_inductor/analyze_preserves_zero_mask.py +++ b/torch/_inductor/analyze_preserves_zero_mask.py @@ -1,6 +1,6 @@ import dataclasses import itertools -from typing import Any, Callable, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import sympy @@ -9,6 +9,7 @@ from torch._inductor import config from torch._inductor.dtype_propagation import DtypePropagationOpsHandler from torch._inductor.index_propagation import SymPyOps, TypedExpr +from .ops_handler import DefaultHandler from .virtualized import StoreMode, V @@ -20,7 +21,7 @@ def construct_symbol(count: int, dtype: torch.dtype) -> sympy.Symbol: return sympy.Symbol(f"unknown_{count}") -class PreservesZeros(SymPyOps): +class PreservesZeros(SymPyOps, DefaultHandler): """ For prologue kernels where the loads are masked, does the final store of this kernel preserve the zeros. @@ -31,41 +32,32 @@ class PreservesZeros(SymPyOps): self.store_preserves_zeros: Optional[bool] = None self.dtype_prop = DtypePropagationOpsHandler() - @staticmethod - def load(name: str, index: sympy.Expr) -> TypedExpr: + def load(self, name: str, index: sympy.Expr) -> TypedExpr: # In prologue fusion, all loads get broadcasted - dtype = V.get_ops_handler().dtype_prop.load(name, index) + dtype = self.dtype_prop.load(name, index) return TypedExpr( sympy.Float(0) if dtype.is_floating_point else sympy.Integer(0), dtype ) - @staticmethod def store( - name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None + self, name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None ) -> None: - self = V.get_ops_handler() assert isinstance(self, PreservesZeros) # should only have a single store in prologue assert self.store_preserves_zeros is None self.store_preserves_zeros = value.is_constant() and value.expr == 0 - @staticmethod - def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr: - self = V.get_ops_handler() + def indirect_indexing(self, *args: Any, **kwargs: Any) -> sympy.Expr: return construct_symbol(next(self.count), torch.int32) - def __getattr__(self, name: str) -> Callable[..., Any]: + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: from torch._inductor.codegen.common import OpDecompositions - def inner(*args: Any, **kwargs: Any) -> TypedExpr: - if hasattr(OpDecompositions, name): - return getattr(OpDecompositions, name)(*args, **kwargs).value + if hasattr(OpDecompositions, name): + return getattr(OpDecompositions, name)(*args, **kwargs).value - nonlocal self - dtype = getattr(self.dtype_prop, name)(*args, **kwargs) - return TypedExpr(construct_symbol(next(self.count), dtype), dtype) - - return inner + dtype = getattr(self.dtype_prop, name)(*args, **kwargs) + return TypedExpr(construct_symbol(next(self.count), dtype), dtype) def prologue_preserves_zero_mask(prologue: "SchedulerNode") -> bool: @@ -88,7 +80,7 @@ class DTypeContainer: is_scalar: bool = False -class RecordLowPrecisionOps: +class RecordLowPrecisionOps(DefaultHandler): def __init__(self) -> None: self.low_precision_numeric_op = False self.dtype_prop = DtypePropagationOpsHandler() @@ -97,9 +89,8 @@ class RecordLowPrecisionOps: "constant", ) - @staticmethod - def load(name: str, index: sympy.Expr) -> DTypeContainer: - return DTypeContainer(V.get_ops_handler().dtype_prop.load(name, index)) + def load(self, name: str, index: sympy.Expr) -> DTypeContainer: + return DTypeContainer(self.dtype_prop.load(name, index)) @staticmethod def store( @@ -111,28 +102,25 @@ class RecordLowPrecisionOps: def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr: return sympy.S.Zero - def __getattr__(self, name: str) -> Callable[..., Any]: - def low_prec_float(dtype: torch.dtype) -> bool: - return dtype.is_floating_point and dtype.itemsize < 4 + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs) + out = DTypeContainer(out_dtype, is_scalar=(name == "constant")) + if name == "constant": + out = DTypeContainer(torch.float, is_scalar=True) - def inner(*args: Any, **kwargs: Any) -> DTypeContainer: - out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs) - out = DTypeContainer(out_dtype, is_scalar=(name == "constant")) - if name == "constant": - out = DTypeContainer(torch.float, is_scalar=True) + uses_low_prec = any( + isinstance(dtype_cont, DTypeContainer) and low_prec_float(dtype_cont.dtype) + for dtype_cont in itertools.chain((out,), args, kwargs.values()) + ) - uses_low_prec = any( - isinstance(dtype_cont, DTypeContainer) - and low_prec_float(dtype_cont.dtype) - for dtype_cont in itertools.chain((out,), args, kwargs.values()) - ) + if uses_low_prec and name not in self.non_numeric_ops: + self.low_precision_numeric_op = True - if uses_low_prec and name not in self.non_numeric_ops: - self.low_precision_numeric_op = True + return out - return out - return inner +def low_prec_float(dtype: torch.dtype) -> bool: + return dtype.is_floating_point and dtype.itemsize < 4 def can_codegen_without_upcasts( diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index 3df87ada0dd..69c331646f8 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -1,14 +1,22 @@ import logging import operator from functools import partial -from typing import Any, Callable, Union +from typing import Any, Callable, Optional, Union +import sympy from sympy import Expr import torch -from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges +from torch.utils._sympy.value_ranges import ( + bound_sympy, + SymPyValueRangeAnalysis, + ValueRanges, +) +from ..utils._sympy.functions import PowByNatural +from ..utils._sympy.numbers import int_oo from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock +from .ops_handler import DefaultHandler, ReductionType, StoreMode from .utils import cache_on_self, dominated_nodes from .virtualized import V @@ -139,3 +147,113 @@ class BoundVars: # assert bound is None or bound == bound_sympy(expr, self.replacement_vals) self.replacement_vals[name] = bound return bound + + +class ValueRangeAnalysis(SymPyValueRangeAnalysis, DefaultHandler): + def __init__(self) -> None: + self.name = "ValueRangeAnalysis" + boolean_operators = ( + "xor", + "logical_and", + "logical_or", + "logical_not", + ) + for op in boolean_operators: + setattr(self, op, self.bool_handler) + + @staticmethod + def bool_handler(*args: Any, **kwargs: Any) -> ValueRanges[Any]: + # just assuming bools can have both values + return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type] + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + # many ops are unlikely to show up in optimizable indexing compute, + # so we dont have full coverage + return ValueRanges.unknown() + + def load(self, name: str, index: sympy.Expr) -> ValueRanges[Any]: + return ValueRanges.unknown() + + def store( + self, name: str, index: sympy.Expr, value: Any, mode: StoreMode = None + ) -> None: + return + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Any, + ) -> ValueRanges[Any]: + return ValueRanges.unknown() + + @classmethod + def index_expr(cls, index: Any, dtype: torch.dtype) -> ValueRanges[Any]: + assert isinstance(index, ValueRanges) + return cls.to_dtype(index, dtype) + + @staticmethod + def to_dtype( + x: Any, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = True, + ) -> ValueRanges[Any]: + x = ValueRanges.wrap(x) + + if dtype == torch.bool: + if x.is_singleton(): + return ValueRanges.wrap(x.lower != 0) + elif x.is_bool: + return x + elif 0 not in x: + return ValueRanges.wrap(sympy.true) + else: + return ValueRanges(sympy.false, sympy.true) + + def cast(x: Any, dtype: torch.dtype) -> sympy.Expr: + # dtype is int or float + if dtype.is_floating_point: + return sympy.Float(x) + else: + if x in (int_oo, -int_oo): + return x + try: + return sympy.Integer(x) + except TypeError: + # inf cannot be cast to Integer + return x + + if x.is_bool: + if x.is_singleton(): + val = 1 if x.lower else 0 + return ValueRanges.wrap(cast(val, dtype)) + else: + return ValueRanges(cast(0, dtype), cast(1, dtype)) + else: + # int to float or float to int + return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype)) + + @staticmethod + def square(x: Any) -> ValueRanges[Any]: + return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2)) + + @staticmethod + def neg(x: Any) -> ValueRanges[Any]: + return ValueRanges.decreasing_map(x, operator.neg) + + # TODO: this is slightly inaccurate because truncdiv operates at integer + # precision, but we're going through float truediv which means we can + # potentially lose precision on the bounds + @classmethod + def truncdiv(cls, a: Any, b: Any) -> ValueRanges[Any]: + x = cls.truediv(a, b) + if x == ValueRanges.unknown(): + return x + + return cls.trunc(x) + + @classmethod + def sub(cls, a: Any, b: Any) -> ValueRanges[Any]: + return cls.add(a, cls.neg(b)) diff --git a/torch/_inductor/codegen/block_analysis.py b/torch/_inductor/codegen/block_analysis.py index 484fa135986..1c816eb8e29 100644 --- a/torch/_inductor/codegen/block_analysis.py +++ b/torch/_inductor/codegen/block_analysis.py @@ -17,8 +17,8 @@ class BlockPatternMatcher: Matches block indexing expressions. """ - @staticmethod - def get_subexpr_involving_symbol(expr: Expr, symbol: Symbol) -> Expr: + @classmethod + def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr: """ Given a sympy expression, return the subexpression comprised only of terms involving the specified symbol. @@ -26,6 +26,7 @@ class BlockPatternMatcher: For example, if `expr` is `x * 5 + x ** 2 + y * 2 + 5`, and `symbol` is `x`, this returns `x * 5 + x ** 2`. """ + expr = cls._preprocess(expr) return sympy.S.Zero + sum( term for term in sympy.Add.make_args(expr) if symbol in term.free_symbols ) @@ -42,6 +43,11 @@ class BlockPatternMatcher: numels.appendleft(numel) return [*numels] + @staticmethod + def _preprocess(expr: Expr) -> Expr: + # Remove any Identity nodes, e.g. expand x + (5 * y) to x + 5 * y. + return expr.expand(identity=True) + @classmethod def match_mod_div_block_expr( cls, @@ -54,6 +60,7 @@ class BlockPatternMatcher: Matches modular indexing expressions, converting them to implied block dimensions and strides. See triton.py for more information. """ + index = cls._preprocess(index) # Pattern match to find the strides and offset. wild = functools.partial(sympy.Wild, exclude=[index_var]) @@ -141,3 +148,21 @@ class BlockPatternMatcher: ) return dims, strides, block_index_exprs + + @classmethod + def match_affine_block_expr( + cls, + index: Expr, + index_var: Symbol, + ) -> Optional[Expr]: + """ + Matches simple expressions of the form stride * index, returning the + stride. + """ + index = cls._preprocess(index) + stride = sympy.Wild("stride", exclude=[index_var]) + m = index.match(index_var * stride) + if m is None: + return None + + return m[stride] diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index c254aacdb17..fec37fb6002 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -37,11 +37,11 @@ from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT -from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges from .. import config, metrics from ..dtype_propagation import DtypePropagationOpsHandler -from ..ops_handler import BasicMathOps +from ..ops_handler import BasicMathOpsMixin, DefaultHandler from ..utils import ( boolean_ops, DeferredLineBase, @@ -764,7 +764,7 @@ def _all_in_parens(string: str) -> bool: return True -class OpOverrides(BasicMathOps, OpDecompositions): +class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]): @staticmethod def paren(string: OpVarT) -> OpVarT: if ( @@ -948,6 +948,16 @@ class OpOverrides(BasicMathOps, OpDecompositions): f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend" ) + def output(self, *args: OpVarT) -> None: + raise AssertionError( + f"{type(self).__name__}: ops.output should not appear at codegen time" + ) + + def placeholder(self, index: int) -> OpVarT: + raise AssertionError( + f"{type(self).__name__}: ops.placeholder should not appear at codegen time" + ) + @staticmethod def _unimplemented(name: str) -> Callable[..., OpVarT]: def unimplemented(self: OpOverrides, *args: Any, **kwargs: Any) -> OpVarT: @@ -1225,12 +1235,6 @@ pointwise_overrides_data: dict[str, OverridesData] = dict( ) -if TYPE_CHECKING: - - class _typecheck_OpOverrides(OpOverrides, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - class DeferredLine(DeferredLineBase): """A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" @@ -2253,84 +2257,84 @@ class KernelTemplate: raise NotImplementedError -class CSEProxy: +class CSEProxy(DefaultHandler): name = "CSEProxy" def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]): super().__init__() + from ..bounds import ValueRangeAnalysis + self.vr_analysis = ValueRangeAnalysis() self.kernel = kernel self.parent_handler = parent_handler - def __getattr__(self, name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] - def inner(*args: Any, **kwargs: Any) -> CSEVariable: - bounds = self._bound_variable(name, *args, **kwargs) + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + bounds = self._bound_variable(name, *args, **kwargs) - value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type] - dtype_handler = DtypePropagationOpsHandler() + value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type] + dtype_handler = DtypePropagationOpsHandler() - output_idx = 0 + output_idx = 0 - def do_cse(v: str) -> CSEVariable: - # cpp backend doesnt set current device - TODO: fix - if V.graph.current_device is not None: - device_str = V.graph.get_current_device_or_throw().type - triton_backend = ( - config.cpu_backend == "triton" - if device_str == "cpu" - else config.cuda_backend == "triton" - if device_str != "mps" - else False - ) - else: - triton_backend = False - - # only triton backend tracks dtype currently - if triton_backend: - if name == "masked": - output_dtype = value.dtype - else: - output_dtype = getattr( - dtype_handler, - name, - )(*args, **kwargs) - else: - # cpp backend doesnt track dtype yet - output_dtype = None - - csevar = V.kernel.cse.generate( - V.kernel.compute, - v, - bounds=bounds, - dtype=output_dtype, + def do_cse(v: str) -> CSEVariable: + # cpp backend doesnt set current device - TODO: fix + if V.graph.current_device is not None: + device_str = V.graph.get_current_device_or_throw().type + triton_backend = ( + config.cpu_backend == "triton" + if device_str == "cpu" + else config.cuda_backend == "triton" + if device_str != "mps" + else False ) + else: + triton_backend = False - nonlocal output_idx - if config.test_configs.runtime_triton_dtype_assert and triton_backend: - from torch._inductor.codegen.triton import triton_type + # only triton backend tracks dtype currently + if triton_backend: + if name == "masked": + output_dtype = value.dtype + else: + output_dtype = getattr( + dtype_handler, + name, + )(*args, **kwargs) + else: + # cpp backend doesnt track dtype yet + output_dtype = None - # we tree_map over the output, so we need to fetch corresponding dtype - if isinstance(output_dtype, (list, tuple)): - output_dtype = output_dtype[output_idx] + csevar = V.kernel.cse.generate( + V.kernel.compute, + v, + bounds=bounds, + dtype=output_dtype, + ) - V.kernel.compute.writeline( - f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})" - ) - output_idx += 1 + nonlocal output_idx + if config.test_configs.runtime_triton_dtype_assert and triton_backend: + from torch._inductor.codegen.triton import triton_type - csevar.update_on_args(name, args, kwargs) + # we tree_map over the output, so we need to fetch corresponding dtype + if isinstance(output_dtype, (list, tuple)): + output_dtype = output_dtype[output_idx] - return csevar + V.kernel.compute.writeline( + f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})" + ) + output_idx += 1 - return pytree.tree_map(do_cse, value) + csevar.update_on_args(name, args, kwargs) - return inner + return csevar + + return pytree.tree_map(do_cse, value) def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[Any]: """ If the variable comes from an FX node, we forward the bound we have already computed Else, if the variable when codegen'ing another op, we try to compute its bounds """ + from ..bounds import ValueRangeAnalysis from ..select_algorithm import TritonTemplateKernel if isinstance(V.kernel, TritonTemplateKernel): @@ -2568,8 +2572,3 @@ class CSEProxy: sorter, sorter_indices, ) - - -# Use mypy to check protocol implemented correctly -def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]: - return h diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 560e75c648f..f2bdebf3c1b 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -33,7 +33,7 @@ from ..utils import ( sympy_index_symbol, sympy_subs, ) -from ..virtualized import _ops as ops, OpsHandler, V +from ..virtualized import _ops as ops, V from .common import ( BackendFeature, CSEVariable, @@ -563,12 +563,6 @@ class HalideOverrides(OpOverrides): HalideOverrides._initialize_pointwise_overrides("halide") -if TYPE_CHECKING: - - class _typecheck_HalideOverrides(HalideOverrides, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - class HalideCSEVariable(CSEVariable): undefined_re = re.compile(r"\b(tmp\d+)\[\?\]") diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 86cbb6f5361..dd3ff699e8a 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: import sympy - from ..ops_handler import OpsHandler, ReductionType, StoreMode + from ..ops_handler import ReductionType, StoreMode from ..scheduler import Scheduler, SchedulerNode from .common import OpVarT @@ -367,12 +367,6 @@ class MetalOverrides(OpOverrides): MetalOverrides._initialize_pointwise_overrides("mps") -if TYPE_CHECKING: - - class _typecheck_MetalOverrides(MetalOverrides, OpsHandler[Any]): - pass # mypy will error if we got any of the signatures wrong - - class MetalKernel(SIMDKernel): overrides = MetalOverrides # type: ignore[assignment] suffix = ";" diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index c0898d13c26..5f3e1b78783 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -31,6 +31,7 @@ from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_ty from ...utils._sympy.value_ranges import ValueRanges from .. import config, ir, metrics from ..codecache import code_hash, get_path, PyCodeCache +from ..ops_handler import DefaultHandler from ..runtime.benchmarking import benchmarker from ..runtime.hints import ( AutotuneHint, @@ -60,7 +61,7 @@ from ..utils import ( triton_version_uses_attrs_dict, upcast_compute_type, ) -from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V +from ..virtualized import _ops as ops, ReductionType, StoreMode, V from ..wrapper_benchmark import get_kernel_category_by_source_code from .block_analysis import BlockPatternMatcher from .common import ( @@ -1427,12 +1428,6 @@ class TritonKernelOverrides(TritonOverrides): return (mantissa, exponent) -if TYPE_CHECKING: - - class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - class HelperFunctions: """An ordered set of helper functions.""" @@ -1795,7 +1790,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): and self.index_dtype == "tl.int32" ): - def match_strided_block( + def match_affine_block( index: sympy.Expr, range_tree: IterationRangesRoot ) -> Optional[BlockParameters]: """ @@ -1804,16 +1799,16 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): This implies stride (s,), and shape (XBLOCK,). """ - symbol = range_tree.symbol() - stride = sympy.Wild("stride", exclude=[symbol]) - m = index.match(symbol * stride) - if m is None: + stride = BlockPatternMatcher.match_affine_block_expr( + index, range_tree.symbol() + ) + if stride is None: return None return BlockParameters( shape=[range_tree.numel], block_shape=[TritonSymbols.get_block_size(range_tree)], - strides=[m[stride]], + strides=[stride], offsets=[TritonSymbols.get_block_offset(range_tree)], ) @@ -1922,7 +1917,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): Match a block indexing subexpression involving a single range tree. """ for match_func in ( - match_strided_block, + match_affine_block, match_mod_div_block, ): match = match_func(expr, range_tree) @@ -2872,24 +2867,23 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): dtype_handler = DtypePropagationOpsHandler() - class CSEProxy: - def __getattr__(self, name: str) -> Callable[..., CSEVariable]: - def inner(*args, **kwargs): - nonlocal helper_name - helper_name += f"_{name}" + class CSEProxy(DefaultHandler): + def _default( + self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> Any: + nonlocal helper_name + helper_name += f"_{name}" - output_dtype = getattr( - dtype_handler, - name, - )(*args, **kwargs) + output_dtype = getattr( + dtype_handler, + name, + )(*args, **kwargs) - return cse.generate( - helper, - getattr(overrides, name)(*args, **kwargs), - dtype=output_dtype, - ) - - return inner + return cse.generate( + helper, + getattr(overrides, name)(*args, **kwargs), + dtype=output_dtype, + ) with helper.indent(), V.set_ops_handler(CSEProxy()): outputs = fn(*args) diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index f66e2e791a1..36000a50cb8 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -4,7 +4,7 @@ import itertools import logging import re from collections.abc import Sequence -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union from unittest.mock import patch import sympy @@ -15,6 +15,7 @@ from torch.utils._ordered_set import OrderedSet from ..utils._sympy.symbol import make_symbol, SymT from .codegen.common import index_prevent_reordering +from .ops_handler import DefaultHandler from .utils import ( get_dtype_size, reduction_num_outputs, @@ -23,7 +24,7 @@ from .utils import ( sympy_subs, VarRanges, ) -from .virtualized import OpsHandler, ReductionType, V +from .virtualized import ReductionType, V T = TypeVar("T") @@ -737,19 +738,16 @@ def canonicalization_prefix() -> str: # ops handler which computes all the free unbacked symbols for an IR -class FreeUnbackedSymbolsOpsHandler: +class FreeUnbackedSymbolsOpsHandler(DefaultHandler): symbols: OrderedSet[sympy.Symbol] def __init__(self) -> None: self.symbols = OrderedSet() - def __getattr__(self, name: str) -> Callable[..., Any]: - def inner(*args: Sequence[Any], **kwargs: Dict[Any, Any]) -> None: - for a in itertools.chain(args, kwargs.values()): - if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)): - self.symbols |= free_unbacked_symbols(a) - - return inner + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + for a in itertools.chain(args, kwargs.values()): + if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)): + self.symbols |= free_unbacked_symbols(a) def indirect_indexing( self, @@ -791,12 +789,6 @@ class FreeUnbackedSymbolsOpsHandler: body() -def _typecheck_FreeUnbackedSymbolsOpsHandler( - h: FreeUnbackedSymbolsOpsHandler, -) -> OpsHandler[None]: - return h - - def extract_free_unbacked_symbols( fn: Callable[..., Any], index: Sequence[sympy.Expr], diff --git a/torch/_inductor/dtype_propagation.py b/torch/_inductor/dtype_propagation.py index efe0ebe2caf..256079c8071 100644 --- a/torch/_inductor/dtype_propagation.py +++ b/torch/_inductor/dtype_propagation.py @@ -368,6 +368,16 @@ class DtypePropagationOpsHandler: ) -> None: return None + def output(self, *args: DTypeArg) -> None: + raise AssertionError( + f"{type(self).__name__}: ops.output should not appear here" + ) + + def placeholder(self, index: int) -> torch.dtype: + raise AssertionError( + f"{type(self).__name__}: ops.placeholder should not appear here" + ) + if TYPE_CHECKING: diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 91998631a49..92e5a0c189f 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -2,6 +2,7 @@ import functools import itertools import logging +import operator import typing from collections import Counter from typing import Any, Union @@ -442,6 +443,81 @@ def constant_fold_uniform_value(gm: torch.fx.GraphModule): remove_redundant_views(gm) +def canonicalize_quant_mapping(gm: torch.fx.GraphModule): + """ + + + torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'quant_invoke_0_0', (arg0_1, arg1_1)); + -> + torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4'); + """ + graph = gm.graph + invoke_quant_invocations = graph.find_nodes( + op="call_function", target=torch.ops.higher_order.invoke_quant_packed + ) + for invoke_quant in invoke_quant_invocations: + kwargs = dict(invoke_quant.kwargs) + + quant_options_node = kwargs.pop("quant_options", None) + if quant_options_node is not None: + assert isinstance(quant_options_node, torch.fx.Node) + quant_options = torch._higher_order_ops.InvokeQuant( + *invoke_quant.kwargs["quant_options"].args, + **invoke_quant.kwargs["quant_options"].kwargs, + ) + else: + quant_options = None + + subgraph, args = invoke_quant.args + with gm.graph.inserting_before(invoke_quant): + invoke_quant_replacement = graph.call_function( + torch._higher_order_ops.invoke_quant, + (subgraph, *args), + kwargs, + ) + invoke_quant_replacement.meta.update(subgraph.meta) + invoke_quant_replacement.meta["quant_options"] = quant_options + + invoke_quant.replace_all_uses_with(invoke_quant_replacement) + graph.erase_node(invoke_quant) + + if quant_options_node and len(quant_options_node.users) == 0: + graph.erase_node(quant_options_node) + + first_user = next(iter(invoke_quant_replacement.users)) + + if ( + len(invoke_quant_replacement.users) == 1 + and len(subgraph.users) == 1 + and first_user.target == operator.getitem + and first_user.args[1] == 0 + ): + subgraph_graph = getattr(gm, subgraph.target) + output_node = torch._inductor.utils.output_node(subgraph_graph) + assert ( + isinstance(output_node.args[0], (list, tuple)) + and len(output_node.args[0]) == 1 + ) + + unpacked_output = output_node.args[0][0] + output_node.args = (unpacked_output,) + if "val" in output_node.meta: + output_node.meta["val"] = output_node.meta["val"][0] + subgraph_graph.recompile() + + invoke_quant_replacement.meta.update(first_user.meta) + first_user.replace_all_uses_with(invoke_quant_replacement) + graph.erase_node(first_user) + + +def canonicalize_aten_ir_passes(gm: torch.fx.GraphModule): + """ + Canonicalization passes that will run immediately after aot autograd + tracing. Thsis must be run before all other graph passes. + """ + canonicalize_quant_mapping(gm) + + def joint_graph_passes(graph: torch.fx.GraphModule): """ Run FX transformations on the joint forwards+backwards graph. @@ -454,6 +530,15 @@ def joint_graph_passes(graph: torch.fx.GraphModule): lazy_init() count = 0 + # must occur before other passes + canonicalize_aten_ir_passes(graph) + + if config.joint_custom_pre_pass is not None: + GraphTransformObserver(graph, "joint_custom_pre_pass").apply_graph_pass( + config.joint_custom_pre_pass + ) + count += 1 + from .post_grad import remove_noop_ops GraphTransformObserver(graph, "remove_noop_ops").apply_graph_pass(remove_noop_ops) diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py index 310df89ffa3..2e564041340 100644 --- a/torch/_inductor/index_propagation.py +++ b/torch/_inductor/index_propagation.py @@ -21,8 +21,9 @@ SymPy expressions yet, despite sympy.Min and sympy.Max existing. """ import itertools +from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Callable, Literal, Optional, overload, Union +from typing import Any, Literal, Optional, overload, Union from typing_extensions import TypeAlias import sympy @@ -32,6 +33,7 @@ from torch._prims_common import dtype_to_type, is_integer_dtype from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges +from .ops_handler import DefaultHandler from .sizevars import evaluate_expr from .utils import generate_assert from .virtualized import V @@ -185,7 +187,7 @@ class IndexPropVar: IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]] -class IndexPropagation: +class IndexPropagation(DefaultHandler): """Ops wrapper that tries to propagate constant and index_expr values through the computation. This aims to maximize the compile time simplification possible, and convert @@ -247,19 +249,19 @@ class IndexPropagation: def fallback( self, name: Literal["indirect_indexing"], - args: tuple[Any, ...], + args: Sequence[Any], kwargs: dict[str, Any], ) -> IndexPropVar: ... @overload def fallback( - self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any] + self, name: str, args: Sequence[Any], kwargs: dict[str, Any] ) -> IndexPropResult: ... def fallback( - self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any] + self, name: str, args: Sequence[Any], kwargs: dict[str, Any] ) -> IndexPropResult: # Fallback to the wrapped handler new_args = [self.unwrap(a) for a in args] @@ -267,7 +269,7 @@ class IndexPropagation: return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs)) def propagate_sympy( - self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any] + self, name: str, args: Sequence[Any], kwargs: dict[str, Any] ) -> IndexPropResult: # Build a new SymPy expression from this ops call def unwrap(a: Union[Any, IndexPropVar]) -> Any: @@ -288,22 +290,19 @@ class IndexPropagation: return self.fallback(name, args, kwargs) return IndexPropVar.new_symbolic(new_expr) - def __getattr__(self, name: str) -> Callable[..., IndexPropResult]: - def inner(*args: Any, **kwargs: Any) -> IndexPropResult: - if not hasattr(SymPyOps, name): - return self.fallback(name, args, kwargs) + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + if not hasattr(SymPyOps, name): + return self.fallback(name, args, kwargs) - var_arguments = [ - a - for a in itertools.chain(args, kwargs.values()) - if isinstance(a, IndexPropVar) - ] - if not all(v.is_symbolic for v in var_arguments): - return self.fallback(name, args, kwargs) + var_arguments = [ + a + for a in itertools.chain(args, kwargs.values()) + if isinstance(a, IndexPropVar) + ] + if not all(v.is_symbolic for v in var_arguments): + return self.fallback(name, args, kwargs) - return self.propagate_sympy(name, args, kwargs) - - return inner + return self.propagate_sympy(name, args, kwargs) def statically_true(self, e): """ diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index d90398ca043..6a800d9e81d 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -75,7 +75,7 @@ from .dependencies import ( var_builder, ) from .loop_body import LoopBody -from .ops_handler import OpCounterCSE, OpCountResult +from .ops_handler import OpCounterCSE, OpCountResult, ReductionType, StoreMode from .runtime.benchmarking import benchmarker from .runtime.hints import DeviceProperties, ReductionHint from .utils import ( @@ -916,9 +916,9 @@ class Pointwise(Loops): output_name: Optional[str], indexer: Callable[[Sequence[Expr]], Never], vars: Sequence[Expr], - ) -> OpsValue: + ) -> None: loader = self.make_loader() - return ops.store(output_name, indexer(vars), loader(vars)) + return ops.store(output_name or "unnamed", indexer(vars), loader(vars)) def constant_to_device(self, device: torch.device) -> IRNode: """Move this to a given device. Requires that all reads are to constants.""" @@ -932,7 +932,7 @@ class Pointwise(Loops): @ir_dataclass class Scatter(Pointwise): output_indexer: Callable[[Sequence[Expr]], Expr] - scatter_mode: Optional[str] = None + scatter_mode: StoreMode = None def constant_to_device(self, device: torch.device) -> IRNode: """Move this to a given device. Requires that all reads are to constants.""" @@ -952,8 +952,10 @@ class Scatter(Pointwise): output_name: Optional[str], indexer: Callable[[Sequence[Expr]], Never], vars: Sequence[Expr], - ) -> OpsValue: + ) -> None: loader = self.make_loader() + if output_name is None: + output_name = "unnamed" return ops.store( output_name, indexer(self.output_indexer(vars)), @@ -1038,7 +1040,7 @@ def get_reduction_combine_fn( @ir_dataclass class Reduction(Loops): reduction_ranges: Sequence[_IntLike] - reduction_type: str + reduction_type: ReductionType # self.dtype represents the dst dtype src_dtype: torch.dtype reduction_hint: ReductionHint @@ -1065,14 +1067,14 @@ class Reduction(Loops): indexer: Callable[[Sequence[Expr]], Never], vars: Sequence[Expr], reduction_vars: Sequence[Symbol], - ) -> OpsValue: + ) -> None: value = ops.reduction( self.dtype, self.src_dtype, self.reduction_type, self.inner_fn(vars, reduction_vars), ) - return ops.store_reduction(output_name, indexer(vars), value) + return ops.store_reduction(output_name or "unnamed", indexer(vars), value) def index_length(self) -> int: return len(self.ranges) + len(self.reduction_ranges) @@ -1110,7 +1112,7 @@ class Reduction(Loops): inner_fn: Callable[..., OpsValue], ranges: Sequence[_IntLike], reduction_ranges: Sequence[_IntLike], - reduction_type: str, + reduction_type: Union[ReductionType, Literal["scan"]], reduction_numel: Expr, input_node: Optional[IRNode] = None, ) -> tuple[ReductionHint, _IntLike]: @@ -1196,7 +1198,7 @@ class Reduction(Loops): inner_fn=inner_fn, ranges=ranges, reduction_ranges=reduction_ranges, - reduction_type=reduction_type, + reduction_type=reduction_type if reduction_type != "scan" else "sum", src_dtype=src_dtype, reduction_hint=ReductionHint.DEFAULT, ) @@ -1323,7 +1325,7 @@ class Reduction(Loops): inner_fn: Callable[..., Any], ranges: Sequence[Expr], reduction_ranges: Sequence[Expr], - reduction_type: str, + reduction_type: ReductionType, reduction_hint: ReductionHint = ReductionHint.DEFAULT, input_node: Optional[IRNode] = None, ) -> TensorBox: @@ -1593,7 +1595,7 @@ class Reduction(Loops): original_reduction_ranges: Sequence[Expr], new_ranges: list[Expr], new_reduction_ranges: list[Integer], - reduction_type: str, + reduction_type: ReductionType, split: _IntLike, reduction_hint: ReductionHint, ) -> TensorBox: @@ -1655,7 +1657,7 @@ class Reduction(Loops): inner_fn: Callable[..., Any], ranges: Sequence[Expr], reduction_ranges: Sequence[Expr], - reduction_type: str, + reduction_type: ReductionType, split: _IntLike, reduction_hint: ReductionHint, ) -> TensorBox: @@ -1696,7 +1698,7 @@ class Reduction(Loops): original_reduction_ranges: Sequence[Expr], new_ranges: list[Integer], new_reduction_ranges: list[Integer], - reduction_type: str, + reduction_type: ReductionType, reduction_hint: ReductionHint, ) -> TensorBox: """ @@ -1735,7 +1737,7 @@ class WelfordReduction(Reduction): inner_fns: Sequence[Callable[[Sequence[Expr], Sequence[Expr]], OpsValue]], ranges: Sequence[Integer], reduction_ranges: Sequence[Integer], - reduction_type: str, + reduction_type: ReductionType, reduction_hint: ReductionHint, output_index: int, ) -> None: @@ -1767,7 +1769,7 @@ class WelfordReduction(Reduction): indexer: Callable[[Sequence[Expr]], Never], vars: Sequence[Expr], reduction_vars: Sequence[Symbol], - ) -> OpsValue: + ) -> None: values = ops.reduction( self.dtype, self.src_dtype, @@ -1775,7 +1777,7 @@ class WelfordReduction(Reduction): self.inner_fn(vars, reduction_vars), ) value = values[self.output_index] - return ops.store_reduction(output_name, indexer(vars), value) + return ops.store_reduction(output_name or "unnamed", indexer(vars), value) @classmethod def create( # type: ignore[override] @@ -1785,7 +1787,7 @@ class WelfordReduction(Reduction): inner_fns: Sequence[Callable[..., Any]], ranges: list[Integer], reduction_ranges: list[Integer], - reduction_type: str, + reduction_type: ReductionType, reduction_hint: ReductionHint = ReductionHint.DEFAULT, ) -> Sequence[TensorBox]: assert reduction_type in ("welford_reduce", "welford_combine") @@ -1911,7 +1913,7 @@ class WelfordReduction(Reduction): inner_fns: Sequence[Callable[..., Any]], ranges: list[Integer], reduction_ranges: list[Integer], - reduction_type: str, + reduction_type: ReductionType, split: _IntLike, reduction_hint: ReductionHint, ) -> Sequence[TensorBox]: @@ -2031,11 +2033,13 @@ class Scan(Loops): indexer: Callable[[Sequence[_IntLike]], Never], vars: Sequence[Expr], scan_vars: Sequence[Symbol], - ) -> OpsValue: + ) -> None: idx = self.reindex(vars, scan_vars) - values = [inner_fn(idx) for inner_fn in self.inner_fns] + values = tuple(inner_fn(idx) for inner_fn in self.inner_fns) result = ops.scan(self.dtypes, self.combine_fn, values) - return ops.store(output_name, indexer(idx), result[self.output_index]) + return ops.store( + output_name or "unnamed", indexer(idx), result[self.output_index] + ) def get_reduction_type(self) -> Optional[str]: # return self.scan_op @@ -2229,11 +2233,13 @@ class Sort(Loops): indexer: Callable[[Sequence[Expr]], Expr], vars: Sequence[Expr], reduction_vars: Sequence[Expr], - ) -> OpsValue: + ) -> None: idx = self.reindex(vars, reduction_vars) - values = [inner_fn(idx) for inner_fn in self.inner_fns] + values = tuple(inner_fn(idx) for inner_fn in self.inner_fns) result = ops.sort(self.dtypes, values, self.stable, self.descending) - return ops.store(output_name, indexer(idx), result[self.output_index]) + return ops.store( + output_name or "unnamed", indexer(idx), result[self.output_index] + ) def get_reduction_type(self) -> Optional[str]: return "sort" @@ -3790,7 +3796,7 @@ class Buffer(IRNode): def loader(index): # type: ignore[no-untyped-def] indexer = self.make_indexer() - return ops.load(self.name, indexer(index)) + return ops.load(self.name or "unnamed", indexer(index)) return loader @@ -3983,7 +3989,7 @@ class ComputedBuffer(OperationBuffer): return self.data.make_loader() return super().make_loader() - def get_store_function(self) -> Callable[..., OpsValue]: + def get_store_function(self) -> Callable[..., None]: indexer = self.get_layout().as_fixed().make_indexer() if isinstance(self.data, (Reduction, Scan, Sort)): return partial(self.data.store_reduction, self.name, indexer) diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index 21f63a11b67..c3a3ab7133e 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -17,6 +17,7 @@ from torch.utils._sympy.symbol import SymT from . import config, dependencies from .codegen.common import index_prevent_reordering +from .ops_handler import DefaultHandler, OpsHandler, WrapperHandler from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs from .virtualized import ops, V @@ -439,179 +440,13 @@ class LoopBodyBlock: def __init__(self, body: LoopBody, fn: Callable[..., Any], args: list[Any]): self.body = body - - def add_index(expr: sympy.Expr, mtype: MemoryUsageType, **kwargs): - return tracer.create_proxy( - "call_module", - "get_index", - (body.add_index_expr(expr, mtype, **kwargs),), - {}, - ) - - class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined] - self.name = "CaptureIndexing" - - def load(self, name: str, index: sympy.Expr): - index = add_index(index, MemoryUsageType.LOAD, buffer_name=name) - return self._inner.load(name, index) - - def load_seed(self, name: str, index: int): - assert isinstance(index, int) - body.add_index_expr( - sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name - ) - return self._inner.load_seed(name, index) - - def store(self, name, index, value, mode=None): - index = add_index( - index, MemoryUsageType.STORE, buffer_name=name, mode=mode - ) - return self._inner.store(name, index, value, mode) - - def store_reduction(self, name, index, value): - index = add_index( - index, MemoryUsageType.STORE_REDUCTION, buffer_name=name - ) - return self._inner.store_reduction(name, index, value) - - def reduction(self, dtype, src_dtype, reduction_type, value): - result = self._inner.reduction(dtype, src_dtype, reduction_type, value) - if "welford" in reduction_type: - return tuple(result[i] for i in range(3)) - return result - - def index_expr(self, index, dtype): - if isinstance(index, (int, sympy.Integer)): - return self._inner.constant(int(index), dtype) - index = add_index(index, MemoryUsageType.INDEX_EXPR) - return self._inner.index_expr(index, dtype) - - def check_bounds(self, index, size, lower, upper): - index = add_index(index, MemoryUsageType.CHECK_BOUNDS) - size = add_index(size, MemoryUsageType.CHECK_BOUNDS) - return self._inner.check_bounds(index, size, lower, upper) - - def bucketize( - self, - values: T, - boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], - boundary_indices: T, - indexing_dtype: torch.dtype, - right: bool, - sorter: Optional[tuple[str, sympy.Expr]] = None, - sorter_indices: Optional[T] = None, - ) -> T: - """ - See [Note: Inductor bucketize op] - """ - boundaries = ( - boundaries[0], - add_index( - boundaries[1], - MemoryUsageType.BUCKETIZE, - buffer_name=boundaries[0], - ), - add_index( - boundaries[2], - MemoryUsageType.BUCKETIZE, - buffer_name=boundaries[0], - ), - add_index( - boundaries[3], - MemoryUsageType.BUCKETIZE, - buffer_name=boundaries[0], - ), - ) - if sorter is not None: - sorter = ( - sorter[0], - add_index( - sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0] - ), - ) - - return self._inner.bucketize( - values, - boundaries, - boundary_indices, - indexing_dtype, - right, - sorter, - sorter_indices, - ) - - @staticmethod - def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy): - """ - Recursively capture the masked out body in another LoopBodyBlock - """ - name = self.body.add_submodule(None, "masked_subblock") - self.body.submodules[name] = self.body.bind_masked_shim(name) - self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, []) - return tracer.create_proxy( - "call_module", name, (mask_proxy, other_proxy), {} - ) - - @staticmethod - def scan( - dtype_proxy, - combine_fn: Callable[ - [tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...] - ], - value_proxy, - ): - shim = self.body.bind_scan_shim(combine_fn) - name = self.body.add_submodule(shim, "scan") - result = tracer.create_proxy( - "call_module", - name, - (dtype_proxy, value_proxy), - {}, - ) - # Proxies are iterable, but some methods expect tuples/lists - return tuple(result[i] for i in range(len(value_proxy))) - - def sort(self, dtypes, values, stable, descending): - result = self._inner.sort(dtypes, values, stable, descending) - # Proxies are iterable, but some methods expect tuples/lists - return tuple(result[i] for i in range(len(values))) - - def frexp(self, value_proxy): - result = self._inner.frexp(value_proxy) - # Proxies are iterable, but some methods expect tuples/lists - return (result[0], result[1]) - - @staticmethod - def indirect_indexing(index_proxy, size, check=True, wrap_neg=True): - """ - Flow data from tensors into indexing formulas. - Introduce a call_module to update the indexing. - """ - - var = self.body.add_indirect(size) - set_indirect = self.body.bind_set_indirect_shim( - var, size, check, wrap_neg - ) - tracer.create_proxy( - "call_module", - self.body.add_submodule(set_indirect, f"set_{var}"), - (index_proxy,), - {}, - ) - return var - - @staticmethod - def output(result): - tracer.create_proxy("output", "output", (result,), {}) - tracer = LightTracer() proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) from .index_propagation import IndexPropagation - from .sizevars import SimplifyIndexing handler: Any = CountOps( - SimplifyIndexing(CaptureIndexing(proxy_ops), self.body.var_ranges), + CaptureIndexing(proxy_ops, body, tracer), body.op_counts, ) if config.constant_and_index_propagation: @@ -653,11 +488,187 @@ class LoopBodyBlock: return copy -class CountOps: - def __init__(self, inner: Any, counts: collections.Counter[str]): +class CountOps(DefaultHandler): + def __init__(self, inner: OpsHandler[Any], counts: collections.Counter[str]): self._inner = inner self._counts = counts - def __getattr__(self, name): + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: self._counts[name] += 1 - return getattr(self._inner, name) + return getattr(self._inner, name)(*args, **kwargs) + + +class CaptureIndexing(WrapperHandler): + name = "CaptureIndexing" + + def __init__( + self, + inner: OpsHandler[Any], + body: LoopBody, + tracer: LightTracer, + ): + super().__init__(inner) + self.body = body + self.tracer = tracer + + def _add_index(self, expr: sympy.Expr, mtype: MemoryUsageType, **kwargs: Any): + return self.tracer.create_proxy( + "call_module", + "get_index", + (self.body.add_index_expr(expr, mtype, **kwargs),), + {}, + ) + + def _simplify(self, expr: sympy.Expr) -> sympy.Expr: + return V.graph.sizevars.simplify_with_ranges(expr, self.body.var_ranges) + + def load(self, name: str, index: sympy.Expr): + index = self._simplify(index) + index = self._add_index(index, MemoryUsageType.LOAD, buffer_name=name) + return self._inner.load(name, index) + + def load_seed(self, name: str, index: int): + assert isinstance(index, int) + self.body.add_index_expr( + sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name + ) + return self._inner.load_seed(name, index) + + def store(self, name, index, value, mode=None): + index = self._simplify(index) + index = self._add_index( + index, MemoryUsageType.STORE, buffer_name=name, mode=mode + ) + return self._inner.store(name, index, value, mode) + + def store_reduction(self, name, index, value): + index = self._simplify(index) + index = self._add_index( + index, MemoryUsageType.STORE_REDUCTION, buffer_name=name + ) + return self._inner.store_reduction(name, index, value) + + def reduction(self, dtype, src_dtype, reduction_type, value): + result = self._inner.reduction(dtype, src_dtype, reduction_type, value) + if "welford" in reduction_type: + return tuple(result[i] for i in range(3)) + return result + + def index_expr(self, index, dtype): + index = self._simplify(index) + if isinstance(index, (int, sympy.Integer)): + return self._inner.constant(int(index), dtype) + index = self._add_index(index, MemoryUsageType.INDEX_EXPR) + return self._inner.index_expr(index, dtype) + + def check_bounds(self, index, size, lower, upper): + index = self._simplify(index) + index = self._add_index(index, MemoryUsageType.CHECK_BOUNDS) + size = self._add_index(size, MemoryUsageType.CHECK_BOUNDS) + return self._inner.check_bounds(index, size, lower, upper) + + def bucketize( + self, + values: T, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: T, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[T] = None, + ) -> T: + """ + See [Note: Inductor bucketize op] + """ + boundaries = ( + boundaries[0], + self._add_index( + boundaries[1], + MemoryUsageType.BUCKETIZE, + buffer_name=boundaries[0], + ), + self._add_index( + boundaries[2], + MemoryUsageType.BUCKETIZE, + buffer_name=boundaries[0], + ), + self._add_index( + boundaries[3], + MemoryUsageType.BUCKETIZE, + buffer_name=boundaries[0], + ), + ) + if sorter is not None: + sorter = ( + sorter[0], + self._add_index( + sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0] + ), + ) + + return self._inner.bucketize( + values, + boundaries, + boundary_indices, + indexing_dtype, + right, + sorter, + sorter_indices, + ) + + def masked(self, mask_proxy, masked_body: Callable[..., Any], other_proxy): + """ + Recursively capture the masked out body in another LoopBodyBlock + """ + name = self.body.add_submodule(None, "masked_subblock") + self.body.submodules[name] = self.body.bind_masked_shim(name) + self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, []) + return self.tracer.create_proxy( + "call_module", name, (mask_proxy, other_proxy), {} + ) + + def scan( + self, + dtype_proxy, + combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]], + value_proxy, + ): + shim = self.body.bind_scan_shim(combine_fn) + name = self.body.add_submodule(shim, "scan") + result = self.tracer.create_proxy( + "call_module", + name, + (dtype_proxy, value_proxy), + {}, + ) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(value_proxy))) + + def sort(self, dtypes, values, stable, descending): + result = self._inner.sort(dtypes, values, stable, descending) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(values))) + + def frexp(self, value_proxy): + result = self._inner.frexp(value_proxy) + # Proxies are iterable, but some methods expect tuples/lists + return (result[0], result[1]) + + def indirect_indexing(self, index_proxy, size, check=True, wrap_neg=True): + """ + Flow data from tensors into indexing formulas. + Introduce a call_module to update the indexing. + """ + + var = self.body.add_indirect(size) + set_indirect = self.body.bind_set_indirect_shim(var, size, check, wrap_neg) + self.tracer.create_proxy( + "call_module", + self.body.add_submodule(set_indirect, f"set_{var}"), + (index_proxy,), + {}, + ) + return var + + def output(self, *result): + self.tracer.create_proxy("output", "output", result, {}) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index eacf4dbb3d0..3383a4e0773 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -12,7 +12,7 @@ import os import warnings from collections import defaultdict from collections.abc import Iterable, Sequence -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union from typing_extensions import ParamSpec from unittest.mock import patch @@ -81,6 +81,10 @@ from .utils import ( from .virtualized import ops, V +if TYPE_CHECKING: + from .ops_handler import ReductionType + + _T = TypeVar("_T") _P = ParamSpec("_P") @@ -5633,7 +5637,7 @@ def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype): ) -def make_reduction(reduction_type: str, override_return_dtype=None): +def make_reduction(reduction_type: ReductionType, override_return_dtype=None): def inner(x, axis=None, keepdims=False, *, dtype=None): kwargs = _make_reduction_inner( x, @@ -6750,6 +6754,23 @@ def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, operands): return list(map(TensorBox.create, result)) +@register_lowering(torch._higher_order_ops.invoke_quant, type_promotion_kind=None) +def invoke_quant_tracer(subgraph_fn: ir.Subgraph, *operands, scheme=None): + output = None + for i, node in enumerate(subgraph_fn.graph_module.graph.nodes): + if node.op == "placeholder": + V.graph.env[node] = operands[i] + continue + # todo getattr + elif node.op == "output": + args, kwargs = V.graph.fetch_args_kwargs_from_env(node) + output = torch.fx.Interpreter.output(V.graph, node, args, kwargs) + else: + V.graph.env[node] = V.graph.run_node(node) + + return output + + @register_lowering(associative_scan_op, type_promotion_kind=None) def associative_scan(combine_fn: ir.Subgraph, xs): from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 935c5f6fc36..0118d29368c 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -1,10 +1,12 @@ # mypy: allow-untyped-defs from __future__ import annotations +import inspect import itertools import re +import warnings +from io import StringIO from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union -from typing_extensions import Protocol from unittest.mock import patch import sympy @@ -38,12 +40,8 @@ def _arg_str(a: object) -> str: return str(a) -# NB: This is not done as a parent class, because our ops handlers -# implementations make heavy use of __getattr__ magic, and pre-existing -# stubs for methods would interfere with this mechanism. -# # See OpDecompositions for superclass that desugars operations like reciprocal/square. -class OpsHandler(Protocol[T]): +class OpsHandler(Generic[T]): """ Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``, as well as the contract for op handlers. The type T signifies the domain @@ -67,49 +65,30 @@ class OpsHandler(Protocol[T]): ops handlers. Handlers are often defined using metaprogramming (e.g. _initialize_pointwise_overrides), - which means you will get type errors if you subclass OpsHandler since mypy doesn't know - about the methods added via metaprogramming and thinks the class is still abstract. - Instead, you should add a block like: - - if TYPE_CHECKING: - - class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - Which will check the signatures of non-meta-programmed methods and gives decent error messages. - - Some older parts of the code use a pattern like: - - def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]: - return h - - This pattern only works if the class defines a __getattr__ method, which we are moving away from. - Additionally, this pattern generates horrible error messages if the signatures are wrong. - It gives zero information about what the problem is, which makes the pattern harmful. - - Instead of that, we have tests in test/inductor/test_op_completeness.py which check that all - operators are implemented after all the metaprogramming has run. + which means you will not get type errors for those methods. We have tests in + test/inductor/test_op_completeness.py which check that all operators are implemented after + all the metaprogramming has run. """ def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T: """Produces a scalar constant of type dtype.""" - ... + raise NotImplementedError def load_seed(self, name: str, offset: T) -> T: """Computes inductor_prims.lookup_seed.""" - ... + raise NotImplementedError def rand(self, seed: T, offset: T) -> T: """Computes inductor_prims.random with mode="rand". offset has dtype int32.""" - ... + raise NotImplementedError def randn(self, seed: T, offset: T) -> T: """Computes inductor_prims.random with mode="randn". offset has dtype int32.""" - ... + raise NotImplementedError def randint64(self, seed: T, offset: T, low: T, high: T) -> T: """Computes inductor_prims.randint. offset has dtype int32.""" - ... + raise NotImplementedError def masked(self, mask: T, body: Callable[[], T], other: T) -> T: """ @@ -123,13 +102,13 @@ class OpsHandler(Protocol[T]): Contrast this with ops.where, which can multiplex between two values that have been unconditionally computed. """ - ... + raise NotImplementedError def where(self, condition: T, input: T, other: T) -> T: """ Computes torch.where: when condition is true, return input; otherwise return other. """ - ... + raise NotImplementedError def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> T: """ @@ -137,7 +116,7 @@ class OpsHandler(Protocol[T]): an indexing expression, thus the name; however, it can also be used in non-indexing situations. """ - ... + raise NotImplementedError def to_dtype( self, @@ -150,7 +129,7 @@ class OpsHandler(Protocol[T]): Convert x to dtype. src_dtype can be optionally set to specify what the original dtype of x was, which can improve code generation (used by torch to(dtype=dtype)). """ - ... + raise NotImplementedError def trunc_to_int(self, x: T, dtype: torch.dtype) -> T: """ @@ -164,38 +143,38 @@ class OpsHandler(Protocol[T]): int64 depending on if we've shown that all the indexing operations can be done in int32. """ - ... + raise NotImplementedError def ceil_to_int(self, x: T, dtype: torch.dtype) -> T: """ Convert x to dtype with ceiling semantics. See also trunc_to_int. """ - ... + raise NotImplementedError def floor_to_int(self, x: T, dtype: torch.dtype) -> T: """ Convert x to dtype with ceiling semantics. See also trunc_to_int. """ - ... + raise NotImplementedError def round_to_int(self, x: T, dtype: torch.dtype) -> T: """ Convert x to dtype with round-to-even semantics. See also trunc_to_int. """ - ... + raise NotImplementedError def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T: """ Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.) src_dtype must be the original type of x. """ - ... + raise NotImplementedError def identity(self, x: T) -> T: """ Returns x as is. This is used to trigger CSE. """ - ... + raise NotImplementedError # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # These operations are only available in a "kernel" context. Check @@ -217,13 +196,13 @@ class OpsHandler(Protocol[T]): NB: This is typically mandatory to implement for any analysis, because you MUST return a valid sympy.Expr of some sort (even if it's a meaningless symbol). """ - ... + raise NotImplementedError def load(self, name: str, index: sympy.Expr) -> T: """ Load from the memory location 'name', offset by some indexing expression 'index'. """ - ... + raise NotImplementedError def store( self, @@ -236,7 +215,7 @@ class OpsHandler(Protocol[T]): Store 'value' to the memory location 'name' offset by 'expr'. If specified, 'mode' can require the store to be an atomic addition. """ - ... + raise NotImplementedError # TODO: Better explain how the "collective" semantics of these ops; # remember that the input value is a scalar, you can't reduce on it in the @@ -258,7 +237,7 @@ class OpsHandler(Protocol[T]): function returns multiple outputs; consult reduction_num_outputs to determine the amount in metaprogramming applications. """ - ... + raise NotImplementedError # TODO: in practice, this seems to actually return None, but not returning # a T makes common __getattr__ idioms not type correctly. Figure out if @@ -268,7 +247,7 @@ class OpsHandler(Protocol[T]): Store the fully accumulated result of 'reduction' to the memory location 'name' offset by 'expr'. """ - ... + raise NotImplementedError def scan( self, @@ -280,7 +259,7 @@ class OpsHandler(Protocol[T]): Perform an associative scan on 'value'. """ # TODO: Improve the description with some pseudocode - ... + raise NotImplementedError def sort( self, @@ -292,7 +271,7 @@ class OpsHandler(Protocol[T]): """ Sort values along the reduction dimension. """ - ... + raise NotImplementedError def bucketize( self, @@ -305,231 +284,231 @@ class OpsHandler(Protocol[T]): sorter_indices: Optional[T] = None, ) -> T: # See [Note: Inductor bucketize op] - ... + raise NotImplementedError # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # The following ops have semantics that correspond exactly to the torch # operation with the same corresponding name. def abs(self, x0: T) -> T: - ... + raise NotImplementedError def exp(self, x0: T) -> T: - ... + raise NotImplementedError def exp2(self, x0: T) -> T: - ... + raise NotImplementedError def expm1(self, x0: T) -> T: - ... + raise NotImplementedError def sqrt(self, x0: T) -> T: - ... + raise NotImplementedError def relu(self, x0: T) -> T: - ... + raise NotImplementedError def minimum(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def maximum(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def cos(self, x0: T) -> T: - ... + raise NotImplementedError def sin(self, x0: T) -> T: - ... + raise NotImplementedError def lgamma(self, x0: T) -> T: - ... + raise NotImplementedError def erf(self, x0: T) -> T: - ... + raise NotImplementedError def cosh(self, x0: T) -> T: - ... + raise NotImplementedError def sinh(self, x0: T) -> T: - ... + raise NotImplementedError def acos(self, x0: T) -> T: - ... + raise NotImplementedError def acosh(self, x0: T) -> T: - ... + raise NotImplementedError def asin(self, x0: T) -> T: - ... + raise NotImplementedError def asinh(self, x0: T) -> T: - ... + raise NotImplementedError def atan2(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def atan(self, x0: T) -> T: - ... + raise NotImplementedError def atanh(self, x0: T) -> T: - ... + raise NotImplementedError def copysign(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def erfc(self, x0: T) -> T: - ... + raise NotImplementedError def erfinv(self, x0: T) -> T: - ... + raise NotImplementedError def frexp(self, x0: T): - ... + raise NotImplementedError def hypot(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def log10(self, x0: T) -> T: - ... + raise NotImplementedError def log2(self, x0: T) -> T: - ... + raise NotImplementedError def nextafter(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def logical_and(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def logical_not(self, x0: T) -> T: - ... + raise NotImplementedError def logical_or(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def logical_xor(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def bitwise_and(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def bitwise_not(self, x0: T) -> T: - ... + raise NotImplementedError def bitwise_or(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def bitwise_xor(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def bitwise_left_shift(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def bitwise_right_shift(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def rsqrt(self, x0: T) -> T: - ... + raise NotImplementedError def log1p(self, x0: T) -> T: - ... + raise NotImplementedError def tan(self, x0: T) -> T: - ... + raise NotImplementedError def tanh(self, x0: T) -> T: - ... + raise NotImplementedError def sigmoid(self, x0: T) -> T: - ... + raise NotImplementedError def signbit(self, x0: T) -> T: - ... + raise NotImplementedError def fmod(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def log(self, x0: T) -> T: - ... + raise NotImplementedError def isinf(self, x0: T) -> T: - ... + raise NotImplementedError def isnan(self, x0: T) -> T: - ... + raise NotImplementedError # NB: this returns a float, like the torch operation # This rounds half to even to break ties def round(self, x0: T) -> T: - ... + raise NotImplementedError # NB: this returns a float, like the torch operation def floor(self, x0: T) -> T: - ... + raise NotImplementedError def sign(self, x0: T) -> T: - ... + raise NotImplementedError # NB: this returns a float, like the torch operation def trunc(self, x0: T) -> T: - ... + raise NotImplementedError # NB: this returns a float, like the torch operation def ceil(self, x0: T) -> T: - ... + raise NotImplementedError def neg(self, x0: T) -> T: - ... + raise NotImplementedError def reciprocal(self, x0: T) -> T: - ... + raise NotImplementedError def eq(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def ne(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def lt(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def gt(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def le(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def ge(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def add(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def sub(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def mul(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError # NB: this returns a float, like the torch operation def pow(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def and_(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def or_(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def xor(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError # These are metaprogrammed by MockHandler._init_cls def lshift(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def rshift(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # These are "special" operators. These only exist if the target @@ -537,124 +516,124 @@ class OpsHandler(Protocol[T]): # pointwise_overrides_data. def airy_ai(self, x: T) -> T: - ... + raise NotImplementedError def bessel_j0(self, x: T) -> T: - ... + raise NotImplementedError def bessel_j1(self, x: T) -> T: - ... + raise NotImplementedError def bessel_y0(self, x: T) -> T: - ... + raise NotImplementedError def bessel_y1(self, x: T) -> T: - ... + raise NotImplementedError def digamma(self, x: T) -> T: - ... + raise NotImplementedError def erfcx(self, x: T) -> T: - ... + raise NotImplementedError def fma(self, x: T, y: T, z: T) -> T: - ... + raise NotImplementedError def igamma(self, x: T, y: T) -> T: - ... + raise NotImplementedError def igammac(self, x: T, y: T) -> T: - ... + raise NotImplementedError def gammainc(self, x: T, y: T) -> T: - ... + raise NotImplementedError def gammaincc(self, x: T, y: T) -> T: - ... + raise NotImplementedError def i0(self, x: T) -> T: - ... + raise NotImplementedError def i0e(self, x: T) -> T: - ... + raise NotImplementedError def i1(self, x: T) -> T: - ... + raise NotImplementedError def i1e(self, x: T) -> T: - ... + raise NotImplementedError def log_ndtr(self, x: T) -> T: - ... + raise NotImplementedError def modified_bessel_i0(self, x: T) -> T: - ... + raise NotImplementedError def modified_bessel_i1(self, x: T) -> T: - ... + raise NotImplementedError def modified_bessel_k0(self, x: T) -> T: - ... + raise NotImplementedError def modified_bessel_k1(self, x: T) -> T: - ... + raise NotImplementedError def ndtr(self, x: T) -> T: - ... + raise NotImplementedError def ndtri(self, x: T) -> T: - ... + raise NotImplementedError def polygamma(self, x: T, y: T) -> T: - ... + raise NotImplementedError def scaled_modified_bessel_k0(self, x: T) -> T: - ... + raise NotImplementedError def scaled_modified_bessel_k1(self, x: T) -> T: - ... + raise NotImplementedError def spherical_bessel_j0(self, x: T) -> T: - ... + raise NotImplementedError def zeta(self, x: T, y: T) -> T: - ... + raise NotImplementedError def chebyshev_polynomial_t(self, x: T, y: T) -> T: - ... + raise NotImplementedError def chebyshev_polynomial_u(self, x: T, y: T) -> T: - ... + raise NotImplementedError def chebyshev_polynomial_v(self, x: T, y: T) -> T: - ... + raise NotImplementedError def chebyshev_polynomial_w(self, x: T, y: T) -> T: - ... + raise NotImplementedError def legendre_polynomial_p(self, x: T, y: T) -> T: - ... + raise NotImplementedError def shifted_chebyshev_polynomial_t(self, x: T, y: T) -> T: - ... + raise NotImplementedError def shifted_chebyshev_polynomial_u(self, x: T, y: T) -> T: - ... + raise NotImplementedError def shifted_chebyshev_polynomial_v(self, x: T, y: T) -> T: - ... + raise NotImplementedError def shifted_chebyshev_polynomial_w(self, x: T, y: T) -> T: - ... + raise NotImplementedError def hermite_polynomial_h(self, x: T, y: T) -> T: - ... + raise NotImplementedError def hermite_polynomial_he(self, x: T, y: T) -> T: - ... + raise NotImplementedError def laguerre_polynomial_l(self, x: T, y: T) -> T: - ... + raise NotImplementedError # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # These operators are a bit special, because they are conventionally @@ -665,42 +644,42 @@ class OpsHandler(Protocol[T]): """C-style trunc division between integers only. Computes the true division of two numbers and rounds the result to zero. """ - ... + raise NotImplementedError def floordiv(self, x0: T, x1: T) -> T: """Python-style floor division between integers only. Computes the true division of two numbers and floors the result. If you want floor division for floats, do regular truediv and floor the result. """ - ... + raise NotImplementedError def truediv(self, x0: T, x1: T) -> T: """True division between floats. Integer inputs are NOT valid. To do Python-style (int, int) -> float division, use int_truediv""" - ... + raise NotImplementedError def int_truediv(self, x0: T, x1: T) -> T: """True division between integers. This is NOT the same as promoting to float and doing integer division, there is a bespoke algorithm for doing the division in higher precision than the above. """ - ... + raise NotImplementedError def mod(self, x0: T, x1: T) -> T: """C-style modulus, take sign from LHS (x0).""" - ... + raise NotImplementedError def remainder(self, x0: T, x1: T) -> T: """Python-style modulus, take sign from RHS (x1).""" - ... + raise NotImplementedError def square(self, x0: T) -> T: - ... + raise NotImplementedError def check_bounds( self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool ) -> None: - ... + raise NotImplementedError # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # In CUDA, optimized implementations of other mathematical operations are @@ -716,25 +695,25 @@ class OpsHandler(Protocol[T]): # for many analyses it's not conveniently available.) def libdevice_abs(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_exp(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_sqrt(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_cos(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_sin(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_sigmoid(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_log(self, x0: T) -> T: - ... + raise NotImplementedError # halide-only def halide_clamp(self, value: T, size: sympy.Expr, check: bool) -> T: @@ -750,7 +729,15 @@ class OpsHandler(Protocol[T]): is_pure: bool = True, pack: int = 1, ) -> T: - ... + raise NotImplementedError + + def output(self, *args: T) -> None: + """This is a fake op used in analysis but not codegen""" + raise NotImplementedError + + def placeholder(self, index: int) -> T: + """This is a fake op used in analysis but not codegen""" + raise NotImplementedError _ignore_op_re = re.compile(r"_.*|paren").fullmatch @@ -763,15 +750,86 @@ def list_ops(cls: type[Any]): OP_NAMES = list_ops(OpsHandler) -def _return_none(*args, **kwargs): - return None +class DefaultHandler(OpsHandler[Any]): + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + """ + Default implementation for all ops. Override in a subclass to + provide generic op behavior. + + Args: + name: name of the op, see OpHandler.{name} + args: positional args passed to the op + kwargs: keyword args passed to the op + + Returns: + return value of the op + + """ + raise NotImplementedError + + def __getattr__(self, name: str) -> Any: + def fallback(*args: Any, **kwargs: Any) -> Any: + return self._default(name, args, kwargs) + + # would like to remove this function entirely, but it's used in MTIA backend + warnings.warn(f"undefined OpHandler.{name}, please add missing op schema") + return fallback + + @staticmethod + def _call_default(target: str): + def call_default(self, *args, **kwargs): + return self._default(target, args, kwargs) + + call_default.__name__ = target + return call_default + + @classmethod + def _init_cls(cls): + """ + Here we codegen many functions of the form: + + def add(self, a, b): + return self._default('add', (a, b), {}) + + and install them in cls. This is the same as _call_default above, + but is about 1.2x faster since CPython varargs parsing is slow. + """ + code = StringIO() + for target in OP_NAMES: + sig = inspect.signature(getattr(OpsHandler, target)) + if all( + p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + and p.default is inspect.Parameter.empty + for p in sig.parameters.values() + ): + self_arg, *args = sig.parameters.keys() + assert self_arg == "self" + code.write( + f""" + def {target}(self, {', '.join(args)}): + return self._default({target!r}, ({', '.join(args)}, ), {{}}) + """.strip() + ) + code.write("\n\n") + else: + # slower fallback for ops with default or variadic arguments + setattr(cls, target, cls._call_default(target)) + + ctx: dict[str, Any] = {} + exec(code.getvalue(), ctx) + for target, impl in ctx.items(): + if target in OP_NAMES: + setattr(cls, target, impl) -class NoopHandler: +DefaultHandler._init_cls() + + +class NoopHandler(DefaultHandler): name = "NoopHandler" - def __getattr__(self, name: str) -> Callable[..., None]: - return _return_none + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + return None @staticmethod def masked(mask, body, other) -> None: @@ -794,12 +852,7 @@ class NoopHandler: return sympy.S.Zero -# Use mypy to check protocol implemented correctly -def _typecheck_NoopHandler(h: NoopHandler) -> OpsHandler[None]: - return h - - -class BasicMathOps: +class BasicMathOpsMixin: @staticmethod def add(a, b): return f"{a} + {b}" @@ -878,16 +931,14 @@ class BasicMathOps: return f"-{a}" -class MockHandler(BasicMathOps): +class MockHandler(BasicMathOpsMixin, DefaultHandler): name = "MockHandler" - def __getattr__(self, name): - def inner(*args, **kwargs): - fargs = [_arg_str(a) for a in args] - fargs.extend(f"{k}={v}" for k, v in kwargs.items()) - return f"ops.{name}({', '.join(fargs)})" - - return inner + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + fargs = [*map(_arg_str, args)] + for k, v in kwargs.items(): + fargs.append(f"{k}={_arg_str(v)}") + return f"ops.{name}({', '.join(fargs)})" @staticmethod def masked(mask, body, other) -> str: @@ -916,15 +967,10 @@ class MockHandler(BasicMathOps): return sympy_index_symbol(str(index_var)) -# Use mypy to check protocol implemented correctly -def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]: - return h - - -class KernelFormatterHandler: - def __init__(self, parent_handler): +class KernelFormatterHandler(DefaultHandler): + def __init__(self, parent_handler: OpsHandler[Any]): self.parent_handler = parent_handler - self.output = IndentedBuffer(1) + self._output = IndentedBuffer(1) self.var_counter = itertools.count() @staticmethod @@ -936,8 +982,8 @@ class KernelFormatterHandler: names = ["index", "rindex"] if rindex is not None else ["index"] formatter = KernelFormatterHandler(MockHandler()) - with formatter.output.indent(-1): - formatter.output.writeline(f"def inner_fn({', '.join(names)}):") + with formatter._output.indent(-1): + formatter._output.writeline(f"def inner_fn({', '.join(names)}):") for name, arg in zip(names, args): if arg: lhs = ", ".join( @@ -946,7 +992,7 @@ class KernelFormatterHandler: for v in arg ] ) - formatter.output.writeline(f"{lhs} = {name}") + formatter._output.writeline(f"{lhs} = {name}") with V.set_ops_handler(formatter), patch.object( FlexibleLayout, "allow_indexing", True @@ -954,21 +1000,19 @@ class KernelFormatterHandler: result = ir_fn(*args) return formatter.getvalue(result) - def __getattr__(self, name) -> Callable[..., Any]: - def inner(*args, **kwargs): - line = getattr(self.parent_handler, name)(*args, **kwargs) - if name == "indirect_indexing": - return line + def indirect_indexing(self, *args, **kwargs) -> sympy.Symbol: + return self.parent_handler.indirect_indexing(*args, **kwargs) - def write(line): - # replace line with a new variable name - varname = f"tmp{next(self.var_counter)}" - self.output.writeline(f"{varname} = {line}") - return varname + def _write(self, line): + # replace line with a new variable name + varname = f"tmp{next(self.var_counter)}" + self._output.writeline(f"{varname} = {line}") + return varname - return pytree.tree_map(write, line) - - return inner + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + return pytree.tree_map( + self._write, getattr(self.parent_handler, name)(*args, **kwargs) + ) def reduction( self, @@ -980,44 +1024,28 @@ class KernelFormatterHandler: line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value) num_values = reduction_num_outputs(reduction_type) varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)] - self.output.writeline(f"{','.join(varnames)} = {line}") + self._output.writeline(f"{','.join(varnames)} = {line}") return tuple(varnames) if num_values > 1 else varnames[0] def getvalue(self, result): - self.output.writeline(f"return {result}") - return self.output.getvalue() + self._output.writeline(f"return {result}") + return self._output.getvalue() -# Use mypy to check protocol implemented correctly -def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]: - return h - - -class WrapperHandler(Generic[T]): - def __init__(self, inner: Any): +class WrapperHandler(DefaultHandler): + def __init__(self, inner: OpsHandler[Any]): self._inner = inner - def __getattr__(self, item): - return getattr(self._inner, item) + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + return getattr(self._inner, name)(*args, **kwargs) -# Use mypy to check protocol implemented correctly -def _typecheck_WrapperHandler(h: WrapperHandler[T]) -> OpsHandler[T]: - return h - - -class AddParenHandler(WrapperHandler[T]): - def __getattr__(self, name): - def inner(*args, **kwargs): - val = getattr(self._inner, name)(*args, **kwargs) - return f"({val})" - - return inner - - -# Use mypy to check protocol implemented correctly -def _typecheck_AddParenHandler(h: AddParenHandler[T]) -> OpsHandler[T]: - return h +class AddParenHandler(WrapperHandler): + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + val = getattr(self._inner, name)(*args, **kwargs) + if not val or isinstance(val, (sympy.Expr, tuple, list)): + return val + return f"({val})" class OpCountResult(NamedTuple): @@ -1027,26 +1055,23 @@ class OpCountResult(NamedTuple): nontrivial_read_count: int -class OpCounterCSE: +class OpCounterCSE(DefaultHandler): """Shim to count how many ops are used""" - def __init__(self, inner): + def __init__(self, inner: OpsHandler[Any]): super().__init__() self.parent_handler = inner self.op_count = 0 - self.var_names = {} - self._used_ops = OrderedSet[str]() + self.var_names: dict[str, str] = {} + self._used_ops: OrderedSet[str] = OrderedSet() self._read_names: list[str] = [] self._nontrivial_read_count = 0 - def __getattr__(self, name): - def inner(*args, **kwargs): - return pytree.tree_map( - self._update_count, getattr(self.parent_handler, name)(*args, **kwargs) - ) - + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: self._used_ops.add(name) - return inner + return pytree.tree_map( + self._update_count, getattr(self.parent_handler, name)(*args, **kwargs) + ) def _update_count(self, val): varname = self.var_names.get(val) @@ -1111,58 +1136,45 @@ class OpCounterCSE: ) -def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]: - return h - - class ExtractConstantsHandler(NoopHandler): - def __init__(self, device): + def __init__(self, device: Optional[torch.device]): self.device = device def constant(self, value: Any, dtype: torch.dtype) -> torch._inductor.ir.Constant: from torch._inductor import ir - return ir.Constant(value=value, dtype=dtype, device=self.device) + return ir.Constant( + value=value, dtype=dtype, device=self.device or torch.get_default_device() + ) -def _typecheck_ExtractConstantsHandler(h: ExtractConstantsHandler) -> OpsHandler[Any]: - return h - - -class SimpleCSEHandler(WrapperHandler[T]): +class SimpleCSEHandler(WrapperHandler): """Wraps the underlying handler with a CSE pass NOTE: Compared to codegen level CSE this is simplified as it doesn't support stores which require load cache invalidation. """ - def __init__(self, inner: OpsHandler[T]): + def __init__(self, inner: Any): super().__init__(inner) - self.cse_cache: dict[str, Union[T, tuple[T, ...]]] = {} + self.cse_cache: dict[str, Union[Any, tuple[Any, ...]]] = {} self.mock = MockHandler() def indirect_indexing(self, *args, **kwargs) -> sympy.Expr: return super().indirect_indexing(*args, **kwargs) # type: ignore[misc] - def store(self, *args, **kwargs) -> T: + def store(self, *args, **kwargs) -> None: raise NotImplementedError("store not implemented") - def store_reduction(self, *args, **kwargs) -> T: + def store_reduction(self, *args, **kwargs) -> None: raise NotImplementedError("store not implemented") - def __getattr__(self, name) -> Callable[..., Any]: - def inner(*args, **kwargs): - key = getattr(self.mock, name)(*args, **kwargs) - val = self.cse_cache.get(key) - if val is not None: - return val - - val = getattr(self._inner, name)(*args, **kwargs) - self.cse_cache[key] = val + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + key = getattr(self.mock, name)(*args, **kwargs) + val = self.cse_cache.get(key) + if val is not None: return val - return inner - - -def _typecheck_SimpleCSEHandler(h: SimpleCSEHandler[Any]) -> OpsHandler[Any]: - return h + val = getattr(self._inner, name)(*args, **kwargs) + self.cse_cache[key] = val + return val diff --git a/torch/_inductor/output_code.py b/torch/_inductor/output_code.py index 393e282d03c..66f60b12f16 100644 --- a/torch/_inductor/output_code.py +++ b/torch/_inductor/output_code.py @@ -564,10 +564,6 @@ class CompiledFxGraph(OutputCode): return artifact_path -def _typecheck_CompiledFxGraph(h: CompiledFxGraph) -> OutputCode: - return h - - @dataclasses.dataclass class CompiledAOTI(OutputCode): """ @@ -591,10 +587,6 @@ class CompiledAOTI(OutputCode): pass -def _typecheck_CompiledAOTI(h: CompiledAOTI) -> OutputCode: - return h - - @dataclasses.dataclass class MockFXGraphCacheOutput(OutputCode): gm: Any = None diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index bd96e830bcc..d016af99954 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -742,7 +742,7 @@ class TritonTemplateKernel(TritonKernel): template_mask = self.template_mask class StoreOutputSubstitution(V.WrapperHandler): # type: ignore[name-defined] - self.name = name + name = "StoreOutputSubstitution" def store( self, diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 532073a377f..fcc549bf652 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -99,6 +99,8 @@ class SizeVarAllocator: if result is None: result = self._simplify_with_ranges(expr, var_ranges) cache[key] = result + if result != expr: + cache[(result, *var_ranges.items())] = result return result return simplify_with_ranges diff --git a/torch/_inductor/subgraph_lowering.py b/torch/_inductor/subgraph_lowering.py index ce35959c532..d7992385735 100644 --- a/torch/_inductor/subgraph_lowering.py +++ b/torch/_inductor/subgraph_lowering.py @@ -135,7 +135,7 @@ class InputDescriptor: device: torch.device -class TracingOpsHandler(WrapperHandler[T]): +class TracingOpsHandler(WrapperHandler): def __init__(self, tracer: torch.fx.Tracer, num_inputs: int) -> None: parent = tracer.create_proxy("placeholder", "ops", (), {}) super().__init__(parent) @@ -149,8 +149,8 @@ class TracingOpsHandler(WrapperHandler[T]): def placeholder(self, idx: int) -> torch.fx.Proxy: return self.placeholders[idx] - def output(self, *args: tuple[object]) -> torch.fx.Node: - return self.tracer.create_node( + def output(self, *args: tuple[object]) -> None: + self.tracer.create_node( "output", "output", (tuple(self.tracer.create_arg(a) for a in args),), {} ) diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 4de41846166..1ee1ef5a744 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -59,11 +59,12 @@ from __future__ import annotations from contextlib import AbstractContextManager, contextmanager from threading import local -from typing import Any, Callable, Generic, TYPE_CHECKING, TypeVar, Union +from typing import Any, Callable, cast, Generic, TYPE_CHECKING, TypeVar, Union from torch.utils._ordered_set import OrderedSet from .ops_handler import ( # noqa: F401 + DefaultHandler, KernelFormatterHandler, MockHandler, OpsHandler, @@ -154,7 +155,9 @@ class NullKernelHandler(NullHandler): self.index_dtype = "tl.int64" -_ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler) +_ops: Virtualized[OpsHandler[Any]] = Virtualized( + "ops", cast(type[OpsHandler[Any]], MockHandler) +) _graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler) _real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler) _fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler) @@ -272,18 +275,15 @@ class OpsValue: return ops.bitwise_left_shift(self, n) -class OpsWrapper: +class OpsWrapper(DefaultHandler): """This wraps any returned IR values into an `OpsValue` instance, so that we can overload the magic methods for writing mathematical expressions fluently. """ - def __getattr__(self, name): - def inner(*args, **kwargs): - new_args = [OpsWrapper._unwrap(a) for a in args] - new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()} - return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs)) - - return inner + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + new_args = [OpsWrapper._unwrap(a) for a in args] + new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()} + return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs)) @staticmethod def _unwrap(x): @@ -306,7 +306,7 @@ class OpsWrapper: return _ops.indirect_indexing(index, size, check, wrap_neg) -ops = OpsWrapper() +ops: OpsHandler[Any] = OpsWrapper() class _V: @@ -314,8 +314,10 @@ class _V: KernelFormatterHandler = KernelFormatterHandler WrapperHandler = WrapperHandler - set_ops_handler: Callable[[Any], Any] = _ops._set_handler - get_ops_handler: Callable[[], Any] = _ops._get_handler + set_ops_handler: Callable[ + [OpsHandler[Any]], AbstractContextManager[None] + ] = _ops._set_handler + get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler get_real_inputs: Callable[[], Any] = _real_inputs._get_handler diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 057ed0fe63e..2dd16890880 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -11100,8 +11100,8 @@ are designed to work with this function. See the examples below. Args: {input} - indices (tensor): the indices into :attr:`input`. Must have long dtype. - dim (int, optional): dimension to select along. + indices (LongTensor): the indices into :attr:`input`. Must have long dtype. + dim (int, optional): dimension to select along. Default: 0 Keyword args: {out} diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 1bcb2575145..219759ea37b 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -331,9 +331,6 @@ class FunctionMeta(type): name + "Backward", (BackwardCFunction,), {"_forward_cls": cls} ) backward_fn._autograd_function_id = next(AUTOGRAD_FUNCTION_COUNTER) # type: ignore[attr-defined] - backward_fn._compiled_autograd_should_lift = attrs.get( # type: ignore[attr-defined] - "_compiled_autograd_should_lift", True - ) backward_fn._bw_module = None # type: ignore[attr-defined] if getattr(cls, "_lazy_backward_info", None): backward_fn._bw_module = cls._lazy_backward_info.bw_module # type: ignore[attr-defined] diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index dd0b7a927bf..67a307cd043 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -301,14 +301,6 @@ bool PyNode::is_aot_backward() const { return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id"); } -auto PyNode::compiled_autograd_should_lift() const -> bool { - pybind11::gil_scoped_acquire gil; - static PyObject* attr_name = - PyUnicode_InternFromString("_compiled_autograd_should_lift"); - THPObjectPtr should_lift(PyObject_GetAttr(obj, attr_name)); - return PyObject_IsTrue(should_lift.get()) == 1; -} - void PyNode::compiled_args(CompiledNodeArgs& args) { static PyObject* method_name = PyUnicode_InternFromString("_compiled_autograd_key"); diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index 2f28c765ab0..f6f0979dc25 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -50,8 +50,6 @@ struct PyNode : public Node { const variable_list& inputs, SwapSavedVariables& saved) override; - bool compiled_autograd_should_lift() const; - // THPFunction this Function is wrapping. Owning! PyObject* obj; diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index 9d188c9c26d..ff83d687f8a 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -417,6 +417,21 @@ class TORCH_API Backend : public torch::CustomClassHolder { "Backend ", getBackendName(), " does not support getMemAllocator")); } + // Allocate tensor (aten::empty) from backend's communication-optimized memory + // pool + virtual at::Tensor allocateTensor(long size, at::TensorOptions options = {}) { + TORCH_CHECK( + false, + c10::str( + "Backend ", getBackendName(), " does not support allocateTensor")); + } + + // Returns true if backend supports tensor allocation + virtual bool supportsTensorAlloc() { + // Change to true in concrete backend if supported + return false; + } + protected: // Implementations of this interface need to call this to setup // appropriate logging etc. diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 9b5c5962479..99fc244af02 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -340,19 +340,26 @@ ncclResult_t NCCLComm::checkForNcclError() { #endif } -ncclResult_t NCCLComm::registerSegment(void* ptr, size_t size) { +ncclResult_t NCCLComm::registerSegment( + void* ptr, + size_t size, + bool errorOnRereg /*=true*/) { LockType lock(mutex_); #ifdef NCCL_HAS_COMM_REGISTER // We register only segments from cache allocator // which are guaranteed to be with disjoint addr ranges. Thus, a ptr always // maps to a unique handle and should not be registered before the current // ptr is deregistered and freed. - TORCH_CHECK( - registeredSegmentHandles_.count(ptr) == 0, - "Segment with ptr ", - ptr, - " has already been registered on ncclComm_ ", - ncclComm_); + if (registeredSegmentHandles_.count(ptr) > 0) { + TORCH_CHECK( + !errorOnRereg, + "Segment with ptr ", + ptr, + " has already been registered on ncclComm_ ", + ncclComm_); + // Skip below + return ncclSuccess; + } void* handle = nullptr; // Use getNcclComm to make sure comm is ready before calling nccl APIs diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 1ec81494856..c7cd0a30924 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -284,7 +284,10 @@ class NCCLComm { ncclResult_t checkForNcclError(); - ncclResult_t registerSegment(void* ptr, size_t size); + ncclResult_t registerSegment( + void* ptr, + size_t size, + bool errorOnRereg = true); ncclResult_t deregisterSegment(void* ptr); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index d69fb2f5c36..cd9363ec337 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1175,7 +1175,8 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) { ncclComm->registerSegment( // NOLINTNEXTLINE(performance-no-int-to-ptr) reinterpret_cast(segmentInfo.address), - segmentInfo.total_size); + segmentInfo.total_size, + /*errorOnRereg=*/false); // ignores reregistration error } } @@ -1455,6 +1456,14 @@ void ProcessGroupNCCL::shutdown() { // Use long interval to avoid acquiring CPU too frequently ncclComm->waitReady(true); } + // Deregister memory pool after finalizing all collectives + if (memPool_) { + try { + deregisterMemPool(memPool_.get()); + } catch (...) { + LOG(ERROR) << logPrefix() << "Failed to deregister memory pool, ignoring"; + } + } // Tell watchdog to (1) flush its queue and (2) do not use comm objects // anymore because I am going to destroy them now LOG(INFO) << logPrefix() << "Operations flushed, joining watchdog thread."; @@ -5422,6 +5431,46 @@ std::shared_ptr ProcessGroupNCCL::getMemAllocator() { return ncclMemAllocator; } +at::Tensor ProcessGroupNCCL::allocateTensor( + long size, + at::TensorOptions options) { + // Some checks + TORCH_CHECK_VALUE(options.has_device(), "Tensor options must include device"); + auto device = options.device(); + TORCH_CHECK_VALUE( + device.is_cuda(), + "NCCL tensor allocator expects cuda type but got " + c10::str(device)) + + at::cuda::OptionalCUDAGuard gpuGuard(device); + + // Create memory pool + if (!memPool_) { + // Needs a CUDAAllocator + auto allocator = + reinterpret_cast( + getMemAllocator().get()); + // Pool is created + memPool_ = std::make_unique(allocator); + LOG(INFO) << logPrefix() << "Created memory pool"; + } + + // Allocate tensor under this MemPool's context + auto ctx = c10::cuda::MemPoolContext(memPool_.get()); + c10::cuda::CUDACachingAllocator::beginAllocateToPool( + memPool_->device(), memPool_->id(), [](cudaStream_t) { return true; }); + at::Tensor tensor = at::empty({size}, options); + // Also need to ncclCommRegister the pool in case new segments are created; + // reregistration of old segments will be ignored + registerMemPool(memPool_.get()); + c10::cuda::CUDACachingAllocator::endAllocateToPool( + memPool_->device(), memPool_->id()); + c10::cuda::CUDACachingAllocator::releasePool( + memPool_->device(), memPool_->id()); + LOG(INFO) << logPrefix() << "Allocated tensor of size " << size + << " from memory pool"; + return tensor; +} + } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 002b3a1a143..185d9bebe6e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -774,6 +774,13 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::shared_ptr getMemAllocator() override; + // Allocate tensor from communication-optimized memory pool + at::Tensor allocateTensor(long size, at::TensorOptions options = {}) override; + + bool supportsTensorAlloc() override { + return true; + } + // Performs NCCL user buffer registration for all buffers in // the given MemPool void registerMemPool(c10::cuda::MemPool* pool); @@ -1294,6 +1301,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Internal cached value: use NCCL non-blocking API mode or not. // Use `useNonblocking()` method instead of accessing this variable directly. std::optional useNonblocking_{std::nullopt}; + + // Communication-optimized memory pool associated with this PG + std::unique_ptr memPool_ = nullptr; }; // Dumps the NCCL comm traces and additional information about the Process diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 03c1380bfe7..800269fe14e 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -1157,14 +1157,44 @@ void Reducer::initialize_buckets( offset += length; } - // Allocate the bucket's flattened `gradients` tensor. // Make gradient type in the reduced precision if mixed precision is // enabled. This ensures that the type is correct when e.g. rebuilding // buckets. if (mixed_precision_param_dtype_.has_value()) { options = options.dtype(mixed_precision_param_dtype_); } - bucket.gradients = at::empty({static_cast(offset)}, options); + + // Allocate the bucket's flattened `gradients` tensor. + auto bucketSize = static_cast(offset); + // Check if we can use comm-optimized memory pool to allocate tensor + c10::intrusive_ptr backend = nullptr; + // An environment variable to disable comm-optimized memory pool. + // Default is 0, which means comm-optimized memory pool is enabled. + // Users can set it to 1 in case of seeing regression or OOM (because this + // comm MemPool may not share space with regular compute MemPool). + bool ddpDisableCommMem = + (getCvarString({"DDP_DISABLE_COMM_MEM"}, "0") == "1"); + try { + backend = process_group_->getDefaultBackend(); + } catch (...) { + // Sometimes the backend type can be `UNDEFINED` rather than `NCCL` or + // `GLOO`. In this case, we just fall back to the regular way of + // creating tensor + LOG(INFO) + << "Reducer: default comm backend not found, skipping bucket memory optimization"; + } + if (ddpDisableCommMem == 0 && backend != nullptr && + backend->supportsTensorAlloc()) { + // Comm-optimized memory pool is available, use it to allocate tensor + LOG(INFO) + << "Reducer: found comm-optimized memory allocator, using it to create bucket"; + bucket.gradients = backend->allocateTensor(bucketSize, options); + } else { + // Plain creation of tensor + LOG(INFO) + << "Reducer: comm-optimized memory allocator not found, using regular one"; + bucket.gradients = at::empty({bucketSize}, options); + } // Note: "Gradient Layout Contract" // diff --git a/torch/distributed/tensor/_ops/_conv_ops.py b/torch/distributed/tensor/_ops/_conv_ops.py index 9842ed1fde3..2198986d50c 100644 --- a/torch/distributed/tensor/_ops/_conv_ops.py +++ b/torch/distributed/tensor/_ops/_conv_ops.py @@ -105,4 +105,8 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding: [0], tensor_meta=bias_tensor_meta, ) + # TODO: actually the output_mask is not respected here, we should + # set the corresponding spec to `None` if the output_mask is not `False` + # for a certain output Tensor. This also applies to the conv handler + # in torch/distributed/tensor/_tp_conv.py return OutputSharding([grad_input_spec, grad_weight_spec, grad_bias_spec]) diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index cc5a80e2e82..e81957506d6 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -57,12 +57,14 @@ class ShardingPropagator: OpOverload, Callable[[DeviceMesh, OpSchema], StrategyType], ] = {} - # op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop + # op map to save static argnum to decide to reuse sharding prop cache or + # re-run sharding prop self.op_to_schema_info: dict[OpOverload, RuntimeSchemaInfo] = {} self.propagate_op_sharding = LocalLRUCache( self.propagate_op_sharding_non_cached ) - # op map to save indices of shape (and stride) args which may need to be modified in sharding prop + # op map to save indices of shape (and stride) args which may need to be + # modified in sharding prop self.op_to_shape_and_stride_idx: dict[ OpOverload, Union[int, tuple[int, int]] ] = { @@ -171,10 +173,12 @@ class ShardingPropagator: # Either error due to ShardingPropagator or due to incorrect OutputSpec if not isinstance(output_tensor_meta, (tuple, list)): raise ValueError( - "ShardingPropagator error: output does not have an associated TensorMeta" + "ShardingPropagator error: output does not have an associated " + "TensorMeta" ) raise ValueError( - f"For the op {op.name()}, `output_specs` has 1 output which does not equal the " + f"For the op {op.name()}, `output_specs` has 1 output which does " + "not equal the " f"number of op outputs: {len(output_tensor_meta)}." ) output_specs.tensor_meta = output_tensor_meta @@ -183,16 +187,35 @@ class ShardingPropagator: output_specs ) != len(output_tensor_meta): raise ValueError( - f"For the op {op.name()}, `output_specs` has {len(output_specs)} outputs which does not equal the " + f"For the op {op.name()}, `output_specs` has {len(output_specs)} " + "outputs which does not equal the " f"number of op outputs {_length(output_tensor_meta)}." ) + for i, spec in enumerate(output_specs): if isinstance(spec, DTensorSpec): output_tensor_meta_i = output_tensor_meta[i] if not isinstance(output_tensor_meta_i, TensorMeta): - raise ValueError( - f"ShardingPropagator error: output {i} does not have an associated TensorMeta" - ) + # NOTE: aten.convolution_backward.default is an exception and it + # needs extra handling because the first Tensor in the output + # tuple can be `None` if the input Tensor to convolution op has + # `requires_grad=False` (e.g. convolution layer is the first + # layer in the model). We explicitly allow its corresponding + # TensorMeta to be `None`. + if ( + op == aten.convolution_backward.default + and i == 0 + and output_tensor_meta_i is None + ): + assert isinstance(output_specs, list) + output_specs[i] = None + continue + else: + raise ValueError( + f"ShardingPropagator error: output {i} of {op.name()} " + "does not have an associated TensorMeta" + ) + spec.tensor_meta = output_tensor_meta_i def propagate(self, op_info: OpInfo) -> None: diff --git a/torch/distributed/tensor/_tp_conv.py b/torch/distributed/tensor/_tp_conv.py index e9ae126e3c5..f3e908f3e7a 100644 --- a/torch/distributed/tensor/_tp_conv.py +++ b/torch/distributed/tensor/_tp_conv.py @@ -215,12 +215,13 @@ def tp_convolution_backward( # step4 aggregate gradients for edge pixels grad_in_tensor = local_results[0] - grad_in_tensor = _ring_send_recv_aggregate( - grad_in_tensor, d1, d2, left, right, rank, size - ) + if grad_in_tensor is not None: + grad_in_tensor = _ring_send_recv_aggregate( + grad_in_tensor, d1, d2, left, right, rank, size + ) + local_results = list(local_results) + local_results[0] = grad_in_tensor - local_results = list(local_results) - local_results[0] = grad_in_tensor local_results = cast(tuple[object, ...], local_results) return local_results diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index 9435f136183..bd326614fa1 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -11,6 +11,8 @@ from torch.testing._internal.common_device_type import onlyCUDA from torch.testing._internal.common_dtype import all_types_and, custom_types from torch.testing._internal.opinfo.core import DecorateInfo, OpInfo, SampleInput from torch._higher_order_ops.invoke_subgraph import mark_compile_region +from torch._higher_order_ops import InvokeQuant, invoke_quant_packed + def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs): make_arg = functools.partial( @@ -218,6 +220,24 @@ def simple_scan(init, xs): return torch._higher_order_ops.scan(combine_fn, init, xs) +quant_tracer = InvokeQuant() + + +def simple_invoke_quant(x): + def fn(x, y): + return (torch.sin(x) * y,) + + return quant_tracer(fn, (x, x))[0] * 2. + + +def simple_invoke_quant_packed(x): + def fn(x): + return (torch.sin(x),) + + return invoke_quant_packed(fn, (x,))[0] * 2. + + + hop_db = [ OpInfo( name="scan", @@ -300,6 +320,45 @@ hop_db = [ # "torch.compile with aot_autograd does not currently support double backward." supports_gradgrad=False, ), + OpInfo( + name="invoke_quant", + variant_test_name="simple", + op=simple_invoke_quant, + sample_inputs_func=sample_inputs_invoke_subgraph, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + supports_autograd=True, + # "torch.compile with aot_autograd does not currently support double backward." + skips=( + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"), + DecorateInfo( + unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export" + ), + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"), + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"), + ), + # "torch.compile with aot_autograd does not currently support double backward." + supports_gradgrad=False, + ), + OpInfo( + name="invoke_quant_packed", + variant_test_name="simple", + op=simple_invoke_quant_packed, + sample_inputs_func=sample_inputs_invoke_subgraph, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + supports_autograd=True, + # "torch.compile with aot_autograd does not currently support double backward." + supports_gradgrad=False, + ), OpInfo( name="while_loop", variant_test_name="simple", diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 3110c3947af..13de003b330 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -10,6 +10,9 @@ import os from subprocess import CalledProcessError import sys import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch.fx.experimental.proxy_tensor import make_fx +from torch._inductor.graph import GraphLowering +from torch._inductor.compile_fx import shape_env_from_inputs from torch._inductor.codecache import CppCodeCache from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu from torch._inductor.utils import GPU_TYPES, get_gpu_type @@ -142,6 +145,21 @@ IS_H100 = LazyVal( IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu()) +def dummy_graph() -> GraphLowering: + """ + Create a graph. This is useful for unit testing code which accesses + V.graph.sizevars. + """ + example_inputs = [torch.randn(10) for _ in range(2)] + gm = make_fx(torch.add, tracing_mode="fake")(*example_inputs) + shape_env = shape_env_from_inputs(example_inputs) + graph = GraphLowering( + gm, + shape_env=shape_env, + ) + + return graph + def maybe_skip_size_asserts(op): """ For certain ops, there meta and eager implementation returns differents diff --git a/torch/testing/_internal/jit_utils.py b/torch/testing/_internal/jit_utils.py index 06a1c2bd5d4..299eb999676 100644 --- a/torch/testing/_internal/jit_utils.py +++ b/torch/testing/_internal/jit_utils.py @@ -505,7 +505,7 @@ class JitTestCase(JitCommonTestCase): script_outputs = scripted_fn(*recording_inputs) with self.capture_stdout(): opt_script_outputs = scripted_fn(*recording_inputs) - with self.capture_stdout() as _python_stdout: + with self.capture_stdout(): python_outputs = python_fn(*inputs) if not IS_WINDOWS: self.assertExpected(script_stdout[0], subname='stdout') diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index ae0a1eee398..15db18f3307 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1286,6 +1286,10 @@ class Identity(sympy.Function): def _eval_is_integer(self): return self.args[0].is_integer # type: ignore[attr-defined] + def _eval_expand_identity(self, **hints): + # Removes the identity op. + return self.args[0] + def make_opaque_unary_fn(name): class OpaqueUnaryFn(sympy.Function): diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index eb85b6798ea..784f9e7ba05 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -49,7 +49,7 @@ from .numbers import int_oo, IntInfinity, NegativeIntInfinity log = logging.getLogger(__name__) -__all__ = ["ValueRanges", "ValueRangeAnalysis", "bound_sympy"] +__all__ = ["ValueRanges", "bound_sympy"] _T = TypeVar("_T", sympy.Expr, SympyBoolean) @@ -1004,108 +1004,6 @@ class SymPyValueRangeAnalysis: return ValueRanges.increasing_map(x, TruncToFloat) -class ValueRangeAnalysis(SymPyValueRangeAnalysis): - def __init__(self) -> None: - self.name = "ValueRangeAnalysis" - boolean_operators = ( - "xor", - "logical_and", - "logical_or", - "logical_not", - ) - for op in boolean_operators: - setattr(self, op, self.bool_handler) - - @staticmethod - def bool_handler(*args, **kwargs): - # just assuming bools can have both values - return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type] - - @staticmethod - def default_handler(*args, **kwargs): - # many ops are unlikely to show up in optimizable indexing compute, - # so we dont have full coverage - return ValueRanges.unknown() - - def load(self, name: str, index: sympy.Expr): - return ValueRanges.unknown() - - def store(self, name, index, value, mode=None): - return - - def reduction(self, name, dtype, src_dtype, reduction_type, index, value): - return ValueRanges.unknown() - - @classmethod - def index_expr(cls, index, dtype): - assert isinstance(index, ValueRanges) - return cls.to_dtype(index, dtype) - - @staticmethod - def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None): - x = ValueRanges.wrap(x) - - if dtype == torch.bool: - if x.is_singleton(): - return ValueRanges.wrap(x.lower != 0) - elif x.is_bool: - return x - elif 0 not in x: - return ValueRanges.wrap(sympy.true) - else: - return ValueRanges(sympy.false, sympy.true) - - def cast(x, dtype): - # dtype is int or float - if dtype.is_floating_point: - return sympy.Float(x) - else: - if x in (int_oo, -int_oo): - return x - try: - return sympy.Integer(x) - except TypeError: - # inf cannot be cast to Integer - return x - - if x.is_bool: - if x.is_singleton(): - val = 1 if x.lower else 0 - return ValueRanges.wrap(cast(val, dtype)) - else: - return ValueRanges(cast(0, dtype), cast(1, dtype)) - else: - # int to float or float to int - return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype)) - - @staticmethod - def square(x): - return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2)) - - @staticmethod - def neg(x): - return ValueRanges.decreasing_map(x, operator.neg) - - # TODO: this is slightly inaccurate because truncdiv operates at integer - # precision, but we're going through float truediv which means we can - # potentially lose precision on the bounds - @classmethod - def truncdiv(cls, a, b): - x = cls.truediv(a, b) - if x == ValueRanges.unknown(): - return x - - return cls.trunc(x) - - @classmethod - def sub(cls, a, b): - return cls.add(a, cls.neg(b)) - - def __getattr__(self, name): - log.debug("unhandled ValueRange op %s", name) - return self.default_handler - - def bound_sympy( expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None ) -> ValueRanges: