[ghstack-poisoned]
This commit is contained in:
Xuehai Pan 2025-02-10 22:00:04 +08:00
commit 125261b2d3
116 changed files with 4203 additions and 975 deletions

View file

@ -758,11 +758,10 @@ const auto sinc_string = jiterator_stringify(
T sinc(T a) {
if (a == T(0)) {
return T(1);
} else {
constexpr T pi = T(3.14159265358979323846L);
T product = pi * a;
return std::sin(product) / product;
}
constexpr T pi = T(3.14159265358979323846L);
T product = pi * a;
return std::sin(product) / product;
}
); // sinc_string

View file

View file

@ -6,15 +6,15 @@ add_loop_eager_dynamic,compile_time_instruction_count,5703000000,0.025
add_loop_inductor,compile_time_instruction_count,30150000000,0.015
add_loop_inductor,compile_time_instruction_count,29630000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44440000000,0.025
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43980000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,26740000000,0.015
add_loop_inductor_gpu,compile_time_instruction_count,26240000000,0.015
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18980000000,
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17250000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17150000000,0.015
@ -62,4 +62,4 @@ aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3863000000,
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10340000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10390000000,0.015

1 add_loop_eager compile_time_instruction_count 3096000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 945100000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18980000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17250000000 17150000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 10885050825 0.2
10 update_hint_regression compile_time_instruction_count 1686000000 0.02
11 sum_floordiv_regression compile_time_instruction_count 1041000000 0.015
12 symint_sum compile_time_instruction_count 3324000000 0.015
13 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2028000000 0.015
14 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5836000000 0.015
15 aotdispatcher_partitioner_cpu compile_time_instruction_count 9167000000 0.015
16 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3863000000 0.015
17 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10340000000 10390000000 0.015
18
19
20
26
27
28
29
30
31
32
62
63
64
65

View file

@ -997,6 +997,7 @@ def define_buck_targets(
"Config.h": ":generate_aten_config[Config.h]",
},
labels = labels,
visibility = ["PUBLIC"],
)
fb_xplat_cxx_library(

View file

@ -22,3 +22,15 @@ def define_targets(rules):
[],
),
)
rules.cc_library(
name = "c10_headers",
deps = [
"//c10/core:base_headers",
"//c10/macros",
"//c10/util:base_headers",
"//c10/util:bit_cast",
"//c10/util:ssize",
],
visibility = ["//visibility:public"],
)

View file

@ -90,6 +90,22 @@ def define_targets(rules):
alwayslink = True,
)
rules.cc_library(
name = "base_headers",
srcs = [],
hdrs = rules.glob(
[
"*.h",
"impl/*.h",
],
exclude = [
"CPUAllocator.h",
"impl/alloc_cpu.h",
],
),
visibility = ["//visibility:public"],
)
rules.filegroup(
name = "headers",
srcs = rules.glob(
@ -101,5 +117,5 @@ def define_targets(rules):
"alignment.h",
],
),
visibility = ["//c10:__pkg__"],
visibility = ["//visibility:public"],
)

View file

@ -477,5 +477,31 @@ inline float2 sinc(float2 inp) {
return float2(re, im) / a2;
}
template <typename T>
inline T spherical_bessel_j0(T x) {
if (::metal::isinf(x))
return T(0.0);
T x2 = x * x;
T k1 = static_cast<T>(-1.0);
T k2 = static_cast<T>(1.0);
if (::metal::abs(x) < T(0.5)) {
return T(1.0) +
x2 *
(k1 / T(6.0) +
x2 *
(k2 / T(120.0) +
x2 *
(k1 / T(5040.0) +
x2 *
(k2 / T(362880.0) +
x2 *
(k1 / T(39916800.0) +
x2 * (k2 / T(6227020800.0)))))));
}
return ::metal::sin(x) / x;
}
} // namespace metal
} // namespace c10

View file

@ -80,6 +80,18 @@ def define_targets(rules):
],
)
rules.cc_library(
name = "base_headers",
hdrs = rules.glob(
["*.h"],
exclude = [
"bit_cast.h",
"ssize.h",
],
),
visibility = ["//visibility:public"],
)
rules.filegroup(
name = "headers",
srcs = rules.glob(

View file

@ -5,13 +5,15 @@ import copy
import torch
import torch.nn as nn
from torch.distributed._tensor import (
DeviceMesh,
from torch.distributed import DeviceMesh, init_device_mesh
from torch.distributed.tensor import (
distribute_module,
distribute_tensor,
DTensor,
Replicate,
Shard,
)
from torch.nn import functional as F
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
@ -181,6 +183,28 @@ class DistConvolutionOpsTest(DTensorTestBase):
f"Too large relative mse for bias tensor, expected less equal 1e-6, got {bias_mse_rel}",
)
@with_comms
@skip_if_lt_x_gpu(2)
def test_conv_backward_none_grad_inp(self):
device_mesh = init_device_mesh(
device_type="cuda", mesh_shape=(self.world_size,)
)
conv = nn.Conv2d(64, 64, 3, padding=1).train()
x = torch.randn(1, 64, 32, 32)
x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
w = conv.weight
w_dt = torch.nn.Parameter(DTensor.from_local(w, device_mesh, [Replicate()]))
b = conv.bias
b_dt = torch.nn.Parameter(DTensor.from_local(b, device_mesh, [Replicate()]))
res = F.conv2d(x_dt, w_dt, b_dt, padding=1)
dres = torch.rand_like(res)
res.backward(dres)
self.assertTrue(w_dt.grad is not None)
self.assertTrue(b_dt.grad is not None)
self.assertTrue(x_dt.grad is None)
if __name__ == "__main__":
run_tests()

View file

@ -2007,7 +2007,7 @@ class DistributedDataParallelTest(
replica_devices = [dev0]
# Tells _test_grad_layout to construct ConvNet with all layers on this process's first assigned device.
layer_devs = dev0
local_batch_size = 8
local_batch_size = 16
self._test_grad_layout(replica_devices, layer_devs, local_batch_size)
@requires_nccl()
@ -2021,7 +2021,7 @@ class DistributedDataParallelTest(
replica_devices = None
# Tells _test_grad_layout to constructs this process's ConvNet on 2 devices, with 2 layers on each device.
layer_devs = [dev0] * 2 + [dev1] * 2
local_batch_size = 8
local_batch_size = 16
self._test_grad_layout(replica_devices, layer_devs, local_batch_size)
@requires_nccl()

View file

@ -1744,10 +1744,11 @@ class GraphModule(torch.nn.Module):
class ContextlibContextManagerTests(torch._dynamo.test_case.TestCase):
def setUp(self):
self._prev = torch._dynamo.config.enable_trace_contextlib
torch._dynamo.config.enable_trace_contextlib = True
def tearDown(self):
torch._dynamo.config.enable_trace_contextlib = False
torch._dynamo.config.enable_trace_contextlib = self._prev
def test_ctx_basic0(self):
@contextlib.contextmanager
@ -2236,7 +2237,7 @@ class GraphModule(torch.nn.Module):
eager = EagerAndRecordGraphs()
out = torch.compile(backend=eager, fullgraph=False)(fn)(x)
self.assertEqual(expected, out)
self.assertEqual(len(eager.graphs), 1)
self.assertEqual(len(eager.graphs), 0)
def test_graph_break_before_and_after___enter__(self):
@contextlib.contextmanager
@ -2262,7 +2263,7 @@ class GraphModule(torch.nn.Module):
eager = EagerAndRecordGraphs()
out = torch.compile(backend=eager, fullgraph=False)(fn)(x)
self.assertEqual(expected, out)
self.assertEqual(len(eager.graphs), 1)
self.assertEqual(len(eager.graphs), 0)
def test_graph_break_before___enter___and_disable___exit__(self):
@contextlib.contextmanager
@ -2292,7 +2293,7 @@ class GraphModule(torch.nn.Module):
eager = EagerAndRecordGraphs()
out = torch.compile(backend=eager, fullgraph=False)(fn)(x)
self.assertEqual(expected, out)
self.assertEqual(len(eager.graphs), 1)
self.assertEqual(len(eager.graphs), 0)
def test_disable___enter__(self):
def h(x):
@ -2573,7 +2574,7 @@ class GraphModule(torch.nn.Module):
eager = EagerAndRecordGraphs()
out = torch.compile(backend=eager, fullgraph=False)(fn)(x)
self.assertEqual(expected, out)
self.assertEqual(len(eager.graphs), 1)
self.assertEqual(len(eager.graphs), 0)
def test_dynamo_disable_ctx(self):
@contextlib.contextmanager
@ -2623,7 +2624,7 @@ class GraphModule(torch.nn.Module):
eager = EagerAndRecordGraphs()
out = torch.compile(backend=eager, fullgraph=False, dynamic=False)(f)(x)
self.assertEqual(expected, out)
self.assertEqual(len(eager.graphs), 3)
self.assertEqual(len(eager.graphs), 2)
@parametrize("name", ("suppress", "stdout", "stderr"))
def test_contextlib_suppress(self, name):

View file

@ -404,6 +404,21 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref[0], res[0])
self.assertEqual(ref[1], res[1])
def test_raise_GeneratorExit(self):
# GeneratorExit does not inherit from Exception
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
try:
raise GeneratorExit
except Exception:
return t.sin()
except BaseException:
return t.cos()
t = torch.randn(2)
y = fn(t)
self.assertEqual(y, t.cos())
def test_speculation_exception(self):
log = SpeculationLog()
log.next("fake", 555, "fake", Instruction(1, "fake", 1, 1))

File diff suppressed because it is too large Load diff

View file

@ -7000,6 +7000,12 @@ class TestHigherOrderOpsOpInfo(torch._dynamo.test_case.TestCase):
)
def test_hops_compile(self, device, dtype, op, backend):
# Ensure HOPs can be compiled
if backend == "aot_eager" and op.name == "invoke_quant":
raise unittest.SkipTest(
"TODO: partitioner fails. migrate canonicalization to aot eager backend"
)
sample_inputs_itr = op.sample_inputs(
device, dtype, requires_grad=op.supports_autograd
)

View file

@ -9579,21 +9579,6 @@ def ___make_guard_fn():
):
compiled_fn(x)
# FIXME(XuehaiPan): do not inline infinite generator if it does not raise errors in eager mode
def fn(x):
def gen():
while True:
yield x
return list(zip(range(10), gen()))
x = torch.randn([0, 1, 2, 3, 4, 5])
compiled_fn = torch.compile(fn, backend="eager", fullgraph=True)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, "infinite generator"
):
compiled_fn(x)
def test_itertools_islice(self):
counters.clear()

View file

@ -1418,9 +1418,9 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(opt_model(a, b, c, d), correct))
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(cnt.frame_count, """4""")
self.assertExpectedInline(cnt.frame_count, """2""")
else:
self.assertExpectedInline(cnt.frame_count, """5""")
self.assertExpectedInline(cnt.frame_count, """3""")
def test_hf_model_output(self):
ex = ModelOutput(a=torch.randn(10), b=torch.randn(10), c=torch.randn(10))
@ -6510,6 +6510,27 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
).sum()
self.assertEqual(actual, expected)
def test_incompatible_configs(self):
with torch._dynamo.config.patch(
suppress_errors=False, fail_on_recompile_limit_hit=False
):
torch.compile(lambda: None)
with torch._dynamo.config.patch(
suppress_errors=True, fail_on_recompile_limit_hit=False
):
torch.compile(lambda: None)
with torch._dynamo.config.patch(
suppress_errors=False, fail_on_recompile_limit_hit=True
):
torch.compile(lambda: None)
with torch._dynamo.config.patch(
suppress_errors=True, fail_on_recompile_limit_hit=True
), self.assertRaises(AssertionError):
torch.compile(lambda: None)
class ReproTestsDevice(torch._dynamo.test_case.TestCase):
def test_sub_alpha_scalar_repro(self, device):

View file

@ -0,0 +1,183 @@
# Owner(s): ["module: higher order operators"]
# flake8: noqa: B950
import contextlib
import logging
import unittest
import torch
import torch._dynamo
import torch._functorch
import torch._inductor
import torch._inductor.decomposition
from torch._higher_order_ops import InvokeQuant
from torch._inductor.pattern_matcher import (
Arg,
CallFunction,
Ignored,
Match,
PatternMatcherPass,
register_graph_pattern,
)
from torch.testing import FileCheck
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
invoke_quant_tracer = InvokeQuant()
@skipIfTorchDynamo("Not a torch._dynamo test")
class TestInvokeQuant(TestCase):
backend = ""
def test_simple(self):
def gn(x, y):
return (torch.mul(x, y) + y,)
def fn(x, y):
return invoke_quant_tracer(
gn, (x, y), scheme="nf4", quant_options=invoke_quant_tracer
)[0]
x = torch.randn(8, requires_grad=False)
y = torch.randn(8, requires_grad=False)
ref = gn(x, y)[0]
x_clone = x.clone().detach().requires_grad_(False)
y_clone = y.clone().detach().requires_grad_(False)
res = torch.compile(fn, backend=self.backend)(x_clone, y_clone)
self.assertEqual(ref, res)
def test_construct_inline(self):
def gn(x, y):
return (torch.mul(x, y) + y,)
def fn(x, y):
return InvokeQuant(codegen_low_precision=False)(gn, (x, y), scheme="nf4")[0]
x = torch.randn(8, requires_grad=False)
y = torch.randn(8, requires_grad=False)
ref = gn(x, y)[0]
x_clone = x.clone().detach().requires_grad_(False)
y_clone = y.clone().detach().requires_grad_(False)
res = torch.compile(fn, backend=self.backend)(x_clone, y_clone)
self.assertEqual(ref, res)
def test_inline(self):
def gn(x, y):
return (torch.mul(x, y) + y,)
def fn(x, y):
return InvokeQuant()(gn, (x, y), scheme="nf4")[0]
x = torch.randn(8, requires_grad=False)
y = torch.randn(8, requires_grad=False)
ref = gn(x, y)[0]
x_clone = x.clone().detach().requires_grad_(False)
y_clone = y.clone().detach().requires_grad_(False)
res = torch.compile(fn, backend=self.backend)(x_clone, y_clone)
self.assertEqual(ref, res)
def test_multiple(self):
torch._logging.set_logs(post_grad_graphs=True)
def gn(x, y):
return torch.mul(x, y) + y
def fn(x, y, z):
o1 = invoke_quant_tracer(gn, (x, y), scheme="nf4")
o2 = invoke_quant_tracer(gn, (y, z), scheme="nf4")
return o1 + o2
x = torch.randn(8, requires_grad=False)
y = torch.randn(8, requires_grad=False)
z = torch.randn(8, requires_grad=False)
ref = fn(x, y, z)
log_context = (
contextlib.nullcontext()
if self.backend != "inductor"
else self.assertLogs(logger="torch._inductor", level=logging.DEBUG)
)
with log_context as log:
res = torch.compile(fn, backend=self.backend)(x, y, z)
self.assertEqual(ref, res)
if self.backend == "inductor":
logs = "\n".join(r.getMessage() for r in log.records)
f = FileCheck()
f.check("AFTER POST GRAD")
f.check("subgraph0").check("subgraph1")
for _ in range(2):
f.check("torch.ops.higher_order.invoke_quant(").check_same("nf4")
f.run(logs)
class TestInvokeQuantEager(TestInvokeQuant):
backend = "eager"
class TestInvokeQuantAotEager(TestInvokeQuant):
backend = "aot_eager"
class TestInvokeQuantInductor(TestInvokeQuant):
backend = "inductor"
def test_pattern_matching(self):
counter = 0
test_pass = PatternMatcherPass()
def my_pass(g):
return test_pass.apply(g)
def gn(x, y):
return torch.mul(x, y) + y
def fn(x, y, z):
return invoke_quant_tracer(gn, (x, y), scheme="nf4") @ z
def fn_no_match(x, y, z):
return invoke_quant_tracer(gn, (x, y)) @ z
x = torch.randn(64, 64, requires_grad=False)
y = torch.randn(64, 64, requires_grad=False)
z = torch.randn(64, 64, requires_grad=False)
@register_graph_pattern(
CallFunction(
torch.ops.aten.mm,
CallFunction(
torch.ops.higher_order.invoke_quant,
Ignored(),
Ignored(),
Ignored(),
scheme="nf4",
),
Arg(),
),
pass_dict=test_pass,
)
def quant_matching(match: Match, *args, **kwargs):
nonlocal counter
counter += 1
with unittest.mock.patch(
"torch._inductor.config.post_grad_custom_pre_pass", my_pass
):
torch.compile(fn)(x, y, z)
self.assertTrue(counter == 1)
torch.compile(fn_no_match)(x, y, z)
self.assertTrue(counter == 1)
del TestInvokeQuant
if __name__ == "__main__":
run_tests()

