Produce constant variables in cases where a SymNode is created with a constant (#100144)

` AOT_DYNAMIC_SHAPES=1 TORCHDYNAMO_DYNAMIC_SHAPES=1  benchmarks/dynamo/huggingface.py --performance  --training --amp --backend eager --disable-cudagraphs --device cuda --only AllenaiLongformerBase --explain`

Looks promising!

Goes from:

Dynamo produced 173 graphs covering 2760 ops with 160 graph breaks (14 unique)

To:

Dynamo produced 6 graphs covering 2298 ops with 15 graph breaks (7 unique)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100144
Approved by: https://github.com/ezyang
This commit is contained in:
Michael Voznesensky 2023-05-01 18:12:08 +00:00 committed by PyTorch MergeBot
parent 0cf6e74fa9
commit aafc6ce8cc
9 changed files with 89 additions and 36 deletions

View file

@ -1,7 +1,7 @@
name,accuracy,graph_breaks
AlbertForMaskedLM,pass,0
AlbertForQuestionAnswering,pass,0
AllenaiLongformerBase,pass,152
AllenaiLongformerBase,pass,136
BartForCausalLM,pass,0
BertForMaskedLM,pass,0
BertForQuestionAnswering,pass,0

1 name accuracy graph_breaks
2 AlbertForMaskedLM pass 0
3 AlbertForQuestionAnswering pass 0
4 AllenaiLongformerBase pass 152 136
5 BartForCausalLM pass 0
6 BertForMaskedLM pass 0
7 BertForQuestionAnswering pass 0

View file

@ -1,7 +1,7 @@
name,accuracy,graph_breaks
AlbertForMaskedLM,pass,7
AlbertForQuestionAnswering,pass,7
AllenaiLongformerBase,pass,160
AllenaiLongformerBase,pass,144
BartForCausalLM,pass,7
BertForMaskedLM,pass,7
BertForQuestionAnswering,pass,7

1 name accuracy graph_breaks
2 AlbertForMaskedLM pass 7
3 AlbertForQuestionAnswering pass 7
4 AllenaiLongformerBase pass 160 144
5 BartForCausalLM pass 7
6 BertForMaskedLM pass 7
7 BertForQuestionAnswering pass 7

View file

@ -34,7 +34,7 @@ from torch._dynamo.testing import (
unsupported,
)
from torch._dynamo.utils import CompileProfiler, ifdyn, ifunspec
from torch._dynamo.utils import CompileProfiler, ifdyn, ifdynstaticdefault, ifunspec
from torch.ao.quantization import MinMaxObserver
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.qconfig import QConfig
@ -154,7 +154,11 @@ class MiscTests(torch._dynamo.test_case.TestCase):
return o
torch._dynamo.testing.standard_test(
self, unpack4, 2, expected_ops=5, expected_ops_dynamic=8
self,
unpack4,
2,
expected_ops=5,
expected_ops_dynamic=ifdynstaticdefault(6, 7),
)
def test_unpack5(self):
@ -167,7 +171,11 @@ class MiscTests(torch._dynamo.test_case.TestCase):
return o
torch._dynamo.testing.standard_test(
self, unpack5, 2, expected_ops=5, expected_ops_dynamic=8
self,
unpack5,
2,
expected_ops=5,
expected_ops_dynamic=ifdynstaticdefault(6, 7),
)
def test_matmul1(self):
@ -191,7 +199,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
return x + y
torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=1, expected_ops_dynamic=11
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 11)
)
def test_shape_int_inplace_binops(self):
@ -207,7 +215,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
return x + p
torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=1, expected_ops_dynamic=10
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 10)
)
def test_int_shape_inplace_binops(self):
@ -231,7 +239,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
return x + y
torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=1, expected_ops_dynamic=10
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 10)
)
def test_int_int_comparisons(self):
@ -276,7 +284,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
# expect for dynamic: size, index, 6 comparison ops, add
torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=1, expected_ops_dynamic=9
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 9)
)
def test_int_shape_comparisons(self):
@ -301,7 +309,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
# expect for dynamic: size, index, 6 comparison ops, add
torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=1, expected_ops_dynamic=9
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 9)
)
def test_param_shape_binops(self):
@ -333,7 +341,12 @@ class MiscTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(ref, res))
self.assertEqual(counts.frame_count, 1)
expected_op_count = 13 if torch._dynamo.testing.config.dynamic_shapes else 1
expected_op_count = (
ifdynstaticdefault(3, 12)
if torch._dynamo.testing.config.dynamic_shapes
else 1
)
self.assertEqual(counts.op_count, expected_op_count)
def test_user_defined_binop(self):
@ -358,7 +371,11 @@ class MiscTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(ref, res))
self.assertEqual(counts.frame_count, 1)
expected_op_count = 4 if torch._dynamo.testing.config.dynamic_shapes else 1
expected_op_count = (
ifdynstaticdefault(2, 4)
if torch._dynamo.testing.config.dynamic_shapes
else 1
)
self.assertEqual(counts.op_count, expected_op_count)
def test_compare_shapes_eq(self):
@ -511,16 +528,19 @@ class MiscTests(torch._dynamo.test_case.TestCase):
return _fn
# expect for dynamic:
# 2 * (size, getitem) ops +
# 1 add op +
# 4 * 2 min / max ops +
# 4 final add ops = 17
torch._dynamo.testing.standard_test(
self, get_test_fn(func=min), 2, expected_ops=1, expected_ops_dynamic=17
self,
get_test_fn(func=min),
2,
expected_ops=1,
expected_ops_dynamic=ifdynstaticdefault(3, 14),
)
torch._dynamo.testing.standard_test(
self, get_test_fn(func=max), 2, expected_ops=1, expected_ops_dynamic=17
self,
get_test_fn(func=max),
2,
expected_ops=1,
expected_ops_dynamic=ifdynstaticdefault(3, 17),
)
def test_config_obj(self):
@ -773,7 +793,11 @@ class MiscTests(torch._dynamo.test_case.TestCase):
)
return torch._dynamo.testing.standard_test(
self, fn=fn, nargs=1, expected_ops=5, expected_ops_dynamic=8
self,
fn=fn,
nargs=1,
expected_ops=5,
expected_ops_dynamic=ifdynstaticdefault(6, 8),
)
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@ -916,7 +940,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
opt_fn = torch._dynamo.optimize(cnts)(fn)
self.assertEqual(opt_fn(2), [2, 3] * 4)
self.assertEqual(cnts.frame_count, ifunspec(1, 0))
self.assertEqual(cnts.op_count, ifunspec(14, 0))
self.assertEqual(cnts.op_count, ifunspec(2, 0))
def test_tuple_mul(self):
def fn(count):
@ -927,7 +951,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
opt_fn = torch._dynamo.optimize(cnts)(fn)
self.assertEqual(opt_fn(2), (2, 3) * 4)
self.assertEqual(cnts.frame_count, ifunspec(1, 0))
self.assertEqual(cnts.op_count, ifunspec(14, 0))
self.assertEqual(cnts.op_count, ifunspec(ifdynstaticdefault(2, 2), 0))
def test_tuple_mul_with_shape(self):
def fn(a):
@ -937,7 +961,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
# expect 3 ops post folding for dynamic case: size, index, add
torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=1, expected_ops_dynamic=3
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 3)
)
def test_tuple_iadd_with_shape(self):
@ -951,7 +975,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
# expect 4 add / subs for static, 4 * 3 (size, index, math op) for dynamic
torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=4, expected_ops_dynamic=12
self, fn, 1, expected_ops=4, expected_ops_dynamic=ifdynstaticdefault(8, 12)
)
def test_list_iadd_with_shape(self):
@ -964,8 +988,9 @@ class MiscTests(torch._dynamo.test_case.TestCase):
return output
# expect 6 add / subs for static, 6 * 3 (size, index, math op) for dynamic
torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=6, expected_ops_dynamic=18
self, fn, 1, expected_ops=6, expected_ops_dynamic=ifdynstaticdefault(12, 18)
)
def test_user_getattr1(self):

