mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[jit] fix tuple alias analysis (#41992)
Previously when analyzing a TupleConstruct, we ignored the aliasing information of the inputs and simply marked all elements of the returned tuple as wildcards. But since we can fully reason about the contents of a tuple statically, we should be able to assign them aliasing information. This analysis was not only incomplete but produced incorrect results, since if `a` is not a wildcard, `a noalias wilcard`. So if we looked at `tuple(a)` and reported the aliasing info as `tuple(wildcard)`, then `tuple[0] noalias a`, which is...wrong.
This commit is contained in:
parent
7c7c9c3aa6
commit
8aa878fc93
3 changed files with 19 additions and 11 deletions
|
|
@ -794,9 +794,6 @@ class TestFreezing(JitTestCase):
|
|||
expected = m_s.forward(inp)
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
# Check attribute a is preserved. Alias analysis detects that 'a' has output writers.
|
||||
# In this example, 'a' is not mutated. However, we do not track which sub
|
||||
# values of a composite ivalue is mutated.
|
||||
def test_freeze_module_with_aliased_attr2(self):
|
||||
class FreezeMe(nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -815,7 +812,6 @@ class TestFreezing(JitTestCase):
|
|||
m_s = torch.jit.script(m)
|
||||
m_s.eval()
|
||||
m_f = torch._C._freeze_module(m_s._c)
|
||||
self.assertTrue(m_f.hasattr('a'))
|
||||
inp = torch.tensor([5])
|
||||
out = m_f.forward(inp)
|
||||
expected = m.forward(inp)
|
||||
|
|
|
|||
|
|
@ -498,6 +498,7 @@ void AliasDb::analyzeImpl(Node* node) {
|
|||
case prim::tolist:
|
||||
return analyzeCreator(node);
|
||||
case prim::TupleConstruct:
|
||||
return analyzeTupleConstruct(node);
|
||||
case prim::DictConstruct:
|
||||
case prim::ListConstruct:
|
||||
return analyzeContainerConstruct(node);
|
||||
|
|
@ -864,6 +865,22 @@ void AliasDb::analyzeConservative(Node* node) {
|
|||
}
|
||||
}
|
||||
|
||||
void AliasDb::analyzeTupleConstruct(Node* node) {
|
||||
TORCH_INTERNAL_ASSERT(node->kind() == prim::TupleConstruct);
|
||||
// tuples which contain immutable types are immutable
|
||||
if (!isMutableTypeInternal(node->output())) {
|
||||
return;
|
||||
}
|
||||
|
||||
giveFreshAlias(node->output());
|
||||
|
||||
for (const auto& input : node->inputs()) {
|
||||
if (isMutableTypeInternal(input)) {
|
||||
addToContainedElements(input, node->output());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// List or dict or tuple: construct: create an aliasing element for the actual
|
||||
// container, then mark all inputs as wildcards, since they've gone inside the
|
||||
// container. Then, add the wildcard sets of appropriate type to the contained
|
||||
|
|
@ -871,13 +888,7 @@ void AliasDb::analyzeConservative(Node* node) {
|
|||
void AliasDb::analyzeContainerConstruct(Node* node) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
node->kind() == prim::ListConstruct ||
|
||||
node->kind() == prim::DictConstruct ||
|
||||
node->kind() == prim::TupleConstruct);
|
||||
|
||||
// tuples which contain immutable types are immutable
|
||||
if (!isMutableTypeInternal(node->output())) {
|
||||
return;
|
||||
}
|
||||
node->kind() == prim::DictConstruct);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
|
||||
auto container = node->output();
|
||||
|
|
|
|||
|
|
@ -194,6 +194,7 @@ class AliasDb {
|
|||
void analyzeSetAttr(Node* node);
|
||||
void analyzeConservative(Node* node);
|
||||
void analyzeContainerConstruct(Node* node);
|
||||
void analyzeTupleConstruct(Node* node);
|
||||
bool tryRegisteredAnalysis(Node* node);
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Reference in a new issue