View file

@ -0,0 +1,102 @@
# Owner(s): ["module: inductor"]
import sympy
import torch
from torch._inductor.codegen.block_analysis import BlockPatternMatcher
from torch._inductor.virtualized import V
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
from torch.testing._internal.inductor_utils import dummy_graph
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
# Some useful symbols
x, y = sympy.symbols("x y")
@instantiate_parametrized_tests
class BlockAnalysisTest(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
# Create a GraphLowering, so we can access V.graph.
cls.graph = dummy_graph()
@parametrize(
"stride,symbol,expr",
[
(5, x, Identity(5 * x)),
(4, y, 4 * Identity(y)),
(3, x, Identity(3) * x),
],
)
def test_affine_identity(self, stride: int, symbol: sympy.Symbol, expr: sympy.Expr):
# Test that we can handle an identity expression in affine indexing.
matched_stride = BlockPatternMatcher.match_affine_block_expr(expr, symbol)
self.assertEqual(matched_stride, stride)
@parametrize(
"dims,strides,symbol,expr",
[
(
(2, 4),
(4, 1),
x,
4 * FloorDiv(Identity(x), 4) + ModularIndexing(x, 1, 4),
),
(
(3, 9),
(5, 2),
x,
5 * FloorDiv(x, 9) + 2 * ModularIndexing(Identity(x), 1, 9),
),
((2, 7), (1, 1), x, Identity(FloorDiv(x, 7) + ModularIndexing(x, 1, 7))),
],
)
def test_mod_div_identity(
self,
dims: tuple[int],
strides: tuple[int],
symbol: sympy.Symbol,
expr: sympy.Expr,
):
# Test that we can handle an identity expression in modular indexing.
numel = int(torch.prod(torch.Tensor(dims)))
num_dims = len(dims)
with V.set_graph_handler(self.graph):
match_result = BlockPatternMatcher.match_mod_div_block_expr(
expr, symbol, numel, num_dims
)
# Check the matched block dimensions.
self.assertNotEqual(match_result, None)
matched_dims, matched_strides, matched_block_index_exprs = match_result
self.assertEqual(matched_dims, dims)
self.assertEqual(matched_strides, strides)
@parametrize(
"symbol,expr,subexpr",
[
(x, Identity(x), x),
(x, Identity(x + 5), x),
(y, Identity(x + 2 * y) + 5, 2 * y),
],
)
def test_subexpr_identity(
self,
symbol: sympy.Symbol,
expr: sympy.Expr,
subexpr: sympy.Expr,
):
matched_subexpr = BlockPatternMatcher.get_subexpr_involving_symbol(expr, symbol)
self.assertEqual(matched_subexpr, subexpr)
if __name__ == "__main__":
run_tests()

View file

@ -2510,9 +2510,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
@supported_platform
def test_strided_backwards(self):
shape = (1, 2, 4096, 64)
Q = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16)
K = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16)
V = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16)
Q = torch.randn(shape, requires_grad=True, device="cuda")
K = torch.randn(shape, requires_grad=True, device="cuda")
V = torch.randn(shape, requires_grad=True, device="cuda")
func = torch.compile(flex_attention, dynamic=True, fullgraph=True)
K_sliced = K[:, :, :-128]

View file

@ -5,19 +5,23 @@ from torch._inductor.codegen.cpp import CppOverrides, CppVecOverrides
from torch._inductor.codegen.halide import HalideOverrides
from torch._inductor.codegen.mps import MetalOverrides
from torch._inductor.codegen.triton import TritonKernelOverrides
from torch._inductor.ops_handler import list_ops, OP_NAMES
from torch._inductor.ops_handler import list_ops, OP_NAMES, OpsHandler
from torch._inductor.test_case import TestCase
class TestOpCompleteness(TestCase):
def verify_ops_handler_completeness(self, handler):
op_names = list_ops(handler)
if OP_NAMES == op_names:
return
print(f"Missing ops: {OP_NAMES - op_names}")
print(f"Extra ops: {op_names - OP_NAMES}")
self.assertEqual(", ".join(OP_NAMES - op_names), "")
self.assertEqual(", ".join(op_names - OP_NAMES), "")
for op in OP_NAMES:
self.assertIsNot(
getattr(handler, op),
getattr(OpsHandler, op),
msg=f"{handler} must implement {op}",
)
extra_ops = list_ops(handler) - OP_NAMES
if extra_ops:
raise AssertionError(
f"{handler} has an extra ops: {extra_ops}, add them to OpHandler class or prefix with `_`"
)
def test_triton_overrides(self):
self.verify_ops_handler_completeness(TritonKernelOverrides)

View file

