diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_inference.csv index b641f92d924..826ceb1c9d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_inference.csv @@ -1,7 +1,7 @@ name,accuracy,graph_breaks AlbertForMaskedLM,pass,0 AlbertForQuestionAnswering,pass,0 -AllenaiLongformerBase,pass,152 +AllenaiLongformerBase,pass,136 BartForCausalLM,pass,0 BertForMaskedLM,pass,0 BertForQuestionAnswering,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_training.csv index 51f19caf1f4..e8fb2a431f4 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_training.csv @@ -1,7 +1,7 @@ name,accuracy,graph_breaks AlbertForMaskedLM,pass,7 AlbertForQuestionAnswering,pass,7 -AllenaiLongformerBase,pass,160 +AllenaiLongformerBase,pass,144 BartForCausalLM,pass,7 BertForMaskedLM,pass,7 BertForQuestionAnswering,pass,7 diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index d6f6805f207..e94d7219408 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -34,7 +34,7 @@ from torch._dynamo.testing import ( unsupported, ) -from torch._dynamo.utils import CompileProfiler, ifdyn, ifunspec +from torch._dynamo.utils import CompileProfiler, ifdyn, ifdynstaticdefault, ifunspec from torch.ao.quantization import MinMaxObserver from torch.ao.quantization.fake_quantize import FakeQuantize from torch.ao.quantization.qconfig import QConfig @@ -154,7 +154,11 @@ class MiscTests(torch._dynamo.test_case.TestCase): return o torch._dynamo.testing.standard_test( - self, unpack4, 2, expected_ops=5, expected_ops_dynamic=8 + self, + unpack4, + 2, + expected_ops=5, + expected_ops_dynamic=ifdynstaticdefault(6, 7), ) def test_unpack5(self): @@ -167,7 +171,11 @@ class MiscTests(torch._dynamo.test_case.TestCase): return o torch._dynamo.testing.standard_test( - self, unpack5, 2, expected_ops=5, expected_ops_dynamic=8 + self, + unpack5, + 2, + expected_ops=5, + expected_ops_dynamic=ifdynstaticdefault(6, 7), ) def test_matmul1(self): @@ -191,7 +199,7 @@ class MiscTests(torch._dynamo.test_case.TestCase): return x + y torch._dynamo.testing.standard_test( - self, fn, 1, expected_ops=1, expected_ops_dynamic=11 + self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 11) ) def test_shape_int_inplace_binops(self): @@ -207,7 +215,7 @@ class MiscTests(torch._dynamo.test_case.TestCase): return x + p torch._dynamo.testing.standard_test( - self, fn, 1, expected_ops=1, expected_ops_dynamic=10 + self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 10) ) def test_int_shape_inplace_binops(self): @@ -231,7 +239,7 @@ class MiscTests(torch._dynamo.test_case.TestCase): return x + y torch._dynamo.testing.standard_test( - self, fn, 1, expected_ops=1, expected_ops_dynamic=10 + self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 10) ) def test_int_int_comparisons(self): @@ -276,7 +284,7 @@ class MiscTests(torch._dynamo.test_case.TestCase): # expect for dynamic: size, index, 6 comparison ops, add torch._dynamo.testing.standard_test( - self, fn, 1, expected_ops=1, expected_ops_dynamic=9 + self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 9) ) def test_int_shape_comparisons(self): @@ -301,7 +309,7 @@ class MiscTests(torch._dynamo.test_case.TestCase): # expect for dynamic: size, index, 6 comparison ops, add torch._dynamo.testing.standard_test( - self, fn, 1, expected_ops=1, expected_ops_dynamic=9 + self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 9) ) def test_param_shape_binops(self): @@ -333,7 +341,12 @@ class MiscTests(torch._dynamo.test_case.TestCase): self.assertTrue(same(ref, res)) self.assertEqual(counts.frame_count, 1) - expected_op_count = 13 if torch._dynamo.testing.config.dynamic_shapes else 1 + + expected_op_count = ( + ifdynstaticdefault(3, 12) + if torch._dynamo.testing.config.dynamic_shapes + else 1 + ) self.assertEqual(counts.op_count, expected_op_count) def test_user_defined_binop(self): @@ -358,7 +371,11 @@ class MiscTests(torch._dynamo.test_case.TestCase): self.assertTrue(same(ref, res)) self.assertEqual(counts.frame_count, 1) - expected_op_count = 4 if torch._dynamo.testing.config.dynamic_shapes else 1 + expected_op_count = ( + ifdynstaticdefault(2, 4) + if torch._dynamo.testing.config.dynamic_shapes + else 1 + ) self.assertEqual(counts.op_count, expected_op_count) def test_compare_shapes_eq(self): @@ -511,16 +528,19 @@ class MiscTests(torch._dynamo.test_case.TestCase): return _fn - # expect for dynamic: - # 2 * (size, getitem) ops + - # 1 add op + - # 4 * 2 min / max ops + - # 4 final add ops = 17 torch._dynamo.testing.standard_test( - self, get_test_fn(func=min), 2, expected_ops=1, expected_ops_dynamic=17 + self, + get_test_fn(func=min), + 2, + expected_ops=1, + expected_ops_dynamic=ifdynstaticdefault(3, 14), ) torch._dynamo.testing.standard_test( - self, get_test_fn(func=max), 2, expected_ops=1, expected_ops_dynamic=17 + self, + get_test_fn(func=max), + 2, + expected_ops=1, + expected_ops_dynamic=ifdynstaticdefault(3, 17), ) def test_config_obj(self): @@ -773,7 +793,11 @@ class MiscTests(torch._dynamo.test_case.TestCase): ) return torch._dynamo.testing.standard_test( - self, fn=fn, nargs=1, expected_ops=5, expected_ops_dynamic=8 + self, + fn=fn, + nargs=1, + expected_ops=5, + expected_ops_dynamic=ifdynstaticdefault(6, 8), ) @patch.object(torch._dynamo.config, "dynamic_shapes", True) @@ -916,7 +940,7 @@ class MiscTests(torch._dynamo.test_case.TestCase): opt_fn = torch._dynamo.optimize(cnts)(fn) self.assertEqual(opt_fn(2), [2, 3] * 4) self.assertEqual(cnts.frame_count, ifunspec(1, 0)) - self.assertEqual(cnts.op_count, ifunspec(14, 0)) + self.assertEqual(cnts.op_count, ifunspec(2, 0)) def test_tuple_mul(self): def fn(count): @@ -927,7 +951,7 @@ class MiscTests(torch._dynamo.test_case.TestCase): opt_fn = torch._dynamo.optimize(cnts)(fn) self.assertEqual(opt_fn(2), (2, 3) * 4) self.assertEqual(cnts.frame_count, ifunspec(1, 0)) - self.assertEqual(cnts.op_count, ifunspec(14, 0)) + self.assertEqual(cnts.op_count, ifunspec(ifdynstaticdefault(2, 2), 0)) def test_tuple_mul_with_shape(self): def fn(a): @@ -937,7 +961,7 @@ class MiscTests(torch._dynamo.test_case.TestCase): # expect 3 ops post folding for dynamic case: size, index, add torch._dynamo.testing.standard_test( - self, fn, 1, expected_ops=1, expected_ops_dynamic=3 + self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 3) ) def test_tuple_iadd_with_shape(self): @@ -951,7 +975,7 @@ class MiscTests(torch._dynamo.test_case.TestCase): # expect 4 add / subs for static, 4 * 3 (size, index, math op) for dynamic torch._dynamo.testing.standard_test( - self, fn, 1, expected_ops=4, expected_ops_dynamic=12 + self, fn, 1, expected_ops=4, expected_ops_dynamic=ifdynstaticdefault(8, 12) ) def test_list_iadd_with_shape(self): @@ -964,8 +988,9 @@ class MiscTests(torch._dynamo.test_case.TestCase): return output # expect 6 add / subs for static, 6 * 3 (size, index, math op) for dynamic + torch._dynamo.testing.standard_test( - self, fn, 1, expected_ops=6, expected_ops_dynamic=18 + self, fn, 1, expected_ops=6, expected_ops_dynamic=ifdynstaticdefault(12, 18) ) def test_user_getattr1(self): diff --git a/test/dynamo/test_recompile_ux.py b/test/dynamo/test_recompile_ux.py index a52b87b21c6..95f39ec8fe9 100644 --- a/test/dynamo/test_recompile_ux.py +++ b/test/dynamo/test_recompile_ux.py @@ -163,7 +163,7 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase): cache_fail_test( a, a.clone().as_strided((3, 4, 5), stride=(1, 3, 12)), - "tensor 'L['a']' strides mismatch at index 0. expected 20, actual 1", + "tensor 'L['a']' stride mismatch at index 0. expected 20, actual 1", ) cache_fail_test( a, a[0, :, :], "tensor 'L['a']' rank mismatch. expected 3, actual 2" diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 00f545bbb5e..386f70d8fbd 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -31,7 +31,7 @@ import torch.library from torch import nn from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import rand_strided, requires_static_shapes, same -from torch._dynamo.utils import ifdyn, ifunspec +from torch._dynamo.utils import ifdyn, ifdynstaticdefault, ifunspec from torch.nn import functional as F @@ -890,7 +890,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): self.assertTrue(same(opt_fn(boxes1), boxes1.tensor + 4.0)) self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, ifdyn(6, 1)) + self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(3, 6), 1)) def _reformer(self, nopython): input = torch.randn([1, 64, 256]) @@ -981,7 +981,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): self.assertTrue(same(opt_fn(input2), correct2)) self.assertEqual(cnt.frame_count, 2) - self.assertEqual(cnt.op_count, ifunspec(42, ifdyn(38, 4))) + self.assertEqual(cnt.op_count, ifunspec(37, ifdyn(20, 4))) def test_hf_t5_forward(self): input = torch.randn([1, 2048, 512]) @@ -992,7 +992,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): self.assertTrue(same(opt_model(input), correct)) self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, ifdyn(13, 11)) + self.assertEqual(cnt.op_count, ifdyn(12, 11)) def test_module_in_skipfiles(self): model = nn.Linear(10, 10) @@ -1283,7 +1283,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): opt_fn = torch._dynamo.optimize_assert(cnt)(fn) self.assertTrue(same(opt_fn(x), correct)) self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, ifdyn(28, 14)) + self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(21, 27), 14)) def test_recursive_map(self): # https://github.com/pytorch/torchdynamo/issues/132 diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 4894578be10..dd766a04586 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1368,6 +1368,13 @@ def ifdyn(count1, count2): return count2 +def ifdynstaticdefault(count1, count2): + if torch._dynamo.config.assume_static_by_default: + return count1 + else: + return count2 + + def ifunspec(count1, count2): if torch._dynamo.config.dynamic_shapes and not torch._dynamo.config.specialize_int: return count1 diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 3afe664ebfc..1183d6d53f5 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -901,7 +901,10 @@ class VariableBuilder: ) ) fake_tensor_value = None - example_value = unspec_var.proxy.node.meta["example_value"] + if isinstance(unspec_var, ConstantVariable): + example_value = unspec_var.value + else: + example_value = unspec_var.proxy.node.meta["example_value"] if isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor): fake_tensor_value = example_value proxy.node.meta["grapharg"] = GraphArg( diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index d8dfd770523..f92ab23de65 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1263,11 +1263,16 @@ class BuiltinVariable(VariableTracker): sym_num=None, ) + if isinstance(left, ConstantVariable) and isinstance(right, ConstantVariable): + return ConstantVariable(op(left.value, right.value)) + _unimplemented() # and_ is a constant fold function, so we only get here if constant fold is not valid def call_and_(self, tx, a, b): - if isinstance(a, SymNodeVariable) and isinstance(b, SymNodeVariable): + if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( + b, (SymNodeVariable, ConstantVariable) + ): return SymNodeVariable.create( tx, tx.output.create_proxy( @@ -1280,7 +1285,9 @@ class BuiltinVariable(VariableTracker): # or_ is a constant fold function, so we only get here if constant fold is not valid def call_or_(self, tx, a, b): - if isinstance(a, SymNodeVariable) and isinstance(b, SymNodeVariable): + if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( + b, (SymNodeVariable, ConstantVariable) + ): return SymNodeVariable.create( tx, tx.output.create_proxy( diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 67a9bbbc8de..3e5a7b017c2 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -4,6 +4,8 @@ import operator import types from typing import Dict, List +import sympy + import torch.fx import torch.random from torch.fx.experimental.symbolic_shapes import guard_scalar, SymTypes @@ -238,8 +240,13 @@ class TensorVariable(VariableTracker): length = self.size[0] else: dyn_length = self.call_method(tx, "size", [ConstantVariable(0)], {}) - assert isinstance(dyn_length, SymNodeVariable) - length = dyn_length.evaluate_expr(tx.output) + # SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values prouced through + # symbolic_shapes, but that end up as int/sympy.Integer + assert isinstance(dyn_length, (SymNodeVariable, ConstantVariable)) + if isinstance(dyn_length, SymNodeVariable): + length = dyn_length.evaluate_expr(tx.output) + else: + length = dyn_length.value idxes = range(length) return [wrap_fx_proxy(tx, self.as_proxy()[i], **options) for i in idxes] @@ -495,6 +502,10 @@ class SymNodeVariable(VariableTracker): if sym_num is None: sym_num = get_fake_value(proxy.node, tx) proxy.node.meta["example_value"] = sym_num + + if isinstance(sym_num, (sympy.Integer, int)): + return ConstantVariable(int(sym_num)) + return SymNodeVariable(proxy, sym_num, **options) def __init__(self, proxy, sym_num, **kwargs):