View file

@ -163,7 +163,7 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
cache_fail_test(
a,
a.clone().as_strided((3, 4, 5), stride=(1, 3, 12)),
"tensor 'L['a']' strides mismatch at index 0. expected 20, actual 1",
"tensor 'L['a']' stride mismatch at index 0. expected 20, actual 1",
)
cache_fail_test(
a, a[0, :, :], "tensor 'L['a']' rank mismatch. expected 3, actual 2"

View file

@ -31,7 +31,7 @@ import torch.library
from torch import nn
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.testing import rand_strided, requires_static_shapes, same
from torch._dynamo.utils import ifdyn, ifunspec
from torch._dynamo.utils import ifdyn, ifdynstaticdefault, ifunspec
from torch.nn import functional as F
@ -890,7 +890,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(opt_fn(boxes1), boxes1.tensor + 4.0))
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, ifdyn(6, 1))
self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(3, 6), 1))
def _reformer(self, nopython):
input = torch.randn([1, 64, 256])
@ -981,7 +981,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(opt_fn(input2), correct2))
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, ifunspec(42, ifdyn(38, 4)))
self.assertEqual(cnt.op_count, ifunspec(37, ifdyn(20, 4)))
def test_hf_t5_forward(self):
input = torch.randn([1, 2048, 512])
@ -992,7 +992,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(opt_model(input), correct))
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, ifdyn(13, 11))
self.assertEqual(cnt.op_count, ifdyn(12, 11))
def test_module_in_skipfiles(self):
model = nn.Linear(10, 10)
@ -1283,7 +1283,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
self.assertTrue(same(opt_fn(x), correct))
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, ifdyn(28, 14))
self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(21, 27), 14))
def test_recursive_map(self):
# https://github.com/pytorch/torchdynamo/issues/132

