mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Produce constant variables in cases where a SymNode is created with a constant (#100144)
` AOT_DYNAMIC_SHAPES=1 TORCHDYNAMO_DYNAMIC_SHAPES=1 benchmarks/dynamo/huggingface.py --performance --training --amp --backend eager --disable-cudagraphs --device cuda --only AllenaiLongformerBase --explain` Looks promising! Goes from: Dynamo produced 173 graphs covering 2760 ops with 160 graph breaks (14 unique) To: Dynamo produced 6 graphs covering 2298 ops with 15 graph breaks (7 unique) Pull Request resolved: https://github.com/pytorch/pytorch/pull/100144 Approved by: https://github.com/ezyang
This commit is contained in:
parent
0cf6e74fa9
commit
aafc6ce8cc
9 changed files with 89 additions and 36 deletions
|
|
@ -1,7 +1,7 @@
|
|||
name,accuracy,graph_breaks
|
||||
AlbertForMaskedLM,pass,0
|
||||
AlbertForQuestionAnswering,pass,0
|
||||
AllenaiLongformerBase,pass,152
|
||||
AllenaiLongformerBase,pass,136
|
||||
BartForCausalLM,pass,0
|
||||
BertForMaskedLM,pass,0
|
||||
BertForQuestionAnswering,pass,0
|
||||
|
|
|
|||
|
|
|
@ -1,7 +1,7 @@
|
|||
name,accuracy,graph_breaks
|
||||
AlbertForMaskedLM,pass,7
|
||||
AlbertForQuestionAnswering,pass,7
|
||||
AllenaiLongformerBase,pass,160
|
||||
AllenaiLongformerBase,pass,144
|
||||
BartForCausalLM,pass,7
|
||||
BertForMaskedLM,pass,7
|
||||
BertForQuestionAnswering,pass,7
|
||||
|
|
|
|||
|
|
|
@ -34,7 +34,7 @@ from torch._dynamo.testing import (
|
|||
unsupported,
|
||||
)
|
||||
|
||||
from torch._dynamo.utils import CompileProfiler, ifdyn, ifunspec
|
||||
from torch._dynamo.utils import CompileProfiler, ifdyn, ifdynstaticdefault, ifunspec
|
||||
from torch.ao.quantization import MinMaxObserver
|
||||
from torch.ao.quantization.fake_quantize import FakeQuantize
|
||||
from torch.ao.quantization.qconfig import QConfig
|
||||
|
|
@ -154,7 +154,11 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
return o
|
||||
|
||||
torch._dynamo.testing.standard_test(
|
||||
self, unpack4, 2, expected_ops=5, expected_ops_dynamic=8
|
||||
self,
|
||||
unpack4,
|
||||
2,
|
||||
expected_ops=5,
|
||||
expected_ops_dynamic=ifdynstaticdefault(6, 7),
|
||||
)
|
||||
|
||||
def test_unpack5(self):
|
||||
|
|
@ -167,7 +171,11 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
return o
|
||||
|
||||
torch._dynamo.testing.standard_test(
|
||||
self, unpack5, 2, expected_ops=5, expected_ops_dynamic=8
|
||||
self,
|
||||
unpack5,
|
||||
2,
|
||||
expected_ops=5,
|
||||
expected_ops_dynamic=ifdynstaticdefault(6, 7),
|
||||
)
|
||||
|
||||
def test_matmul1(self):
|
||||
|
|
@ -191,7 +199,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
return x + y
|
||||
|
||||
torch._dynamo.testing.standard_test(
|
||||
self, fn, 1, expected_ops=1, expected_ops_dynamic=11
|
||||
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 11)
|
||||
)
|
||||
|
||||
def test_shape_int_inplace_binops(self):
|
||||
|
|
@ -207,7 +215,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
return x + p
|
||||
|
||||
torch._dynamo.testing.standard_test(
|
||||
self, fn, 1, expected_ops=1, expected_ops_dynamic=10
|
||||
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 10)
|
||||
)
|
||||
|
||||
def test_int_shape_inplace_binops(self):
|
||||
|
|
@ -231,7 +239,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
return x + y
|
||||
|
||||
torch._dynamo.testing.standard_test(
|
||||
self, fn, 1, expected_ops=1, expected_ops_dynamic=10
|
||||
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 10)
|
||||
)
|
||||
|
||||
def test_int_int_comparisons(self):
|
||||
|
|
@ -276,7 +284,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
# expect for dynamic: size, index, 6 comparison ops, add
|
||||
torch._dynamo.testing.standard_test(
|
||||
self, fn, 1, expected_ops=1, expected_ops_dynamic=9
|
||||
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 9)
|
||||
)
|
||||
|
||||
def test_int_shape_comparisons(self):
|
||||
|
|
@ -301,7 +309,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
# expect for dynamic: size, index, 6 comparison ops, add
|
||||
torch._dynamo.testing.standard_test(
|
||||
self, fn, 1, expected_ops=1, expected_ops_dynamic=9
|
||||
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 9)
|
||||
)
|
||||
|
||||
def test_param_shape_binops(self):
|
||||
|
|
@ -333,7 +341,12 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
self.assertTrue(same(ref, res))
|
||||
self.assertEqual(counts.frame_count, 1)
|
||||
expected_op_count = 13 if torch._dynamo.testing.config.dynamic_shapes else 1
|
||||
|
||||
expected_op_count = (
|
||||
ifdynstaticdefault(3, 12)
|
||||
if torch._dynamo.testing.config.dynamic_shapes
|
||||
else 1
|
||||
)
|
||||
self.assertEqual(counts.op_count, expected_op_count)
|
||||
|
||||
def test_user_defined_binop(self):
|
||||
|
|
@ -358,7 +371,11 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
self.assertTrue(same(ref, res))
|
||||
self.assertEqual(counts.frame_count, 1)
|
||||
expected_op_count = 4 if torch._dynamo.testing.config.dynamic_shapes else 1
|
||||
expected_op_count = (
|
||||
ifdynstaticdefault(2, 4)
|
||||
if torch._dynamo.testing.config.dynamic_shapes
|
||||
else 1
|
||||
)
|
||||
self.assertEqual(counts.op_count, expected_op_count)
|
||||
|
||||
def test_compare_shapes_eq(self):
|
||||
|
|
@ -511,16 +528,19 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
return _fn
|
||||
|
||||
# expect for dynamic:
|
||||
# 2 * (size, getitem) ops +
|
||||
# 1 add op +
|
||||
# 4 * 2 min / max ops +
|
||||
# 4 final add ops = 17
|
||||
torch._dynamo.testing.standard_test(
|
||||
self, get_test_fn(func=min), 2, expected_ops=1, expected_ops_dynamic=17
|
||||
self,
|
||||
get_test_fn(func=min),
|
||||
2,
|
||||
expected_ops=1,
|
||||
expected_ops_dynamic=ifdynstaticdefault(3, 14),
|
||||
)
|
||||
torch._dynamo.testing.standard_test(
|
||||
self, get_test_fn(func=max), 2, expected_ops=1, expected_ops_dynamic=17
|
||||
self,
|
||||
get_test_fn(func=max),
|
||||
2,
|
||||
expected_ops=1,
|
||||
expected_ops_dynamic=ifdynstaticdefault(3, 17),
|
||||
)
|
||||
|
||||
def test_config_obj(self):
|
||||
|
|
@ -773,7 +793,11 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
)
|
||||
|
||||
return torch._dynamo.testing.standard_test(
|
||||
self, fn=fn, nargs=1, expected_ops=5, expected_ops_dynamic=8
|
||||
self,
|
||||
fn=fn,
|
||||
nargs=1,
|
||||
expected_ops=5,
|
||||
expected_ops_dynamic=ifdynstaticdefault(6, 8),
|
||||
)
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
|
|
@ -916,7 +940,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
self.assertEqual(opt_fn(2), [2, 3] * 4)
|
||||
self.assertEqual(cnts.frame_count, ifunspec(1, 0))
|
||||
self.assertEqual(cnts.op_count, ifunspec(14, 0))
|
||||
self.assertEqual(cnts.op_count, ifunspec(2, 0))
|
||||
|
||||
def test_tuple_mul(self):
|
||||
def fn(count):
|
||||
|
|
@ -927,7 +951,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
self.assertEqual(opt_fn(2), (2, 3) * 4)
|
||||
self.assertEqual(cnts.frame_count, ifunspec(1, 0))
|
||||
self.assertEqual(cnts.op_count, ifunspec(14, 0))
|
||||
self.assertEqual(cnts.op_count, ifunspec(ifdynstaticdefault(2, 2), 0))
|
||||
|
||||
def test_tuple_mul_with_shape(self):
|
||||
def fn(a):
|
||||
|
|
@ -937,7 +961,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
# expect 3 ops post folding for dynamic case: size, index, add
|
||||
torch._dynamo.testing.standard_test(
|
||||
self, fn, 1, expected_ops=1, expected_ops_dynamic=3
|
||||
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 3)
|
||||
)
|
||||
|
||||
def test_tuple_iadd_with_shape(self):
|
||||
|
|
@ -951,7 +975,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
# expect 4 add / subs for static, 4 * 3 (size, index, math op) for dynamic
|
||||
torch._dynamo.testing.standard_test(
|
||||
self, fn, 1, expected_ops=4, expected_ops_dynamic=12
|
||||
self, fn, 1, expected_ops=4, expected_ops_dynamic=ifdynstaticdefault(8, 12)
|
||||
)
|
||||
|
||||
def test_list_iadd_with_shape(self):
|
||||
|
|
@ -964,8 +988,9 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
return output
|
||||
|
||||
# expect 6 add / subs for static, 6 * 3 (size, index, math op) for dynamic
|
||||
|
||||
torch._dynamo.testing.standard_test(
|
||||
self, fn, 1, expected_ops=6, expected_ops_dynamic=18
|
||||
self, fn, 1, expected_ops=6, expected_ops_dynamic=ifdynstaticdefault(12, 18)
|
||||
)
|
||||
|
||||
def test_user_getattr1(self):
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
|
|||
cache_fail_test(
|
||||
a,
|
||||
a.clone().as_strided((3, 4, 5), stride=(1, 3, 12)),
|
||||
"tensor 'L['a']' strides mismatch at index 0. expected 20, actual 1",
|
||||
"tensor 'L['a']' stride mismatch at index 0. expected 20, actual 1",
|
||||
)
|
||||
cache_fail_test(
|
||||
a, a[0, :, :], "tensor 'L['a']' rank mismatch. expected 3, actual 2"
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ import torch.library
|
|||
from torch import nn
|
||||
from torch._dynamo.debug_utils import same_two_models
|
||||
from torch._dynamo.testing import rand_strided, requires_static_shapes, same
|
||||
from torch._dynamo.utils import ifdyn, ifunspec
|
||||
from torch._dynamo.utils import ifdyn, ifdynstaticdefault, ifunspec
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
|
|
@ -890,7 +890,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertTrue(same(opt_fn(boxes1), boxes1.tensor + 4.0))
|
||||
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.op_count, ifdyn(6, 1))
|
||||
self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(3, 6), 1))
|
||||
|
||||
def _reformer(self, nopython):
|
||||
input = torch.randn([1, 64, 256])
|
||||
|
|
@ -981,7 +981,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertTrue(same(opt_fn(input2), correct2))
|
||||
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
self.assertEqual(cnt.op_count, ifunspec(42, ifdyn(38, 4)))
|
||||
self.assertEqual(cnt.op_count, ifunspec(37, ifdyn(20, 4)))
|
||||
|
||||
def test_hf_t5_forward(self):
|
||||
input = torch.randn([1, 2048, 512])
|
||||
|
|
@ -992,7 +992,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertTrue(same(opt_model(input), correct))
|
||||
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.op_count, ifdyn(13, 11))
|
||||
self.assertEqual(cnt.op_count, ifdyn(12, 11))
|
||||
|
||||
def test_module_in_skipfiles(self):
|
||||
model = nn.Linear(10, 10)
|
||||
|
|
@ -1283,7 +1283,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
||||
self.assertTrue(same(opt_fn(x), correct))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.op_count, ifdyn(28, 14))
|
||||
self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(21, 27), 14))
|
||||
|
||||
def test_recursive_map(self):
|
||||
# https://github.com/pytorch/torchdynamo/issues/132
|
||||
|
|
|
|||
|
|
@ -1368,6 +1368,13 @@ def ifdyn(count1, count2):
|
|||
return count2
|
||||
|
||||
|
||||
def ifdynstaticdefault(count1, count2):
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
return count1
|
||||
else:
|
||||
return count2
|
||||
|
||||
|
||||
def ifunspec(count1, count2):
|
||||
if torch._dynamo.config.dynamic_shapes and not torch._dynamo.config.specialize_int:
|
||||
return count1
|
||||
|
|
|
|||
|
|
@ -901,7 +901,10 @@ class VariableBuilder:
|
|||
)
|
||||
)
|
||||
fake_tensor_value = None
|
||||
example_value = unspec_var.proxy.node.meta["example_value"]
|
||||
if isinstance(unspec_var, ConstantVariable):
|
||||
example_value = unspec_var.value
|
||||
else:
|
||||
example_value = unspec_var.proxy.node.meta["example_value"]
|
||||
if isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor):
|
||||
fake_tensor_value = example_value
|
||||
proxy.node.meta["grapharg"] = GraphArg(
|
||||
|
|
|
|||
|
|
@ -1263,11 +1263,16 @@ class BuiltinVariable(VariableTracker):
|
|||
sym_num=None,
|
||||
)
|
||||
|
||||
if isinstance(left, ConstantVariable) and isinstance(right, ConstantVariable):
|
||||
return ConstantVariable(op(left.value, right.value))
|
||||
|
||||
_unimplemented()
|
||||
|
||||
# and_ is a constant fold function, so we only get here if constant fold is not valid
|
||||
def call_and_(self, tx, a, b):
|
||||
if isinstance(a, SymNodeVariable) and isinstance(b, SymNodeVariable):
|
||||
if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance(
|
||||
b, (SymNodeVariable, ConstantVariable)
|
||||
):
|
||||
return SymNodeVariable.create(
|
||||
tx,
|
||||
tx.output.create_proxy(
|
||||
|
|
@ -1280,7 +1285,9 @@ class BuiltinVariable(VariableTracker):
|
|||
|
||||
# or_ is a constant fold function, so we only get here if constant fold is not valid
|
||||
def call_or_(self, tx, a, b):
|
||||
if isinstance(a, SymNodeVariable) and isinstance(b, SymNodeVariable):
|
||||
if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance(
|
||||
b, (SymNodeVariable, ConstantVariable)
|
||||
):
|
||||
return SymNodeVariable.create(
|
||||
tx,
|
||||
tx.output.create_proxy(
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import operator
|
|||
import types
|
||||
from typing import Dict, List
|
||||
|
||||
import sympy
|
||||
|
||||
import torch.fx
|
||||
import torch.random
|
||||
from torch.fx.experimental.symbolic_shapes import guard_scalar, SymTypes
|
||||
|
|
@ -238,8 +240,13 @@ class TensorVariable(VariableTracker):
|
|||
length = self.size[0]
|
||||
else:
|
||||
dyn_length = self.call_method(tx, "size", [ConstantVariable(0)], {})
|
||||
assert isinstance(dyn_length, SymNodeVariable)
|
||||
length = dyn_length.evaluate_expr(tx.output)
|
||||
# SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values prouced through
|
||||
# symbolic_shapes, but that end up as int/sympy.Integer
|
||||
assert isinstance(dyn_length, (SymNodeVariable, ConstantVariable))
|
||||
if isinstance(dyn_length, SymNodeVariable):
|
||||
length = dyn_length.evaluate_expr(tx.output)
|
||||
else:
|
||||
length = dyn_length.value
|
||||
idxes = range(length)
|
||||
return [wrap_fx_proxy(tx, self.as_proxy()[i], **options) for i in idxes]
|
||||
|
||||
|
|
@ -495,6 +502,10 @@ class SymNodeVariable(VariableTracker):
|
|||
if sym_num is None:
|
||||
sym_num = get_fake_value(proxy.node, tx)
|
||||
proxy.node.meta["example_value"] = sym_num
|
||||
|
||||
if isinstance(sym_num, (sympy.Integer, int)):
|
||||
return ConstantVariable(int(sym_num))
|
||||
|
||||
return SymNodeVariable(proxy, sym_num, **options)
|
||||
|
||||
def __init__(self, proxy, sym_num, **kwargs):
|
||||
|
|
|
|||
Loading…
Reference in a new issue