@ -50,7 +50,7 @@ class TestSortAndSelect(TestCase):
return ((b != b) | (a <= b)).all().item()
else:
error( # noqa: F821
raise ValueError(
f'unknown order "{order}", must be "ascending" or "descending"'
)

View file

@ -22,6 +22,7 @@ from torch.testing._internal.common_utils import (
)
from torch.utils._sympy.functions import (
FloorDiv,
Identity,
OpaqueUnaryFn_cos,
simple_floordiv_gcd,
)
@ -34,7 +35,8 @@ from torch.utils._sympy.reference import (
)
from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
from torch.utils._sympy.value_ranges import ValueRanges
from torch._inductor.bounds import ValueRangeAnalysis
UNARY_OPS = [
@ -954,6 +956,17 @@ class TestSingletonInt(TestCase):
self.assertEqual(j1.free_symbols, set())
class TestIdentity(TestCase):
def test_expand_identity(self):
"""
Test removing an identity via expansion.
"""
x = sympy.Symbol("x")
arg = x + sympy.S.One
expr = Identity(arg)
expanded = expr.expand(identity=True)
self.assertEqual(expanded.count(Identity), 0)
self.assertEqual(expanded, arg)
instantiate_parametrized_tests(TestValueRanges)
instantiate_parametrized_tests(TestSympyInterp)

View file

@ -324,7 +324,7 @@ class TestTransformers(NNTestCase):
encoder(test, src_key_padding_mask=pad_mask.to(torch.uint8))
except AssertionError:
continue
self.assertFalse(e, "Failed to catch unsupported uint8 type exception") # noqa: F821
self.assertFalse(e, "Failed to catch unsupported uint8 type exception")
test_train_bool = encoder(test, src_key_padding_mask=pad_mask)
encoder.eval()
@ -335,7 +335,7 @@ class TestTransformers(NNTestCase):
encoder(test, src_key_padding_mask=pad_mask.to(torch.int64))
except AssertionError as e:
continue
self.assertFalse(e, "Failed to catch unsupported Long type exception") # noqa: F821
self.assertFalse(e, "Failed to catch unsupported Long type exception")
test_eval_bool = encoder(test, src_key_padding_mask=pad_mask)
l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item()

View file

@ -78,6 +78,9 @@ def create_build_plan() -> list[tuple[str, str]]:
if line.startswith(": &&") and line.endswith("&& :"):
line = line[4:-4]
line = line.replace("-O2", "-g").replace("-O3", "-g")
# Build Metal shaders with debug infomation
if "xcrun metal " in line and "-frecord-sources" not in line:
line += " -frecord-sources -gline-tables-only"
try:
name = line.split("-o ", 1)[1].split(" ")[0]
rc.append((name, line))

View file

@ -26,7 +26,10 @@ from .exc import IncorrectUsage, unimplemented
from .source import AttrSource, Source
from .utils import is_safe_constant, rot_n_helper
from .variables.base import ValueMutationExisting, VariableTracker
from .variables.functions import FunctionDecoratedByContextlibContextManagerVariable
from .variables.functions import (
ContextlibContextManagerLocalGeneratorObjectVariable,
LocalGeneratorObjectVariable,
)
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
NumpyNdarrayVariable,
@ -162,14 +165,20 @@ class PyCodegen:
return
if value.is_realized() and isinstance(
value, FunctionDecoratedByContextlibContextManagerVariable
value, ContextlibContextManagerLocalGeneratorObjectVariable
):
raise IncorrectUsage(
"NYI: Returning a @contextmanager object from a torch.compile function"
)
# Dynamo normally prefers codegen from source to account for aliasing.
if value.source is not None and allow_cache:
if (
value.source is not None
and allow_cache
and not (
value.is_realized() and isinstance(value, LocalGeneratorObjectVariable)
)
):
# There's a corner case for export: for instance, if the computation
# graph is just identity on an input tensor, Dynamo would just emit
# a `LOAD_FAST` from the input source, rather than generating an

View file

@ -52,6 +52,7 @@ skip_code_recursive_on_recompile_limit_hit = True
# raise a hard error if cache limit is hit. If you are on a model where you
# know you've sized the cache correctly, this can help detect problems when
# you regress guards/specialization. This works best when recompile_limit = 1.
# This flag is incompatible with: suppress_errors.
# [@compile_ignored: runtime_behaviour]
fail_on_recompile_limit_hit = False
@ -164,6 +165,7 @@ traceable_tensor_subclasses: set[type[Any]] = set()
# This is a good way to get your model to work one way or another, but you may
# lose optimization opportunities this way. Devs, if your benchmark model is failing
# this way, you should figure out why instead of suppressing it.
# This flag is incompatible with: fail_on_recompile_limit_hit.
suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False))
# Record and write an execution record of the current frame to a file
@ -417,6 +419,10 @@ enable_cpp_symbolic_shape_guards = False
# Enable tracing through contextlib.contextmanager
enable_trace_contextlib = True
# Enable tracing generator functions lazily. If False, Dynamo will exhaust
# generators upon first execution. And if True, the generator will be accessed lazily
enable_faithful_generator_behavior = True
# Inline inbuilt nn modules
inline_inbuilt_nn_modules = Config( # type: ignore[var-annotated]
default=True,

View file

@ -833,6 +833,13 @@ def is_inductor_supported():
return False
def check_for_incompatible_configs():
# Some of the configs should be mutually exclusive
assert not (
config.suppress_errors and config.fail_on_recompile_limit_hit
), "Dynamo configs suppress_error and fail_on_recompile_limit_hit can not both be active at the same time."
def optimize(*args, **kwargs):
def rebuild_ctx():
ca_kwargs_override = config.compiled_autograd_kwargs_override
@ -885,6 +892,7 @@ def _optimize(
...
"""
check_if_dynamo_supported()
check_for_incompatible_configs()
# Note: The hooks object could be global instead of passed around, *however* that would make
# for a confusing API usage and plumbing story wherein we nest multiple .optimize calls.
# There is some prior art around this, w/r/t nesting backend calls are enforced to be the same

View file

@ -162,6 +162,17 @@ class AttributeMutationError(Unsupported):
super().__init__(msg)
class InfiniteGeneratorError(Unsupported):
# Raised when the number of yielded values is greater than MAX_ITERATOR_LIMIT
def __init__(self, msg: str) -> None:
super().__init__(msg)
class SideEffectsError(Unsupported):
def __init__(self, msg: str) -> None:
super().__init__(msg)
class CondOpArgsMismatchError(ArgsMismatchError):
"""
Internal error from cond() due to arguments mismatch.
@ -267,12 +278,17 @@ class ObservedKeyError(ObservedLookupError):
pass
class ObservedGeneratorExit(ObservedException):
pass
class ObservedAttributeError(ObservedException):
# An AttributeError exception to be raised from inside Dynamo tracing. This can happen on user defined object __getattr__
pass
class ObservedRuntimeError(ObservedException):
# A RuntimeError exception to be raised from inside Dynamo tracing. This can happen on generator.throw(..) method
pass
@ -280,17 +296,32 @@ class ObservedNotImplementedError(ObservedException):
pass
class ObservedTypeError(ObservedException):
# A TypeError exception to be raised from inside Dynamo tracing. This can happen on generator.send(..) method
pass
observed_exception_map = {
StopIteration: ObservedUserStopIteration,
LookupError: ObservedLookupError,
IndexError: ObservedIndexError,
GeneratorExit: ObservedGeneratorExit,
KeyError: ObservedKeyError,
AttributeError: ObservedAttributeError,
RuntimeError: ObservedRuntimeError,
NotImplementedError: ObservedNotImplementedError,
TypeError: ObservedTypeError,
}
def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedException]:
if exc_type not in observed_exception_map:
observed_exception_map[exc_type] = type(
f"Observed{exc_type.__name__}Error", (ObservedException,), {}
)
return observed_exception_map[exc_type]
def raise_observed_exception(
exc_type: type[Exception],
tx: InstructionTranslatorBase,

View file

@ -1936,6 +1936,9 @@ class SubgraphTracer(fx.Tracer):
# backward recomputation of the checkpoint region doesn't affect its correctness.
self.allow_side_effects_under_checkpoint = False
# True if this tracer is currently tracing (reconstructing) into a Python generator
self.is_reconstructing_generator = False
self.debug_level: int = parent.debug_level + 1 if parent is not None else 0
self._cur_code = None

View file

@ -7,7 +7,7 @@ import warnings
import weakref
from collections.abc import MutableMapping
from types import CellType
from typing import Any, Optional
from typing import Any, Optional, TYPE_CHECKING
import torch.nn
@ -19,7 +19,7 @@ from .bytecode_transformation import (
create_instruction,
)
from .codegen import PyCodegen
from .exc import unimplemented
from .exc import SideEffectsError, unimplemented
from .source import GlobalSource, LocalCellSource, LocalSource, Source
from .utils import dict_new, is_frozen_dataclass, nn_module_new, object_new, tuple_new
from .variables.base import (
@ -34,6 +34,10 @@ from .variables.base import (
from .variables.user_defined import FrozenDataClassVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
def _manual_dict_setitem(dict_from, dict_to, mro_index):
# Carefully calls the dict or OrderedDict `clear` or `__setitem__`. We have
# to be careful because we don't want to trigger the user defined object
@ -134,6 +138,14 @@ class SideEffects:
and output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint
)
def is_reconstructing_generator(self):
output_graph = self.output_graph_weakref()
return (
output_graph
and output_graph.current_tx.output.current_tracer.is_reconstructing_generator
)
def check_allowed_side_effect(self, item):
from torch._dynamo.variables.misc import AutogradFunctionContextVariable
@ -143,6 +155,14 @@ class SideEffects:
return True
if self.should_allow_side_effects_under_checkpoint():
return True
if self.is_reconstructing_generator():
# This is missing the case where one mutates a tensor. See
# test_generator.py::test_reconstruct_generator_tensor_mutation
raise SideEffectsError(
"Cannot reconstruct a generator with variable mutations. "
"Dynamo needs to fully exhaust the generator, which may cause "
"unintended variable modifications."
)
if not is_side_effect_safe(item.mutation_type):
unimplemented(
"HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)"
@ -842,7 +862,7 @@ class SideEffects:
@contextlib.contextmanager
def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): # type: ignore[name-defined] # noqa: F821
def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"):
assert tx.output.current_tracer.under_activation_checkpoint
orig_val = tx.output.current_tracer.allow_side_effects_under_checkpoint
try:
@ -850,3 +870,13 @@ def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): # type: i
yield
finally:
tx.output.current_tracer.allow_side_effects_under_checkpoint = orig_val
@contextlib.contextmanager
def disallow_side_effects_in_generator(tx: "InstructionTranslator"):
orig_val = tx.output.current_tracer.is_reconstructing_generator
try:
tx.output.current_tracer.is_reconstructing_generator = True
yield
finally:
tx.output.current_tracer.is_reconstructing_generator = orig_val

View file

@ -85,7 +85,8 @@ from .variables.ctx_manager import (
from .variables.dicts import ConstDictVariable, SetVariable
from .variables.functions import (
BaseUserFunctionVariable,
FunctionDecoratedByContextlibContextManagerVariable,
LocalGeneratorFunctionVariable,
LocalGeneratorObjectVariable,
NestedUserFunctionVariable,
SkipFunctionVariable,
UserFunctionVariable,
@ -290,6 +291,34 @@ def _step_logger():
return torchdynamo_logging.get_step_logger(log)
@contextlib.contextmanager
def save_and_restart_speculation_log(tx: "InstructionTranslatorBase"):
# When reconstructing a generator after a graph break, we advance it until
# it is fully exhausted. This process adds new entries to the speculation
# log that were not previously observed. Without temporarily clearing the
# speculation log, this could lead to a divergence error.
entries = tx.speculation_log.entries
index = tx.speculation_log.index
try:
tx.speculation_log.entries = []
tx.speculation_log.index = 0
yield
finally:
tx.speculation_log.entries = entries
tx.speculation_log.index = index
@contextlib.contextmanager
def temporarely_allow_writes_to_output_graph(tx: "InstructionTranslatorBase"):
try:
tmp = tx.output.should_exit
tx.output.should_exit = False
yield
finally:
tx.output.should_exit = tmp
@dataclasses.dataclass
class BlockStackEntry:
# Current instruction that pushes something to block_stack
@ -922,11 +951,22 @@ class InstructionTranslatorBase(
raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
def inline_generator_function(self, fn, args, kwargs):
"""
Redirect the call to the generator "call_function"
"""
if not isinstance(fn, LocalGeneratorFunctionVariable):
fn = LocalGeneratorFunctionVariable(fn)
return fn.call_function(self, args, kwargs)
def inline_user_function_return(self, fn, args, kwargs):
"""
A call to some user defined function by inlining it.
"""
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
if config.enable_faithful_generator_behavior and is_generator(fn.get_code()):
return self.inline_generator_function(fn, args, kwargs)
else:
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
def get_line_of_code_header(self, lineno=None):
if lineno is None:
@ -1499,6 +1539,15 @@ class InstructionTranslatorBase(
self._raise_exception_variable(inst)
unimplemented("raise ... from ...")
def CLEANUP_THROW(self, inst):
# https://github.com/python/cpython/pull/96010
tos = self.stack[-1]
assert isinstance(tos, ExceptionVariable)
if tos.exc_type is StopIteration:
unimplemented("CLEANUP_THROW with StopIteration")
else:
self.RERAISE(inst)
def RERAISE(self, inst):
if sys.version_info >= (3, 11):
# RERAISE is currently supported in a narrow case of `raise ... from None`
@ -3083,7 +3132,20 @@ class InstructionTranslator(InstructionTranslatorBase):
return True
return False
def replace_tos_if_return_is_generator(self):
if (
len(self.stack)
and (tos := self.stack[-1])
and isinstance(tos, LocalGeneratorObjectVariable)
):
self.stack[-1] = ListIteratorVariable(
tos.force_unpack_var_sequence(self),
mutation_type=ValueMutationNew(),
)
def _return(self, inst):
self.replace_tos_if_return_is_generator()
if (
not config.allow_empty_graphs
and self.output.count_calls() == 0
@ -3093,6 +3155,7 @@ class InstructionTranslator(InstructionTranslatorBase):
and not self.one_graph
):
raise exc.SkipFrame("because no content in function call")
self.instruction_pointer = None
_step_logger()(
logging.INFO,
@ -3179,8 +3242,6 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
func: VariableTracker,
args: list[VariableTracker],
kwargs,
*,
stop_generator_on_yield: bool = False,
):
if isinstance(func, SkipFunctionVariable):
unimplemented("inline with functions in skip files")
@ -3189,7 +3250,8 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
(
UserFunctionVariable,
NestedUserFunctionVariable,
FunctionDecoratedByContextlibContextManagerVariable,
LocalGeneratorFunctionVariable,
LocalGeneratorObjectVariable,
),
)
result = InliningInstructionTranslator.check_inlineable(func)
@ -3254,9 +3316,10 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
parent.symbolic_globals,
parent.symbolic_torch_function_state,
func,
stop_generator_on_yield=stop_generator_on_yield,
)
else:
# need the line below to make MyPy happy
assert not isinstance(func, LocalGeneratorObjectVariable)
tracer = InliningInstructionTranslator(
parent,
code,
@ -3302,24 +3365,30 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
log.debug("DONE INLINING %s", code)
if is_generator(code):
assert isinstance(self, InliningGeneratorInstructionTranslator)
# The first flag tells us if we consume generators lazily or not
# and the second is if the generator is exhausted.
# In the future, generators should be lazily consumed and the first
# flag (stop_generator_on_yield) will not be needed.
if self.stop_generator_on_yield and self.generator_exhausted:
if config.enable_faithful_generator_behavior or (
isinstance(self, InliningGeneratorInstructionTranslator)
and self.is_generator_from_ctx_manager
):
if (
is_generator(code)
and isinstance(self, InliningGeneratorInstructionTranslator)
and self.generator_exhausted
):
assert isinstance(self, InliningGeneratorInstructionTranslator)
# When the generator returns None, we raise StopIteration
r = self.symbolic_result
assert r.as_python_constant() is None
exc.raise_observed_exception(StopIteration, self)
else:
return self.symbolic_result
else:
if is_generator(code):
assert isinstance(self, InliningGeneratorInstructionTranslator)
assert self.symbolic_result.as_python_constant() is None
return ListIteratorVariable(
self.generated_items,
mutation_type=ValueMutationNew(),
)
else:
return self.symbolic_result
else:
return self.symbolic_result
def __init__(
self,
@ -3438,27 +3507,26 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
generated_items: list[VariableTracker]
# Flag wether or not the InlineGenerator should consume the entire iterator
stop_generator_on_yield: bool
def __init__(self, *args, stop_generator_on_yield: bool = False, **kwargs) -> None:
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.generated_items = []
# In the future, generators should run lazily (i.e. when next(...) is called)
# TODO: Set this to True by default, so that dynamo follows CPython more
# closely
self.stop_generator_on_yield = stop_generator_on_yield
self.generator_exhausted = False
self.is_generator_from_ctx_manager = False
def YIELD_VALUE(self, inst: Instruction):
top = self.pop()
self.generated_items.append(top)
if len(self.generated_items) > MAX_ITERATOR_LIMIT:
unimplemented(
raise exc.InfiniteGeneratorError(
"Too many yield values in generator. Maybe you are inlining an infinite generator. "
f"If not, please report a bug at {PT2_ISSUE_TRACKER_URL}",
)
self.push(ConstantVariable.create(None))
if self.stop_generator_on_yield:
if (
config.enable_faithful_generator_behavior
or self.is_generator_from_ctx_manager
):
self.symbolic_result = top
# Stop tracing
raise YieldValueOp
@ -3500,10 +3568,6 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
self.pop()
self.push(ConstantVariable.create(ex.value))
else:
self.push(val)
# Add the value to yield into generated_items and replace the top of the stack with None
self.YIELD_VALUE(inst)
# Repeat the YIELD_FROM instruction in the next eval loop
assert (
isinstance(self.instruction_pointer, int)
@ -3511,11 +3575,15 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
)
self.instruction_pointer -= 1
self.push(val)
# Add the value to yield into generated_items and replace the top of the stack with None
self.YIELD_VALUE(inst)
def SEND(self, inst):
assert len(self.stack) >= 2
val = self.pop()
tos = self.stack[-1]
if isinstance(tos, ListIteratorVariable) or (
if isinstance(tos, (ListIteratorVariable, LocalGeneratorObjectVariable)) or (
isinstance(tos, UserDefinedObjectVariable)
and isinstance(tos.value, collections.abc.Iterator)
):

View file

@ -32,8 +32,9 @@ from .utils import getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper
from .variables import (
BuiltinVariable,
FunctionalCallVariable,
FunctionDecoratedByContextlibContextManagerVariable,
FunctorchHigherOrderVariable,
LocalGeneratorFunctionVariable,
LocalGeneratorObjectVariable,
NestedUserFunctionVariable,
PolyfilledFunctionVariable,
SkipFunctionVariable,
@ -3209,6 +3210,7 @@ LEGACY_MOD_INLINELIST = {
"torch._higher_order_ops.while_loop",
"torch._higher_order_ops.associative_scan",
"torch._higher_order_ops.scan",
"torch._higher_order_ops._invoke_quant",
"torch._higher_order_ops.utils",
"torch.nn.attention.flex_attention",
"torch.ao.quantization.pt2e.export_utils",
@ -3619,7 +3621,8 @@ def check_verbose(obj, is_inlined_call=False):
UserFunctionVariable,
UserMethodVariable,
NestedUserFunctionVariable,
FunctionDecoratedByContextlibContextManagerVariable,
LocalGeneratorFunctionVariable,
LocalGeneratorObjectVariable,
),
):
try:
@ -3639,7 +3642,14 @@ def check_verbose(obj, is_inlined_call=False):
# Consulte the central trace rules defined in torch._dynamo.trace_rules.
reasons: set[str] = set()
rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons)
if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)):
if issubclass(
rule,
(
UserFunctionVariable,
LocalGeneratorFunctionVariable,
PolyfilledFunctionVariable,
),
):
return SkipResult(
False,
f"inlined according trace_rules.lookup {reasons.pop()}",

View file

@ -39,6 +39,8 @@ from .functions import (
FunctionDecoratedByContextlibContextManagerVariable,
FunctoolsPartialVariable,
FunctoolsWrapsVariable,
LocalGeneratorFunctionVariable,
LocalGeneratorObjectVariable,
NestedUserFunctionVariable,
PolyfilledFunctionVariable,
SkipFunctionVariable,

View file

@ -773,7 +773,13 @@ class BuiltinVariable(VariableTracker):
tx, [v.realize() for v in args], kwargs
)
if inspect.isclass(fn) and issubclass(fn, Exception):
if inspect.isclass(fn) and (
issubclass(fn, Exception)
# GeneratorExit doens't inherit from Exception
# >>> issubclass(GeneratorExit, Exception)
# False
or fn is GeneratorExit
):
def create_exception_class_object(
tx: "InstructionTranslator", args, kwargs
@ -1425,6 +1431,13 @@ class BuiltinVariable(VariableTracker):
mutation_type=ValueMutationNew(),
)
def _call_iter_tuple_generator(self, tx, obj, *args, **kwargs):
cls = variables.BaseListVariable.cls_for(self.fn)
return cls(
list(obj.force_unpack_var_sequence(tx)), # exhaust generator
mutation_type=ValueMutationNew(),
)
def _call_tuple_list(self, tx, obj=None, *args, **kwargs):
if isinstance(obj, variables.IteratorVariable):
cls = variables.BaseListVariable.cls_for(self.fn)
@ -1432,6 +1445,8 @@ class BuiltinVariable(VariableTracker):
list(obj.force_unpack_var_sequence(tx)),
mutation_type=ValueMutationNew(),
)
elif isinstance(obj, variables.LocalGeneratorObjectVariable):
return self._call_iter_tuple_generator(tx, obj, *args, **kwargs)
else:
return self._call_iter_tuple_list(tx, obj, *args, **kwargs)

View file

@ -4,17 +4,30 @@ import builtins
import functools
import inspect
import itertools
import sys
import types
from collections.abc import Sequence
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, TypeVar
from typing_extensions import Never
from unittest.mock import patch
import torch
from .. import polyfills, variables
from ..bytecode_transformation import create_call_function, create_rot_n
from ..exc import raise_observed_exception, unimplemented, Unsupported
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
from ..exc import (
get_dynamo_observed_exception,
handle_observed_exception,
IncorrectUsage,
InfiniteGeneratorError,
ObservedException,
ObservedGeneratorExit,
ObservedUserStopIteration,
raise_observed_exception,
SkipFrame,
unimplemented,
Unsupported,
)
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
from ..utils import (
@ -378,61 +391,426 @@ class BuiltinMethodVariable(BaseUserFunctionVariable):
return obj_vt.call_method(tx, name, args, kwargs)
class FunctionDecoratedByContextlibContextManagerVariable(BaseUserFunctionVariable):
# TODO(guilherme): replace this with a generic GeneratorFunctionVariable
class LocalGeneratorObjectVariable(VariableTracker):
def __init__(
self,
code: types.CodeType,
f_globals,
inline_tracer: Optional["InstructionTranslator"],
**kwargs,
):
super().__init__(**kwargs)
self.code = code
self.f_globals = f_globals
self.inline_tracer = inline_tracer
def get_code(self):
return self.code
def get_filename(self):
return self.get_code().co_filename
def get_name(self):
return self.get_code().co_name
def get_function(self):
raise NotImplementedError
def has_self(self):
return False
def __name__(self):
return self.get_name()
def __str__(self):
return f"{self.__class__.__name__}({self.get_name()})"
__repr__ = __str__
def reconstruct(self, codegen):
from torch._dynamo.side_effects import disallow_side_effects_in_generator
from torch._dynamo.symbolic_convert import (
InstructionTranslator,
save_and_restart_speculation_log,
temporarely_allow_writes_to_output_graph,
)
tx = InstructionTranslator.current_tx()
save = save_and_restart_speculation_log(tx)
disallow = disallow_side_effects_in_generator(tx)
temp = temporarely_allow_writes_to_output_graph(tx)
with save, disallow, temp:
tracer = self._get_inline_tracer(tx)
if not tracer.generator_exhausted:
self.remaining_items = self.force_unpack_var_sequence(tx)
variables.ListIteratorVariable(self.remaining_items).reconstruct(codegen)
def bind_args(self, tx, args, kwargs):
return self.fn.bind_args(tx, args, kwargs)
def get_globals(self):
return self.f_globals
def python_type(self):
return types.GeneratorType
def _get_inline_tracer(self, tx):
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
if self.inline_tracer is None:
self.inline_tracer = InliningInstructionTranslator.build_inline_tracer(
tx, self, [], {}
)
return self.inline_tracer
def next_variable(self, tx):
tracer = self._get_inline_tracer(tx)
if self._is_generator_exhausted():
raise_observed_exception(StopIteration, tx)
try:
# Hierarchically, tx can be seen as the parent of the inline tracer
# created on call_function. Any exception needs to be propagated to tx
# for Dynamo to behave correctly
with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
return tracer.inline_call_()
except ObservedException as e:
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
raise e
except InfiniteGeneratorError:
# test/dynamo/test_misc.py::test_iterator_limit
raise
except Unsupported as e:
torch._C._dynamo.eval_frame.skip_code(self.get_code())
raise SkipFrame from e
finally:
counters["unimplemented"] |= counters["inline_call"]
def has_unpack_var_sequence(self, tx):
return False
def has_force_unpack_var_sequence(self, tx) -> builtins.bool:
return True
def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
result = []
while True:
try:
result.append(self.next_variable(tx))
except ObservedUserStopIteration:
handle_observed_exception(tx)
break
return result
def _setup_exception(self, tx, exc):
tracer = self._get_inline_tracer(tx)
tracer.push(exc)
try:
tracer._raise_exception_variable(None)
except ObservedException as e:
# if no handler is available (i.e. user code doesn't catch it), the
# exception is raised again.
tracer.exception_handler(e)
def _is_generator_just_started(self):
return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0
def _is_generator_exhausted(self):
return getattr(self.inline_tracer, "generator_exhausted", False)
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__next__":
return self.next_variable(tx)
elif name == "__iter__":
# iter(gen) returns itself
return self
elif name == "send":
# Sends a value into the generator function. Returns the next value
# yielded by the generator, or raises StopIteration if the generator
# exits without yielding another value
if self._is_generator_just_started() and len(args):
# can't send non-None value to a just-started generator
# Test: GeneratorCPythonTests.test_send_non_none_to_new_gen
if not all(
isinstance(arg, ConstantVariable) and arg.value is None
for arg in args
):
raise_observed_exception(TypeError, tx)
tracer = self._get_inline_tracer(tx)
tracer.push_many(args)
return self.next_variable(tx)
elif name == "close":
# * Raises a GeneratorExit at the point where the generator function was paused.
# * If the generator function catches the exception and returns a
# value, this value is returned from close() - Python 3.13+
# * If the generator function is already closed, or raises GeneratorExit
# (by not catching the exception), close() returns None.
# * If the generator yields a value, a RuntimeError is raised.
# * If the generator raises any other exception, it is propagated to the caller.
# * If the generator has already exited due to an exception or normal
# exit, close() returns None and has no other effect.
# Return None if close is called on a just-started generator
# See test GeneratorCloseCpythonTests::test_close_not_started
tracer = self._get_inline_tracer(tx)
if self._is_generator_just_started() or self._is_generator_exhausted():
tracer.generator_exhausted = True
return variables.ConstantVariable(None)
# Raise GeneratorExit to see if user code catches it. Any other exception
# is propagated to the parent frame.
try:
self._setup_exception(
tx, variables.ExceptionVariable(GeneratorExit, ())
)
# There's an extra block on Python 3.12+ to handle StopIteration
# see: https://github.com/python/cpython/blob/8f93dd8a8f237b277abad20d566df90c5cbd7f1e/Objects/genobject.c#L394-L397
#
# 1 0 RETURN_GENERATOR
# 2 POP_TOP
# 4 RESUME 0
# 2 6 LOAD_CONST 1 (1)
# 8 YIELD_VALUE 1
# 10 RESUME 1
# 12 POP_TOP
# 14 RETURN_CONST 0 (None)
# >> 16 CALL_INTRINSIC_1 3 (INTRINSIC_STOPITERATION_ERROR)
# 18 RERAISE 1
# ExceptionTable:
# 4 to 14 -> 16 [0] lasti
if (
sys.version_info >= (3, 12)
and tracer.next_instruction.opname == "CALL_INTRINSIC_1"
):
tracer.generator_exhausted = True
return variables.ConstantVariable(None)
except ObservedGeneratorExit:
# If it doesn't catch, we just return None, as per the text above
tracer.generator_exhausted = True
return variables.ConstantVariable(None)
try:
# Raise RuntimeError if the generator yields any other value
if self.next_variable(tx):
raise_observed_exception(RuntimeError, tx)
except ObservedGeneratorExit:
tracer.generator_exhausted = True
return variables.ConstantVariable(None)
except ObservedUserStopIteration:
# In Python 3.13+, one can capture GeneratorExit and return a value
# See test_generator.py::test_close_capture_GeneratorExit_return
# https://discuss.python.org/t/let-generator-close-return-stopiteration-value/24786/26
# https://github.com/python/cpython/pull/104771
assert tracer.symbolic_result is not None
return tracer.symbolic_result
elif name == "throw":
# * Raises an exception at the point where the generator was paused, and
# returns the next value yielded by the generator.
# * If the generator exits without yielding, raise StopIteration
# * If the generator function does not catch the passed-in exception,
# or raises a different exception, then that exception propagates to the caller.
if len(args) > 1:
raise IncorrectUsage(
"the (type, exc, tb) signature of throw() is deprecated, "
"use the single-arg signature instead."
)
# Setup the exception table and jump target in case of try...finally
tracer = self._get_inline_tracer(tx)
try:
self._setup_exception(tx, args[0])
except ObservedException:
# propagate the exception back to the parent caller
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
raise
retval = self.next_variable(tx)
# The exception raised before is still active. We need to check the exception
# table one more time to find the next target. But why? Lets walk
# through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M
#
# z = 0
# def whoo():
# global z
# z = 0
# try:
# yield 1
# except ValueError:
# yield 2
# finally:
# z += 1
# z += 10
#
# gen = whoo()
# next(gen)
# gen.throw(ValueError)
# print('z', z) -> z = 1
#
# ...
# >> 58 PUSH_EXC_INFO
#
# 8 60 LOAD_GLOBAL 2 (ValueError)
# 70 CHECK_EXC_MATCH
# 72 POP_JUMP_IF_FALSE 7 (to 88)
# 74 POP_TOP
#
# 9 76 LOAD_CONST 3 (2)
# 78 YIELD_VALUE 3 <------ ValueError is still active here
# 80 RESUME 1
# 82 POP_TOP
# 84 POP_EXCEPT
# 86 jump_backward 34 (to 20)
# ...
#
# ExceptionTable:
# 4 to 8 -> 124 [0] lasti
# 12 to 18 -> 58 [0]
# 20 to 56 -> 124 [0] lasti
# 58 to 82 -> 90 [1] lasti <------ move to 90
# 84 to 86 -> 96 [0]
# 88 to 88 -> 90 [1] lasti
# 90 to 94 -> 96 [0]
# 96 to 116 -> 118 [1] lasti
# 118 to 122 -> 124 [0] lasti
#
# In this scenario, a generator can yield after `throw()` is called. Even
# after the exception is raised a few lines above, it remains active
# within the `78 YIELD_VALUE` instruction. When the generator resumes
# after the second yield on instruction `80 RESUME`, we cannot simply
# return the control flow to the next instruction. Instead, one must
# check the exception table (or equivalent) to find the next target
# In this case, it says the instruction pointer must be moved to 90.
#
# Without this step, if we let the trace proceed to the next
# instruction, it would follow the control flow where the exception
# raised by `throw()` was handled and swallowed, potentially leading
# to incorrect behavior.
exc_type = type("__InternalThrowException", (Exception,), {})
try:
self._setup_exception(tx, variables.ExceptionVariable(exc_type, ()))
self.next_variable(tx)
except get_dynamo_observed_exception(exc_type):
# We should get back the exception raised before.
pass
except ObservedException:
# Propagate anything else back to the parent caller
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
else:
raise_observed_exception(RuntimeError, tracer)
return retval
super().call_method(tx, name, args, kwargs)
class ContextlibContextManagerLocalGeneratorObjectVariable(
LocalGeneratorObjectVariable
):
"""
.. note::
This is only used when the function is annotated with @contextlib.contextmanager
It is a special case of a generator function as we do not allow return a context manager
from a torch.compile function.
"""
class LocalGeneratorFunctionVariable(BaseUserFunctionVariable):
"""functions that behaves like iterators
.. note::
This is only used when the function is annotated with @contextlib.contextmanager
This is a wrapper around (Nested)UserFunctionVariable
"""
def __init__(self, vt: VariableTracker, **kwargs):
def __init__(
self,
vt: VariableTracker,
*,
generator_cls=LocalGeneratorObjectVariable,
**kwargs,
):
super().__init__(**kwargs)
self.vt = vt
self.inline_tracer = None
self.generator_cls = generator_cls
def __getattr__(self, name):
if name in self.__class__.__dict__.keys():
return getattr(self, name)
return getattr(self.vt, name)
def _build_inline_tracer(self, tx, args, kwargs):
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
return InliningInstructionTranslator.build_inline_tracer(
tx,
self,
args,
kwargs,
)
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
from torch._dynamo.bytecode_transformation import is_generator
assert is_generator(self.vt.get_code())
assert is_generator(self.get_code())
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
inline_tracer = self._build_inline_tracer(tx, args, kwargs)
code = self.vt.get_code()
f_globals = self.vt.get_globals()
self.inline_tracer = InliningInstructionTranslator.build_inline_tracer(
tx,
self,
[*self.self_args(), *args],
kwargs,
stop_generator_on_yield=True,
# calling a generator returns a generator object
return self.generator_cls(
code,
f_globals,
inline_tracer,
source=self.source,
)
return self
def next_variable(self, tx):
from torch._dynamo import exc
class FunctionDecoratedByContextlibContextManagerVariable(
LocalGeneratorFunctionVariable
):
"""
.. note::
tracer = self.inline_tracer
This is only used when the function is annotated with @contextlib.contextmanager
"""
try:
# Hierarchically, tx can be seen as the parent of the inline tracer
# created on call_function. Any exception needs to be propagated to tx
# for Dynamo to behave correctly
with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
return tracer.inline_call_().next_variable(tx)
except exc.ObservedException as e:
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
raise e
def __init__(self, vt, **kwargs):
super().__init__(
vt,
generator_cls=ContextlibContextManagerLocalGeneratorObjectVariable,
**kwargs,
)
def _build_inline_tracer(self, tx, args, kwargs):
# NOTE: This only exists to not break support for context manager when
# config.enable_faithful_generator_behavior = False and
# config.enable_trace_contextlib = True. In case the former is false,
# Dynamo should still be able to trace through @contextmanager functions
tracer = super()._build_inline_tracer(tx, args, kwargs)
assert isinstance(
tracer,
torch._dynamo.symbolic_convert.InliningGeneratorInstructionTranslator,
)
tracer.is_generator_from_ctx_manager = True
return tracer
class UserMethodVariable(UserFunctionVariable):

View file

@ -234,6 +234,7 @@ class SuperVariable(VariableTracker):
class ExceptionVariable(VariableTracker):
# The ExceptionVariable corresponds to the BaseException class in Python
def __init__(self, exc_type, args, **kwargs) -> None:
super().__init__(**kwargs)
self.exc_type = exc_type

View file

@ -493,7 +493,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
):
if not torch._dynamo.config.enable_trace_contextlib:
unimplemented("contextlib.contextmanager")
# Replace UserFunctionVariable by FunctionDecoratedBycontextlibContextManagerVariable
# Wrap UserFunctionVariable in FunctionDecoratedByContextlibContextManagerVariable
# if the function is annotated with @contextlib.contextmanager
# This shouldn't be necessary once generator functions are fully
# supported in dynamo
@ -805,6 +805,11 @@ class UserDefinedObjectVariable(UserDefinedVariable):
# of the cmp_eq polyfill function.
return ConstantVariable.create(self.value is other.value)
if torch._dynamo.config.enable_faithful_generator_behavior and isinstance(
self.value, types.GeneratorType
):
unimplemented("Generator as graph argument is not supported")
# check for methods implemented in C++
if isinstance(method, types.FunctionType):
source = (

View file

@ -1822,7 +1822,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
metadata: ViewAndMutationMeta = fw_metadata # type: ignore[assignment]
maybe_subclass_metadata: Optional[SubclassMeta] = maybe_subclass_meta
num_symints_saved_for_bw = num_symints_saved_for_bw_
_compiled_autograd_should_lift = False
_aot_id = aot_config.aot_id
_lazy_backward_info = lazy_backward_info
@ -1989,7 +1988,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
# https://github.com/pytorch/pytorch/pull/92348/files#r1072962107
class CompiledFunctionBackward(torch.autograd.Function):
# CompiledFunctionBackward is not yet supported in dynamo skipfiles
_compiled_autograd_should_lift = False
_aot_id = aot_config.aot_id
@staticmethod

View file

@ -1,3 +1,8 @@
from torch._higher_order_ops._invoke_quant import (
invoke_quant,
invoke_quant_packed,
InvokeQuant,
)
from torch._higher_order_ops.aoti_call_delegate import aoti_call_delegate
from torch._higher_order_ops.associative_scan import associative_scan
from torch._higher_order_ops.auto_functionalize import (
@ -51,6 +56,9 @@ __all__ = [
"executorch_call_delegate",
"call_torchbind",
"run_const_graph",
"InvokeQuant",
"invoke_quant",
"invoke_quant_packed",
"wrap_with_set_grad_enabled",
"wrap_with_autocast",
"wrap_activation_checkpoint",

View file

@ -0,0 +1,72 @@
# mypy: allow-untyped-defs
# need to fix prim_hop_base type annotations first
import dataclasses
from typing import Optional
import torch
from torch._higher_order_ops.prim_hop_base import FunctionWithNoFreeVars, PrimHOPBase
class InvokeQuantTracer(PrimHOPBase):
def __init__(self) -> None:
super().__init__("invoke_quant_packed")
def __call__(self, subgraph, operands, *, scheme=None, quant_options=None):
subgraph = FunctionWithNoFreeVars(subgraph)
return super().__call__(
subgraph, operands, scheme=scheme, quant_options=quant_options
)
invoke_quant_packed = InvokeQuantTracer()
class InvokeQuantUnpacked(PrimHOPBase):
def __init__(self) -> None:
super().__init__("invoke_quant")
def __call__(self, subgraph, *operands, scheme=None):
return super().__call__(subgraph, operands, scheme=scheme)
def _call_FakeTensorMode(
self, mode, subgraph, operands, scheme: Optional[str] = None, **kwargs
):
# TODO: this should probably route through FakeTensorMode to reuse caching
with mode:
return subgraph(*operands[0], **kwargs)
invoke_quant = InvokeQuantUnpacked()
@dataclasses.dataclass(frozen=True, repr=True)
class InvokeQuant:
"""
Invoke a quantization function that will be preserved as a single operator. Preservation
as a single operator aids in pattern matching and custom lowerings.
The operation appears as:
torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=scheme)
Args:
codegen_low_precision: Use observed subgraph dtypes for codegen instead of
upcasting to fp32. Can improve performance for prologue fusion but
requires careful testing of numerics.
"""
codegen_low_precision: bool = True
def __call__(
self,
*args,
scheme: Optional[str] = None,
**kwargs,
):
if not torch._utils.is_compiling():
return args[0](*args[1], **kwargs)
if scheme is not None:
kwargs["scheme"] = scheme
return invoke_quant_packed(*args, **kwargs, quant_options=self) # type: ignore[call-arg]

View file

@ -1,6 +1,6 @@
import dataclasses
import itertools
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING
import sympy
@ -9,6 +9,7 @@ from torch._inductor import config
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
from torch._inductor.index_propagation import SymPyOps, TypedExpr
from .ops_handler import DefaultHandler
from .virtualized import StoreMode, V
@ -20,7 +21,7 @@ def construct_symbol(count: int, dtype: torch.dtype) -> sympy.Symbol:
return sympy.Symbol(f"unknown_{count}")
class PreservesZeros(SymPyOps):
class PreservesZeros(SymPyOps, DefaultHandler):
"""
For prologue kernels where the loads are masked, does the final store of this kernel preserve
the zeros.
@ -31,41 +32,32 @@ class PreservesZeros(SymPyOps):
self.store_preserves_zeros: Optional[bool] = None
self.dtype_prop = DtypePropagationOpsHandler()
@staticmethod
def load(name: str, index: sympy.Expr) -> TypedExpr:
def load(self, name: str, index: sympy.Expr) -> TypedExpr:
# In prologue fusion, all loads get broadcasted
dtype = V.get_ops_handler().dtype_prop.load(name, index)
dtype = self.dtype_prop.load(name, index)
return TypedExpr(
sympy.Float(0) if dtype.is_floating_point else sympy.Integer(0), dtype
)
@staticmethod
def store(
name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None
self, name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None
) -> None:
self = V.get_ops_handler()
assert isinstance(self, PreservesZeros)
# should only have a single store in prologue
assert self.store_preserves_zeros is None
self.store_preserves_zeros = value.is_constant() and value.expr == 0
@staticmethod
def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr:
self = V.get_ops_handler()
def indirect_indexing(self, *args: Any, **kwargs: Any) -> sympy.Expr:
return construct_symbol(next(self.count), torch.int32)
def __getattr__(self, name: str) -> Callable[..., Any]:
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
from torch._inductor.codegen.common import OpDecompositions
def inner(*args: Any, **kwargs: Any) -> TypedExpr:
if hasattr(OpDecompositions, name):
return getattr(OpDecompositions, name)(*args, **kwargs).value
if hasattr(OpDecompositions, name):
return getattr(OpDecompositions, name)(*args, **kwargs).value
nonlocal self
dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
return TypedExpr(construct_symbol(next(self.count), dtype), dtype)
return inner
dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
return TypedExpr(construct_symbol(next(self.count), dtype), dtype)
def prologue_preserves_zero_mask(prologue: "SchedulerNode") -> bool:
@ -88,7 +80,7 @@ class DTypeContainer:
is_scalar: bool = False
class RecordLowPrecisionOps:
class RecordLowPrecisionOps(DefaultHandler):
def __init__(self) -> None:
self.low_precision_numeric_op = False
self.dtype_prop = DtypePropagationOpsHandler()
@ -97,9 +89,8 @@ class RecordLowPrecisionOps:
"constant",
)
@staticmethod
def load(name: str, index: sympy.Expr) -> DTypeContainer:
return DTypeContainer(V.get_ops_handler().dtype_prop.load(name, index))
def load(self, name: str, index: sympy.Expr) -> DTypeContainer:
return DTypeContainer(self.dtype_prop.load(name, index))
@staticmethod
def store(
@ -111,28 +102,25 @@ class RecordLowPrecisionOps:
def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr:
return sympy.S.Zero
def __getattr__(self, name: str) -> Callable[..., Any]:
def low_prec_float(dtype: torch.dtype) -> bool:
return dtype.is_floating_point and dtype.itemsize < 4
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
out = DTypeContainer(out_dtype, is_scalar=(name == "constant"))
if name == "constant":
out = DTypeContainer(torch.float, is_scalar=True)
def inner(*args: Any, **kwargs: Any) -> DTypeContainer:
out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
out = DTypeContainer(out_dtype, is_scalar=(name == "constant"))
if name == "constant":
out = DTypeContainer(torch.float, is_scalar=True)
uses_low_prec = any(
isinstance(dtype_cont, DTypeContainer) and low_prec_float(dtype_cont.dtype)
for dtype_cont in itertools.chain((out,), args, kwargs.values())
)
uses_low_prec = any(
isinstance(dtype_cont, DTypeContainer)
and low_prec_float(dtype_cont.dtype)
for dtype_cont in itertools.chain((out,), args, kwargs.values())
)
if uses_low_prec and name not in self.non_numeric_ops:
self.low_precision_numeric_op = True
if uses_low_prec and name not in self.non_numeric_ops:
self.low_precision_numeric_op = True
return out
return out
return inner
def low_prec_float(dtype: torch.dtype) -> bool:
return dtype.is_floating_point and dtype.itemsize < 4
def can_codegen_without_upcasts(

View file

@ -1,14 +1,22 @@
import logging
import operator
from functools import partial
from typing import Any, Callable, Union
from typing import Any, Callable, Optional, Union
import sympy
from sympy import Expr
import torch
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
from torch.utils._sympy.value_ranges import (
bound_sympy,
SymPyValueRangeAnalysis,
ValueRanges,
)
from ..utils._sympy.functions import PowByNatural
from ..utils._sympy.numbers import int_oo
from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock
from .ops_handler import DefaultHandler, ReductionType, StoreMode
from .utils import cache_on_self, dominated_nodes
from .virtualized import V
@ -139,3 +147,113 @@ class BoundVars:
# assert bound is None or bound == bound_sympy(expr, self.replacement_vals)
self.replacement_vals[name] = bound
return bound
class ValueRangeAnalysis(SymPyValueRangeAnalysis, DefaultHandler):
def __init__(self) -> None:
self.name = "ValueRangeAnalysis"
boolean_operators = (
"xor",
"logical_and",
"logical_or",
"logical_not",
)
for op in boolean_operators:
setattr(self, op, self.bool_handler)
@staticmethod
def bool_handler(*args: Any, **kwargs: Any) -> ValueRanges[Any]:
# just assuming bools can have both values
return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type]
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
# many ops are unlikely to show up in optimizable indexing compute,
# so we dont have full coverage
return ValueRanges.unknown()
def load(self, name: str, index: sympy.Expr) -> ValueRanges[Any]:
return ValueRanges.unknown()
def store(
self, name: str, index: sympy.Expr, value: Any, mode: StoreMode = None
) -> None:
return
def reduction(
self,
dtype: torch.dtype,
src_dtype: torch.dtype,
reduction_type: ReductionType,
value: Any,
) -> ValueRanges[Any]:
return ValueRanges.unknown()
@classmethod
def index_expr(cls, index: Any, dtype: torch.dtype) -> ValueRanges[Any]:
assert isinstance(index, ValueRanges)
return cls.to_dtype(index, dtype)
@staticmethod
def to_dtype(
x: Any,
dtype: torch.dtype,
src_dtype: Optional[torch.dtype] = None,
use_compute_types: bool = True,
) -> ValueRanges[Any]:
x = ValueRanges.wrap(x)
if dtype == torch.bool:
if x.is_singleton():
return ValueRanges.wrap(x.lower != 0)
elif x.is_bool:
return x
elif 0 not in x:
return ValueRanges.wrap(sympy.true)
else:
return ValueRanges(sympy.false, sympy.true)
def cast(x: Any, dtype: torch.dtype) -> sympy.Expr:
# dtype is int or float
if dtype.is_floating_point:
return sympy.Float(x)
else:
if x in (int_oo, -int_oo):
return x
try:
return sympy.Integer(x)
except TypeError:
# inf cannot be cast to Integer
return x
if x.is_bool:
if x.is_singleton():
val = 1 if x.lower else 0
return ValueRanges.wrap(cast(val, dtype))
else:
return ValueRanges(cast(0, dtype), cast(1, dtype))
else:
# int to float or float to int
return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype))
@staticmethod
def square(x: Any) -> ValueRanges[Any]:
return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2))
@staticmethod
def neg(x: Any) -> ValueRanges[Any]:
return ValueRanges.decreasing_map(x, operator.neg)
# TODO: this is slightly inaccurate because truncdiv operates at integer
# precision, but we're going through float truediv which means we can
# potentially lose precision on the bounds
@classmethod
def truncdiv(cls, a: Any, b: Any) -> ValueRanges[Any]:
x = cls.truediv(a, b)
if x == ValueRanges.unknown():
return x
return cls.trunc(x)
@classmethod
def sub(cls, a: Any, b: Any) -> ValueRanges[Any]:
return cls.add(a, cls.neg(b))

View file

@ -17,8 +17,8 @@ class BlockPatternMatcher:
Matches block indexing expressions.
"""
@staticmethod
def get_subexpr_involving_symbol(expr: Expr, symbol: Symbol) -> Expr:
@classmethod
def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr:
"""
Given a sympy expression, return the subexpression comprised only of terms
involving the specified symbol.
@ -26,6 +26,7 @@ class BlockPatternMatcher:
For example, if `expr` is `x * 5 + x ** 2 + y * 2 + 5`, and `symbol` is `x`,
this returns `x * 5 + x ** 2`.
"""
expr = cls._preprocess(expr)
return sympy.S.Zero + sum(
term for term in sympy.Add.make_args(expr) if symbol in term.free_symbols
)
@ -42,6 +43,11 @@ class BlockPatternMatcher:
numels.appendleft(numel)
return [*numels]
@staticmethod
def _preprocess(expr: Expr) -> Expr:
# Remove any Identity nodes, e.g. expand x + (5 * y) to x + 5 * y.
return expr.expand(identity=True)
@classmethod
def match_mod_div_block_expr(
cls,
@ -54,6 +60,7 @@ class BlockPatternMatcher:
Matches modular indexing expressions, converting them to implied block dimensions and strides.
See triton.py for more information.
"""
index = cls._preprocess(index)
# Pattern match to find the strides and offset.
wild = functools.partial(sympy.Wild, exclude=[index_var])
@ -141,3 +148,21 @@ class BlockPatternMatcher:
)
return dims, strides, block_index_exprs
@classmethod
def match_affine_block_expr(
cls,
index: Expr,
index_var: Symbol,
) -> Optional[Expr]:
"""
Matches simple expressions of the form stride * index, returning the
stride.
"""
index = cls._preprocess(index)
stride = sympy.Wild("stride", exclude=[index_var])
m = index.match(index_var * stride)
if m is None:
return None
return m[stride]

View file

@ -37,11 +37,11 @@ from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from .. import config, metrics
from ..dtype_propagation import DtypePropagationOpsHandler
from ..ops_handler import BasicMathOps
from ..ops_handler import BasicMathOpsMixin, DefaultHandler
from ..utils import (
boolean_ops,
DeferredLineBase,
@ -764,7 +764,7 @@ def _all_in_parens(string: str) -> bool:
return True
class OpOverrides(BasicMathOps, OpDecompositions):
class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
@staticmethod
def paren(string: OpVarT) -> OpVarT:
if (
@ -948,6 +948,16 @@ class OpOverrides(BasicMathOps, OpDecompositions):
f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend"
)
def output(self, *args: OpVarT) -> None:
raise AssertionError(
f"{type(self).__name__}: ops.output should not appear at codegen time"
)
def placeholder(self, index: int) -> OpVarT:
raise AssertionError(
f"{type(self).__name__}: ops.placeholder should not appear at codegen time"
)
@staticmethod
def _unimplemented(name: str) -> Callable[..., OpVarT]:
def unimplemented(self: OpOverrides, *args: Any, **kwargs: Any) -> OpVarT:
@ -1225,12 +1235,6 @@ pointwise_overrides_data: dict[str, OverridesData] = dict(
)
if TYPE_CHECKING:
class _typecheck_OpOverrides(OpOverrides, OpsHandler[str]):
pass # mypy will error if we got any of the signatures wrong
class DeferredLine(DeferredLineBase):
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
@ -2253,84 +2257,84 @@ class KernelTemplate:
raise NotImplementedError
class CSEProxy:
class CSEProxy(DefaultHandler):
name = "CSEProxy"
def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]):
super().__init__()
from ..bounds import ValueRangeAnalysis
self.vr_analysis = ValueRangeAnalysis()
self.kernel = kernel
self.parent_handler = parent_handler
def __getattr__(self, name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
def inner(*args: Any, **kwargs: Any) -> CSEVariable:
bounds = self._bound_variable(name, *args, **kwargs)
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
bounds = self._bound_variable(name, *args, **kwargs)
value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
dtype_handler = DtypePropagationOpsHandler()
value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
dtype_handler = DtypePropagationOpsHandler()
output_idx = 0
output_idx = 0
def do_cse(v: str) -> CSEVariable:
# cpp backend doesnt set current device - TODO: fix
if V.graph.current_device is not None:
device_str = V.graph.get_current_device_or_throw().type
triton_backend = (
config.cpu_backend == "triton"
if device_str == "cpu"
else config.cuda_backend == "triton"
if device_str != "mps"
else False
)
else:
triton_backend = False
# only triton backend tracks dtype currently
if triton_backend:
if name == "masked":
output_dtype = value.dtype
else:
output_dtype = getattr(
dtype_handler,
name,
)(*args, **kwargs)
else:
# cpp backend doesnt track dtype yet
output_dtype = None
csevar = V.kernel.cse.generate(
V.kernel.compute,
v,
bounds=bounds,
dtype=output_dtype,
def do_cse(v: str) -> CSEVariable:
# cpp backend doesnt set current device - TODO: fix
if V.graph.current_device is not None:
device_str = V.graph.get_current_device_or_throw().type
triton_backend = (
config.cpu_backend == "triton"
if device_str == "cpu"
else config.cuda_backend == "triton"
if device_str != "mps"
else False
)
else:
triton_backend = False
nonlocal output_idx
if config.test_configs.runtime_triton_dtype_assert and triton_backend:
from torch._inductor.codegen.triton import triton_type
# only triton backend tracks dtype currently
if triton_backend:
if name == "masked":
output_dtype = value.dtype
else:
output_dtype = getattr(
dtype_handler,
name,
)(*args, **kwargs)
else:
# cpp backend doesnt track dtype yet
output_dtype = None
# we tree_map over the output, so we need to fetch corresponding dtype
if isinstance(output_dtype, (list, tuple)):
output_dtype = output_dtype[output_idx]
csevar = V.kernel.cse.generate(
V.kernel.compute,
v,
bounds=bounds,
dtype=output_dtype,
)
V.kernel.compute.writeline(
f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})"
)
output_idx += 1
nonlocal output_idx
if config.test_configs.runtime_triton_dtype_assert and triton_backend:
from torch._inductor.codegen.triton import triton_type
csevar.update_on_args(name, args, kwargs)
# we tree_map over the output, so we need to fetch corresponding dtype
if isinstance(output_dtype, (list, tuple)):
output_dtype = output_dtype[output_idx]
return csevar
V.kernel.compute.writeline(
f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})"
)
output_idx += 1
return pytree.tree_map(do_cse, value)
csevar.update_on_args(name, args, kwargs)
return inner
return csevar
return pytree.tree_map(do_cse, value)
def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[Any]:
"""
If the variable comes from an FX node, we forward the bound we have already computed
Else, if the variable when codegen'ing another op, we try to compute its bounds
"""
from ..bounds import ValueRangeAnalysis
from ..select_algorithm import TritonTemplateKernel
if isinstance(V.kernel, TritonTemplateKernel):
@ -2568,8 +2572,3 @@ class CSEProxy:
sorter,
sorter_indices,
)
# Use mypy to check protocol implemented correctly
def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
return h

View file

@ -33,7 +33,7 @@ from ..utils import (
sympy_index_symbol,
sympy_subs,
)
from ..virtualized import _ops as ops, OpsHandler, V
from ..virtualized import _ops as ops, V
from .common import (
BackendFeature,
CSEVariable,
@ -563,12 +563,6 @@ class HalideOverrides(OpOverrides):
HalideOverrides._initialize_pointwise_overrides("halide")
if TYPE_CHECKING:
class _typecheck_HalideOverrides(HalideOverrides, OpsHandler[str]):
pass # mypy will error if we got any of the signatures wrong
class HalideCSEVariable(CSEVariable):
undefined_re = re.compile(r"\b(tmp\d+)\[\?\]")

View file

@ -29,7 +29,7 @@ if TYPE_CHECKING:
import sympy
from ..ops_handler import OpsHandler, ReductionType, StoreMode
from ..ops_handler import ReductionType, StoreMode
from ..scheduler import Scheduler, SchedulerNode
from .common import OpVarT
@ -367,12 +367,6 @@ class MetalOverrides(OpOverrides):
MetalOverrides._initialize_pointwise_overrides("mps")
if TYPE_CHECKING:
class _typecheck_MetalOverrides(MetalOverrides, OpsHandler[Any]):
pass # mypy will error if we got any of the signatures wrong
class MetalKernel(SIMDKernel):
overrides = MetalOverrides # type: ignore[assignment]
suffix = ";"

View file

@ -31,6 +31,7 @@ from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_ty
from ...utils._sympy.value_ranges import ValueRanges
from .. import config, ir, metrics
from ..codecache import code_hash, get_path, PyCodeCache
from ..ops_handler import DefaultHandler
from ..runtime.benchmarking import benchmarker
from ..runtime.hints import (
AutotuneHint,
@ -60,7 +61,7 @@ from ..utils import (
triton_version_uses_attrs_dict,
upcast_compute_type,
)
from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V
from ..virtualized import _ops as ops, ReductionType, StoreMode, V
from ..wrapper_benchmark import get_kernel_category_by_source_code
from .block_analysis import BlockPatternMatcher
from .common import (
@ -1427,12 +1428,6 @@ class TritonKernelOverrides(TritonOverrides):
return (mantissa, exponent)
if TYPE_CHECKING:
class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]):
pass # mypy will error if we got any of the signatures wrong
class HelperFunctions:
"""An ordered set of helper functions."""
@ -1795,7 +1790,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
and self.index_dtype == "tl.int32"
):
def match_strided_block(
def match_affine_block(
index: sympy.Expr, range_tree: IterationRangesRoot
) -> Optional[BlockParameters]:
"""
@ -1804,16 +1799,16 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
This implies stride (s,), and shape (XBLOCK,).
"""
symbol = range_tree.symbol()
stride = sympy.Wild("stride", exclude=[symbol])
m = index.match(symbol * stride)
if m is None:
stride = BlockPatternMatcher.match_affine_block_expr(
index, range_tree.symbol()
)
if stride is None:
return None
return BlockParameters(
shape=[range_tree.numel],
block_shape=[TritonSymbols.get_block_size(range_tree)],
strides=[m[stride]],
strides=[stride],
offsets=[TritonSymbols.get_block_offset(range_tree)],
)
@ -1922,7 +1917,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
Match a block indexing subexpression involving a single range tree.
"""
for match_func in (
match_strided_block,
match_affine_block,
match_mod_div_block,
):
match = match_func(expr, range_tree)
@ -2872,24 +2867,23 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
dtype_handler = DtypePropagationOpsHandler()
class CSEProxy:
def __getattr__(self, name: str) -> Callable[..., CSEVariable]:
def inner(*args, **kwargs):
nonlocal helper_name
helper_name += f"_{name}"
class CSEProxy(DefaultHandler):
def _default(
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> Any:
nonlocal helper_name
helper_name += f"_{name}"
output_dtype = getattr(
dtype_handler,
name,
)(*args, **kwargs)
output_dtype = getattr(
dtype_handler,
name,
)(*args, **kwargs)
return cse.generate(
helper,
getattr(overrides, name)(*args, **kwargs),
dtype=output_dtype,
)
return inner
return cse.generate(
helper,
getattr(overrides, name)(*args, **kwargs),
dtype=output_dtype,
)
with helper.indent(), V.set_ops_handler(CSEProxy()):
outputs = fn(*args)

View file

@ -4,7 +4,7 @@ import itertools
import logging
import re
from collections.abc import Sequence
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union
from unittest.mock import patch
import sympy
@ -15,6 +15,7 @@ from torch.utils._ordered_set import OrderedSet
from ..utils._sympy.symbol import make_symbol, SymT
from .codegen.common import index_prevent_reordering
from .ops_handler import DefaultHandler
from .utils import (
get_dtype_size,
reduction_num_outputs,
@ -23,7 +24,7 @@ from .utils import (
sympy_subs,
VarRanges,
)
from .virtualized import OpsHandler, ReductionType, V
from .virtualized import ReductionType, V
T = TypeVar("T")
@ -737,19 +738,16 @@ def canonicalization_prefix() -> str:
# ops handler which computes all the free unbacked symbols for an IR
class FreeUnbackedSymbolsOpsHandler:
class FreeUnbackedSymbolsOpsHandler(DefaultHandler):
symbols: OrderedSet[sympy.Symbol]
def __init__(self) -> None:
self.symbols = OrderedSet()
def __getattr__(self, name: str) -> Callable[..., Any]:
def inner(*args: Sequence[Any], **kwargs: Dict[Any, Any]) -> None:
for a in itertools.chain(args, kwargs.values()):
if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
self.symbols |= free_unbacked_symbols(a)
return inner
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
for a in itertools.chain(args, kwargs.values()):
if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
self.symbols |= free_unbacked_symbols(a)
def indirect_indexing(
self,
@ -791,12 +789,6 @@ class FreeUnbackedSymbolsOpsHandler:
body()
def _typecheck_FreeUnbackedSymbolsOpsHandler(
h: FreeUnbackedSymbolsOpsHandler,
) -> OpsHandler[None]:
return h
def extract_free_unbacked_symbols(
fn: Callable[..., Any],
index: Sequence[sympy.Expr],

View file

@ -368,6 +368,16 @@ class DtypePropagationOpsHandler:
) -> None:
return None
def output(self, *args: DTypeArg) -> None:
raise AssertionError(
f"{type(self).__name__}: ops.output should not appear here"
)
def placeholder(self, index: int) -> torch.dtype:
raise AssertionError(
f"{type(self).__name__}: ops.placeholder should not appear here"
)
if TYPE_CHECKING:

View file

@ -2,6 +2,7 @@
import functools
import itertools
import logging
import operator
import typing
from collections import Counter
from typing import Any, Union
@ -442,6 +443,81 @@ def constant_fold_uniform_value(gm: torch.fx.GraphModule):
remove_redundant_views(gm)
def canonicalize_quant_mapping(gm: torch.fx.GraphModule):
"""
torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'quant_invoke_0_0', (arg0_1, arg1_1));
->
torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4');
"""
graph = gm.graph
invoke_quant_invocations = graph.find_nodes(
op="call_function", target=torch.ops.higher_order.invoke_quant_packed
)
for invoke_quant in invoke_quant_invocations:
kwargs = dict(invoke_quant.kwargs)
quant_options_node = kwargs.pop("quant_options", None)
if quant_options_node is not None:
assert isinstance(quant_options_node, torch.fx.Node)
quant_options = torch._higher_order_ops.InvokeQuant(
*invoke_quant.kwargs["quant_options"].args,
**invoke_quant.kwargs["quant_options"].kwargs,
)
else:
quant_options = None
subgraph, args = invoke_quant.args
with gm.graph.inserting_before(invoke_quant):
invoke_quant_replacement = graph.call_function(
torch._higher_order_ops.invoke_quant,
(subgraph, *args),
kwargs,
)
invoke_quant_replacement.meta.update(subgraph.meta)
invoke_quant_replacement.meta["quant_options"] = quant_options
invoke_quant.replace_all_uses_with(invoke_quant_replacement)
graph.erase_node(invoke_quant)
if quant_options_node and len(quant_options_node.users) == 0:
graph.erase_node(quant_options_node)
first_user = next(iter(invoke_quant_replacement.users))
if (
len(invoke_quant_replacement.users) == 1
and len(subgraph.users) == 1
and first_user.target == operator.getitem
and first_user.args[1] == 0
):
subgraph_graph = getattr(gm, subgraph.target)
output_node = torch._inductor.utils.output_node(subgraph_graph)
assert (
isinstance(output_node.args[0], (list, tuple))
and len(output_node.args[0]) == 1
)
unpacked_output = output_node.args[0][0]
output_node.args = (unpacked_output,)
if "val" in output_node.meta:
output_node.meta["val"] = output_node.meta["val"][0]
subgraph_graph.recompile()
invoke_quant_replacement.meta.update(first_user.meta)
first_user.replace_all_uses_with(invoke_quant_replacement)
graph.erase_node(first_user)
def canonicalize_aten_ir_passes(gm: torch.fx.GraphModule):
"""
Canonicalization passes that will run immediately after aot autograd
tracing. Thsis must be run before all other graph passes.
"""
canonicalize_quant_mapping(gm)
def joint_graph_passes(graph: torch.fx.GraphModule):
"""
Run FX transformations on the joint forwards+backwards graph.
@ -454,6 +530,15 @@ def joint_graph_passes(graph: torch.fx.GraphModule):
lazy_init()
count = 0
# must occur before other passes
canonicalize_aten_ir_passes(graph)
if config.joint_custom_pre_pass is not None:
GraphTransformObserver(graph, "joint_custom_pre_pass").apply_graph_pass(
config.joint_custom_pre_pass
)
count += 1
from .post_grad import remove_noop_ops
GraphTransformObserver(graph, "remove_noop_ops").apply_graph_pass(remove_noop_ops)

View file

@ -21,8 +21,9 @@ SymPy expressions yet, despite sympy.Min and sympy.Max existing.
"""
import itertools
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, overload, Union
from typing import Any, Literal, Optional, overload, Union
from typing_extensions import TypeAlias
import sympy
@ -32,6 +33,7 @@ from torch._prims_common import dtype_to_type, is_integer_dtype
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from .ops_handler import DefaultHandler
from .sizevars import evaluate_expr
from .utils import generate_assert
from .virtualized import V
@ -185,7 +187,7 @@ class IndexPropVar:
IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]]
class IndexPropagation:
class IndexPropagation(DefaultHandler):
"""Ops wrapper that tries to propagate constant and index_expr values through the computation.
This aims to maximize the compile time simplification possible, and convert
@ -247,19 +249,19 @@ class IndexPropagation:
def fallback(
self,
name: Literal["indirect_indexing"],
args: tuple[Any, ...],
args: Sequence[Any],
kwargs: dict[str, Any],
) -> IndexPropVar:
...
@overload
def fallback(
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
) -> IndexPropResult:
...
def fallback(
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
) -> IndexPropResult:
# Fallback to the wrapped handler
new_args = [self.unwrap(a) for a in args]
@ -267,7 +269,7 @@ class IndexPropagation:
return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
def propagate_sympy(
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
) -> IndexPropResult:
# Build a new SymPy expression from this ops call
def unwrap(a: Union[Any, IndexPropVar]) -> Any:
@ -288,22 +290,19 @@ class IndexPropagation:
return self.fallback(name, args, kwargs)
return IndexPropVar.new_symbolic(new_expr)
def __getattr__(self, name: str) -> Callable[..., IndexPropResult]:
def inner(*args: Any, **kwargs: Any) -> IndexPropResult:
if not hasattr(SymPyOps, name):
return self.fallback(name, args, kwargs)
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
if not hasattr(SymPyOps, name):
return self.fallback(name, args, kwargs)
var_arguments = [
a
for a in itertools.chain(args, kwargs.values())
if isinstance(a, IndexPropVar)
]
if not all(v.is_symbolic for v in var_arguments):
return self.fallback(name, args, kwargs)
var_arguments = [
a
for a in itertools.chain(args, kwargs.values())
if isinstance(a, IndexPropVar)
]
if not all(v.is_symbolic for v in var_arguments):
return self.fallback(name, args, kwargs)
return self.propagate_sympy(name, args, kwargs)
return inner
return self.propagate_sympy(name, args, kwargs)
def statically_true(self, e):
"""

View file

@ -75,7 +75,7 @@ from .dependencies import (
var_builder,
)
from .loop_body import LoopBody
from .ops_handler import OpCounterCSE, OpCountResult
from .ops_handler import OpCounterCSE, OpCountResult, ReductionType, StoreMode
from .runtime.benchmarking import benchmarker
from .runtime.hints import DeviceProperties, ReductionHint
from .utils import (
@ -916,9 +916,9 @@ class Pointwise(Loops):
output_name: Optional[str],
indexer: Callable[[Sequence[Expr]], Never],
vars: Sequence[Expr],
) -> OpsValue:
) -> None:
loader = self.make_loader()
return ops.store(output_name, indexer(vars), loader(vars))
return ops.store(output_name or "unnamed", indexer(vars), loader(vars))
def constant_to_device(self, device: torch.device) -> IRNode:
"""Move this to a given device. Requires that all reads are to constants."""
@ -932,7 +932,7 @@ class Pointwise(Loops):
@ir_dataclass
class Scatter(Pointwise):
output_indexer: Callable[[Sequence[Expr]], Expr]
scatter_mode: Optional[str] = None
scatter_mode: StoreMode = None
def constant_to_device(self, device: torch.device) -> IRNode:
"""Move this to a given device. Requires that all reads are to constants."""
@ -952,8 +952,10 @@ class Scatter(Pointwise):
output_name: Optional[str],
indexer: Callable[[Sequence[Expr]], Never],
vars: Sequence[Expr],
) -> OpsValue:
) -> None:
loader = self.make_loader()
if output_name is None:
output_name = "unnamed"
return ops.store(
output_name,
indexer(self.output_indexer(vars)),
@ -1038,7 +1040,7 @@ def get_reduction_combine_fn(
@ir_dataclass
class Reduction(Loops):
reduction_ranges: Sequence[_IntLike]
reduction_type: str
reduction_type: ReductionType
# self.dtype represents the dst dtype
src_dtype: torch.dtype
reduction_hint: ReductionHint
@ -1065,14 +1067,14 @@ class Reduction(Loops):
indexer: Callable[[Sequence[Expr]], Never],
vars: Sequence[Expr],
reduction_vars: Sequence[Symbol],
) -> OpsValue:
) -> None:
value = ops.reduction(
self.dtype,
self.src_dtype,
self.reduction_type,
self.inner_fn(vars, reduction_vars),
)
return ops.store_reduction(output_name, indexer(vars), value)
return ops.store_reduction(output_name or "unnamed", indexer(vars), value)
def index_length(self) -> int:
return len(self.ranges) + len(self.reduction_ranges)
@ -1110,7 +1112,7 @@ class Reduction(Loops):
inner_fn: Callable[..., OpsValue],
ranges: Sequence[_IntLike],
reduction_ranges: Sequence[_IntLike],
reduction_type: str,
reduction_type: Union[ReductionType, Literal["scan"]],
reduction_numel: Expr,
input_node: Optional[IRNode] = None,
) -> tuple[ReductionHint, _IntLike]:
@ -1196,7 +1198,7 @@ class Reduction(Loops):
inner_fn=inner_fn,
ranges=ranges,
reduction_ranges=reduction_ranges,
reduction_type=reduction_type,
reduction_type=reduction_type if reduction_type != "scan" else "sum",
src_dtype=src_dtype,
reduction_hint=ReductionHint.DEFAULT,
)
@ -1323,7 +1325,7 @@ class Reduction(Loops):
inner_fn: Callable[..., Any],
ranges: Sequence[Expr],
reduction_ranges: Sequence[Expr],
reduction_type: str,
reduction_type: ReductionType,
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
input_node: Optional[IRNode] = None,
) -> TensorBox:
@ -1593,7 +1595,7 @@ class Reduction(Loops):
original_reduction_ranges: Sequence[Expr],
new_ranges: list[Expr],
new_reduction_ranges: list[Integer],
reduction_type: str,
reduction_type: ReductionType,
split: _IntLike,
reduction_hint: ReductionHint,
) -> TensorBox:
@ -1655,7 +1657,7 @@ class Reduction(Loops):
inner_fn: Callable[..., Any],
ranges: Sequence[Expr],
reduction_ranges: Sequence[Expr],
reduction_type: str,
reduction_type: ReductionType,
split: _IntLike,
reduction_hint: ReductionHint,
) -> TensorBox:
@ -1696,7 +1698,7 @@ class Reduction(Loops):
original_reduction_ranges: Sequence[Expr],
new_ranges: list[Integer],
new_reduction_ranges: list[Integer],
reduction_type: str,
reduction_type: ReductionType,
reduction_hint: ReductionHint,
) -> TensorBox:
"""
@ -1735,7 +1737,7 @@ class WelfordReduction(Reduction):
inner_fns: Sequence[Callable[[Sequence[Expr], Sequence[Expr]], OpsValue]],
ranges: Sequence[Integer],
reduction_ranges: Sequence[Integer],
reduction_type: str,
reduction_type: ReductionType,
reduction_hint: ReductionHint,
output_index: int,
) -> None:
@ -1767,7 +1769,7 @@ class WelfordReduction(Reduction):
indexer: Callable[[Sequence[Expr]], Never],
vars: Sequence[Expr],
reduction_vars: Sequence[Symbol],
) -> OpsValue:
) -> None:
values = ops.reduction(
self.dtype,
self.src_dtype,
@ -1775,7 +1777,7 @@ class WelfordReduction(Reduction):
self.inner_fn(vars, reduction_vars),
)
value = values[self.output_index]
return ops.store_reduction(output_name, indexer(vars), value)
return ops.store_reduction(output_name or "unnamed", indexer(vars), value)
@classmethod
def create( # type: ignore[override]
@ -1785,7 +1787,7 @@ class WelfordReduction(Reduction):
inner_fns: Sequence[Callable[..., Any]],
ranges: list[Integer],
reduction_ranges: list[Integer],
reduction_type: str,
reduction_type: ReductionType,
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
) -> Sequence[TensorBox]:
assert reduction_type in ("welford_reduce", "welford_combine")
@ -1911,7 +1913,7 @@ class WelfordReduction(Reduction):
inner_fns: Sequence[Callable[..., Any]],
ranges: list[Integer],
reduction_ranges: list[Integer],
reduction_type: str,
reduction_type: ReductionType,
split: _IntLike,
reduction_hint: ReductionHint,
) -> Sequence[TensorBox]:
@ -2031,11 +2033,13 @@ class Scan(Loops):
indexer: Callable[[Sequence[_IntLike]], Never],
vars: Sequence[Expr],
scan_vars: Sequence[Symbol],
) -> OpsValue:
) -> None:
idx = self.reindex(vars, scan_vars)
values = [inner_fn(idx) for inner_fn in self.inner_fns]
values = tuple(inner_fn(idx) for inner_fn in self.inner_fns)
result = ops.scan(self.dtypes, self.combine_fn, values)
return ops.store(output_name, indexer(idx), result[self.output_index])
return ops.store(
output_name or "unnamed", indexer(idx), result[self.output_index]
)
def get_reduction_type(self) -> Optional[str]:
# return self.scan_op
@ -2229,11 +2233,13 @@ class Sort(Loops):
indexer: Callable[[Sequence[Expr]], Expr],
vars: Sequence[Expr],
reduction_vars: Sequence[Expr],
) -> OpsValue:
) -> None:
idx = self.reindex(vars, reduction_vars)
values = [inner_fn(idx) for inner_fn in self.inner_fns]
values = tuple(inner_fn(idx) for inner_fn in self.inner_fns)
result = ops.sort(self.dtypes, values, self.stable, self.descending)
return ops.store(output_name, indexer(idx), result[self.output_index])
return ops.store(
output_name or "unnamed", indexer(idx), result[self.output_index]
)
def get_reduction_type(self) -> Optional[str]:
return "sort"
@ -3790,7 +3796,7 @@ class Buffer(IRNode):
def loader(index): # type: ignore[no-untyped-def]
indexer = self.make_indexer()
return ops.load(self.name, indexer(index))
return ops.load(self.name or "unnamed", indexer(index))
return loader
@ -3983,7 +3989,7 @@ class ComputedBuffer(OperationBuffer):
return self.data.make_loader()
return super().make_loader()
def get_store_function(self) -> Callable[..., OpsValue]:
def get_store_function(self) -> Callable[..., None]:
indexer = self.get_layout().as_fixed().make_indexer()
if isinstance(self.data, (Reduction, Scan, Sort)):
return partial(self.data.store_reduction, self.name, indexer)

View file

@ -17,6 +17,7 @@ from torch.utils._sympy.symbol import SymT
from . import config, dependencies
from .codegen.common import index_prevent_reordering
from .ops_handler import DefaultHandler, OpsHandler, WrapperHandler
from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs
from .virtualized import ops, V
@ -439,179 +440,13 @@ class LoopBodyBlock:
def __init__(self, body: LoopBody, fn: Callable[..., Any], args: list[Any]):
self.body = body
def add_index(expr: sympy.Expr, mtype: MemoryUsageType, **kwargs):
return tracer.create_proxy(
"call_module",
"get_index",
(body.add_index_expr(expr, mtype, **kwargs),),
{},
)
class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined]
self.name = "CaptureIndexing"
def load(self, name: str, index: sympy.Expr):
index = add_index(index, MemoryUsageType.LOAD, buffer_name=name)
return self._inner.load(name, index)
def load_seed(self, name: str, index: int):
assert isinstance(index, int)
body.add_index_expr(
sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name
)
return self._inner.load_seed(name, index)
def store(self, name, index, value, mode=None):
index = add_index(
index, MemoryUsageType.STORE, buffer_name=name, mode=mode
)
return self._inner.store(name, index, value, mode)
def store_reduction(self, name, index, value):
index = add_index(
index, MemoryUsageType.STORE_REDUCTION, buffer_name=name
)
return self._inner.store_reduction(name, index, value)
def reduction(self, dtype, src_dtype, reduction_type, value):
result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
if "welford" in reduction_type:
return tuple(result[i] for i in range(3))
return result
def index_expr(self, index, dtype):
if isinstance(index, (int, sympy.Integer)):
return self._inner.constant(int(index), dtype)
index = add_index(index, MemoryUsageType.INDEX_EXPR)
return self._inner.index_expr(index, dtype)
def check_bounds(self, index, size, lower, upper):
index = add_index(index, MemoryUsageType.CHECK_BOUNDS)
size = add_index(size, MemoryUsageType.CHECK_BOUNDS)
return self._inner.check_bounds(index, size, lower, upper)
def bucketize(
self,
values: T,
boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
boundary_indices: T,
indexing_dtype: torch.dtype,
right: bool,
sorter: Optional[tuple[str, sympy.Expr]] = None,
sorter_indices: Optional[T] = None,
) -> T:
"""
See [Note: Inductor bucketize op]
"""
boundaries = (
boundaries[0],
add_index(
boundaries[1],
MemoryUsageType.BUCKETIZE,
buffer_name=boundaries[0],
),
add_index(
boundaries[2],
MemoryUsageType.BUCKETIZE,
buffer_name=boundaries[0],
),
add_index(
boundaries[3],
MemoryUsageType.BUCKETIZE,
buffer_name=boundaries[0],
),
)
if sorter is not None:
sorter = (
sorter[0],
add_index(
sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0]
),
)
return self._inner.bucketize(
values,
boundaries,
boundary_indices,
indexing_dtype,
right,
sorter,
sorter_indices,
)
@staticmethod
def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy):
"""
Recursively capture the masked out body in another LoopBodyBlock
"""
name = self.body.add_submodule(None, "masked_subblock")
self.body.submodules[name] = self.body.bind_masked_shim(name)
self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, [])
return tracer.create_proxy(
"call_module", name, (mask_proxy, other_proxy), {}
)
@staticmethod
def scan(
dtype_proxy,
combine_fn: Callable[
[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]
],
value_proxy,
):
shim = self.body.bind_scan_shim(combine_fn)
name = self.body.add_submodule(shim, "scan")
result = tracer.create_proxy(
"call_module",
name,
(dtype_proxy, value_proxy),
{},
)
# Proxies are iterable, but some methods expect tuples/lists
return tuple(result[i] for i in range(len(value_proxy)))
def sort(self, dtypes, values, stable, descending):
result = self._inner.sort(dtypes, values, stable, descending)
# Proxies are iterable, but some methods expect tuples/lists
return tuple(result[i] for i in range(len(values)))
def frexp(self, value_proxy):
result = self._inner.frexp(value_proxy)
# Proxies are iterable, but some methods expect tuples/lists
return (result[0], result[1])
@staticmethod
def indirect_indexing(index_proxy, size, check=True, wrap_neg=True):
"""
Flow data from tensors into indexing formulas.
Introduce a call_module to update the indexing.
"""
var = self.body.add_indirect(size)
set_indirect = self.body.bind_set_indirect_shim(
var, size, check, wrap_neg
)
tracer.create_proxy(
"call_module",
self.body.add_submodule(set_indirect, f"set_{var}"),
(index_proxy,),
{},
)
return var
@staticmethod
def output(result):
tracer.create_proxy("output", "output", (result,), {})
tracer = LightTracer()
proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
from .index_propagation import IndexPropagation
from .sizevars import SimplifyIndexing
handler: Any = CountOps(
SimplifyIndexing(CaptureIndexing(proxy_ops), self.body.var_ranges),
CaptureIndexing(proxy_ops, body, tracer),
body.op_counts,
)
if config.constant_and_index_propagation:
@ -653,11 +488,187 @@ class LoopBodyBlock:
return copy
class CountOps:
def __init__(self, inner: Any, counts: collections.Counter[str]):
class CountOps(DefaultHandler):
def __init__(self, inner: OpsHandler[Any], counts: collections.Counter[str]):
self._inner = inner
self._counts = counts
def __getattr__(self, name):
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
self._counts[name] += 1
return getattr(self._inner, name)
return getattr(self._inner, name)(*args, **kwargs)
class CaptureIndexing(WrapperHandler):
name = "CaptureIndexing"
def __init__(
self,
inner: OpsHandler[Any],
body: LoopBody,
tracer: LightTracer,
):
super().__init__(inner)
self.body = body
self.tracer = tracer
def _add_index(self, expr: sympy.Expr, mtype: MemoryUsageType, **kwargs: Any):
return self.tracer.create_proxy(
"call_module",
"get_index",
(self.body.add_index_expr(expr, mtype, **kwargs),),
{},
)
def _simplify(self, expr: sympy.Expr) -> sympy.Expr:
return V.graph.sizevars.simplify_with_ranges(expr, self.body.var_ranges)
def load(self, name: str, index: sympy.Expr):
index = self._simplify(index)
index = self._add_index(index, MemoryUsageType.LOAD, buffer_name=name)
return self._inner.load(name, index)
def load_seed(self, name: str, index: int):
assert isinstance(index, int)
self.body.add_index_expr(
sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name
)
return self._inner.load_seed(name, index)
def store(self, name, index, value, mode=None):
index = self._simplify(index)
index = self._add_index(
index, MemoryUsageType.STORE, buffer_name=name, mode=mode
)
return self._inner.store(name, index, value, mode)
def store_reduction(self, name, index, value):
index = self._simplify(index)
index = self._add_index(
index, MemoryUsageType.STORE_REDUCTION, buffer_name=name
)
return self._inner.store_reduction(name, index, value)
def reduction(self, dtype, src_dtype, reduction_type, value):
result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
if "welford" in reduction_type:
return tuple(result[i] for i in range(3))
return result
def index_expr(self, index, dtype):
index = self._simplify(index)
if isinstance(index, (int, sympy.Integer)):
return self._inner.constant(int(index), dtype)
index = self._add_index(index, MemoryUsageType.INDEX_EXPR)
return self._inner.index_expr(index, dtype)
def check_bounds(self, index, size, lower, upper):
index = self._simplify(index)
index = self._add_index(index, MemoryUsageType.CHECK_BOUNDS)
size = self._add_index(size, MemoryUsageType.CHECK_BOUNDS)
return self._inner.check_bounds(index, size, lower, upper)
def bucketize(
self,
values: T,
boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
boundary_indices: T,
indexing_dtype: torch.dtype,
right: bool,
sorter: Optional[tuple[str, sympy.Expr]] = None,
sorter_indices: Optional[T] = None,
) -> T:
"""
See [Note: Inductor bucketize op]
"""
boundaries = (
boundaries[0],
self._add_index(
boundaries[1],
MemoryUsageType.BUCKETIZE,
buffer_name=boundaries[0],
),
self._add_index(
boundaries[2],
MemoryUsageType.BUCKETIZE,
buffer_name=boundaries[0],
),
self._add_index(
boundaries[3],
MemoryUsageType.BUCKETIZE,
buffer_name=boundaries[0],
),
)
if sorter is not None:
sorter = (
sorter[0],
self._add_index(
sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0]
),
)
return self._inner.bucketize(
values,
boundaries,
boundary_indices,
indexing_dtype,
right,
sorter,
sorter_indices,
)
def masked(self, mask_proxy, masked_body: Callable[..., Any], other_proxy):
"""
Recursively capture the masked out body in another LoopBodyBlock
"""
name = self.body.add_submodule(None, "masked_subblock")
self.body.submodules[name] = self.body.bind_masked_shim(name)
self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, [])
return self.tracer.create_proxy(
"call_module", name, (mask_proxy, other_proxy), {}
)
def scan(
self,
dtype_proxy,
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
value_proxy,
):
shim = self.body.bind_scan_shim(combine_fn)
name = self.body.add_submodule(shim, "scan")
result = self.tracer.create_proxy(
"call_module",
name,
(dtype_proxy, value_proxy),
{},
)
# Proxies are iterable, but some methods expect tuples/lists
return tuple(result[i] for i in range(len(value_proxy)))
def sort(self, dtypes, values, stable, descending):
result = self._inner.sort(dtypes, values, stable, descending)
# Proxies are iterable, but some methods expect tuples/lists
return tuple(result[i] for i in range(len(values)))
def frexp(self, value_proxy):
result = self._inner.frexp(value_proxy)
# Proxies are iterable, but some methods expect tuples/lists
return (result[0], result[1])
def indirect_indexing(self, index_proxy, size, check=True, wrap_neg=True):
"""
Flow data from tensors into indexing formulas.
Introduce a call_module to update the indexing.
"""
var = self.body.add_indirect(size)
set_indirect = self.body.bind_set_indirect_shim(var, size, check, wrap_neg)
self.tracer.create_proxy(
"call_module",
self.body.add_submodule(set_indirect, f"set_{var}"),
(index_proxy,),
{},
)
return var
def output(self, *result):
self.tracer.create_proxy("output", "output", result, {})

View file

@ -12,7 +12,7 @@ import os
import warnings
from collections import defaultdict
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import ParamSpec
from unittest.mock import patch
@ -81,6 +81,10 @@ from .utils import (
from .virtualized import ops, V
if TYPE_CHECKING:
from .ops_handler import ReductionType
_T = TypeVar("_T")
_P = ParamSpec("_P")
@ -5633,7 +5637,7 @@ def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype):
)
def make_reduction(reduction_type: str, override_return_dtype=None):
def make_reduction(reduction_type: ReductionType, override_return_dtype=None):
def inner(x, axis=None, keepdims=False, *, dtype=None):
kwargs = _make_reduction_inner(
x,
@ -6750,6 +6754,23 @@ def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, operands):
return list(map(TensorBox.create, result))
@register_lowering(torch._higher_order_ops.invoke_quant, type_promotion_kind=None)
def invoke_quant_tracer(subgraph_fn: ir.Subgraph, *operands, scheme=None):
output = None
for i, node in enumerate(subgraph_fn.graph_module.graph.nodes):
if node.op == "placeholder":
V.graph.env[node] = operands[i]
continue
# todo getattr
elif node.op == "output":
args, kwargs = V.graph.fetch_args_kwargs_from_env(node)
output = torch.fx.Interpreter.output(V.graph, node, args, kwargs)
else:
V.graph.env[node] = V.graph.run_node(node)
return output
@register_lowering(associative_scan_op, type_promotion_kind=None)
def associative_scan(combine_fn: ir.Subgraph, xs):
from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph

File diff suppressed because it is too large Load diff

View file

@ -564,10 +564,6 @@ class CompiledFxGraph(OutputCode):
return artifact_path
def _typecheck_CompiledFxGraph(h: CompiledFxGraph) -> OutputCode:
return h
@dataclasses.dataclass
class CompiledAOTI(OutputCode):
"""
@ -591,10 +587,6 @@ class CompiledAOTI(OutputCode):
pass
def _typecheck_CompiledAOTI(h: CompiledAOTI) -> OutputCode:
return h
@dataclasses.dataclass
class MockFXGraphCacheOutput(OutputCode):
gm: Any = None

View file

@ -742,7 +742,7 @@ class TritonTemplateKernel(TritonKernel):
template_mask = self.template_mask
class StoreOutputSubstitution(V.WrapperHandler): # type: ignore[name-defined]
self.name = name
name = "StoreOutputSubstitution"
def store(
self,

View file

@ -99,6 +99,8 @@ class SizeVarAllocator:
if result is None:
result = self._simplify_with_ranges(expr, var_ranges)
cache[key] = result
if result != expr:
cache[(result, *var_ranges.items())] = result
return result
return simplify_with_ranges

View file

@ -135,7 +135,7 @@ class InputDescriptor:
device: torch.device
class TracingOpsHandler(WrapperHandler[T]):
class TracingOpsHandler(WrapperHandler):
def __init__(self, tracer: torch.fx.Tracer, num_inputs: int) -> None:
parent = tracer.create_proxy("placeholder", "ops", (), {})
super().__init__(parent)
@ -149,8 +149,8 @@ class TracingOpsHandler(WrapperHandler[T]):
def placeholder(self, idx: int) -> torch.fx.Proxy:
return self.placeholders[idx]
def output(self, *args: tuple[object]) -> torch.fx.Node:
return self.tracer.create_node(
def output(self, *args: tuple[object]) -> None:
self.tracer.create_node(
"output", "output", (tuple(self.tracer.create_arg(a) for a in args),), {}
)

View file

@ -59,11 +59,12 @@ from __future__ import annotations
from contextlib import AbstractContextManager, contextmanager
from threading import local
from typing import Any, Callable, Generic, TYPE_CHECKING, TypeVar, Union
from typing import Any, Callable, cast, Generic, TYPE_CHECKING, TypeVar, Union
from torch.utils._ordered_set import OrderedSet
from .ops_handler import ( # noqa: F401
DefaultHandler,
KernelFormatterHandler,
MockHandler,
OpsHandler,
@ -154,7 +155,9 @@ class NullKernelHandler(NullHandler):
self.index_dtype = "tl.int64"
_ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler)
_ops: Virtualized[OpsHandler[Any]] = Virtualized(
"ops", cast(type[OpsHandler[Any]], MockHandler)
)
_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler)
_real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler)
_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler)
@ -272,18 +275,15 @@ class OpsValue:
return ops.bitwise_left_shift(self, n)
class OpsWrapper:
class OpsWrapper(DefaultHandler):
"""This wraps any returned IR values into an `OpsValue` instance, so that we
can overload the magic methods for writing mathematical expressions fluently.
"""
def __getattr__(self, name):
def inner(*args, **kwargs):
new_args = [OpsWrapper._unwrap(a) for a in args]
new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
return inner
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
new_args = [OpsWrapper._unwrap(a) for a in args]
new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
@staticmethod
def _unwrap(x):
@ -306,7 +306,7 @@ class OpsWrapper:
return _ops.indirect_indexing(index, size, check, wrap_neg)
ops = OpsWrapper()
ops: OpsHandler[Any] = OpsWrapper()
class _V:
@ -314,8 +314,10 @@ class _V:
KernelFormatterHandler = KernelFormatterHandler
WrapperHandler = WrapperHandler
set_ops_handler: Callable[[Any], Any] = _ops._set_handler
get_ops_handler: Callable[[], Any] = _ops._get_handler
set_ops_handler: Callable[
[OpsHandler[Any]], AbstractContextManager[None]
] = _ops._set_handler
get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler
set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler
get_real_inputs: Callable[[], Any] = _real_inputs._get_handler

View file

@ -11100,8 +11100,8 @@ are designed to work with this function. See the examples below.
Args:
{input}
indices (tensor): the indices into :attr:`input`. Must have long dtype.
dim (int, optional): dimension to select along.
indices (LongTensor): the indices into :attr:`input`. Must have long dtype.
dim (int, optional): dimension to select along. Default: 0
Keyword args:
{out}

View file

@ -331,9 +331,6 @@ class FunctionMeta(type):
name + "Backward", (BackwardCFunction,), {"_forward_cls": cls}
)
backward_fn._autograd_function_id = next(AUTOGRAD_FUNCTION_COUNTER) # type: ignore[attr-defined]
backward_fn._compiled_autograd_should_lift = attrs.get( # type: ignore[attr-defined]
"_compiled_autograd_should_lift", True
)
backward_fn._bw_module = None # type: ignore[attr-defined]
if getattr(cls, "_lazy_backward_info", None):
backward_fn._bw_module = cls._lazy_backward_info.bw_module # type: ignore[attr-defined]

Some files were not shown because too many files have changed in this diff Show more