View file

@ -1368,6 +1368,13 @@ def ifdyn(count1, count2):
return count2
def ifdynstaticdefault(count1, count2):
if torch._dynamo.config.assume_static_by_default:
return count1
else:
return count2
def ifunspec(count1, count2):
if torch._dynamo.config.dynamic_shapes and not torch._dynamo.config.specialize_int:
return count1

View file

@ -901,7 +901,10 @@ class VariableBuilder:
)
)
fake_tensor_value = None
example_value = unspec_var.proxy.node.meta["example_value"]
if isinstance(unspec_var, ConstantVariable):
example_value = unspec_var.value
else:
example_value = unspec_var.proxy.node.meta["example_value"]
if isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor):
fake_tensor_value = example_value
proxy.node.meta["grapharg"] = GraphArg(

View file

@ -1263,11 +1263,16 @@ class BuiltinVariable(VariableTracker):
sym_num=None,
)
if isinstance(left, ConstantVariable) and isinstance(right, ConstantVariable):
return ConstantVariable(op(left.value, right.value))
_unimplemented()
# and_ is a constant fold function, so we only get here if constant fold is not valid
def call_and_(self, tx, a, b):
if isinstance(a, SymNodeVariable) and isinstance(b, SymNodeVariable):
if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance(
b, (SymNodeVariable, ConstantVariable)
):
return SymNodeVariable.create(
tx,
tx.output.create_proxy(
@ -1280,7 +1285,9 @@ class BuiltinVariable(VariableTracker):
# or_ is a constant fold function, so we only get here if constant fold is not valid
def call_or_(self, tx, a, b):
if isinstance(a, SymNodeVariable) and isinstance(b, SymNodeVariable):
if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance(
b, (SymNodeVariable, ConstantVariable)
):
return SymNodeVariable.create(
tx,
tx.output.create_proxy(

View file

@ -4,6 +4,8 @@ import operator
import types
from typing import Dict, List
import sympy
import torch.fx
import torch.random
from torch.fx.experimental.symbolic_shapes import guard_scalar, SymTypes
@ -238,8 +240,13 @@ class TensorVariable(VariableTracker):
length = self.size[0]
else:
dyn_length = self.call_method(tx, "size", [ConstantVariable(0)], {})
assert isinstance(dyn_length, SymNodeVariable)
length = dyn_length.evaluate_expr(tx.output)
# SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values prouced through
# symbolic_shapes, but that end up as int/sympy.Integer
assert isinstance(dyn_length, (SymNodeVariable, ConstantVariable))
if isinstance(dyn_length, SymNodeVariable):
length = dyn_length.evaluate_expr(tx.output)
else:
length = dyn_length.value
idxes = range(length)
return [wrap_fx_proxy(tx, self.as_proxy()[i], **options) for i in idxes]
@ -495,6 +502,10 @@ class SymNodeVariable(VariableTracker):
if sym_num is None:
sym_num = get_fake_value(proxy.node, tx)
proxy.node.meta["example_value"] = sym_num
if isinstance(sym_num, (sympy.Integer, int)):
return ConstantVariable(int(sym_num))
return SymNodeVariable(proxy, sym_num, **options)
def __init__(self, proxy, sym_num, **kwargs):