[jit] fix tuple alias analysis (#41992)

Previously when analyzing a TupleConstruct, we ignored the aliasing
information of the inputs and simply marked all elements of the returned
tuple as wildcards. But since we can fully reason about the contents of
a tuple statically, we should be able to assign them aliasing
information.

This analysis was not only incomplete but produced incorrect results,
since if `a` is not a wildcard, `a noalias wilcard`. So if we looked at
`tuple(a)` and reported the aliasing info as `tuple(wildcard)`, then
`tuple[0] noalias a`, which is...wrong.
This commit is contained in:
Michael Suo 2020-07-24 08:05:20 -07:00 committed by GitHub
parent 7c7c9c3aa6
commit 8aa878fc93
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 11 deletions

View file

@ -794,9 +794,6 @@ class TestFreezing(JitTestCase):
expected = m_s.forward(inp)
self.assertEqual(out, expected)
# Check attribute a is preserved. Alias analysis detects that 'a' has output writers.
# In this example, 'a' is not mutated. However, we do not track which sub
# values of a composite ivalue is mutated.
def test_freeze_module_with_aliased_attr2(self):
class FreezeMe(nn.Module):
def __init__(self):
@ -815,7 +812,6 @@ class TestFreezing(JitTestCase):
m_s = torch.jit.script(m)
m_s.eval()
m_f = torch._C._freeze_module(m_s._c)
self.assertTrue(m_f.hasattr('a'))
inp = torch.tensor([5])
out = m_f.forward(inp)
expected = m.forward(inp)

View file

@ -498,6 +498,7 @@ void AliasDb::analyzeImpl(Node* node) {
case prim::tolist:
return analyzeCreator(node);
case prim::TupleConstruct:
return analyzeTupleConstruct(node);
case prim::DictConstruct:
case prim::ListConstruct:
return analyzeContainerConstruct(node);
@ -864,6 +865,22 @@ void AliasDb::analyzeConservative(Node* node) {
}
}
void AliasDb::analyzeTupleConstruct(Node* node) {
TORCH_INTERNAL_ASSERT(node->kind() == prim::TupleConstruct);
// tuples which contain immutable types are immutable
if (!isMutableTypeInternal(node->output())) {
return;
}
giveFreshAlias(node->output());
for (const auto& input : node->inputs()) {
if (isMutableTypeInternal(input)) {
addToContainedElements(input, node->output());
}
}
}
// List or dict or tuple: construct: create an aliasing element for the actual
// container, then mark all inputs as wildcards, since they've gone inside the
// container. Then, add the wildcard sets of appropriate type to the contained
@ -871,13 +888,7 @@ void AliasDb::analyzeConservative(Node* node) {
void AliasDb::analyzeContainerConstruct(Node* node) {
TORCH_INTERNAL_ASSERT(
node->kind() == prim::ListConstruct ||
node->kind() == prim::DictConstruct ||
node->kind() == prim::TupleConstruct);
// tuples which contain immutable types are immutable
if (!isMutableTypeInternal(node->output())) {
return;
}
node->kind() == prim::DictConstruct);
TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
auto container = node->output();

View file

@ -194,6 +194,7 @@ class AliasDb {
void analyzeSetAttr(Node* node);
void analyzeConservative(Node* node);
void analyzeContainerConstruct(Node* node);
void analyzeTupleConstruct(Node* node);
bool tryRegisteredAnalysis(Node* node);
/**