From 8aa878fc935564bdd1e4fc00d7f34381a746b504 Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Fri, 24 Jul 2020 08:05:20 -0700 Subject: [PATCH] [jit] fix tuple alias analysis (#41992) Previously when analyzing a TupleConstruct, we ignored the aliasing information of the inputs and simply marked all elements of the returned tuple as wildcards. But since we can fully reason about the contents of a tuple statically, we should be able to assign them aliasing information. This analysis was not only incomplete but produced incorrect results, since if `a` is not a wildcard, `a noalias wilcard`. So if we looked at `tuple(a)` and reported the aliasing info as `tuple(wildcard)`, then `tuple[0] noalias a`, which is...wrong. --- test/jit/test_freezing.py | 4 ---- torch/csrc/jit/ir/alias_analysis.cpp | 25 ++++++++++++++++++------- torch/csrc/jit/ir/alias_analysis.h | 1 + 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index 018470bf851..dfbc0ce8214 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -794,9 +794,6 @@ class TestFreezing(JitTestCase): expected = m_s.forward(inp) self.assertEqual(out, expected) - # Check attribute a is preserved. Alias analysis detects that 'a' has output writers. - # In this example, 'a' is not mutated. However, we do not track which sub - # values of a composite ivalue is mutated. def test_freeze_module_with_aliased_attr2(self): class FreezeMe(nn.Module): def __init__(self): @@ -815,7 +812,6 @@ class TestFreezing(JitTestCase): m_s = torch.jit.script(m) m_s.eval() m_f = torch._C._freeze_module(m_s._c) - self.assertTrue(m_f.hasattr('a')) inp = torch.tensor([5]) out = m_f.forward(inp) expected = m.forward(inp) diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index 9e7963f5153..c82019a420e 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -498,6 +498,7 @@ void AliasDb::analyzeImpl(Node* node) { case prim::tolist: return analyzeCreator(node); case prim::TupleConstruct: + return analyzeTupleConstruct(node); case prim::DictConstruct: case prim::ListConstruct: return analyzeContainerConstruct(node); @@ -864,6 +865,22 @@ void AliasDb::analyzeConservative(Node* node) { } } +void AliasDb::analyzeTupleConstruct(Node* node) { + TORCH_INTERNAL_ASSERT(node->kind() == prim::TupleConstruct); + // tuples which contain immutable types are immutable + if (!isMutableTypeInternal(node->output())) { + return; + } + + giveFreshAlias(node->output()); + + for (const auto& input : node->inputs()) { + if (isMutableTypeInternal(input)) { + addToContainedElements(input, node->output()); + } + } +} + // List or dict or tuple: construct: create an aliasing element for the actual // container, then mark all inputs as wildcards, since they've gone inside the // container. Then, add the wildcard sets of appropriate type to the contained @@ -871,13 +888,7 @@ void AliasDb::analyzeConservative(Node* node) { void AliasDb::analyzeContainerConstruct(Node* node) { TORCH_INTERNAL_ASSERT( node->kind() == prim::ListConstruct || - node->kind() == prim::DictConstruct || - node->kind() == prim::TupleConstruct); - - // tuples which contain immutable types are immutable - if (!isMutableTypeInternal(node->output())) { - return; - } + node->kind() == prim::DictConstruct); TORCH_INTERNAL_ASSERT(node->outputs().size() == 1); auto container = node->output(); diff --git a/torch/csrc/jit/ir/alias_analysis.h b/torch/csrc/jit/ir/alias_analysis.h index e3e69185891..707dfc12a28 100644 --- a/torch/csrc/jit/ir/alias_analysis.h +++ b/torch/csrc/jit/ir/alias_analysis.h @@ -194,6 +194,7 @@ class AliasDb { void analyzeSetAttr(Node* node); void analyzeConservative(Node* node); void analyzeContainerConstruct(Node* node); + void analyzeTupleConstruct(Node* node); bool tryRegisteredAnalysis(Node* node); /**