mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update
[ghstack-poisoned]
This commit is contained in:
commit
125261b2d3
116 changed files with 4203 additions and 975 deletions
|
|
@ -758,11 +758,10 @@ const auto sinc_string = jiterator_stringify(
|
|||
T sinc(T a) {
|
||||
if (a == T(0)) {
|
||||
return T(1);
|
||||
} else {
|
||||
constexpr T pi = T(3.14159265358979323846L);
|
||||
T product = pi * a;
|
||||
return std::sin(product) / product;
|
||||
}
|
||||
constexpr T pi = T(3.14159265358979323846L);
|
||||
T product = pi * a;
|
||||
return std::sin(product) / product;
|
||||
}
|
||||
); // sinc_string
|
||||
|
||||
|
|
|
|||
0
benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh
Normal file → Executable file
0
benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh
Normal file → Executable file
|
|
@ -6,15 +6,15 @@ add_loop_eager_dynamic,compile_time_instruction_count,5703000000,0.025
|
|||
|
||||
|
||||
|
||||
add_loop_inductor,compile_time_instruction_count,30150000000,0.015
|
||||
add_loop_inductor,compile_time_instruction_count,29630000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44440000000,0.025
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43980000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,26740000000,0.015
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,26240000000,0.015
|
||||
|
||||
|
||||
|
||||
|
|
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18980000000,
|
|||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17250000000,0.015
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17150000000,0.015
|
||||
|
||||
|
||||
|
||||
|
|
@ -62,4 +62,4 @@ aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3863000000,
|
|||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10340000000,0.015
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10390000000,0.015
|
||||
|
|
|
|||
|
|
|
@ -997,6 +997,7 @@ def define_buck_targets(
|
|||
"Config.h": ":generate_aten_config[Config.h]",
|
||||
},
|
||||
labels = labels,
|
||||
visibility = ["PUBLIC"],
|
||||
)
|
||||
|
||||
fb_xplat_cxx_library(
|
||||
|
|
|
|||
|
|
@ -22,3 +22,15 @@ def define_targets(rules):
|
|||
[],
|
||||
),
|
||||
)
|
||||
|
||||
rules.cc_library(
|
||||
name = "c10_headers",
|
||||
deps = [
|
||||
"//c10/core:base_headers",
|
||||
"//c10/macros",
|
||||
"//c10/util:base_headers",
|
||||
"//c10/util:bit_cast",
|
||||
"//c10/util:ssize",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -90,6 +90,22 @@ def define_targets(rules):
|
|||
alwayslink = True,
|
||||
)
|
||||
|
||||
rules.cc_library(
|
||||
name = "base_headers",
|
||||
srcs = [],
|
||||
hdrs = rules.glob(
|
||||
[
|
||||
"*.h",
|
||||
"impl/*.h",
|
||||
],
|
||||
exclude = [
|
||||
"CPUAllocator.h",
|
||||
"impl/alloc_cpu.h",
|
||||
],
|
||||
),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
rules.filegroup(
|
||||
name = "headers",
|
||||
srcs = rules.glob(
|
||||
|
|
@ -101,5 +117,5 @@ def define_targets(rules):
|
|||
"alignment.h",
|
||||
],
|
||||
),
|
||||
visibility = ["//c10:__pkg__"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -477,5 +477,31 @@ inline float2 sinc(float2 inp) {
|
|||
return float2(re, im) / a2;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T spherical_bessel_j0(T x) {
|
||||
if (::metal::isinf(x))
|
||||
return T(0.0);
|
||||
T x2 = x * x;
|
||||
T k1 = static_cast<T>(-1.0);
|
||||
T k2 = static_cast<T>(1.0);
|
||||
|
||||
if (::metal::abs(x) < T(0.5)) {
|
||||
return T(1.0) +
|
||||
x2 *
|
||||
(k1 / T(6.0) +
|
||||
x2 *
|
||||
(k2 / T(120.0) +
|
||||
x2 *
|
||||
(k1 / T(5040.0) +
|
||||
x2 *
|
||||
(k2 / T(362880.0) +
|
||||
x2 *
|
||||
(k1 / T(39916800.0) +
|
||||
x2 * (k2 / T(6227020800.0)))))));
|
||||
}
|
||||
|
||||
return ::metal::sin(x) / x;
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -80,6 +80,18 @@ def define_targets(rules):
|
|||
],
|
||||
)
|
||||
|
||||
rules.cc_library(
|
||||
name = "base_headers",
|
||||
hdrs = rules.glob(
|
||||
["*.h"],
|
||||
exclude = [
|
||||
"bit_cast.h",
|
||||
"ssize.h",
|
||||
],
|
||||
),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
rules.filegroup(
|
||||
name = "headers",
|
||||
srcs = rules.glob(
|
||||
|
|
|
|||
|
|
@ -5,13 +5,15 @@ import copy
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tensor import (
|
||||
DeviceMesh,
|
||||
from torch.distributed import DeviceMesh, init_device_mesh
|
||||
from torch.distributed.tensor import (
|
||||
distribute_module,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
|
|
@ -181,6 +183,28 @@ class DistConvolutionOpsTest(DTensorTestBase):
|
|||
f"Too large relative mse for bias tensor, expected less equal 1e-6, got {bias_mse_rel}",
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_conv_backward_none_grad_inp(self):
|
||||
device_mesh = init_device_mesh(
|
||||
device_type="cuda", mesh_shape=(self.world_size,)
|
||||
)
|
||||
conv = nn.Conv2d(64, 64, 3, padding=1).train()
|
||||
x = torch.randn(1, 64, 32, 32)
|
||||
x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
|
||||
w = conv.weight
|
||||
w_dt = torch.nn.Parameter(DTensor.from_local(w, device_mesh, [Replicate()]))
|
||||
|
||||
b = conv.bias
|
||||
b_dt = torch.nn.Parameter(DTensor.from_local(b, device_mesh, [Replicate()]))
|
||||
|
||||
res = F.conv2d(x_dt, w_dt, b_dt, padding=1)
|
||||
dres = torch.rand_like(res)
|
||||
res.backward(dres)
|
||||
self.assertTrue(w_dt.grad is not None)
|
||||
self.assertTrue(b_dt.grad is not None)
|
||||
self.assertTrue(x_dt.grad is None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -2007,7 +2007,7 @@ class DistributedDataParallelTest(
|
|||
replica_devices = [dev0]
|
||||
# Tells _test_grad_layout to construct ConvNet with all layers on this process's first assigned device.
|
||||
layer_devs = dev0
|
||||
local_batch_size = 8
|
||||
local_batch_size = 16
|
||||
self._test_grad_layout(replica_devices, layer_devs, local_batch_size)
|
||||
|
||||
@requires_nccl()
|
||||
|
|
@ -2021,7 +2021,7 @@ class DistributedDataParallelTest(
|
|||
replica_devices = None
|
||||
# Tells _test_grad_layout to constructs this process's ConvNet on 2 devices, with 2 layers on each device.
|
||||
layer_devs = [dev0] * 2 + [dev1] * 2
|
||||
local_batch_size = 8
|
||||
local_batch_size = 16
|
||||
self._test_grad_layout(replica_devices, layer_devs, local_batch_size)
|
||||
|
||||
@requires_nccl()
|
||||
|
|
|
|||
|
|
@ -1744,10 +1744,11 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
class ContextlibContextManagerTests(torch._dynamo.test_case.TestCase):
|
||||
def setUp(self):
|
||||
self._prev = torch._dynamo.config.enable_trace_contextlib
|
||||
torch._dynamo.config.enable_trace_contextlib = True
|
||||
|
||||
def tearDown(self):
|
||||
torch._dynamo.config.enable_trace_contextlib = False
|
||||
torch._dynamo.config.enable_trace_contextlib = self._prev
|
||||
|
||||
def test_ctx_basic0(self):
|
||||
@contextlib.contextmanager
|
||||
|
|
@ -2236,7 +2237,7 @@ class GraphModule(torch.nn.Module):
|
|||
eager = EagerAndRecordGraphs()
|
||||
out = torch.compile(backend=eager, fullgraph=False)(fn)(x)
|
||||
self.assertEqual(expected, out)
|
||||
self.assertEqual(len(eager.graphs), 1)
|
||||
self.assertEqual(len(eager.graphs), 0)
|
||||
|
||||
def test_graph_break_before_and_after___enter__(self):
|
||||
@contextlib.contextmanager
|
||||
|
|
@ -2262,7 +2263,7 @@ class GraphModule(torch.nn.Module):
|
|||
eager = EagerAndRecordGraphs()
|
||||
out = torch.compile(backend=eager, fullgraph=False)(fn)(x)
|
||||
self.assertEqual(expected, out)
|
||||
self.assertEqual(len(eager.graphs), 1)
|
||||
self.assertEqual(len(eager.graphs), 0)
|
||||
|
||||
def test_graph_break_before___enter___and_disable___exit__(self):
|
||||
@contextlib.contextmanager
|
||||
|
|
@ -2292,7 +2293,7 @@ class GraphModule(torch.nn.Module):
|
|||
eager = EagerAndRecordGraphs()
|
||||
out = torch.compile(backend=eager, fullgraph=False)(fn)(x)
|
||||
self.assertEqual(expected, out)
|
||||
self.assertEqual(len(eager.graphs), 1)
|
||||
self.assertEqual(len(eager.graphs), 0)
|
||||
|
||||
def test_disable___enter__(self):
|
||||
def h(x):
|
||||
|
|
@ -2573,7 +2574,7 @@ class GraphModule(torch.nn.Module):
|
|||
eager = EagerAndRecordGraphs()
|
||||
out = torch.compile(backend=eager, fullgraph=False)(fn)(x)
|
||||
self.assertEqual(expected, out)
|
||||
self.assertEqual(len(eager.graphs), 1)
|
||||
self.assertEqual(len(eager.graphs), 0)
|
||||
|
||||
def test_dynamo_disable_ctx(self):
|
||||
@contextlib.contextmanager
|
||||
|
|
@ -2623,7 +2624,7 @@ class GraphModule(torch.nn.Module):
|
|||
eager = EagerAndRecordGraphs()
|
||||
out = torch.compile(backend=eager, fullgraph=False, dynamic=False)(f)(x)
|
||||
self.assertEqual(expected, out)
|
||||
self.assertEqual(len(eager.graphs), 3)
|
||||
self.assertEqual(len(eager.graphs), 2)
|
||||
|
||||
@parametrize("name", ("suppress", "stdout", "stderr"))
|
||||
def test_contextlib_suppress(self, name):
|
||||
|
|
|
|||
|
|
@ -404,6 +404,21 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(ref[0], res[0])
|
||||
self.assertEqual(ref[1], res[1])
|
||||
|
||||
def test_raise_GeneratorExit(self):
|
||||
# GeneratorExit does not inherit from Exception
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
try:
|
||||
raise GeneratorExit
|
||||
except Exception:
|
||||
return t.sin()
|
||||
except BaseException:
|
||||
return t.cos()
|
||||
|
||||
t = torch.randn(2)
|
||||
y = fn(t)
|
||||
self.assertEqual(y, t.cos())
|
||||
|
||||
def test_speculation_exception(self):
|
||||
log = SpeculationLog()
|
||||
log.next("fake", 555, "fake", Instruction(1, "fake", 1, 1))
|
||||
|
|
|
|||
1809
test/dynamo/test_generator.py
Normal file
1809
test/dynamo/test_generator.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -7000,6 +7000,12 @@ class TestHigherOrderOpsOpInfo(torch._dynamo.test_case.TestCase):
|
|||
)
|
||||
def test_hops_compile(self, device, dtype, op, backend):
|
||||
# Ensure HOPs can be compiled
|
||||
|
||||
if backend == "aot_eager" and op.name == "invoke_quant":
|
||||
raise unittest.SkipTest(
|
||||
"TODO: partitioner fails. migrate canonicalization to aot eager backend"
|
||||
)
|
||||
|
||||
sample_inputs_itr = op.sample_inputs(
|
||||
device, dtype, requires_grad=op.supports_autograd
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9579,21 +9579,6 @@ def ___make_guard_fn():
|
|||
):
|
||||
compiled_fn(x)
|
||||
|
||||
# FIXME(XuehaiPan): do not inline infinite generator if it does not raise errors in eager mode
|
||||
def fn(x):
|
||||
def gen():
|
||||
while True:
|
||||
yield x
|
||||
|
||||
return list(zip(range(10), gen()))
|
||||
|
||||
x = torch.randn([0, 1, 2, 3, 4, 5])
|
||||
compiled_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported, "infinite generator"
|
||||
):
|
||||
compiled_fn(x)
|
||||
|
||||
def test_itertools_islice(self):
|
||||
counters.clear()
|
||||
|
||||
|
|
|
|||
|
|
@ -1418,9 +1418,9 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertTrue(same(opt_model(a, b, c, d), correct))
|
||||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(cnt.frame_count, """4""")
|
||||
self.assertExpectedInline(cnt.frame_count, """2""")
|
||||
else:
|
||||
self.assertExpectedInline(cnt.frame_count, """5""")
|
||||
self.assertExpectedInline(cnt.frame_count, """3""")
|
||||
|
||||
def test_hf_model_output(self):
|
||||
ex = ModelOutput(a=torch.randn(10), b=torch.randn(10), c=torch.randn(10))
|
||||
|
|
@ -6510,6 +6510,27 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
).sum()
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_incompatible_configs(self):
|
||||
with torch._dynamo.config.patch(
|
||||
suppress_errors=False, fail_on_recompile_limit_hit=False
|
||||
):
|
||||
torch.compile(lambda: None)
|
||||
|
||||
with torch._dynamo.config.patch(
|
||||
suppress_errors=True, fail_on_recompile_limit_hit=False
|
||||
):
|
||||
torch.compile(lambda: None)
|
||||
|
||||
with torch._dynamo.config.patch(
|
||||
suppress_errors=False, fail_on_recompile_limit_hit=True
|
||||
):
|
||||
torch.compile(lambda: None)
|
||||
|
||||
with torch._dynamo.config.patch(
|
||||
suppress_errors=True, fail_on_recompile_limit_hit=True
|
||||
), self.assertRaises(AssertionError):
|
||||
torch.compile(lambda: None)
|
||||
|
||||
|
||||
class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
def test_sub_alpha_scalar_repro(self, device):
|
||||
|
|
|
|||
183
test/higher_order_ops/test_invoke_quant.py
Normal file
183
test/higher_order_ops/test_invoke_quant.py
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
# Owner(s): ["module: higher order operators"]
|
||||
# flake8: noqa: B950
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._functorch
|
||||
import torch._inductor
|
||||
import torch._inductor.decomposition
|
||||
from torch._higher_order_ops import InvokeQuant
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
CallFunction,
|
||||
Ignored,
|
||||
Match,
|
||||
PatternMatcherPass,
|
||||
register_graph_pattern,
|
||||
)
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
|
||||
|
||||
|
||||
invoke_quant_tracer = InvokeQuant()
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Not a torch._dynamo test")
|
||||
class TestInvokeQuant(TestCase):
|
||||
backend = ""
|
||||
|
||||
def test_simple(self):
|
||||
def gn(x, y):
|
||||
return (torch.mul(x, y) + y,)
|
||||
|
||||
def fn(x, y):
|
||||
return invoke_quant_tracer(
|
||||
gn, (x, y), scheme="nf4", quant_options=invoke_quant_tracer
|
||||
)[0]
|
||||
|
||||
x = torch.randn(8, requires_grad=False)
|
||||
y = torch.randn(8, requires_grad=False)
|
||||
ref = gn(x, y)[0]
|
||||
|
||||
x_clone = x.clone().detach().requires_grad_(False)
|
||||
y_clone = y.clone().detach().requires_grad_(False)
|
||||
res = torch.compile(fn, backend=self.backend)(x_clone, y_clone)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_construct_inline(self):
|
||||
def gn(x, y):
|
||||
return (torch.mul(x, y) + y,)
|
||||
|
||||
def fn(x, y):
|
||||
return InvokeQuant(codegen_low_precision=False)(gn, (x, y), scheme="nf4")[0]
|
||||
|
||||
x = torch.randn(8, requires_grad=False)
|
||||
y = torch.randn(8, requires_grad=False)
|
||||
ref = gn(x, y)[0]
|
||||
|
||||
x_clone = x.clone().detach().requires_grad_(False)
|
||||
y_clone = y.clone().detach().requires_grad_(False)
|
||||
res = torch.compile(fn, backend=self.backend)(x_clone, y_clone)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_inline(self):
|
||||
def gn(x, y):
|
||||
return (torch.mul(x, y) + y,)
|
||||
|
||||
def fn(x, y):
|
||||
return InvokeQuant()(gn, (x, y), scheme="nf4")[0]
|
||||
|
||||
x = torch.randn(8, requires_grad=False)
|
||||
y = torch.randn(8, requires_grad=False)
|
||||
ref = gn(x, y)[0]
|
||||
|
||||
x_clone = x.clone().detach().requires_grad_(False)
|
||||
y_clone = y.clone().detach().requires_grad_(False)
|
||||
res = torch.compile(fn, backend=self.backend)(x_clone, y_clone)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_multiple(self):
|
||||
torch._logging.set_logs(post_grad_graphs=True)
|
||||
|
||||
def gn(x, y):
|
||||
return torch.mul(x, y) + y
|
||||
|
||||
def fn(x, y, z):
|
||||
o1 = invoke_quant_tracer(gn, (x, y), scheme="nf4")
|
||||
o2 = invoke_quant_tracer(gn, (y, z), scheme="nf4")
|
||||
return o1 + o2
|
||||
|
||||
x = torch.randn(8, requires_grad=False)
|
||||
y = torch.randn(8, requires_grad=False)
|
||||
z = torch.randn(8, requires_grad=False)
|
||||
ref = fn(x, y, z)
|
||||
|
||||
log_context = (
|
||||
contextlib.nullcontext()
|
||||
if self.backend != "inductor"
|
||||
else self.assertLogs(logger="torch._inductor", level=logging.DEBUG)
|
||||
)
|
||||
|
||||
with log_context as log:
|
||||
res = torch.compile(fn, backend=self.backend)(x, y, z)
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
if self.backend == "inductor":
|
||||
logs = "\n".join(r.getMessage() for r in log.records)
|
||||
f = FileCheck()
|
||||
f.check("AFTER POST GRAD")
|
||||
f.check("subgraph0").check("subgraph1")
|
||||
for _ in range(2):
|
||||
f.check("torch.ops.higher_order.invoke_quant(").check_same("nf4")
|
||||
f.run(logs)
|
||||
|
||||
|
||||
class TestInvokeQuantEager(TestInvokeQuant):
|
||||
backend = "eager"
|
||||
|
||||
|
||||
class TestInvokeQuantAotEager(TestInvokeQuant):
|
||||
backend = "aot_eager"
|
||||
|
||||
|
||||
class TestInvokeQuantInductor(TestInvokeQuant):
|
||||
backend = "inductor"
|
||||
|
||||
def test_pattern_matching(self):
|
||||
counter = 0
|
||||
|
||||
test_pass = PatternMatcherPass()
|
||||
|
||||
def my_pass(g):
|
||||
return test_pass.apply(g)
|
||||
|
||||
def gn(x, y):
|
||||
return torch.mul(x, y) + y
|
||||
|
||||
def fn(x, y, z):
|
||||
return invoke_quant_tracer(gn, (x, y), scheme="nf4") @ z
|
||||
|
||||
def fn_no_match(x, y, z):
|
||||
return invoke_quant_tracer(gn, (x, y)) @ z
|
||||
|
||||
x = torch.randn(64, 64, requires_grad=False)
|
||||
y = torch.randn(64, 64, requires_grad=False)
|
||||
z = torch.randn(64, 64, requires_grad=False)
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunction(
|
||||
torch.ops.aten.mm,
|
||||
CallFunction(
|
||||
torch.ops.higher_order.invoke_quant,
|
||||
Ignored(),
|
||||
Ignored(),
|
||||
Ignored(),
|
||||
scheme="nf4",
|
||||
),
|
||||
Arg(),
|
||||
),
|
||||
pass_dict=test_pass,
|
||||
)
|
||||
def quant_matching(match: Match, *args, **kwargs):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
|
||||
with unittest.mock.patch(
|
||||
"torch._inductor.config.post_grad_custom_pre_pass", my_pass
|
||||
):
|
||||
torch.compile(fn)(x, y, z)
|
||||
self.assertTrue(counter == 1)
|
||||
|
||||
torch.compile(fn_no_match)(x, y, z)
|
||||
self.assertTrue(counter == 1)
|
||||
|
||||
|
||||
del TestInvokeQuant
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
102
test/inductor/test_block_analysis.py
Normal file
102
test/inductor/test_block_analysis.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor.codegen.block_analysis import BlockPatternMatcher
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import dummy_graph
|
||||
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
|
||||
|
||||
|
||||
# Some useful symbols
|
||||
x, y = sympy.symbols("x y")
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class BlockAnalysisTest(TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
|
||||
# Create a GraphLowering, so we can access V.graph.
|
||||
cls.graph = dummy_graph()
|
||||
|
||||
@parametrize(
|
||||
"stride,symbol,expr",
|
||||
[
|
||||
(5, x, Identity(5 * x)),
|
||||
(4, y, 4 * Identity(y)),
|
||||
(3, x, Identity(3) * x),
|
||||
],
|
||||
)
|
||||
def test_affine_identity(self, stride: int, symbol: sympy.Symbol, expr: sympy.Expr):
|
||||
# Test that we can handle an identity expression in affine indexing.
|
||||
matched_stride = BlockPatternMatcher.match_affine_block_expr(expr, symbol)
|
||||
self.assertEqual(matched_stride, stride)
|
||||
|
||||
@parametrize(
|
||||
"dims,strides,symbol,expr",
|
||||
[
|
||||
(
|
||||
(2, 4),
|
||||
(4, 1),
|
||||
x,
|
||||
4 * FloorDiv(Identity(x), 4) + ModularIndexing(x, 1, 4),
|
||||
),
|
||||
(
|
||||
(3, 9),
|
||||
(5, 2),
|
||||
x,
|
||||
5 * FloorDiv(x, 9) + 2 * ModularIndexing(Identity(x), 1, 9),
|
||||
),
|
||||
((2, 7), (1, 1), x, Identity(FloorDiv(x, 7) + ModularIndexing(x, 1, 7))),
|
||||
],
|
||||
)
|
||||
def test_mod_div_identity(
|
||||
self,
|
||||
dims: tuple[int],
|
||||
strides: tuple[int],
|
||||
symbol: sympy.Symbol,
|
||||
expr: sympy.Expr,
|
||||
):
|
||||
# Test that we can handle an identity expression in modular indexing.
|
||||
numel = int(torch.prod(torch.Tensor(dims)))
|
||||
num_dims = len(dims)
|
||||
with V.set_graph_handler(self.graph):
|
||||
match_result = BlockPatternMatcher.match_mod_div_block_expr(
|
||||
expr, symbol, numel, num_dims
|
||||
)
|
||||
|
||||
# Check the matched block dimensions.
|
||||
self.assertNotEqual(match_result, None)
|
||||
matched_dims, matched_strides, matched_block_index_exprs = match_result
|
||||
self.assertEqual(matched_dims, dims)
|
||||
self.assertEqual(matched_strides, strides)
|
||||
|
||||
@parametrize(
|
||||
"symbol,expr,subexpr",
|
||||
[
|
||||
(x, Identity(x), x),
|
||||
(x, Identity(x + 5), x),
|
||||
(y, Identity(x + 2 * y) + 5, 2 * y),
|
||||
],
|
||||
)
|
||||
def test_subexpr_identity(
|
||||
self,
|
||||
symbol: sympy.Symbol,
|
||||
expr: sympy.Expr,
|
||||
subexpr: sympy.Expr,
|
||||
):
|
||||
matched_subexpr = BlockPatternMatcher.get_subexpr_involving_symbol(expr, symbol)
|
||||
self.assertEqual(matched_subexpr, subexpr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -2510,9 +2510,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
@supported_platform
|
||||
def test_strided_backwards(self):
|
||||
shape = (1, 2, 4096, 64)
|
||||
Q = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16)
|
||||
K = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16)
|
||||
V = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16)
|
||||
Q = torch.randn(shape, requires_grad=True, device="cuda")
|
||||
K = torch.randn(shape, requires_grad=True, device="cuda")
|
||||
V = torch.randn(shape, requires_grad=True, device="cuda")
|
||||
func = torch.compile(flex_attention, dynamic=True, fullgraph=True)
|
||||
|
||||
K_sliced = K[:, :, :-128]
|
||||
|
|
|
|||
|
|
@ -5,19 +5,23 @@ from torch._inductor.codegen.cpp import CppOverrides, CppVecOverrides
|
|||
from torch._inductor.codegen.halide import HalideOverrides
|
||||
from torch._inductor.codegen.mps import MetalOverrides
|
||||
from torch._inductor.codegen.triton import TritonKernelOverrides
|
||||
from torch._inductor.ops_handler import list_ops, OP_NAMES
|
||||
from torch._inductor.ops_handler import list_ops, OP_NAMES, OpsHandler
|
||||
from torch._inductor.test_case import TestCase
|
||||
|
||||
|
||||
class TestOpCompleteness(TestCase):
|
||||
def verify_ops_handler_completeness(self, handler):
|
||||
op_names = list_ops(handler)
|
||||
if OP_NAMES == op_names:
|
||||
return
|
||||
print(f"Missing ops: {OP_NAMES - op_names}")
|
||||
print(f"Extra ops: {op_names - OP_NAMES}")
|
||||
self.assertEqual(", ".join(OP_NAMES - op_names), "")
|
||||
self.assertEqual(", ".join(op_names - OP_NAMES), "")
|
||||
for op in OP_NAMES:
|
||||
self.assertIsNot(
|
||||
getattr(handler, op),
|
||||
getattr(OpsHandler, op),
|
||||
msg=f"{handler} must implement {op}",
|
||||
)
|
||||
extra_ops = list_ops(handler) - OP_NAMES
|
||||
if extra_ops:
|
||||
raise AssertionError(
|
||||
f"{handler} has an extra ops: {extra_ops}, add them to OpHandler class or prefix with `_`"
|
||||
)
|
||||
|
||||
def test_triton_overrides(self):
|
||||
self.verify_ops_handler_completeness(TritonKernelOverrides)
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class TestSortAndSelect(TestCase):
|
|||
return ((b != b) | (a <= b)).all().item()
|
||||
|
||||
else:
|
||||
error( # noqa: F821
|
||||
raise ValueError(
|
||||
f'unknown order "{order}", must be "ascending" or "descending"'
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from torch.testing._internal.common_utils import (
|
|||
)
|
||||
from torch.utils._sympy.functions import (
|
||||
FloorDiv,
|
||||
Identity,
|
||||
OpaqueUnaryFn_cos,
|
||||
simple_floordiv_gcd,
|
||||
)
|
||||
|
|
@ -34,7 +35,8 @@ from torch.utils._sympy.reference import (
|
|||
)
|
||||
from torch.utils._sympy.singleton_int import SingletonInt
|
||||
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
|
||||
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
from torch._inductor.bounds import ValueRangeAnalysis
|
||||
|
||||
|
||||
UNARY_OPS = [
|
||||
|
|
@ -954,6 +956,17 @@ class TestSingletonInt(TestCase):
|
|||
|
||||
self.assertEqual(j1.free_symbols, set())
|
||||
|
||||
class TestIdentity(TestCase):
|
||||
def test_expand_identity(self):
|
||||
"""
|
||||
Test removing an identity via expansion.
|
||||
"""
|
||||
x = sympy.Symbol("x")
|
||||
arg = x + sympy.S.One
|
||||
expr = Identity(arg)
|
||||
expanded = expr.expand(identity=True)
|
||||
self.assertEqual(expanded.count(Identity), 0)
|
||||
self.assertEqual(expanded, arg)
|
||||
|
||||
instantiate_parametrized_tests(TestValueRanges)
|
||||
instantiate_parametrized_tests(TestSympyInterp)
|
||||
|
|
|
|||
|
|
@ -324,7 +324,7 @@ class TestTransformers(NNTestCase):
|
|||
encoder(test, src_key_padding_mask=pad_mask.to(torch.uint8))
|
||||
except AssertionError:
|
||||
continue
|
||||
self.assertFalse(e, "Failed to catch unsupported uint8 type exception") # noqa: F821
|
||||
self.assertFalse(e, "Failed to catch unsupported uint8 type exception")
|
||||
|
||||
test_train_bool = encoder(test, src_key_padding_mask=pad_mask)
|
||||
encoder.eval()
|
||||
|
|
@ -335,7 +335,7 @@ class TestTransformers(NNTestCase):
|
|||
encoder(test, src_key_padding_mask=pad_mask.to(torch.int64))
|
||||
except AssertionError as e:
|
||||
continue
|
||||
self.assertFalse(e, "Failed to catch unsupported Long type exception") # noqa: F821
|
||||
self.assertFalse(e, "Failed to catch unsupported Long type exception")
|
||||
|
||||
test_eval_bool = encoder(test, src_key_padding_mask=pad_mask)
|
||||
l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item()
|
||||
|
|
|
|||
|
|
@ -78,6 +78,9 @@ def create_build_plan() -> list[tuple[str, str]]:
|
|||
if line.startswith(": &&") and line.endswith("&& :"):
|
||||
line = line[4:-4]
|
||||
line = line.replace("-O2", "-g").replace("-O3", "-g")
|
||||
# Build Metal shaders with debug infomation
|
||||
if "xcrun metal " in line and "-frecord-sources" not in line:
|
||||
line += " -frecord-sources -gline-tables-only"
|
||||
try:
|
||||
name = line.split("-o ", 1)[1].split(" ")[0]
|
||||
rc.append((name, line))
|
||||
|
|
|
|||
|
|
@ -26,7 +26,10 @@ from .exc import IncorrectUsage, unimplemented
|
|||
from .source import AttrSource, Source
|
||||
from .utils import is_safe_constant, rot_n_helper
|
||||
from .variables.base import ValueMutationExisting, VariableTracker
|
||||
from .variables.functions import FunctionDecoratedByContextlibContextManagerVariable
|
||||
from .variables.functions import (
|
||||
ContextlibContextManagerLocalGeneratorObjectVariable,
|
||||
LocalGeneratorObjectVariable,
|
||||
)
|
||||
from .variables.nn_module import NNModuleVariable
|
||||
from .variables.tensor import (
|
||||
NumpyNdarrayVariable,
|
||||
|
|
@ -162,14 +165,20 @@ class PyCodegen:
|
|||
return
|
||||
|
||||
if value.is_realized() and isinstance(
|
||||
value, FunctionDecoratedByContextlibContextManagerVariable
|
||||
value, ContextlibContextManagerLocalGeneratorObjectVariable
|
||||
):
|
||||
raise IncorrectUsage(
|
||||
"NYI: Returning a @contextmanager object from a torch.compile function"
|
||||
)
|
||||
|
||||
# Dynamo normally prefers codegen from source to account for aliasing.
|
||||
if value.source is not None and allow_cache:
|
||||
if (
|
||||
value.source is not None
|
||||
and allow_cache
|
||||
and not (
|
||||
value.is_realized() and isinstance(value, LocalGeneratorObjectVariable)
|
||||
)
|
||||
):
|
||||
# There's a corner case for export: for instance, if the computation
|
||||
# graph is just identity on an input tensor, Dynamo would just emit
|
||||
# a `LOAD_FAST` from the input source, rather than generating an
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ skip_code_recursive_on_recompile_limit_hit = True
|
|||
# raise a hard error if cache limit is hit. If you are on a model where you
|
||||
# know you've sized the cache correctly, this can help detect problems when
|
||||
# you regress guards/specialization. This works best when recompile_limit = 1.
|
||||
# This flag is incompatible with: suppress_errors.
|
||||
# [@compile_ignored: runtime_behaviour]
|
||||
fail_on_recompile_limit_hit = False
|
||||
|
||||
|
|
@ -164,6 +165,7 @@ traceable_tensor_subclasses: set[type[Any]] = set()
|
|||
# This is a good way to get your model to work one way or another, but you may
|
||||
# lose optimization opportunities this way. Devs, if your benchmark model is failing
|
||||
# this way, you should figure out why instead of suppressing it.
|
||||
# This flag is incompatible with: fail_on_recompile_limit_hit.
|
||||
suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False))
|
||||
|
||||
# Record and write an execution record of the current frame to a file
|
||||
|
|
@ -417,6 +419,10 @@ enable_cpp_symbolic_shape_guards = False
|
|||
# Enable tracing through contextlib.contextmanager
|
||||
enable_trace_contextlib = True
|
||||
|
||||
# Enable tracing generator functions lazily. If False, Dynamo will exhaust
|
||||
# generators upon first execution. And if True, the generator will be accessed lazily
|
||||
enable_faithful_generator_behavior = True
|
||||
|
||||
# Inline inbuilt nn modules
|
||||
inline_inbuilt_nn_modules = Config( # type: ignore[var-annotated]
|
||||
default=True,
|
||||
|
|
|
|||
|
|
@ -833,6 +833,13 @@ def is_inductor_supported():
|
|||
return False
|
||||
|
||||
|
||||
def check_for_incompatible_configs():
|
||||
# Some of the configs should be mutually exclusive
|
||||
assert not (
|
||||
config.suppress_errors and config.fail_on_recompile_limit_hit
|
||||
), "Dynamo configs suppress_error and fail_on_recompile_limit_hit can not both be active at the same time."
|
||||
|
||||
|
||||
def optimize(*args, **kwargs):
|
||||
def rebuild_ctx():
|
||||
ca_kwargs_override = config.compiled_autograd_kwargs_override
|
||||
|
|
@ -885,6 +892,7 @@ def _optimize(
|
|||
...
|
||||
"""
|
||||
check_if_dynamo_supported()
|
||||
check_for_incompatible_configs()
|
||||
# Note: The hooks object could be global instead of passed around, *however* that would make
|
||||
# for a confusing API usage and plumbing story wherein we nest multiple .optimize calls.
|
||||
# There is some prior art around this, w/r/t nesting backend calls are enforced to be the same
|
||||
|
|
|
|||
|
|
@ -162,6 +162,17 @@ class AttributeMutationError(Unsupported):
|
|||
super().__init__(msg)
|
||||
|
||||
|
||||
class InfiniteGeneratorError(Unsupported):
|
||||
# Raised when the number of yielded values is greater than MAX_ITERATOR_LIMIT
|
||||
def __init__(self, msg: str) -> None:
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class SideEffectsError(Unsupported):
|
||||
def __init__(self, msg: str) -> None:
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class CondOpArgsMismatchError(ArgsMismatchError):
|
||||
"""
|
||||
Internal error from cond() due to arguments mismatch.
|
||||
|
|
@ -267,12 +278,17 @@ class ObservedKeyError(ObservedLookupError):
|
|||
pass
|
||||
|
||||
|
||||
class ObservedGeneratorExit(ObservedException):
|
||||
pass
|
||||
|
||||
|
||||
class ObservedAttributeError(ObservedException):
|
||||
# An AttributeError exception to be raised from inside Dynamo tracing. This can happen on user defined object __getattr__
|
||||
pass
|
||||
|
||||
|
||||
class ObservedRuntimeError(ObservedException):
|
||||
# A RuntimeError exception to be raised from inside Dynamo tracing. This can happen on generator.throw(..) method
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -280,17 +296,32 @@ class ObservedNotImplementedError(ObservedException):
|
|||
pass
|
||||
|
||||
|
||||
class ObservedTypeError(ObservedException):
|
||||
# A TypeError exception to be raised from inside Dynamo tracing. This can happen on generator.send(..) method
|
||||
pass
|
||||
|
||||
|
||||
observed_exception_map = {
|
||||
StopIteration: ObservedUserStopIteration,
|
||||
LookupError: ObservedLookupError,
|
||||
IndexError: ObservedIndexError,
|
||||
GeneratorExit: ObservedGeneratorExit,
|
||||
KeyError: ObservedKeyError,
|
||||
AttributeError: ObservedAttributeError,
|
||||
RuntimeError: ObservedRuntimeError,
|
||||
NotImplementedError: ObservedNotImplementedError,
|
||||
TypeError: ObservedTypeError,
|
||||
}
|
||||
|
||||
|
||||
def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedException]:
|
||||
if exc_type not in observed_exception_map:
|
||||
observed_exception_map[exc_type] = type(
|
||||
f"Observed{exc_type.__name__}Error", (ObservedException,), {}
|
||||
)
|
||||
return observed_exception_map[exc_type]
|
||||
|
||||
|
||||
def raise_observed_exception(
|
||||
exc_type: type[Exception],
|
||||
tx: InstructionTranslatorBase,
|
||||
|
|
|
|||
|
|
@ -1936,6 +1936,9 @@ class SubgraphTracer(fx.Tracer):
|
|||
# backward recomputation of the checkpoint region doesn't affect its correctness.
|
||||
self.allow_side_effects_under_checkpoint = False
|
||||
|
||||
# True if this tracer is currently tracing (reconstructing) into a Python generator
|
||||
self.is_reconstructing_generator = False
|
||||
|
||||
self.debug_level: int = parent.debug_level + 1 if parent is not None else 0
|
||||
|
||||
self._cur_code = None
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import warnings
|
|||
import weakref
|
||||
from collections.abc import MutableMapping
|
||||
from types import CellType
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import torch.nn
|
||||
|
||||
|
|
@ -19,7 +19,7 @@ from .bytecode_transformation import (
|
|||
create_instruction,
|
||||
)
|
||||
from .codegen import PyCodegen
|
||||
from .exc import unimplemented
|
||||
from .exc import SideEffectsError, unimplemented
|
||||
from .source import GlobalSource, LocalCellSource, LocalSource, Source
|
||||
from .utils import dict_new, is_frozen_dataclass, nn_module_new, object_new, tuple_new
|
||||
from .variables.base import (
|
||||
|
|
@ -34,6 +34,10 @@ from .variables.base import (
|
|||
from .variables.user_defined import FrozenDataClassVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
def _manual_dict_setitem(dict_from, dict_to, mro_index):
|
||||
# Carefully calls the dict or OrderedDict `clear` or `__setitem__`. We have
|
||||
# to be careful because we don't want to trigger the user defined object
|
||||
|
|
@ -134,6 +138,14 @@ class SideEffects:
|
|||
and output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint
|
||||
)
|
||||
|
||||
def is_reconstructing_generator(self):
|
||||
output_graph = self.output_graph_weakref()
|
||||
|
||||
return (
|
||||
output_graph
|
||||
and output_graph.current_tx.output.current_tracer.is_reconstructing_generator
|
||||
)
|
||||
|
||||
def check_allowed_side_effect(self, item):
|
||||
from torch._dynamo.variables.misc import AutogradFunctionContextVariable
|
||||
|
||||
|
|
@ -143,6 +155,14 @@ class SideEffects:
|
|||
return True
|
||||
if self.should_allow_side_effects_under_checkpoint():
|
||||
return True
|
||||
if self.is_reconstructing_generator():
|
||||
# This is missing the case where one mutates a tensor. See
|
||||
# test_generator.py::test_reconstruct_generator_tensor_mutation
|
||||
raise SideEffectsError(
|
||||
"Cannot reconstruct a generator with variable mutations. "
|
||||
"Dynamo needs to fully exhaust the generator, which may cause "
|
||||
"unintended variable modifications."
|
||||
)
|
||||
if not is_side_effect_safe(item.mutation_type):
|
||||
unimplemented(
|
||||
"HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)"
|
||||
|
|
@ -842,7 +862,7 @@ class SideEffects:
|
|||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): # type: ignore[name-defined] # noqa: F821
|
||||
def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"):
|
||||
assert tx.output.current_tracer.under_activation_checkpoint
|
||||
orig_val = tx.output.current_tracer.allow_side_effects_under_checkpoint
|
||||
try:
|
||||
|
|
@ -850,3 +870,13 @@ def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): # type: i
|
|||
yield
|
||||
finally:
|
||||
tx.output.current_tracer.allow_side_effects_under_checkpoint = orig_val
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def disallow_side_effects_in_generator(tx: "InstructionTranslator"):
|
||||
orig_val = tx.output.current_tracer.is_reconstructing_generator
|
||||
try:
|
||||
tx.output.current_tracer.is_reconstructing_generator = True
|
||||
yield
|
||||
finally:
|
||||
tx.output.current_tracer.is_reconstructing_generator = orig_val
|
||||
|
|
|
|||
|
|
@ -85,7 +85,8 @@ from .variables.ctx_manager import (
|
|||
from .variables.dicts import ConstDictVariable, SetVariable
|
||||
from .variables.functions import (
|
||||
BaseUserFunctionVariable,
|
||||
FunctionDecoratedByContextlibContextManagerVariable,
|
||||
LocalGeneratorFunctionVariable,
|
||||
LocalGeneratorObjectVariable,
|
||||
NestedUserFunctionVariable,
|
||||
SkipFunctionVariable,
|
||||
UserFunctionVariable,
|
||||
|
|
@ -290,6 +291,34 @@ def _step_logger():
|
|||
return torchdynamo_logging.get_step_logger(log)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def save_and_restart_speculation_log(tx: "InstructionTranslatorBase"):
|
||||
# When reconstructing a generator after a graph break, we advance it until
|
||||
# it is fully exhausted. This process adds new entries to the speculation
|
||||
# log that were not previously observed. Without temporarily clearing the
|
||||
# speculation log, this could lead to a divergence error.
|
||||
|
||||
entries = tx.speculation_log.entries
|
||||
index = tx.speculation_log.index
|
||||
try:
|
||||
tx.speculation_log.entries = []
|
||||
tx.speculation_log.index = 0
|
||||
yield
|
||||
finally:
|
||||
tx.speculation_log.entries = entries
|
||||
tx.speculation_log.index = index
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temporarely_allow_writes_to_output_graph(tx: "InstructionTranslatorBase"):
|
||||
try:
|
||||
tmp = tx.output.should_exit
|
||||
tx.output.should_exit = False
|
||||
yield
|
||||
finally:
|
||||
tx.output.should_exit = tmp
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BlockStackEntry:
|
||||
# Current instruction that pushes something to block_stack
|
||||
|
|
@ -922,11 +951,22 @@ class InstructionTranslatorBase(
|
|||
raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
|
||||
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
|
||||
|
||||
def inline_generator_function(self, fn, args, kwargs):
|
||||
"""
|
||||
Redirect the call to the generator "call_function"
|
||||
"""
|
||||
if not isinstance(fn, LocalGeneratorFunctionVariable):
|
||||
fn = LocalGeneratorFunctionVariable(fn)
|
||||
return fn.call_function(self, args, kwargs)
|
||||
|
||||
def inline_user_function_return(self, fn, args, kwargs):
|
||||
"""
|
||||
A call to some user defined function by inlining it.
|
||||
"""
|
||||
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
|
||||
if config.enable_faithful_generator_behavior and is_generator(fn.get_code()):
|
||||
return self.inline_generator_function(fn, args, kwargs)
|
||||
else:
|
||||
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
|
||||
|
||||
def get_line_of_code_header(self, lineno=None):
|
||||
if lineno is None:
|
||||
|
|
@ -1499,6 +1539,15 @@ class InstructionTranslatorBase(
|
|||
self._raise_exception_variable(inst)
|
||||
unimplemented("raise ... from ...")
|
||||
|
||||
def CLEANUP_THROW(self, inst):
|
||||
# https://github.com/python/cpython/pull/96010
|
||||
tos = self.stack[-1]
|
||||
assert isinstance(tos, ExceptionVariable)
|
||||
if tos.exc_type is StopIteration:
|
||||
unimplemented("CLEANUP_THROW with StopIteration")
|
||||
else:
|
||||
self.RERAISE(inst)
|
||||
|
||||
def RERAISE(self, inst):
|
||||
if sys.version_info >= (3, 11):
|
||||
# RERAISE is currently supported in a narrow case of `raise ... from None`
|
||||
|
|
@ -3083,7 +3132,20 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||
return True
|
||||
return False
|
||||
|
||||
def replace_tos_if_return_is_generator(self):
|
||||
if (
|
||||
len(self.stack)
|
||||
and (tos := self.stack[-1])
|
||||
and isinstance(tos, LocalGeneratorObjectVariable)
|
||||
):
|
||||
self.stack[-1] = ListIteratorVariable(
|
||||
tos.force_unpack_var_sequence(self),
|
||||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
|
||||
def _return(self, inst):
|
||||
self.replace_tos_if_return_is_generator()
|
||||
|
||||
if (
|
||||
not config.allow_empty_graphs
|
||||
and self.output.count_calls() == 0
|
||||
|
|
@ -3093,6 +3155,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||
and not self.one_graph
|
||||
):
|
||||
raise exc.SkipFrame("because no content in function call")
|
||||
|
||||
self.instruction_pointer = None
|
||||
_step_logger()(
|
||||
logging.INFO,
|
||||
|
|
@ -3179,8 +3242,6 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
func: VariableTracker,
|
||||
args: list[VariableTracker],
|
||||
kwargs,
|
||||
*,
|
||||
stop_generator_on_yield: bool = False,
|
||||
):
|
||||
if isinstance(func, SkipFunctionVariable):
|
||||
unimplemented("inline with functions in skip files")
|
||||
|
|
@ -3189,7 +3250,8 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
(
|
||||
UserFunctionVariable,
|
||||
NestedUserFunctionVariable,
|
||||
FunctionDecoratedByContextlibContextManagerVariable,
|
||||
LocalGeneratorFunctionVariable,
|
||||
LocalGeneratorObjectVariable,
|
||||
),
|
||||
)
|
||||
result = InliningInstructionTranslator.check_inlineable(func)
|
||||
|
|
@ -3254,9 +3316,10 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
parent.symbolic_globals,
|
||||
parent.symbolic_torch_function_state,
|
||||
func,
|
||||
stop_generator_on_yield=stop_generator_on_yield,
|
||||
)
|
||||
else:
|
||||
# need the line below to make MyPy happy
|
||||
assert not isinstance(func, LocalGeneratorObjectVariable)
|
||||
tracer = InliningInstructionTranslator(
|
||||
parent,
|
||||
code,
|
||||
|
|
@ -3302,24 +3365,30 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
|
||||
log.debug("DONE INLINING %s", code)
|
||||
|
||||
if is_generator(code):
|
||||
assert isinstance(self, InliningGeneratorInstructionTranslator)
|
||||
# The first flag tells us if we consume generators lazily or not
|
||||
# and the second is if the generator is exhausted.
|
||||
# In the future, generators should be lazily consumed and the first
|
||||
# flag (stop_generator_on_yield) will not be needed.
|
||||
if self.stop_generator_on_yield and self.generator_exhausted:
|
||||
if config.enable_faithful_generator_behavior or (
|
||||
isinstance(self, InliningGeneratorInstructionTranslator)
|
||||
and self.is_generator_from_ctx_manager
|
||||
):
|
||||
if (
|
||||
is_generator(code)
|
||||
and isinstance(self, InliningGeneratorInstructionTranslator)
|
||||
and self.generator_exhausted
|
||||
):
|
||||
assert isinstance(self, InliningGeneratorInstructionTranslator)
|
||||
# When the generator returns None, we raise StopIteration
|
||||
r = self.symbolic_result
|
||||
assert r.as_python_constant() is None
|
||||
exc.raise_observed_exception(StopIteration, self)
|
||||
else:
|
||||
return self.symbolic_result
|
||||
else:
|
||||
if is_generator(code):
|
||||
assert isinstance(self, InliningGeneratorInstructionTranslator)
|
||||
assert self.symbolic_result.as_python_constant() is None
|
||||
return ListIteratorVariable(
|
||||
self.generated_items,
|
||||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
else:
|
||||
return self.symbolic_result
|
||||
else:
|
||||
return self.symbolic_result
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -3438,27 +3507,26 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
|
||||
generated_items: list[VariableTracker]
|
||||
# Flag wether or not the InlineGenerator should consume the entire iterator
|
||||
stop_generator_on_yield: bool
|
||||
|
||||
def __init__(self, *args, stop_generator_on_yield: bool = False, **kwargs) -> None:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.generated_items = []
|
||||
# In the future, generators should run lazily (i.e. when next(...) is called)
|
||||
# TODO: Set this to True by default, so that dynamo follows CPython more
|
||||
# closely
|
||||
self.stop_generator_on_yield = stop_generator_on_yield
|
||||
self.generator_exhausted = False
|
||||
self.is_generator_from_ctx_manager = False
|
||||
|
||||
def YIELD_VALUE(self, inst: Instruction):
|
||||
top = self.pop()
|
||||
self.generated_items.append(top)
|
||||
if len(self.generated_items) > MAX_ITERATOR_LIMIT:
|
||||
unimplemented(
|
||||
raise exc.InfiniteGeneratorError(
|
||||
"Too many yield values in generator. Maybe you are inlining an infinite generator. "
|
||||
f"If not, please report a bug at {PT2_ISSUE_TRACKER_URL}",
|
||||
)
|
||||
self.push(ConstantVariable.create(None))
|
||||
if self.stop_generator_on_yield:
|
||||
if (
|
||||
config.enable_faithful_generator_behavior
|
||||
or self.is_generator_from_ctx_manager
|
||||
):
|
||||
self.symbolic_result = top
|
||||
# Stop tracing
|
||||
raise YieldValueOp
|
||||
|
|
@ -3500,10 +3568,6 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
|
|||
self.pop()
|
||||
self.push(ConstantVariable.create(ex.value))
|
||||
else:
|
||||
self.push(val)
|
||||
# Add the value to yield into generated_items and replace the top of the stack with None
|
||||
self.YIELD_VALUE(inst)
|
||||
|
||||
# Repeat the YIELD_FROM instruction in the next eval loop
|
||||
assert (
|
||||
isinstance(self.instruction_pointer, int)
|
||||
|
|
@ -3511,11 +3575,15 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
|
|||
)
|
||||
self.instruction_pointer -= 1
|
||||
|
||||
self.push(val)
|
||||
# Add the value to yield into generated_items and replace the top of the stack with None
|
||||
self.YIELD_VALUE(inst)
|
||||
|
||||
def SEND(self, inst):
|
||||
assert len(self.stack) >= 2
|
||||
val = self.pop()
|
||||
tos = self.stack[-1]
|
||||
if isinstance(tos, ListIteratorVariable) or (
|
||||
if isinstance(tos, (ListIteratorVariable, LocalGeneratorObjectVariable)) or (
|
||||
isinstance(tos, UserDefinedObjectVariable)
|
||||
and isinstance(tos.value, collections.abc.Iterator)
|
||||
):
|
||||
|
|
|
|||
|
|
@ -32,8 +32,9 @@ from .utils import getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper
|
|||
from .variables import (
|
||||
BuiltinVariable,
|
||||
FunctionalCallVariable,
|
||||
FunctionDecoratedByContextlibContextManagerVariable,
|
||||
FunctorchHigherOrderVariable,
|
||||
LocalGeneratorFunctionVariable,
|
||||
LocalGeneratorObjectVariable,
|
||||
NestedUserFunctionVariable,
|
||||
PolyfilledFunctionVariable,
|
||||
SkipFunctionVariable,
|
||||
|
|
@ -3209,6 +3210,7 @@ LEGACY_MOD_INLINELIST = {
|
|||
"torch._higher_order_ops.while_loop",
|
||||
"torch._higher_order_ops.associative_scan",
|
||||
"torch._higher_order_ops.scan",
|
||||
"torch._higher_order_ops._invoke_quant",
|
||||
"torch._higher_order_ops.utils",
|
||||
"torch.nn.attention.flex_attention",
|
||||
"torch.ao.quantization.pt2e.export_utils",
|
||||
|
|
@ -3619,7 +3621,8 @@ def check_verbose(obj, is_inlined_call=False):
|
|||
UserFunctionVariable,
|
||||
UserMethodVariable,
|
||||
NestedUserFunctionVariable,
|
||||
FunctionDecoratedByContextlibContextManagerVariable,
|
||||
LocalGeneratorFunctionVariable,
|
||||
LocalGeneratorObjectVariable,
|
||||
),
|
||||
):
|
||||
try:
|
||||
|
|
@ -3639,7 +3642,14 @@ def check_verbose(obj, is_inlined_call=False):
|
|||
# Consulte the central trace rules defined in torch._dynamo.trace_rules.
|
||||
reasons: set[str] = set()
|
||||
rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons)
|
||||
if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)):
|
||||
if issubclass(
|
||||
rule,
|
||||
(
|
||||
UserFunctionVariable,
|
||||
LocalGeneratorFunctionVariable,
|
||||
PolyfilledFunctionVariable,
|
||||
),
|
||||
):
|
||||
return SkipResult(
|
||||
False,
|
||||
f"inlined according trace_rules.lookup {reasons.pop()}",
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ from .functions import (
|
|||
FunctionDecoratedByContextlibContextManagerVariable,
|
||||
FunctoolsPartialVariable,
|
||||
FunctoolsWrapsVariable,
|
||||
LocalGeneratorFunctionVariable,
|
||||
LocalGeneratorObjectVariable,
|
||||
NestedUserFunctionVariable,
|
||||
PolyfilledFunctionVariable,
|
||||
SkipFunctionVariable,
|
||||
|
|
|
|||
|
|
@ -773,7 +773,13 @@ class BuiltinVariable(VariableTracker):
|
|||
tx, [v.realize() for v in args], kwargs
|
||||
)
|
||||
|
||||
if inspect.isclass(fn) and issubclass(fn, Exception):
|
||||
if inspect.isclass(fn) and (
|
||||
issubclass(fn, Exception)
|
||||
# GeneratorExit doens't inherit from Exception
|
||||
# >>> issubclass(GeneratorExit, Exception)
|
||||
# False
|
||||
or fn is GeneratorExit
|
||||
):
|
||||
|
||||
def create_exception_class_object(
|
||||
tx: "InstructionTranslator", args, kwargs
|
||||
|
|
@ -1425,6 +1431,13 @@ class BuiltinVariable(VariableTracker):
|
|||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
|
||||
def _call_iter_tuple_generator(self, tx, obj, *args, **kwargs):
|
||||
cls = variables.BaseListVariable.cls_for(self.fn)
|
||||
return cls(
|
||||
list(obj.force_unpack_var_sequence(tx)), # exhaust generator
|
||||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
|
||||
def _call_tuple_list(self, tx, obj=None, *args, **kwargs):
|
||||
if isinstance(obj, variables.IteratorVariable):
|
||||
cls = variables.BaseListVariable.cls_for(self.fn)
|
||||
|
|
@ -1432,6 +1445,8 @@ class BuiltinVariable(VariableTracker):
|
|||
list(obj.force_unpack_var_sequence(tx)),
|
||||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
elif isinstance(obj, variables.LocalGeneratorObjectVariable):
|
||||
return self._call_iter_tuple_generator(tx, obj, *args, **kwargs)
|
||||
else:
|
||||
return self._call_iter_tuple_list(tx, obj, *args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,17 +4,30 @@ import builtins
|
|||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import sys
|
||||
import types
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
|
||||
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, TypeVar
|
||||
from typing_extensions import Never
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from .. import polyfills, variables
|
||||
from ..bytecode_transformation import create_call_function, create_rot_n
|
||||
from ..exc import raise_observed_exception, unimplemented, Unsupported
|
||||
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
|
||||
from ..exc import (
|
||||
get_dynamo_observed_exception,
|
||||
handle_observed_exception,
|
||||
IncorrectUsage,
|
||||
InfiniteGeneratorError,
|
||||
ObservedException,
|
||||
ObservedGeneratorExit,
|
||||
ObservedUserStopIteration,
|
||||
raise_observed_exception,
|
||||
SkipFrame,
|
||||
unimplemented,
|
||||
Unsupported,
|
||||
)
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
|
||||
from ..utils import (
|
||||
|
|
@ -378,61 +391,426 @@ class BuiltinMethodVariable(BaseUserFunctionVariable):
|
|||
return obj_vt.call_method(tx, name, args, kwargs)
|
||||
|
||||
|
||||
class FunctionDecoratedByContextlibContextManagerVariable(BaseUserFunctionVariable):
|
||||
# TODO(guilherme): replace this with a generic GeneratorFunctionVariable
|
||||
class LocalGeneratorObjectVariable(VariableTracker):
|
||||
def __init__(
|
||||
self,
|
||||
code: types.CodeType,
|
||||
f_globals,
|
||||
inline_tracer: Optional["InstructionTranslator"],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.code = code
|
||||
self.f_globals = f_globals
|
||||
self.inline_tracer = inline_tracer
|
||||
|
||||
def get_code(self):
|
||||
return self.code
|
||||
|
||||
def get_filename(self):
|
||||
return self.get_code().co_filename
|
||||
|
||||
def get_name(self):
|
||||
return self.get_code().co_name
|
||||
|
||||
def get_function(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def has_self(self):
|
||||
return False
|
||||
|
||||
def __name__(self):
|
||||
return self.get_name()
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.__class__.__name__}({self.get_name()})"
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
from torch._dynamo.side_effects import disallow_side_effects_in_generator
|
||||
from torch._dynamo.symbolic_convert import (
|
||||
InstructionTranslator,
|
||||
save_and_restart_speculation_log,
|
||||
temporarely_allow_writes_to_output_graph,
|
||||
)
|
||||
|
||||
tx = InstructionTranslator.current_tx()
|
||||
save = save_and_restart_speculation_log(tx)
|
||||
disallow = disallow_side_effects_in_generator(tx)
|
||||
temp = temporarely_allow_writes_to_output_graph(tx)
|
||||
|
||||
with save, disallow, temp:
|
||||
tracer = self._get_inline_tracer(tx)
|
||||
if not tracer.generator_exhausted:
|
||||
self.remaining_items = self.force_unpack_var_sequence(tx)
|
||||
variables.ListIteratorVariable(self.remaining_items).reconstruct(codegen)
|
||||
|
||||
def bind_args(self, tx, args, kwargs):
|
||||
return self.fn.bind_args(tx, args, kwargs)
|
||||
|
||||
def get_globals(self):
|
||||
return self.f_globals
|
||||
|
||||
def python_type(self):
|
||||
return types.GeneratorType
|
||||
|
||||
def _get_inline_tracer(self, tx):
|
||||
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||
|
||||
if self.inline_tracer is None:
|
||||
self.inline_tracer = InliningInstructionTranslator.build_inline_tracer(
|
||||
tx, self, [], {}
|
||||
)
|
||||
return self.inline_tracer
|
||||
|
||||
def next_variable(self, tx):
|
||||
tracer = self._get_inline_tracer(tx)
|
||||
|
||||
if self._is_generator_exhausted():
|
||||
raise_observed_exception(StopIteration, tx)
|
||||
|
||||
try:
|
||||
# Hierarchically, tx can be seen as the parent of the inline tracer
|
||||
# created on call_function. Any exception needs to be propagated to tx
|
||||
# for Dynamo to behave correctly
|
||||
with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
|
||||
return tracer.inline_call_()
|
||||
except ObservedException as e:
|
||||
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
|
||||
raise e
|
||||
except InfiniteGeneratorError:
|
||||
# test/dynamo/test_misc.py::test_iterator_limit
|
||||
raise
|
||||
except Unsupported as e:
|
||||
torch._C._dynamo.eval_frame.skip_code(self.get_code())
|
||||
raise SkipFrame from e
|
||||
finally:
|
||||
counters["unimplemented"] |= counters["inline_call"]
|
||||
|
||||
def has_unpack_var_sequence(self, tx):
|
||||
return False
|
||||
|
||||
def has_force_unpack_var_sequence(self, tx) -> builtins.bool:
|
||||
return True
|
||||
|
||||
def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
|
||||
result = []
|
||||
while True:
|
||||
try:
|
||||
result.append(self.next_variable(tx))
|
||||
except ObservedUserStopIteration:
|
||||
handle_observed_exception(tx)
|
||||
break
|
||||
return result
|
||||
|
||||
def _setup_exception(self, tx, exc):
|
||||
tracer = self._get_inline_tracer(tx)
|
||||
tracer.push(exc)
|
||||
try:
|
||||
tracer._raise_exception_variable(None)
|
||||
except ObservedException as e:
|
||||
# if no handler is available (i.e. user code doesn't catch it), the
|
||||
# exception is raised again.
|
||||
tracer.exception_handler(e)
|
||||
|
||||
def _is_generator_just_started(self):
|
||||
return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0
|
||||
|
||||
def _is_generator_exhausted(self):
|
||||
return getattr(self.inline_tracer, "generator_exhausted", False)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: "List[VariableTracker]",
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
if name == "__next__":
|
||||
return self.next_variable(tx)
|
||||
elif name == "__iter__":
|
||||
# iter(gen) returns itself
|
||||
return self
|
||||
elif name == "send":
|
||||
# Sends a value into the generator function. Returns the next value
|
||||
# yielded by the generator, or raises StopIteration if the generator
|
||||
# exits without yielding another value
|
||||
if self._is_generator_just_started() and len(args):
|
||||
# can't send non-None value to a just-started generator
|
||||
# Test: GeneratorCPythonTests.test_send_non_none_to_new_gen
|
||||
if not all(
|
||||
isinstance(arg, ConstantVariable) and arg.value is None
|
||||
for arg in args
|
||||
):
|
||||
raise_observed_exception(TypeError, tx)
|
||||
tracer = self._get_inline_tracer(tx)
|
||||
tracer.push_many(args)
|
||||
return self.next_variable(tx)
|
||||
elif name == "close":
|
||||
# * Raises a GeneratorExit at the point where the generator function was paused.
|
||||
# * If the generator function catches the exception and returns a
|
||||
# value, this value is returned from close() - Python 3.13+
|
||||
# * If the generator function is already closed, or raises GeneratorExit
|
||||
# (by not catching the exception), close() returns None.
|
||||
# * If the generator yields a value, a RuntimeError is raised.
|
||||
# * If the generator raises any other exception, it is propagated to the caller.
|
||||
# * If the generator has already exited due to an exception or normal
|
||||
# exit, close() returns None and has no other effect.
|
||||
|
||||
# Return None if close is called on a just-started generator
|
||||
# See test GeneratorCloseCpythonTests::test_close_not_started
|
||||
|
||||
tracer = self._get_inline_tracer(tx)
|
||||
if self._is_generator_just_started() or self._is_generator_exhausted():
|
||||
tracer.generator_exhausted = True
|
||||
return variables.ConstantVariable(None)
|
||||
|
||||
# Raise GeneratorExit to see if user code catches it. Any other exception
|
||||
# is propagated to the parent frame.
|
||||
try:
|
||||
self._setup_exception(
|
||||
tx, variables.ExceptionVariable(GeneratorExit, ())
|
||||
)
|
||||
# There's an extra block on Python 3.12+ to handle StopIteration
|
||||
# see: https://github.com/python/cpython/blob/8f93dd8a8f237b277abad20d566df90c5cbd7f1e/Objects/genobject.c#L394-L397
|
||||
#
|
||||
# 1 0 RETURN_GENERATOR
|
||||
# 2 POP_TOP
|
||||
# 4 RESUME 0
|
||||
|
||||
# 2 6 LOAD_CONST 1 (1)
|
||||
# 8 YIELD_VALUE 1
|
||||
# 10 RESUME 1
|
||||
# 12 POP_TOP
|
||||
# 14 RETURN_CONST 0 (None)
|
||||
# >> 16 CALL_INTRINSIC_1 3 (INTRINSIC_STOPITERATION_ERROR)
|
||||
# 18 RERAISE 1
|
||||
# ExceptionTable:
|
||||
# 4 to 14 -> 16 [0] lasti
|
||||
if (
|
||||
sys.version_info >= (3, 12)
|
||||
and tracer.next_instruction.opname == "CALL_INTRINSIC_1"
|
||||
):
|
||||
tracer.generator_exhausted = True
|
||||
return variables.ConstantVariable(None)
|
||||
except ObservedGeneratorExit:
|
||||
# If it doesn't catch, we just return None, as per the text above
|
||||
tracer.generator_exhausted = True
|
||||
return variables.ConstantVariable(None)
|
||||
|
||||
try:
|
||||
# Raise RuntimeError if the generator yields any other value
|
||||
if self.next_variable(tx):
|
||||
raise_observed_exception(RuntimeError, tx)
|
||||
except ObservedGeneratorExit:
|
||||
tracer.generator_exhausted = True
|
||||
return variables.ConstantVariable(None)
|
||||
except ObservedUserStopIteration:
|
||||
# In Python 3.13+, one can capture GeneratorExit and return a value
|
||||
# See test_generator.py::test_close_capture_GeneratorExit_return
|
||||
# https://discuss.python.org/t/let-generator-close-return-stopiteration-value/24786/26
|
||||
# https://github.com/python/cpython/pull/104771
|
||||
assert tracer.symbolic_result is not None
|
||||
return tracer.symbolic_result
|
||||
elif name == "throw":
|
||||
# * Raises an exception at the point where the generator was paused, and
|
||||
# returns the next value yielded by the generator.
|
||||
# * If the generator exits without yielding, raise StopIteration
|
||||
# * If the generator function does not catch the passed-in exception,
|
||||
# or raises a different exception, then that exception propagates to the caller.
|
||||
|
||||
if len(args) > 1:
|
||||
raise IncorrectUsage(
|
||||
"the (type, exc, tb) signature of throw() is deprecated, "
|
||||
"use the single-arg signature instead."
|
||||
)
|
||||
|
||||
# Setup the exception table and jump target in case of try...finally
|
||||
tracer = self._get_inline_tracer(tx)
|
||||
try:
|
||||
self._setup_exception(tx, args[0])
|
||||
except ObservedException:
|
||||
# propagate the exception back to the parent caller
|
||||
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
|
||||
raise
|
||||
|
||||
retval = self.next_variable(tx)
|
||||
|
||||
# The exception raised before is still active. We need to check the exception
|
||||
# table one more time to find the next target. But why? Let’s walk
|
||||
# through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M
|
||||
#
|
||||
# z = 0
|
||||
# def whoo():
|
||||
# global z
|
||||
# z = 0
|
||||
# try:
|
||||
# yield 1
|
||||
# except ValueError:
|
||||
# yield 2
|
||||
# finally:
|
||||
# z += 1
|
||||
# z += 10
|
||||
#
|
||||
# gen = whoo()
|
||||
# next(gen)
|
||||
# gen.throw(ValueError)
|
||||
# print('z', z) -> z = 1
|
||||
#
|
||||
# ...
|
||||
# >> 58 PUSH_EXC_INFO
|
||||
#
|
||||
# 8 60 LOAD_GLOBAL 2 (ValueError)
|
||||
# 70 CHECK_EXC_MATCH
|
||||
# 72 POP_JUMP_IF_FALSE 7 (to 88)
|
||||
# 74 POP_TOP
|
||||
#
|
||||
# 9 76 LOAD_CONST 3 (2)
|
||||
# 78 YIELD_VALUE 3 <------ ValueError is still active here
|
||||
# 80 RESUME 1
|
||||
# 82 POP_TOP
|
||||
# 84 POP_EXCEPT
|
||||
# 86 jump_backward 34 (to 20)
|
||||
# ...
|
||||
#
|
||||
# ExceptionTable:
|
||||
# 4 to 8 -> 124 [0] lasti
|
||||
# 12 to 18 -> 58 [0]
|
||||
# 20 to 56 -> 124 [0] lasti
|
||||
# 58 to 82 -> 90 [1] lasti <------ move to 90
|
||||
# 84 to 86 -> 96 [0]
|
||||
# 88 to 88 -> 90 [1] lasti
|
||||
# 90 to 94 -> 96 [0]
|
||||
# 96 to 116 -> 118 [1] lasti
|
||||
# 118 to 122 -> 124 [0] lasti
|
||||
#
|
||||
# In this scenario, a generator can yield after `throw()` is called. Even
|
||||
# after the exception is raised a few lines above, it remains active
|
||||
# within the `78 YIELD_VALUE` instruction. When the generator resumes
|
||||
# after the second yield on instruction `80 RESUME`, we cannot simply
|
||||
# return the control flow to the next instruction. Instead, one must
|
||||
# check the exception table (or equivalent) to find the next target
|
||||
# In this case, it says the instruction pointer must be moved to 90.
|
||||
#
|
||||
# Without this step, if we let the trace proceed to the next
|
||||
# instruction, it would follow the control flow where the exception
|
||||
# raised by `throw()` was handled and swallowed, potentially leading
|
||||
# to incorrect behavior.
|
||||
exc_type = type("__InternalThrowException", (Exception,), {})
|
||||
|
||||
try:
|
||||
self._setup_exception(tx, variables.ExceptionVariable(exc_type, ()))
|
||||
self.next_variable(tx)
|
||||
except get_dynamo_observed_exception(exc_type):
|
||||
# We should get back the exception raised before.
|
||||
pass
|
||||
except ObservedException:
|
||||
# Propagate anything else back to the parent caller
|
||||
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
|
||||
else:
|
||||
raise_observed_exception(RuntimeError, tracer)
|
||||
return retval
|
||||
|
||||
super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
||||
class ContextlibContextManagerLocalGeneratorObjectVariable(
|
||||
LocalGeneratorObjectVariable
|
||||
):
|
||||
"""
|
||||
.. note::
|
||||
|
||||
This is only used when the function is annotated with @contextlib.contextmanager
|
||||
|
||||
It is a special case of a generator function as we do not allow return a context manager
|
||||
from a torch.compile function.
|
||||
"""
|
||||
|
||||
|
||||
class LocalGeneratorFunctionVariable(BaseUserFunctionVariable):
|
||||
"""functions that behaves like iterators
|
||||
|
||||
.. note::
|
||||
|
||||
This is only used when the function is annotated with @contextlib.contextmanager
|
||||
This is a wrapper around (Nested)UserFunctionVariable
|
||||
"""
|
||||
|
||||
def __init__(self, vt: VariableTracker, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
vt: VariableTracker,
|
||||
*,
|
||||
generator_cls=LocalGeneratorObjectVariable,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.vt = vt
|
||||
self.inline_tracer = None
|
||||
self.generator_cls = generator_cls
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self.__class__.__dict__.keys():
|
||||
return getattr(self, name)
|
||||
return getattr(self.vt, name)
|
||||
|
||||
def _build_inline_tracer(self, tx, args, kwargs):
|
||||
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||
|
||||
return InliningInstructionTranslator.build_inline_tracer(
|
||||
tx,
|
||||
self,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
def call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
from torch._dynamo.bytecode_transformation import is_generator
|
||||
assert is_generator(self.vt.get_code())
|
||||
|
||||
assert is_generator(self.get_code())
|
||||
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||
inline_tracer = self._build_inline_tracer(tx, args, kwargs)
|
||||
code = self.vt.get_code()
|
||||
f_globals = self.vt.get_globals()
|
||||
|
||||
self.inline_tracer = InliningInstructionTranslator.build_inline_tracer(
|
||||
tx,
|
||||
self,
|
||||
[*self.self_args(), *args],
|
||||
kwargs,
|
||||
stop_generator_on_yield=True,
|
||||
# calling a generator returns a generator object
|
||||
return self.generator_cls(
|
||||
code,
|
||||
f_globals,
|
||||
inline_tracer,
|
||||
source=self.source,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def next_variable(self, tx):
|
||||
from torch._dynamo import exc
|
||||
class FunctionDecoratedByContextlibContextManagerVariable(
|
||||
LocalGeneratorFunctionVariable
|
||||
):
|
||||
"""
|
||||
.. note::
|
||||
|
||||
tracer = self.inline_tracer
|
||||
This is only used when the function is annotated with @contextlib.contextmanager
|
||||
"""
|
||||
|
||||
try:
|
||||
# Hierarchically, tx can be seen as the parent of the inline tracer
|
||||
# created on call_function. Any exception needs to be propagated to tx
|
||||
# for Dynamo to behave correctly
|
||||
with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
|
||||
return tracer.inline_call_().next_variable(tx)
|
||||
except exc.ObservedException as e:
|
||||
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
|
||||
raise e
|
||||
def __init__(self, vt, **kwargs):
|
||||
super().__init__(
|
||||
vt,
|
||||
generator_cls=ContextlibContextManagerLocalGeneratorObjectVariable,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _build_inline_tracer(self, tx, args, kwargs):
|
||||
# NOTE: This only exists to not break support for context manager when
|
||||
# config.enable_faithful_generator_behavior = False and
|
||||
# config.enable_trace_contextlib = True. In case the former is false,
|
||||
# Dynamo should still be able to trace through @contextmanager functions
|
||||
tracer = super()._build_inline_tracer(tx, args, kwargs)
|
||||
assert isinstance(
|
||||
tracer,
|
||||
torch._dynamo.symbolic_convert.InliningGeneratorInstructionTranslator,
|
||||
)
|
||||
tracer.is_generator_from_ctx_manager = True
|
||||
return tracer
|
||||
|
||||
|
||||
class UserMethodVariable(UserFunctionVariable):
|
||||
|
|
|
|||
|
|
@ -234,6 +234,7 @@ class SuperVariable(VariableTracker):
|
|||
|
||||
|
||||
class ExceptionVariable(VariableTracker):
|
||||
# The ExceptionVariable corresponds to the BaseException class in Python
|
||||
def __init__(self, exc_type, args, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.exc_type = exc_type
|
||||
|
|
|
|||
|
|
@ -493,7 +493,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
):
|
||||
if not torch._dynamo.config.enable_trace_contextlib:
|
||||
unimplemented("contextlib.contextmanager")
|
||||
# Replace UserFunctionVariable by FunctionDecoratedBycontextlibContextManagerVariable
|
||||
# Wrap UserFunctionVariable in FunctionDecoratedByContextlibContextManagerVariable
|
||||
# if the function is annotated with @contextlib.contextmanager
|
||||
# This shouldn't be necessary once generator functions are fully
|
||||
# supported in dynamo
|
||||
|
|
@ -805,6 +805,11 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
# of the cmp_eq polyfill function.
|
||||
return ConstantVariable.create(self.value is other.value)
|
||||
|
||||
if torch._dynamo.config.enable_faithful_generator_behavior and isinstance(
|
||||
self.value, types.GeneratorType
|
||||
):
|
||||
unimplemented("Generator as graph argument is not supported")
|
||||
|
||||
# check for methods implemented in C++
|
||||
if isinstance(method, types.FunctionType):
|
||||
source = (
|
||||
|
|
|
|||
|
|
@ -1822,7 +1822,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
|||
metadata: ViewAndMutationMeta = fw_metadata # type: ignore[assignment]
|
||||
maybe_subclass_metadata: Optional[SubclassMeta] = maybe_subclass_meta
|
||||
num_symints_saved_for_bw = num_symints_saved_for_bw_
|
||||
_compiled_autograd_should_lift = False
|
||||
_aot_id = aot_config.aot_id
|
||||
_lazy_backward_info = lazy_backward_info
|
||||
|
||||
|
|
@ -1989,7 +1988,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
|||
# https://github.com/pytorch/pytorch/pull/92348/files#r1072962107
|
||||
class CompiledFunctionBackward(torch.autograd.Function):
|
||||
# CompiledFunctionBackward is not yet supported in dynamo skipfiles
|
||||
_compiled_autograd_should_lift = False
|
||||
_aot_id = aot_config.aot_id
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -1,3 +1,8 @@
|
|||
from torch._higher_order_ops._invoke_quant import (
|
||||
invoke_quant,
|
||||
invoke_quant_packed,
|
||||
InvokeQuant,
|
||||
)
|
||||
from torch._higher_order_ops.aoti_call_delegate import aoti_call_delegate
|
||||
from torch._higher_order_ops.associative_scan import associative_scan
|
||||
from torch._higher_order_ops.auto_functionalize import (
|
||||
|
|
@ -51,6 +56,9 @@ __all__ = [
|
|||
"executorch_call_delegate",
|
||||
"call_torchbind",
|
||||
"run_const_graph",
|
||||
"InvokeQuant",
|
||||
"invoke_quant",
|
||||
"invoke_quant_packed",
|
||||
"wrap_with_set_grad_enabled",
|
||||
"wrap_with_autocast",
|
||||
"wrap_activation_checkpoint",
|
||||
|
|
|
|||
72
torch/_higher_order_ops/_invoke_quant.py
Normal file
72
torch/_higher_order_ops/_invoke_quant.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# need to fix prim_hop_base type annotations first
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.prim_hop_base import FunctionWithNoFreeVars, PrimHOPBase
|
||||
|
||||
|
||||
class InvokeQuantTracer(PrimHOPBase):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("invoke_quant_packed")
|
||||
|
||||
def __call__(self, subgraph, operands, *, scheme=None, quant_options=None):
|
||||
subgraph = FunctionWithNoFreeVars(subgraph)
|
||||
return super().__call__(
|
||||
subgraph, operands, scheme=scheme, quant_options=quant_options
|
||||
)
|
||||
|
||||
|
||||
invoke_quant_packed = InvokeQuantTracer()
|
||||
|
||||
|
||||
class InvokeQuantUnpacked(PrimHOPBase):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("invoke_quant")
|
||||
|
||||
def __call__(self, subgraph, *operands, scheme=None):
|
||||
return super().__call__(subgraph, operands, scheme=scheme)
|
||||
|
||||
def _call_FakeTensorMode(
|
||||
self, mode, subgraph, operands, scheme: Optional[str] = None, **kwargs
|
||||
):
|
||||
# TODO: this should probably route through FakeTensorMode to reuse caching
|
||||
with mode:
|
||||
return subgraph(*operands[0], **kwargs)
|
||||
|
||||
|
||||
invoke_quant = InvokeQuantUnpacked()
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, repr=True)
|
||||
class InvokeQuant:
|
||||
"""
|
||||
Invoke a quantization function that will be preserved as a single operator. Preservation
|
||||
as a single operator aids in pattern matching and custom lowerings.
|
||||
|
||||
The operation appears as:
|
||||
torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=scheme)
|
||||
|
||||
Args:
|
||||
codegen_low_precision: Use observed subgraph dtypes for codegen instead of
|
||||
upcasting to fp32. Can improve performance for prologue fusion but
|
||||
requires careful testing of numerics.
|
||||
"""
|
||||
|
||||
codegen_low_precision: bool = True
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*args,
|
||||
scheme: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if not torch._utils.is_compiling():
|
||||
return args[0](*args[1], **kwargs)
|
||||
|
||||
if scheme is not None:
|
||||
kwargs["scheme"] = scheme
|
||||
|
||||
return invoke_quant_packed(*args, **kwargs, quant_options=self) # type: ignore[call-arg]
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import dataclasses
|
||||
import itertools
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import sympy
|
||||
|
||||
|
|
@ -9,6 +9,7 @@ from torch._inductor import config
|
|||
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
|
||||
from torch._inductor.index_propagation import SymPyOps, TypedExpr
|
||||
|
||||
from .ops_handler import DefaultHandler
|
||||
from .virtualized import StoreMode, V
|
||||
|
||||
|
||||
|
|
@ -20,7 +21,7 @@ def construct_symbol(count: int, dtype: torch.dtype) -> sympy.Symbol:
|
|||
return sympy.Symbol(f"unknown_{count}")
|
||||
|
||||
|
||||
class PreservesZeros(SymPyOps):
|
||||
class PreservesZeros(SymPyOps, DefaultHandler):
|
||||
"""
|
||||
For prologue kernels where the loads are masked, does the final store of this kernel preserve
|
||||
the zeros.
|
||||
|
|
@ -31,41 +32,32 @@ class PreservesZeros(SymPyOps):
|
|||
self.store_preserves_zeros: Optional[bool] = None
|
||||
self.dtype_prop = DtypePropagationOpsHandler()
|
||||
|
||||
@staticmethod
|
||||
def load(name: str, index: sympy.Expr) -> TypedExpr:
|
||||
def load(self, name: str, index: sympy.Expr) -> TypedExpr:
|
||||
# In prologue fusion, all loads get broadcasted
|
||||
dtype = V.get_ops_handler().dtype_prop.load(name, index)
|
||||
dtype = self.dtype_prop.load(name, index)
|
||||
return TypedExpr(
|
||||
sympy.Float(0) if dtype.is_floating_point else sympy.Integer(0), dtype
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def store(
|
||||
name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None
|
||||
self, name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None
|
||||
) -> None:
|
||||
self = V.get_ops_handler()
|
||||
assert isinstance(self, PreservesZeros)
|
||||
# should only have a single store in prologue
|
||||
assert self.store_preserves_zeros is None
|
||||
self.store_preserves_zeros = value.is_constant() and value.expr == 0
|
||||
|
||||
@staticmethod
|
||||
def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr:
|
||||
self = V.get_ops_handler()
|
||||
def indirect_indexing(self, *args: Any, **kwargs: Any) -> sympy.Expr:
|
||||
return construct_symbol(next(self.count), torch.int32)
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
from torch._inductor.codegen.common import OpDecompositions
|
||||
|
||||
def inner(*args: Any, **kwargs: Any) -> TypedExpr:
|
||||
if hasattr(OpDecompositions, name):
|
||||
return getattr(OpDecompositions, name)(*args, **kwargs).value
|
||||
if hasattr(OpDecompositions, name):
|
||||
return getattr(OpDecompositions, name)(*args, **kwargs).value
|
||||
|
||||
nonlocal self
|
||||
dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
|
||||
return TypedExpr(construct_symbol(next(self.count), dtype), dtype)
|
||||
|
||||
return inner
|
||||
dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
|
||||
return TypedExpr(construct_symbol(next(self.count), dtype), dtype)
|
||||
|
||||
|
||||
def prologue_preserves_zero_mask(prologue: "SchedulerNode") -> bool:
|
||||
|
|
@ -88,7 +80,7 @@ class DTypeContainer:
|
|||
is_scalar: bool = False
|
||||
|
||||
|
||||
class RecordLowPrecisionOps:
|
||||
class RecordLowPrecisionOps(DefaultHandler):
|
||||
def __init__(self) -> None:
|
||||
self.low_precision_numeric_op = False
|
||||
self.dtype_prop = DtypePropagationOpsHandler()
|
||||
|
|
@ -97,9 +89,8 @@ class RecordLowPrecisionOps:
|
|||
"constant",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load(name: str, index: sympy.Expr) -> DTypeContainer:
|
||||
return DTypeContainer(V.get_ops_handler().dtype_prop.load(name, index))
|
||||
def load(self, name: str, index: sympy.Expr) -> DTypeContainer:
|
||||
return DTypeContainer(self.dtype_prop.load(name, index))
|
||||
|
||||
@staticmethod
|
||||
def store(
|
||||
|
|
@ -111,28 +102,25 @@ class RecordLowPrecisionOps:
|
|||
def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr:
|
||||
return sympy.S.Zero
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
def low_prec_float(dtype: torch.dtype) -> bool:
|
||||
return dtype.is_floating_point and dtype.itemsize < 4
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
|
||||
out = DTypeContainer(out_dtype, is_scalar=(name == "constant"))
|
||||
if name == "constant":
|
||||
out = DTypeContainer(torch.float, is_scalar=True)
|
||||
|
||||
def inner(*args: Any, **kwargs: Any) -> DTypeContainer:
|
||||
out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
|
||||
out = DTypeContainer(out_dtype, is_scalar=(name == "constant"))
|
||||
if name == "constant":
|
||||
out = DTypeContainer(torch.float, is_scalar=True)
|
||||
uses_low_prec = any(
|
||||
isinstance(dtype_cont, DTypeContainer) and low_prec_float(dtype_cont.dtype)
|
||||
for dtype_cont in itertools.chain((out,), args, kwargs.values())
|
||||
)
|
||||
|
||||
uses_low_prec = any(
|
||||
isinstance(dtype_cont, DTypeContainer)
|
||||
and low_prec_float(dtype_cont.dtype)
|
||||
for dtype_cont in itertools.chain((out,), args, kwargs.values())
|
||||
)
|
||||
if uses_low_prec and name not in self.non_numeric_ops:
|
||||
self.low_precision_numeric_op = True
|
||||
|
||||
if uses_low_prec and name not in self.non_numeric_ops:
|
||||
self.low_precision_numeric_op = True
|
||||
return out
|
||||
|
||||
return out
|
||||
|
||||
return inner
|
||||
def low_prec_float(dtype: torch.dtype) -> bool:
|
||||
return dtype.is_floating_point and dtype.itemsize < 4
|
||||
|
||||
|
||||
def can_codegen_without_upcasts(
|
||||
|
|
|
|||
|
|
@ -1,14 +1,22 @@
|
|||
import logging
|
||||
import operator
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import sympy
|
||||
from sympy import Expr
|
||||
|
||||
import torch
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
||||
from torch.utils._sympy.value_ranges import (
|
||||
bound_sympy,
|
||||
SymPyValueRangeAnalysis,
|
||||
ValueRanges,
|
||||
)
|
||||
|
||||
from ..utils._sympy.functions import PowByNatural
|
||||
from ..utils._sympy.numbers import int_oo
|
||||
from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock
|
||||
from .ops_handler import DefaultHandler, ReductionType, StoreMode
|
||||
from .utils import cache_on_self, dominated_nodes
|
||||
from .virtualized import V
|
||||
|
||||
|
|
@ -139,3 +147,113 @@ class BoundVars:
|
|||
# assert bound is None or bound == bound_sympy(expr, self.replacement_vals)
|
||||
self.replacement_vals[name] = bound
|
||||
return bound
|
||||
|
||||
|
||||
class ValueRangeAnalysis(SymPyValueRangeAnalysis, DefaultHandler):
|
||||
def __init__(self) -> None:
|
||||
self.name = "ValueRangeAnalysis"
|
||||
boolean_operators = (
|
||||
"xor",
|
||||
"logical_and",
|
||||
"logical_or",
|
||||
"logical_not",
|
||||
)
|
||||
for op in boolean_operators:
|
||||
setattr(self, op, self.bool_handler)
|
||||
|
||||
@staticmethod
|
||||
def bool_handler(*args: Any, **kwargs: Any) -> ValueRanges[Any]:
|
||||
# just assuming bools can have both values
|
||||
return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type]
|
||||
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
# many ops are unlikely to show up in optimizable indexing compute,
|
||||
# so we dont have full coverage
|
||||
return ValueRanges.unknown()
|
||||
|
||||
def load(self, name: str, index: sympy.Expr) -> ValueRanges[Any]:
|
||||
return ValueRanges.unknown()
|
||||
|
||||
def store(
|
||||
self, name: str, index: sympy.Expr, value: Any, mode: StoreMode = None
|
||||
) -> None:
|
||||
return
|
||||
|
||||
def reduction(
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
src_dtype: torch.dtype,
|
||||
reduction_type: ReductionType,
|
||||
value: Any,
|
||||
) -> ValueRanges[Any]:
|
||||
return ValueRanges.unknown()
|
||||
|
||||
@classmethod
|
||||
def index_expr(cls, index: Any, dtype: torch.dtype) -> ValueRanges[Any]:
|
||||
assert isinstance(index, ValueRanges)
|
||||
return cls.to_dtype(index, dtype)
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(
|
||||
x: Any,
|
||||
dtype: torch.dtype,
|
||||
src_dtype: Optional[torch.dtype] = None,
|
||||
use_compute_types: bool = True,
|
||||
) -> ValueRanges[Any]:
|
||||
x = ValueRanges.wrap(x)
|
||||
|
||||
if dtype == torch.bool:
|
||||
if x.is_singleton():
|
||||
return ValueRanges.wrap(x.lower != 0)
|
||||
elif x.is_bool:
|
||||
return x
|
||||
elif 0 not in x:
|
||||
return ValueRanges.wrap(sympy.true)
|
||||
else:
|
||||
return ValueRanges(sympy.false, sympy.true)
|
||||
|
||||
def cast(x: Any, dtype: torch.dtype) -> sympy.Expr:
|
||||
# dtype is int or float
|
||||
if dtype.is_floating_point:
|
||||
return sympy.Float(x)
|
||||
else:
|
||||
if x in (int_oo, -int_oo):
|
||||
return x
|
||||
try:
|
||||
return sympy.Integer(x)
|
||||
except TypeError:
|
||||
# inf cannot be cast to Integer
|
||||
return x
|
||||
|
||||
if x.is_bool:
|
||||
if x.is_singleton():
|
||||
val = 1 if x.lower else 0
|
||||
return ValueRanges.wrap(cast(val, dtype))
|
||||
else:
|
||||
return ValueRanges(cast(0, dtype), cast(1, dtype))
|
||||
else:
|
||||
# int to float or float to int
|
||||
return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype))
|
||||
|
||||
@staticmethod
|
||||
def square(x: Any) -> ValueRanges[Any]:
|
||||
return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2))
|
||||
|
||||
@staticmethod
|
||||
def neg(x: Any) -> ValueRanges[Any]:
|
||||
return ValueRanges.decreasing_map(x, operator.neg)
|
||||
|
||||
# TODO: this is slightly inaccurate because truncdiv operates at integer
|
||||
# precision, but we're going through float truediv which means we can
|
||||
# potentially lose precision on the bounds
|
||||
@classmethod
|
||||
def truncdiv(cls, a: Any, b: Any) -> ValueRanges[Any]:
|
||||
x = cls.truediv(a, b)
|
||||
if x == ValueRanges.unknown():
|
||||
return x
|
||||
|
||||
return cls.trunc(x)
|
||||
|
||||
@classmethod
|
||||
def sub(cls, a: Any, b: Any) -> ValueRanges[Any]:
|
||||
return cls.add(a, cls.neg(b))
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ class BlockPatternMatcher:
|
|||
Matches block indexing expressions.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_subexpr_involving_symbol(expr: Expr, symbol: Symbol) -> Expr:
|
||||
@classmethod
|
||||
def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr:
|
||||
"""
|
||||
Given a sympy expression, return the subexpression comprised only of terms
|
||||
involving the specified symbol.
|
||||
|
|
@ -26,6 +26,7 @@ class BlockPatternMatcher:
|
|||
For example, if `expr` is `x * 5 + x ** 2 + y * 2 + 5`, and `symbol` is `x`,
|
||||
this returns `x * 5 + x ** 2`.
|
||||
"""
|
||||
expr = cls._preprocess(expr)
|
||||
return sympy.S.Zero + sum(
|
||||
term for term in sympy.Add.make_args(expr) if symbol in term.free_symbols
|
||||
)
|
||||
|
|
@ -42,6 +43,11 @@ class BlockPatternMatcher:
|
|||
numels.appendleft(numel)
|
||||
return [*numels]
|
||||
|
||||
@staticmethod
|
||||
def _preprocess(expr: Expr) -> Expr:
|
||||
# Remove any Identity nodes, e.g. expand x + (5 * y) to x + 5 * y.
|
||||
return expr.expand(identity=True)
|
||||
|
||||
@classmethod
|
||||
def match_mod_div_block_expr(
|
||||
cls,
|
||||
|
|
@ -54,6 +60,7 @@ class BlockPatternMatcher:
|
|||
Matches modular indexing expressions, converting them to implied block dimensions and strides.
|
||||
See triton.py for more information.
|
||||
"""
|
||||
index = cls._preprocess(index)
|
||||
|
||||
# Pattern match to find the strides and offset.
|
||||
wild = functools.partial(sympy.Wild, exclude=[index_var])
|
||||
|
|
@ -141,3 +148,21 @@ class BlockPatternMatcher:
|
|||
)
|
||||
|
||||
return dims, strides, block_index_exprs
|
||||
|
||||
@classmethod
|
||||
def match_affine_block_expr(
|
||||
cls,
|
||||
index: Expr,
|
||||
index_var: Symbol,
|
||||
) -> Optional[Expr]:
|
||||
"""
|
||||
Matches simple expressions of the form stride * index, returning the
|
||||
stride.
|
||||
"""
|
||||
index = cls._preprocess(index)
|
||||
stride = sympy.Wild("stride", exclude=[index_var])
|
||||
m = index.match(index_var * stride)
|
||||
if m is None:
|
||||
return None
|
||||
|
||||
return m[stride]
|
||||
|
|
|
|||
|
|
@ -37,11 +37,11 @@ from torch.utils._ordered_set import OrderedSet
|
|||
from torch.utils._sympy.numbers import int_oo
|
||||
from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter
|
||||
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
|
||||
|
||||
from .. import config, metrics
|
||||
from ..dtype_propagation import DtypePropagationOpsHandler
|
||||
from ..ops_handler import BasicMathOps
|
||||
from ..ops_handler import BasicMathOpsMixin, DefaultHandler
|
||||
from ..utils import (
|
||||
boolean_ops,
|
||||
DeferredLineBase,
|
||||
|
|
@ -764,7 +764,7 @@ def _all_in_parens(string: str) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
class OpOverrides(BasicMathOps, OpDecompositions):
|
||||
class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
|
||||
@staticmethod
|
||||
def paren(string: OpVarT) -> OpVarT:
|
||||
if (
|
||||
|
|
@ -948,6 +948,16 @@ class OpOverrides(BasicMathOps, OpDecompositions):
|
|||
f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend"
|
||||
)
|
||||
|
||||
def output(self, *args: OpVarT) -> None:
|
||||
raise AssertionError(
|
||||
f"{type(self).__name__}: ops.output should not appear at codegen time"
|
||||
)
|
||||
|
||||
def placeholder(self, index: int) -> OpVarT:
|
||||
raise AssertionError(
|
||||
f"{type(self).__name__}: ops.placeholder should not appear at codegen time"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _unimplemented(name: str) -> Callable[..., OpVarT]:
|
||||
def unimplemented(self: OpOverrides, *args: Any, **kwargs: Any) -> OpVarT:
|
||||
|
|
@ -1225,12 +1235,6 @@ pointwise_overrides_data: dict[str, OverridesData] = dict(
|
|||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_OpOverrides(OpOverrides, OpsHandler[str]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class DeferredLine(DeferredLineBase):
|
||||
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
|
||||
|
||||
|
|
@ -2253,84 +2257,84 @@ class KernelTemplate:
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class CSEProxy:
|
||||
class CSEProxy(DefaultHandler):
|
||||
name = "CSEProxy"
|
||||
|
||||
def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]):
|
||||
super().__init__()
|
||||
from ..bounds import ValueRangeAnalysis
|
||||
|
||||
self.vr_analysis = ValueRangeAnalysis()
|
||||
self.kernel = kernel
|
||||
self.parent_handler = parent_handler
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
|
||||
def inner(*args: Any, **kwargs: Any) -> CSEVariable:
|
||||
bounds = self._bound_variable(name, *args, **kwargs)
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
bounds = self._bound_variable(name, *args, **kwargs)
|
||||
|
||||
value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
|
||||
dtype_handler = DtypePropagationOpsHandler()
|
||||
value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
|
||||
dtype_handler = DtypePropagationOpsHandler()
|
||||
|
||||
output_idx = 0
|
||||
output_idx = 0
|
||||
|
||||
def do_cse(v: str) -> CSEVariable:
|
||||
# cpp backend doesnt set current device - TODO: fix
|
||||
if V.graph.current_device is not None:
|
||||
device_str = V.graph.get_current_device_or_throw().type
|
||||
triton_backend = (
|
||||
config.cpu_backend == "triton"
|
||||
if device_str == "cpu"
|
||||
else config.cuda_backend == "triton"
|
||||
if device_str != "mps"
|
||||
else False
|
||||
)
|
||||
else:
|
||||
triton_backend = False
|
||||
|
||||
# only triton backend tracks dtype currently
|
||||
if triton_backend:
|
||||
if name == "masked":
|
||||
output_dtype = value.dtype
|
||||
else:
|
||||
output_dtype = getattr(
|
||||
dtype_handler,
|
||||
name,
|
||||
)(*args, **kwargs)
|
||||
else:
|
||||
# cpp backend doesnt track dtype yet
|
||||
output_dtype = None
|
||||
|
||||
csevar = V.kernel.cse.generate(
|
||||
V.kernel.compute,
|
||||
v,
|
||||
bounds=bounds,
|
||||
dtype=output_dtype,
|
||||
def do_cse(v: str) -> CSEVariable:
|
||||
# cpp backend doesnt set current device - TODO: fix
|
||||
if V.graph.current_device is not None:
|
||||
device_str = V.graph.get_current_device_or_throw().type
|
||||
triton_backend = (
|
||||
config.cpu_backend == "triton"
|
||||
if device_str == "cpu"
|
||||
else config.cuda_backend == "triton"
|
||||
if device_str != "mps"
|
||||
else False
|
||||
)
|
||||
else:
|
||||
triton_backend = False
|
||||
|
||||
nonlocal output_idx
|
||||
if config.test_configs.runtime_triton_dtype_assert and triton_backend:
|
||||
from torch._inductor.codegen.triton import triton_type
|
||||
# only triton backend tracks dtype currently
|
||||
if triton_backend:
|
||||
if name == "masked":
|
||||
output_dtype = value.dtype
|
||||
else:
|
||||
output_dtype = getattr(
|
||||
dtype_handler,
|
||||
name,
|
||||
)(*args, **kwargs)
|
||||
else:
|
||||
# cpp backend doesnt track dtype yet
|
||||
output_dtype = None
|
||||
|
||||
# we tree_map over the output, so we need to fetch corresponding dtype
|
||||
if isinstance(output_dtype, (list, tuple)):
|
||||
output_dtype = output_dtype[output_idx]
|
||||
csevar = V.kernel.cse.generate(
|
||||
V.kernel.compute,
|
||||
v,
|
||||
bounds=bounds,
|
||||
dtype=output_dtype,
|
||||
)
|
||||
|
||||
V.kernel.compute.writeline(
|
||||
f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})"
|
||||
)
|
||||
output_idx += 1
|
||||
nonlocal output_idx
|
||||
if config.test_configs.runtime_triton_dtype_assert and triton_backend:
|
||||
from torch._inductor.codegen.triton import triton_type
|
||||
|
||||
csevar.update_on_args(name, args, kwargs)
|
||||
# we tree_map over the output, so we need to fetch corresponding dtype
|
||||
if isinstance(output_dtype, (list, tuple)):
|
||||
output_dtype = output_dtype[output_idx]
|
||||
|
||||
return csevar
|
||||
V.kernel.compute.writeline(
|
||||
f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})"
|
||||
)
|
||||
output_idx += 1
|
||||
|
||||
return pytree.tree_map(do_cse, value)
|
||||
csevar.update_on_args(name, args, kwargs)
|
||||
|
||||
return inner
|
||||
return csevar
|
||||
|
||||
return pytree.tree_map(do_cse, value)
|
||||
|
||||
def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[Any]:
|
||||
"""
|
||||
If the variable comes from an FX node, we forward the bound we have already computed
|
||||
Else, if the variable when codegen'ing another op, we try to compute its bounds
|
||||
"""
|
||||
from ..bounds import ValueRangeAnalysis
|
||||
from ..select_algorithm import TritonTemplateKernel
|
||||
|
||||
if isinstance(V.kernel, TritonTemplateKernel):
|
||||
|
|
@ -2568,8 +2572,3 @@ class CSEProxy:
|
|||
sorter,
|
||||
sorter_indices,
|
||||
)
|
||||
|
||||
|
||||
# Use mypy to check protocol implemented correctly
|
||||
def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
|
||||
return h
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from ..utils import (
|
|||
sympy_index_symbol,
|
||||
sympy_subs,
|
||||
)
|
||||
from ..virtualized import _ops as ops, OpsHandler, V
|
||||
from ..virtualized import _ops as ops, V
|
||||
from .common import (
|
||||
BackendFeature,
|
||||
CSEVariable,
|
||||
|
|
@ -563,12 +563,6 @@ class HalideOverrides(OpOverrides):
|
|||
HalideOverrides._initialize_pointwise_overrides("halide")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_HalideOverrides(HalideOverrides, OpsHandler[str]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class HalideCSEVariable(CSEVariable):
|
||||
undefined_re = re.compile(r"\b(tmp\d+)\[\?\]")
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
|||
|
||||
import sympy
|
||||
|
||||
from ..ops_handler import OpsHandler, ReductionType, StoreMode
|
||||
from ..ops_handler import ReductionType, StoreMode
|
||||
from ..scheduler import Scheduler, SchedulerNode
|
||||
from .common import OpVarT
|
||||
|
||||
|
|
@ -367,12 +367,6 @@ class MetalOverrides(OpOverrides):
|
|||
MetalOverrides._initialize_pointwise_overrides("mps")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_MetalOverrides(MetalOverrides, OpsHandler[Any]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class MetalKernel(SIMDKernel):
|
||||
overrides = MetalOverrides # type: ignore[assignment]
|
||||
suffix = ";"
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_ty
|
|||
from ...utils._sympy.value_ranges import ValueRanges
|
||||
from .. import config, ir, metrics
|
||||
from ..codecache import code_hash, get_path, PyCodeCache
|
||||
from ..ops_handler import DefaultHandler
|
||||
from ..runtime.benchmarking import benchmarker
|
||||
from ..runtime.hints import (
|
||||
AutotuneHint,
|
||||
|
|
@ -60,7 +61,7 @@ from ..utils import (
|
|||
triton_version_uses_attrs_dict,
|
||||
upcast_compute_type,
|
||||
)
|
||||
from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V
|
||||
from ..virtualized import _ops as ops, ReductionType, StoreMode, V
|
||||
from ..wrapper_benchmark import get_kernel_category_by_source_code
|
||||
from .block_analysis import BlockPatternMatcher
|
||||
from .common import (
|
||||
|
|
@ -1427,12 +1428,6 @@ class TritonKernelOverrides(TritonOverrides):
|
|||
return (mantissa, exponent)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class HelperFunctions:
|
||||
"""An ordered set of helper functions."""
|
||||
|
||||
|
|
@ -1795,7 +1790,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
and self.index_dtype == "tl.int32"
|
||||
):
|
||||
|
||||
def match_strided_block(
|
||||
def match_affine_block(
|
||||
index: sympy.Expr, range_tree: IterationRangesRoot
|
||||
) -> Optional[BlockParameters]:
|
||||
"""
|
||||
|
|
@ -1804,16 +1799,16 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
|
||||
This implies stride (s,), and shape (XBLOCK,).
|
||||
"""
|
||||
symbol = range_tree.symbol()
|
||||
stride = sympy.Wild("stride", exclude=[symbol])
|
||||
m = index.match(symbol * stride)
|
||||
if m is None:
|
||||
stride = BlockPatternMatcher.match_affine_block_expr(
|
||||
index, range_tree.symbol()
|
||||
)
|
||||
if stride is None:
|
||||
return None
|
||||
|
||||
return BlockParameters(
|
||||
shape=[range_tree.numel],
|
||||
block_shape=[TritonSymbols.get_block_size(range_tree)],
|
||||
strides=[m[stride]],
|
||||
strides=[stride],
|
||||
offsets=[TritonSymbols.get_block_offset(range_tree)],
|
||||
)
|
||||
|
||||
|
|
@ -1922,7 +1917,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
Match a block indexing subexpression involving a single range tree.
|
||||
"""
|
||||
for match_func in (
|
||||
match_strided_block,
|
||||
match_affine_block,
|
||||
match_mod_div_block,
|
||||
):
|
||||
match = match_func(expr, range_tree)
|
||||
|
|
@ -2872,24 +2867,23 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
|
||||
dtype_handler = DtypePropagationOpsHandler()
|
||||
|
||||
class CSEProxy:
|
||||
def __getattr__(self, name: str) -> Callable[..., CSEVariable]:
|
||||
def inner(*args, **kwargs):
|
||||
nonlocal helper_name
|
||||
helper_name += f"_{name}"
|
||||
class CSEProxy(DefaultHandler):
|
||||
def _default(
|
||||
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> Any:
|
||||
nonlocal helper_name
|
||||
helper_name += f"_{name}"
|
||||
|
||||
output_dtype = getattr(
|
||||
dtype_handler,
|
||||
name,
|
||||
)(*args, **kwargs)
|
||||
output_dtype = getattr(
|
||||
dtype_handler,
|
||||
name,
|
||||
)(*args, **kwargs)
|
||||
|
||||
return cse.generate(
|
||||
helper,
|
||||
getattr(overrides, name)(*args, **kwargs),
|
||||
dtype=output_dtype,
|
||||
)
|
||||
|
||||
return inner
|
||||
return cse.generate(
|
||||
helper,
|
||||
getattr(overrides, name)(*args, **kwargs),
|
||||
dtype=output_dtype,
|
||||
)
|
||||
|
||||
with helper.indent(), V.set_ops_handler(CSEProxy()):
|
||||
outputs = fn(*args)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import itertools
|
|||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union
|
||||
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
|
|
@ -15,6 +15,7 @@ from torch.utils._ordered_set import OrderedSet
|
|||
|
||||
from ..utils._sympy.symbol import make_symbol, SymT
|
||||
from .codegen.common import index_prevent_reordering
|
||||
from .ops_handler import DefaultHandler
|
||||
from .utils import (
|
||||
get_dtype_size,
|
||||
reduction_num_outputs,
|
||||
|
|
@ -23,7 +24,7 @@ from .utils import (
|
|||
sympy_subs,
|
||||
VarRanges,
|
||||
)
|
||||
from .virtualized import OpsHandler, ReductionType, V
|
||||
from .virtualized import ReductionType, V
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
|
@ -737,19 +738,16 @@ def canonicalization_prefix() -> str:
|
|||
|
||||
|
||||
# ops handler which computes all the free unbacked symbols for an IR
|
||||
class FreeUnbackedSymbolsOpsHandler:
|
||||
class FreeUnbackedSymbolsOpsHandler(DefaultHandler):
|
||||
symbols: OrderedSet[sympy.Symbol]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.symbols = OrderedSet()
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
def inner(*args: Sequence[Any], **kwargs: Dict[Any, Any]) -> None:
|
||||
for a in itertools.chain(args, kwargs.values()):
|
||||
if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
|
||||
self.symbols |= free_unbacked_symbols(a)
|
||||
|
||||
return inner
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
for a in itertools.chain(args, kwargs.values()):
|
||||
if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
|
||||
self.symbols |= free_unbacked_symbols(a)
|
||||
|
||||
def indirect_indexing(
|
||||
self,
|
||||
|
|
@ -791,12 +789,6 @@ class FreeUnbackedSymbolsOpsHandler:
|
|||
body()
|
||||
|
||||
|
||||
def _typecheck_FreeUnbackedSymbolsOpsHandler(
|
||||
h: FreeUnbackedSymbolsOpsHandler,
|
||||
) -> OpsHandler[None]:
|
||||
return h
|
||||
|
||||
|
||||
def extract_free_unbacked_symbols(
|
||||
fn: Callable[..., Any],
|
||||
index: Sequence[sympy.Expr],
|
||||
|
|
|
|||
|
|
@ -368,6 +368,16 @@ class DtypePropagationOpsHandler:
|
|||
) -> None:
|
||||
return None
|
||||
|
||||
def output(self, *args: DTypeArg) -> None:
|
||||
raise AssertionError(
|
||||
f"{type(self).__name__}: ops.output should not appear here"
|
||||
)
|
||||
|
||||
def placeholder(self, index: int) -> torch.dtype:
|
||||
raise AssertionError(
|
||||
f"{type(self).__name__}: ops.placeholder should not appear here"
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import operator
|
||||
import typing
|
||||
from collections import Counter
|
||||
from typing import Any, Union
|
||||
|
|
@ -442,6 +443,81 @@ def constant_fold_uniform_value(gm: torch.fx.GraphModule):
|
|||
remove_redundant_views(gm)
|
||||
|
||||
|
||||
def canonicalize_quant_mapping(gm: torch.fx.GraphModule):
|
||||
"""
|
||||
|
||||
|
||||
torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'quant_invoke_0_0', (arg0_1, arg1_1));
|
||||
->
|
||||
torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4');
|
||||
"""
|
||||
graph = gm.graph
|
||||
invoke_quant_invocations = graph.find_nodes(
|
||||
op="call_function", target=torch.ops.higher_order.invoke_quant_packed
|
||||
)
|
||||
for invoke_quant in invoke_quant_invocations:
|
||||
kwargs = dict(invoke_quant.kwargs)
|
||||
|
||||
quant_options_node = kwargs.pop("quant_options", None)
|
||||
if quant_options_node is not None:
|
||||
assert isinstance(quant_options_node, torch.fx.Node)
|
||||
quant_options = torch._higher_order_ops.InvokeQuant(
|
||||
*invoke_quant.kwargs["quant_options"].args,
|
||||
**invoke_quant.kwargs["quant_options"].kwargs,
|
||||
)
|
||||
else:
|
||||
quant_options = None
|
||||
|
||||
subgraph, args = invoke_quant.args
|
||||
with gm.graph.inserting_before(invoke_quant):
|
||||
invoke_quant_replacement = graph.call_function(
|
||||
torch._higher_order_ops.invoke_quant,
|
||||
(subgraph, *args),
|
||||
kwargs,
|
||||
)
|
||||
invoke_quant_replacement.meta.update(subgraph.meta)
|
||||
invoke_quant_replacement.meta["quant_options"] = quant_options
|
||||
|
||||
invoke_quant.replace_all_uses_with(invoke_quant_replacement)
|
||||
graph.erase_node(invoke_quant)
|
||||
|
||||
if quant_options_node and len(quant_options_node.users) == 0:
|
||||
graph.erase_node(quant_options_node)
|
||||
|
||||
first_user = next(iter(invoke_quant_replacement.users))
|
||||
|
||||
if (
|
||||
len(invoke_quant_replacement.users) == 1
|
||||
and len(subgraph.users) == 1
|
||||
and first_user.target == operator.getitem
|
||||
and first_user.args[1] == 0
|
||||
):
|
||||
subgraph_graph = getattr(gm, subgraph.target)
|
||||
output_node = torch._inductor.utils.output_node(subgraph_graph)
|
||||
assert (
|
||||
isinstance(output_node.args[0], (list, tuple))
|
||||
and len(output_node.args[0]) == 1
|
||||
)
|
||||
|
||||
unpacked_output = output_node.args[0][0]
|
||||
output_node.args = (unpacked_output,)
|
||||
if "val" in output_node.meta:
|
||||
output_node.meta["val"] = output_node.meta["val"][0]
|
||||
subgraph_graph.recompile()
|
||||
|
||||
invoke_quant_replacement.meta.update(first_user.meta)
|
||||
first_user.replace_all_uses_with(invoke_quant_replacement)
|
||||
graph.erase_node(first_user)
|
||||
|
||||
|
||||
def canonicalize_aten_ir_passes(gm: torch.fx.GraphModule):
|
||||
"""
|
||||
Canonicalization passes that will run immediately after aot autograd
|
||||
tracing. Thsis must be run before all other graph passes.
|
||||
"""
|
||||
canonicalize_quant_mapping(gm)
|
||||
|
||||
|
||||
def joint_graph_passes(graph: torch.fx.GraphModule):
|
||||
"""
|
||||
Run FX transformations on the joint forwards+backwards graph.
|
||||
|
|
@ -454,6 +530,15 @@ def joint_graph_passes(graph: torch.fx.GraphModule):
|
|||
lazy_init()
|
||||
count = 0
|
||||
|
||||
# must occur before other passes
|
||||
canonicalize_aten_ir_passes(graph)
|
||||
|
||||
if config.joint_custom_pre_pass is not None:
|
||||
GraphTransformObserver(graph, "joint_custom_pre_pass").apply_graph_pass(
|
||||
config.joint_custom_pre_pass
|
||||
)
|
||||
count += 1
|
||||
|
||||
from .post_grad import remove_noop_ops
|
||||
|
||||
GraphTransformObserver(graph, "remove_noop_ops").apply_graph_pass(remove_noop_ops)
|
||||
|
|
|
|||
|
|
@ -21,8 +21,9 @@ SymPy expressions yet, despite sympy.Min and sympy.Max existing.
|
|||
|
||||
"""
|
||||
import itertools
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Literal, Optional, overload, Union
|
||||
from typing import Any, Literal, Optional, overload, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import sympy
|
||||
|
|
@ -32,6 +33,7 @@ from torch._prims_common import dtype_to_type, is_integer_dtype
|
|||
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
|
||||
|
||||
from .ops_handler import DefaultHandler
|
||||
from .sizevars import evaluate_expr
|
||||
from .utils import generate_assert
|
||||
from .virtualized import V
|
||||
|
|
@ -185,7 +187,7 @@ class IndexPropVar:
|
|||
IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]]
|
||||
|
||||
|
||||
class IndexPropagation:
|
||||
class IndexPropagation(DefaultHandler):
|
||||
"""Ops wrapper that tries to propagate constant and index_expr values through the computation.
|
||||
|
||||
This aims to maximize the compile time simplification possible, and convert
|
||||
|
|
@ -247,19 +249,19 @@ class IndexPropagation:
|
|||
def fallback(
|
||||
self,
|
||||
name: Literal["indirect_indexing"],
|
||||
args: tuple[Any, ...],
|
||||
args: Sequence[Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> IndexPropVar:
|
||||
...
|
||||
|
||||
@overload
|
||||
def fallback(
|
||||
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
||||
) -> IndexPropResult:
|
||||
...
|
||||
|
||||
def fallback(
|
||||
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
||||
) -> IndexPropResult:
|
||||
# Fallback to the wrapped handler
|
||||
new_args = [self.unwrap(a) for a in args]
|
||||
|
|
@ -267,7 +269,7 @@ class IndexPropagation:
|
|||
return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
|
||||
|
||||
def propagate_sympy(
|
||||
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
||||
) -> IndexPropResult:
|
||||
# Build a new SymPy expression from this ops call
|
||||
def unwrap(a: Union[Any, IndexPropVar]) -> Any:
|
||||
|
|
@ -288,22 +290,19 @@ class IndexPropagation:
|
|||
return self.fallback(name, args, kwargs)
|
||||
return IndexPropVar.new_symbolic(new_expr)
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., IndexPropResult]:
|
||||
def inner(*args: Any, **kwargs: Any) -> IndexPropResult:
|
||||
if not hasattr(SymPyOps, name):
|
||||
return self.fallback(name, args, kwargs)
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
if not hasattr(SymPyOps, name):
|
||||
return self.fallback(name, args, kwargs)
|
||||
|
||||
var_arguments = [
|
||||
a
|
||||
for a in itertools.chain(args, kwargs.values())
|
||||
if isinstance(a, IndexPropVar)
|
||||
]
|
||||
if not all(v.is_symbolic for v in var_arguments):
|
||||
return self.fallback(name, args, kwargs)
|
||||
var_arguments = [
|
||||
a
|
||||
for a in itertools.chain(args, kwargs.values())
|
||||
if isinstance(a, IndexPropVar)
|
||||
]
|
||||
if not all(v.is_symbolic for v in var_arguments):
|
||||
return self.fallback(name, args, kwargs)
|
||||
|
||||
return self.propagate_sympy(name, args, kwargs)
|
||||
|
||||
return inner
|
||||
return self.propagate_sympy(name, args, kwargs)
|
||||
|
||||
def statically_true(self, e):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ from .dependencies import (
|
|||
var_builder,
|
||||
)
|
||||
from .loop_body import LoopBody
|
||||
from .ops_handler import OpCounterCSE, OpCountResult
|
||||
from .ops_handler import OpCounterCSE, OpCountResult, ReductionType, StoreMode
|
||||
from .runtime.benchmarking import benchmarker
|
||||
from .runtime.hints import DeviceProperties, ReductionHint
|
||||
from .utils import (
|
||||
|
|
@ -916,9 +916,9 @@ class Pointwise(Loops):
|
|||
output_name: Optional[str],
|
||||
indexer: Callable[[Sequence[Expr]], Never],
|
||||
vars: Sequence[Expr],
|
||||
) -> OpsValue:
|
||||
) -> None:
|
||||
loader = self.make_loader()
|
||||
return ops.store(output_name, indexer(vars), loader(vars))
|
||||
return ops.store(output_name or "unnamed", indexer(vars), loader(vars))
|
||||
|
||||
def constant_to_device(self, device: torch.device) -> IRNode:
|
||||
"""Move this to a given device. Requires that all reads are to constants."""
|
||||
|
|
@ -932,7 +932,7 @@ class Pointwise(Loops):
|
|||
@ir_dataclass
|
||||
class Scatter(Pointwise):
|
||||
output_indexer: Callable[[Sequence[Expr]], Expr]
|
||||
scatter_mode: Optional[str] = None
|
||||
scatter_mode: StoreMode = None
|
||||
|
||||
def constant_to_device(self, device: torch.device) -> IRNode:
|
||||
"""Move this to a given device. Requires that all reads are to constants."""
|
||||
|
|
@ -952,8 +952,10 @@ class Scatter(Pointwise):
|
|||
output_name: Optional[str],
|
||||
indexer: Callable[[Sequence[Expr]], Never],
|
||||
vars: Sequence[Expr],
|
||||
) -> OpsValue:
|
||||
) -> None:
|
||||
loader = self.make_loader()
|
||||
if output_name is None:
|
||||
output_name = "unnamed"
|
||||
return ops.store(
|
||||
output_name,
|
||||
indexer(self.output_indexer(vars)),
|
||||
|
|
@ -1038,7 +1040,7 @@ def get_reduction_combine_fn(
|
|||
@ir_dataclass
|
||||
class Reduction(Loops):
|
||||
reduction_ranges: Sequence[_IntLike]
|
||||
reduction_type: str
|
||||
reduction_type: ReductionType
|
||||
# self.dtype represents the dst dtype
|
||||
src_dtype: torch.dtype
|
||||
reduction_hint: ReductionHint
|
||||
|
|
@ -1065,14 +1067,14 @@ class Reduction(Loops):
|
|||
indexer: Callable[[Sequence[Expr]], Never],
|
||||
vars: Sequence[Expr],
|
||||
reduction_vars: Sequence[Symbol],
|
||||
) -> OpsValue:
|
||||
) -> None:
|
||||
value = ops.reduction(
|
||||
self.dtype,
|
||||
self.src_dtype,
|
||||
self.reduction_type,
|
||||
self.inner_fn(vars, reduction_vars),
|
||||
)
|
||||
return ops.store_reduction(output_name, indexer(vars), value)
|
||||
return ops.store_reduction(output_name or "unnamed", indexer(vars), value)
|
||||
|
||||
def index_length(self) -> int:
|
||||
return len(self.ranges) + len(self.reduction_ranges)
|
||||
|
|
@ -1110,7 +1112,7 @@ class Reduction(Loops):
|
|||
inner_fn: Callable[..., OpsValue],
|
||||
ranges: Sequence[_IntLike],
|
||||
reduction_ranges: Sequence[_IntLike],
|
||||
reduction_type: str,
|
||||
reduction_type: Union[ReductionType, Literal["scan"]],
|
||||
reduction_numel: Expr,
|
||||
input_node: Optional[IRNode] = None,
|
||||
) -> tuple[ReductionHint, _IntLike]:
|
||||
|
|
@ -1196,7 +1198,7 @@ class Reduction(Loops):
|
|||
inner_fn=inner_fn,
|
||||
ranges=ranges,
|
||||
reduction_ranges=reduction_ranges,
|
||||
reduction_type=reduction_type,
|
||||
reduction_type=reduction_type if reduction_type != "scan" else "sum",
|
||||
src_dtype=src_dtype,
|
||||
reduction_hint=ReductionHint.DEFAULT,
|
||||
)
|
||||
|
|
@ -1323,7 +1325,7 @@ class Reduction(Loops):
|
|||
inner_fn: Callable[..., Any],
|
||||
ranges: Sequence[Expr],
|
||||
reduction_ranges: Sequence[Expr],
|
||||
reduction_type: str,
|
||||
reduction_type: ReductionType,
|
||||
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
|
||||
input_node: Optional[IRNode] = None,
|
||||
) -> TensorBox:
|
||||
|
|
@ -1593,7 +1595,7 @@ class Reduction(Loops):
|
|||
original_reduction_ranges: Sequence[Expr],
|
||||
new_ranges: list[Expr],
|
||||
new_reduction_ranges: list[Integer],
|
||||
reduction_type: str,
|
||||
reduction_type: ReductionType,
|
||||
split: _IntLike,
|
||||
reduction_hint: ReductionHint,
|
||||
) -> TensorBox:
|
||||
|
|
@ -1655,7 +1657,7 @@ class Reduction(Loops):
|
|||
inner_fn: Callable[..., Any],
|
||||
ranges: Sequence[Expr],
|
||||
reduction_ranges: Sequence[Expr],
|
||||
reduction_type: str,
|
||||
reduction_type: ReductionType,
|
||||
split: _IntLike,
|
||||
reduction_hint: ReductionHint,
|
||||
) -> TensorBox:
|
||||
|
|
@ -1696,7 +1698,7 @@ class Reduction(Loops):
|
|||
original_reduction_ranges: Sequence[Expr],
|
||||
new_ranges: list[Integer],
|
||||
new_reduction_ranges: list[Integer],
|
||||
reduction_type: str,
|
||||
reduction_type: ReductionType,
|
||||
reduction_hint: ReductionHint,
|
||||
) -> TensorBox:
|
||||
"""
|
||||
|
|
@ -1735,7 +1737,7 @@ class WelfordReduction(Reduction):
|
|||
inner_fns: Sequence[Callable[[Sequence[Expr], Sequence[Expr]], OpsValue]],
|
||||
ranges: Sequence[Integer],
|
||||
reduction_ranges: Sequence[Integer],
|
||||
reduction_type: str,
|
||||
reduction_type: ReductionType,
|
||||
reduction_hint: ReductionHint,
|
||||
output_index: int,
|
||||
) -> None:
|
||||
|
|
@ -1767,7 +1769,7 @@ class WelfordReduction(Reduction):
|
|||
indexer: Callable[[Sequence[Expr]], Never],
|
||||
vars: Sequence[Expr],
|
||||
reduction_vars: Sequence[Symbol],
|
||||
) -> OpsValue:
|
||||
) -> None:
|
||||
values = ops.reduction(
|
||||
self.dtype,
|
||||
self.src_dtype,
|
||||
|
|
@ -1775,7 +1777,7 @@ class WelfordReduction(Reduction):
|
|||
self.inner_fn(vars, reduction_vars),
|
||||
)
|
||||
value = values[self.output_index]
|
||||
return ops.store_reduction(output_name, indexer(vars), value)
|
||||
return ops.store_reduction(output_name or "unnamed", indexer(vars), value)
|
||||
|
||||
@classmethod
|
||||
def create( # type: ignore[override]
|
||||
|
|
@ -1785,7 +1787,7 @@ class WelfordReduction(Reduction):
|
|||
inner_fns: Sequence[Callable[..., Any]],
|
||||
ranges: list[Integer],
|
||||
reduction_ranges: list[Integer],
|
||||
reduction_type: str,
|
||||
reduction_type: ReductionType,
|
||||
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
|
||||
) -> Sequence[TensorBox]:
|
||||
assert reduction_type in ("welford_reduce", "welford_combine")
|
||||
|
|
@ -1911,7 +1913,7 @@ class WelfordReduction(Reduction):
|
|||
inner_fns: Sequence[Callable[..., Any]],
|
||||
ranges: list[Integer],
|
||||
reduction_ranges: list[Integer],
|
||||
reduction_type: str,
|
||||
reduction_type: ReductionType,
|
||||
split: _IntLike,
|
||||
reduction_hint: ReductionHint,
|
||||
) -> Sequence[TensorBox]:
|
||||
|
|
@ -2031,11 +2033,13 @@ class Scan(Loops):
|
|||
indexer: Callable[[Sequence[_IntLike]], Never],
|
||||
vars: Sequence[Expr],
|
||||
scan_vars: Sequence[Symbol],
|
||||
) -> OpsValue:
|
||||
) -> None:
|
||||
idx = self.reindex(vars, scan_vars)
|
||||
values = [inner_fn(idx) for inner_fn in self.inner_fns]
|
||||
values = tuple(inner_fn(idx) for inner_fn in self.inner_fns)
|
||||
result = ops.scan(self.dtypes, self.combine_fn, values)
|
||||
return ops.store(output_name, indexer(idx), result[self.output_index])
|
||||
return ops.store(
|
||||
output_name or "unnamed", indexer(idx), result[self.output_index]
|
||||
)
|
||||
|
||||
def get_reduction_type(self) -> Optional[str]:
|
||||
# return self.scan_op
|
||||
|
|
@ -2229,11 +2233,13 @@ class Sort(Loops):
|
|||
indexer: Callable[[Sequence[Expr]], Expr],
|
||||
vars: Sequence[Expr],
|
||||
reduction_vars: Sequence[Expr],
|
||||
) -> OpsValue:
|
||||
) -> None:
|
||||
idx = self.reindex(vars, reduction_vars)
|
||||
values = [inner_fn(idx) for inner_fn in self.inner_fns]
|
||||
values = tuple(inner_fn(idx) for inner_fn in self.inner_fns)
|
||||
result = ops.sort(self.dtypes, values, self.stable, self.descending)
|
||||
return ops.store(output_name, indexer(idx), result[self.output_index])
|
||||
return ops.store(
|
||||
output_name or "unnamed", indexer(idx), result[self.output_index]
|
||||
)
|
||||
|
||||
def get_reduction_type(self) -> Optional[str]:
|
||||
return "sort"
|
||||
|
|
@ -3790,7 +3796,7 @@ class Buffer(IRNode):
|
|||
|
||||
def loader(index): # type: ignore[no-untyped-def]
|
||||
indexer = self.make_indexer()
|
||||
return ops.load(self.name, indexer(index))
|
||||
return ops.load(self.name or "unnamed", indexer(index))
|
||||
|
||||
return loader
|
||||
|
||||
|
|
@ -3983,7 +3989,7 @@ class ComputedBuffer(OperationBuffer):
|
|||
return self.data.make_loader()
|
||||
return super().make_loader()
|
||||
|
||||
def get_store_function(self) -> Callable[..., OpsValue]:
|
||||
def get_store_function(self) -> Callable[..., None]:
|
||||
indexer = self.get_layout().as_fixed().make_indexer()
|
||||
if isinstance(self.data, (Reduction, Scan, Sort)):
|
||||
return partial(self.data.store_reduction, self.name, indexer)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from torch.utils._sympy.symbol import SymT
|
|||
|
||||
from . import config, dependencies
|
||||
from .codegen.common import index_prevent_reordering
|
||||
from .ops_handler import DefaultHandler, OpsHandler, WrapperHandler
|
||||
from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs
|
||||
from .virtualized import ops, V
|
||||
|
||||
|
|
@ -439,179 +440,13 @@ class LoopBodyBlock:
|
|||
|
||||
def __init__(self, body: LoopBody, fn: Callable[..., Any], args: list[Any]):
|
||||
self.body = body
|
||||
|
||||
def add_index(expr: sympy.Expr, mtype: MemoryUsageType, **kwargs):
|
||||
return tracer.create_proxy(
|
||||
"call_module",
|
||||
"get_index",
|
||||
(body.add_index_expr(expr, mtype, **kwargs),),
|
||||
{},
|
||||
)
|
||||
|
||||
class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined]
|
||||
self.name = "CaptureIndexing"
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
index = add_index(index, MemoryUsageType.LOAD, buffer_name=name)
|
||||
return self._inner.load(name, index)
|
||||
|
||||
def load_seed(self, name: str, index: int):
|
||||
assert isinstance(index, int)
|
||||
body.add_index_expr(
|
||||
sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name
|
||||
)
|
||||
return self._inner.load_seed(name, index)
|
||||
|
||||
def store(self, name, index, value, mode=None):
|
||||
index = add_index(
|
||||
index, MemoryUsageType.STORE, buffer_name=name, mode=mode
|
||||
)
|
||||
return self._inner.store(name, index, value, mode)
|
||||
|
||||
def store_reduction(self, name, index, value):
|
||||
index = add_index(
|
||||
index, MemoryUsageType.STORE_REDUCTION, buffer_name=name
|
||||
)
|
||||
return self._inner.store_reduction(name, index, value)
|
||||
|
||||
def reduction(self, dtype, src_dtype, reduction_type, value):
|
||||
result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
|
||||
if "welford" in reduction_type:
|
||||
return tuple(result[i] for i in range(3))
|
||||
return result
|
||||
|
||||
def index_expr(self, index, dtype):
|
||||
if isinstance(index, (int, sympy.Integer)):
|
||||
return self._inner.constant(int(index), dtype)
|
||||
index = add_index(index, MemoryUsageType.INDEX_EXPR)
|
||||
return self._inner.index_expr(index, dtype)
|
||||
|
||||
def check_bounds(self, index, size, lower, upper):
|
||||
index = add_index(index, MemoryUsageType.CHECK_BOUNDS)
|
||||
size = add_index(size, MemoryUsageType.CHECK_BOUNDS)
|
||||
return self._inner.check_bounds(index, size, lower, upper)
|
||||
|
||||
def bucketize(
|
||||
self,
|
||||
values: T,
|
||||
boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
|
||||
boundary_indices: T,
|
||||
indexing_dtype: torch.dtype,
|
||||
right: bool,
|
||||
sorter: Optional[tuple[str, sympy.Expr]] = None,
|
||||
sorter_indices: Optional[T] = None,
|
||||
) -> T:
|
||||
"""
|
||||
See [Note: Inductor bucketize op]
|
||||
"""
|
||||
boundaries = (
|
||||
boundaries[0],
|
||||
add_index(
|
||||
boundaries[1],
|
||||
MemoryUsageType.BUCKETIZE,
|
||||
buffer_name=boundaries[0],
|
||||
),
|
||||
add_index(
|
||||
boundaries[2],
|
||||
MemoryUsageType.BUCKETIZE,
|
||||
buffer_name=boundaries[0],
|
||||
),
|
||||
add_index(
|
||||
boundaries[3],
|
||||
MemoryUsageType.BUCKETIZE,
|
||||
buffer_name=boundaries[0],
|
||||
),
|
||||
)
|
||||
if sorter is not None:
|
||||
sorter = (
|
||||
sorter[0],
|
||||
add_index(
|
||||
sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0]
|
||||
),
|
||||
)
|
||||
|
||||
return self._inner.bucketize(
|
||||
values,
|
||||
boundaries,
|
||||
boundary_indices,
|
||||
indexing_dtype,
|
||||
right,
|
||||
sorter,
|
||||
sorter_indices,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy):
|
||||
"""
|
||||
Recursively capture the masked out body in another LoopBodyBlock
|
||||
"""
|
||||
name = self.body.add_submodule(None, "masked_subblock")
|
||||
self.body.submodules[name] = self.body.bind_masked_shim(name)
|
||||
self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, [])
|
||||
return tracer.create_proxy(
|
||||
"call_module", name, (mask_proxy, other_proxy), {}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def scan(
|
||||
dtype_proxy,
|
||||
combine_fn: Callable[
|
||||
[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]
|
||||
],
|
||||
value_proxy,
|
||||
):
|
||||
shim = self.body.bind_scan_shim(combine_fn)
|
||||
name = self.body.add_submodule(shim, "scan")
|
||||
result = tracer.create_proxy(
|
||||
"call_module",
|
||||
name,
|
||||
(dtype_proxy, value_proxy),
|
||||
{},
|
||||
)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return tuple(result[i] for i in range(len(value_proxy)))
|
||||
|
||||
def sort(self, dtypes, values, stable, descending):
|
||||
result = self._inner.sort(dtypes, values, stable, descending)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return tuple(result[i] for i in range(len(values)))
|
||||
|
||||
def frexp(self, value_proxy):
|
||||
result = self._inner.frexp(value_proxy)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return (result[0], result[1])
|
||||
|
||||
@staticmethod
|
||||
def indirect_indexing(index_proxy, size, check=True, wrap_neg=True):
|
||||
"""
|
||||
Flow data from tensors into indexing formulas.
|
||||
Introduce a call_module to update the indexing.
|
||||
"""
|
||||
|
||||
var = self.body.add_indirect(size)
|
||||
set_indirect = self.body.bind_set_indirect_shim(
|
||||
var, size, check, wrap_neg
|
||||
)
|
||||
tracer.create_proxy(
|
||||
"call_module",
|
||||
self.body.add_submodule(set_indirect, f"set_{var}"),
|
||||
(index_proxy,),
|
||||
{},
|
||||
)
|
||||
return var
|
||||
|
||||
@staticmethod
|
||||
def output(result):
|
||||
tracer.create_proxy("output", "output", (result,), {})
|
||||
|
||||
tracer = LightTracer()
|
||||
proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
|
||||
|
||||
from .index_propagation import IndexPropagation
|
||||
from .sizevars import SimplifyIndexing
|
||||
|
||||
handler: Any = CountOps(
|
||||
SimplifyIndexing(CaptureIndexing(proxy_ops), self.body.var_ranges),
|
||||
CaptureIndexing(proxy_ops, body, tracer),
|
||||
body.op_counts,
|
||||
)
|
||||
if config.constant_and_index_propagation:
|
||||
|
|
@ -653,11 +488,187 @@ class LoopBodyBlock:
|
|||
return copy
|
||||
|
||||
|
||||
class CountOps:
|
||||
def __init__(self, inner: Any, counts: collections.Counter[str]):
|
||||
class CountOps(DefaultHandler):
|
||||
def __init__(self, inner: OpsHandler[Any], counts: collections.Counter[str]):
|
||||
self._inner = inner
|
||||
self._counts = counts
|
||||
|
||||
def __getattr__(self, name):
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
self._counts[name] += 1
|
||||
return getattr(self._inner, name)
|
||||
return getattr(self._inner, name)(*args, **kwargs)
|
||||
|
||||
|
||||
class CaptureIndexing(WrapperHandler):
|
||||
name = "CaptureIndexing"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inner: OpsHandler[Any],
|
||||
body: LoopBody,
|
||||
tracer: LightTracer,
|
||||
):
|
||||
super().__init__(inner)
|
||||
self.body = body
|
||||
self.tracer = tracer
|
||||
|
||||
def _add_index(self, expr: sympy.Expr, mtype: MemoryUsageType, **kwargs: Any):
|
||||
return self.tracer.create_proxy(
|
||||
"call_module",
|
||||
"get_index",
|
||||
(self.body.add_index_expr(expr, mtype, **kwargs),),
|
||||
{},
|
||||
)
|
||||
|
||||
def _simplify(self, expr: sympy.Expr) -> sympy.Expr:
|
||||
return V.graph.sizevars.simplify_with_ranges(expr, self.body.var_ranges)
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
index = self._simplify(index)
|
||||
index = self._add_index(index, MemoryUsageType.LOAD, buffer_name=name)
|
||||
return self._inner.load(name, index)
|
||||
|
||||
def load_seed(self, name: str, index: int):
|
||||
assert isinstance(index, int)
|
||||
self.body.add_index_expr(
|
||||
sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name
|
||||
)
|
||||
return self._inner.load_seed(name, index)
|
||||
|
||||
def store(self, name, index, value, mode=None):
|
||||
index = self._simplify(index)
|
||||
index = self._add_index(
|
||||
index, MemoryUsageType.STORE, buffer_name=name, mode=mode
|
||||
)
|
||||
return self._inner.store(name, index, value, mode)
|
||||
|
||||
def store_reduction(self, name, index, value):
|
||||
index = self._simplify(index)
|
||||
index = self._add_index(
|
||||
index, MemoryUsageType.STORE_REDUCTION, buffer_name=name
|
||||
)
|
||||
return self._inner.store_reduction(name, index, value)
|
||||
|
||||
def reduction(self, dtype, src_dtype, reduction_type, value):
|
||||
result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
|
||||
if "welford" in reduction_type:
|
||||
return tuple(result[i] for i in range(3))
|
||||
return result
|
||||
|
||||
def index_expr(self, index, dtype):
|
||||
index = self._simplify(index)
|
||||
if isinstance(index, (int, sympy.Integer)):
|
||||
return self._inner.constant(int(index), dtype)
|
||||
index = self._add_index(index, MemoryUsageType.INDEX_EXPR)
|
||||
return self._inner.index_expr(index, dtype)
|
||||
|
||||
def check_bounds(self, index, size, lower, upper):
|
||||
index = self._simplify(index)
|
||||
index = self._add_index(index, MemoryUsageType.CHECK_BOUNDS)
|
||||
size = self._add_index(size, MemoryUsageType.CHECK_BOUNDS)
|
||||
return self._inner.check_bounds(index, size, lower, upper)
|
||||
|
||||
def bucketize(
|
||||
self,
|
||||
values: T,
|
||||
boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
|
||||
boundary_indices: T,
|
||||
indexing_dtype: torch.dtype,
|
||||
right: bool,
|
||||
sorter: Optional[tuple[str, sympy.Expr]] = None,
|
||||
sorter_indices: Optional[T] = None,
|
||||
) -> T:
|
||||
"""
|
||||
See [Note: Inductor bucketize op]
|
||||
"""
|
||||
boundaries = (
|
||||
boundaries[0],
|
||||
self._add_index(
|
||||
boundaries[1],
|
||||
MemoryUsageType.BUCKETIZE,
|
||||
buffer_name=boundaries[0],
|
||||
),
|
||||
self._add_index(
|
||||
boundaries[2],
|
||||
MemoryUsageType.BUCKETIZE,
|
||||
buffer_name=boundaries[0],
|
||||
),
|
||||
self._add_index(
|
||||
boundaries[3],
|
||||
MemoryUsageType.BUCKETIZE,
|
||||
buffer_name=boundaries[0],
|
||||
),
|
||||
)
|
||||
if sorter is not None:
|
||||
sorter = (
|
||||
sorter[0],
|
||||
self._add_index(
|
||||
sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0]
|
||||
),
|
||||
)
|
||||
|
||||
return self._inner.bucketize(
|
||||
values,
|
||||
boundaries,
|
||||
boundary_indices,
|
||||
indexing_dtype,
|
||||
right,
|
||||
sorter,
|
||||
sorter_indices,
|
||||
)
|
||||
|
||||
def masked(self, mask_proxy, masked_body: Callable[..., Any], other_proxy):
|
||||
"""
|
||||
Recursively capture the masked out body in another LoopBodyBlock
|
||||
"""
|
||||
name = self.body.add_submodule(None, "masked_subblock")
|
||||
self.body.submodules[name] = self.body.bind_masked_shim(name)
|
||||
self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, [])
|
||||
return self.tracer.create_proxy(
|
||||
"call_module", name, (mask_proxy, other_proxy), {}
|
||||
)
|
||||
|
||||
def scan(
|
||||
self,
|
||||
dtype_proxy,
|
||||
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
|
||||
value_proxy,
|
||||
):
|
||||
shim = self.body.bind_scan_shim(combine_fn)
|
||||
name = self.body.add_submodule(shim, "scan")
|
||||
result = self.tracer.create_proxy(
|
||||
"call_module",
|
||||
name,
|
||||
(dtype_proxy, value_proxy),
|
||||
{},
|
||||
)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return tuple(result[i] for i in range(len(value_proxy)))
|
||||
|
||||
def sort(self, dtypes, values, stable, descending):
|
||||
result = self._inner.sort(dtypes, values, stable, descending)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return tuple(result[i] for i in range(len(values)))
|
||||
|
||||
def frexp(self, value_proxy):
|
||||
result = self._inner.frexp(value_proxy)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return (result[0], result[1])
|
||||
|
||||
def indirect_indexing(self, index_proxy, size, check=True, wrap_neg=True):
|
||||
"""
|
||||
Flow data from tensors into indexing formulas.
|
||||
Introduce a call_module to update the indexing.
|
||||
"""
|
||||
|
||||
var = self.body.add_indirect(size)
|
||||
set_indirect = self.body.bind_set_indirect_shim(var, size, check, wrap_neg)
|
||||
self.tracer.create_proxy(
|
||||
"call_module",
|
||||
self.body.add_submodule(set_indirect, f"set_{var}"),
|
||||
(index_proxy,),
|
||||
{},
|
||||
)
|
||||
return var
|
||||
|
||||
def output(self, *result):
|
||||
self.tracer.create_proxy("output", "output", result, {})
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import os
|
|||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -81,6 +81,10 @@ from .utils import (
|
|||
from .virtualized import ops, V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .ops_handler import ReductionType
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
|
@ -5633,7 +5637,7 @@ def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype):
|
|||
)
|
||||
|
||||
|
||||
def make_reduction(reduction_type: str, override_return_dtype=None):
|
||||
def make_reduction(reduction_type: ReductionType, override_return_dtype=None):
|
||||
def inner(x, axis=None, keepdims=False, *, dtype=None):
|
||||
kwargs = _make_reduction_inner(
|
||||
x,
|
||||
|
|
@ -6750,6 +6754,23 @@ def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, operands):
|
|||
return list(map(TensorBox.create, result))
|
||||
|
||||
|
||||
@register_lowering(torch._higher_order_ops.invoke_quant, type_promotion_kind=None)
|
||||
def invoke_quant_tracer(subgraph_fn: ir.Subgraph, *operands, scheme=None):
|
||||
output = None
|
||||
for i, node in enumerate(subgraph_fn.graph_module.graph.nodes):
|
||||
if node.op == "placeholder":
|
||||
V.graph.env[node] = operands[i]
|
||||
continue
|
||||
# todo getattr
|
||||
elif node.op == "output":
|
||||
args, kwargs = V.graph.fetch_args_kwargs_from_env(node)
|
||||
output = torch.fx.Interpreter.output(V.graph, node, args, kwargs)
|
||||
else:
|
||||
V.graph.env[node] = V.graph.run_node(node)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@register_lowering(associative_scan_op, type_promotion_kind=None)
|
||||
def associative_scan(combine_fn: ir.Subgraph, xs):
|
||||
from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -564,10 +564,6 @@ class CompiledFxGraph(OutputCode):
|
|||
return artifact_path
|
||||
|
||||
|
||||
def _typecheck_CompiledFxGraph(h: CompiledFxGraph) -> OutputCode:
|
||||
return h
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompiledAOTI(OutputCode):
|
||||
"""
|
||||
|
|
@ -591,10 +587,6 @@ class CompiledAOTI(OutputCode):
|
|||
pass
|
||||
|
||||
|
||||
def _typecheck_CompiledAOTI(h: CompiledAOTI) -> OutputCode:
|
||||
return h
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MockFXGraphCacheOutput(OutputCode):
|
||||
gm: Any = None
|
||||
|
|
|
|||
|
|
@ -742,7 +742,7 @@ class TritonTemplateKernel(TritonKernel):
|
|||
template_mask = self.template_mask
|
||||
|
||||
class StoreOutputSubstitution(V.WrapperHandler): # type: ignore[name-defined]
|
||||
self.name = name
|
||||
name = "StoreOutputSubstitution"
|
||||
|
||||
def store(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -99,6 +99,8 @@ class SizeVarAllocator:
|
|||
if result is None:
|
||||
result = self._simplify_with_ranges(expr, var_ranges)
|
||||
cache[key] = result
|
||||
if result != expr:
|
||||
cache[(result, *var_ranges.items())] = result
|
||||
return result
|
||||
|
||||
return simplify_with_ranges
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ class InputDescriptor:
|
|||
device: torch.device
|
||||
|
||||
|
||||
class TracingOpsHandler(WrapperHandler[T]):
|
||||
class TracingOpsHandler(WrapperHandler):
|
||||
def __init__(self, tracer: torch.fx.Tracer, num_inputs: int) -> None:
|
||||
parent = tracer.create_proxy("placeholder", "ops", (), {})
|
||||
super().__init__(parent)
|
||||
|
|
@ -149,8 +149,8 @@ class TracingOpsHandler(WrapperHandler[T]):
|
|||
def placeholder(self, idx: int) -> torch.fx.Proxy:
|
||||
return self.placeholders[idx]
|
||||
|
||||
def output(self, *args: tuple[object]) -> torch.fx.Node:
|
||||
return self.tracer.create_node(
|
||||
def output(self, *args: tuple[object]) -> None:
|
||||
self.tracer.create_node(
|
||||
"output", "output", (tuple(self.tracer.create_arg(a) for a in args),), {}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -59,11 +59,12 @@ from __future__ import annotations
|
|||
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from threading import local
|
||||
from typing import Any, Callable, Generic, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import Any, Callable, cast, Generic, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .ops_handler import ( # noqa: F401
|
||||
DefaultHandler,
|
||||
KernelFormatterHandler,
|
||||
MockHandler,
|
||||
OpsHandler,
|
||||
|
|
@ -154,7 +155,9 @@ class NullKernelHandler(NullHandler):
|
|||
self.index_dtype = "tl.int64"
|
||||
|
||||
|
||||
_ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler)
|
||||
_ops: Virtualized[OpsHandler[Any]] = Virtualized(
|
||||
"ops", cast(type[OpsHandler[Any]], MockHandler)
|
||||
)
|
||||
_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler)
|
||||
_real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler)
|
||||
_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler)
|
||||
|
|
@ -272,18 +275,15 @@ class OpsValue:
|
|||
return ops.bitwise_left_shift(self, n)
|
||||
|
||||
|
||||
class OpsWrapper:
|
||||
class OpsWrapper(DefaultHandler):
|
||||
"""This wraps any returned IR values into an `OpsValue` instance, so that we
|
||||
can overload the magic methods for writing mathematical expressions fluently.
|
||||
"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
def inner(*args, **kwargs):
|
||||
new_args = [OpsWrapper._unwrap(a) for a in args]
|
||||
new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
|
||||
return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
|
||||
|
||||
return inner
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
new_args = [OpsWrapper._unwrap(a) for a in args]
|
||||
new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
|
||||
return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
|
||||
|
||||
@staticmethod
|
||||
def _unwrap(x):
|
||||
|
|
@ -306,7 +306,7 @@ class OpsWrapper:
|
|||
return _ops.indirect_indexing(index, size, check, wrap_neg)
|
||||
|
||||
|
||||
ops = OpsWrapper()
|
||||
ops: OpsHandler[Any] = OpsWrapper()
|
||||
|
||||
|
||||
class _V:
|
||||
|
|
@ -314,8 +314,10 @@ class _V:
|
|||
KernelFormatterHandler = KernelFormatterHandler
|
||||
WrapperHandler = WrapperHandler
|
||||
|
||||
set_ops_handler: Callable[[Any], Any] = _ops._set_handler
|
||||
get_ops_handler: Callable[[], Any] = _ops._get_handler
|
||||
set_ops_handler: Callable[
|
||||
[OpsHandler[Any]], AbstractContextManager[None]
|
||||
] = _ops._set_handler
|
||||
get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler
|
||||
set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
|
||||
set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler
|
||||
get_real_inputs: Callable[[], Any] = _real_inputs._get_handler
|
||||
|
|
|
|||
|
|
@ -11100,8 +11100,8 @@ are designed to work with this function. See the examples below.
|
|||
|
||||
Args:
|
||||
{input}
|
||||
indices (tensor): the indices into :attr:`input`. Must have long dtype.
|
||||
dim (int, optional): dimension to select along.
|
||||
indices (LongTensor): the indices into :attr:`input`. Must have long dtype.
|
||||
dim (int, optional): dimension to select along. Default: 0
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
|
|
|||
|
|
@ -331,9 +331,6 @@ class FunctionMeta(type):
|
|||
name + "Backward", (BackwardCFunction,), {"_forward_cls": cls}
|
||||
)
|
||||
backward_fn._autograd_function_id = next(AUTOGRAD_FUNCTION_COUNTER) # type: ignore[attr-defined]
|
||||
backward_fn._compiled_autograd_should_lift = attrs.get( # type: ignore[attr-defined]
|
||||
"_compiled_autograd_should_lift", True
|
||||
)
|
||||
backward_fn._bw_module = None # type: ignore[attr-defined]
|
||||
if getattr(cls, "_lazy_backward_info", None):
|
||||
backward_fn._bw_module = cls._lazy_backward_info.bw_module # type: ignore[attr-defined]
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue