From 35ad2d8586c54ca3a1a24b39279dde3cba4924e4 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 24 Jul 2020 13:32:00 -0700 Subject: [PATCH] Revert "[jit] fix tuple alias analysis (#41992)" This reverts commit 8aa878fc935564bdd1e4fc00d7f34381a746b504. --- test/jit/test_freezing.py | 4 ++++ torch/csrc/jit/ir/alias_analysis.cpp | 25 +++++++------------------ torch/csrc/jit/ir/alias_analysis.h | 1 - 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index dfbc0ce8214..018470bf851 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -794,6 +794,9 @@ class TestFreezing(JitTestCase): expected = m_s.forward(inp) self.assertEqual(out, expected) + # Check attribute a is preserved. Alias analysis detects that 'a' has output writers. + # In this example, 'a' is not mutated. However, we do not track which sub + # values of a composite ivalue is mutated. def test_freeze_module_with_aliased_attr2(self): class FreezeMe(nn.Module): def __init__(self): @@ -812,6 +815,7 @@ class TestFreezing(JitTestCase): m_s = torch.jit.script(m) m_s.eval() m_f = torch._C._freeze_module(m_s._c) + self.assertTrue(m_f.hasattr('a')) inp = torch.tensor([5]) out = m_f.forward(inp) expected = m.forward(inp) diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index c82019a420e..9e7963f5153 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -498,7 +498,6 @@ void AliasDb::analyzeImpl(Node* node) { case prim::tolist: return analyzeCreator(node); case prim::TupleConstruct: - return analyzeTupleConstruct(node); case prim::DictConstruct: case prim::ListConstruct: return analyzeContainerConstruct(node); @@ -865,22 +864,6 @@ void AliasDb::analyzeConservative(Node* node) { } } -void AliasDb::analyzeTupleConstruct(Node* node) { - TORCH_INTERNAL_ASSERT(node->kind() == prim::TupleConstruct); - // tuples which contain immutable types are immutable - if (!isMutableTypeInternal(node->output())) { - return; - } - - giveFreshAlias(node->output()); - - for (const auto& input : node->inputs()) { - if (isMutableTypeInternal(input)) { - addToContainedElements(input, node->output()); - } - } -} - // List or dict or tuple: construct: create an aliasing element for the actual // container, then mark all inputs as wildcards, since they've gone inside the // container. Then, add the wildcard sets of appropriate type to the contained @@ -888,7 +871,13 @@ void AliasDb::analyzeTupleConstruct(Node* node) { void AliasDb::analyzeContainerConstruct(Node* node) { TORCH_INTERNAL_ASSERT( node->kind() == prim::ListConstruct || - node->kind() == prim::DictConstruct); + node->kind() == prim::DictConstruct || + node->kind() == prim::TupleConstruct); + + // tuples which contain immutable types are immutable + if (!isMutableTypeInternal(node->output())) { + return; + } TORCH_INTERNAL_ASSERT(node->outputs().size() == 1); auto container = node->output(); diff --git a/torch/csrc/jit/ir/alias_analysis.h b/torch/csrc/jit/ir/alias_analysis.h index 707dfc12a28..e3e69185891 100644 --- a/torch/csrc/jit/ir/alias_analysis.h +++ b/torch/csrc/jit/ir/alias_analysis.h @@ -194,7 +194,6 @@ class AliasDb { void analyzeSetAttr(Node* node); void analyzeConservative(Node* node); void analyzeContainerConstruct(Node* node); - void analyzeTupleConstruct(Node* node); bool tryRegisteredAnalysis(Node* node); /**