From ae286d81e00f45b81778635c1aa482d64f2ec7bc Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 22 Sep 2020 09:37:00 -0700 Subject: [PATCH] [JIT] improve alias analysis for list constructs (#39111) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39111 In our present alias analysis, we consider any Value that enter another container as entering the heap, and thus aliasing all other heap values of the same type. There are a number of advantages to this approach: - it is not to hard to maintain the aliasDb implementation - it is much easier from an op schema perspective - there are many composite list ops registered internally and externally that would be tricky to register and get right if we did something more complicated - It limits the size of the AliasDb, because a container of size 10 only contains a single memory dag element instead of 10 elements. The downside is that we have are unable to handle the simple and extremely common case of a list of tensors being used in an ATen op. In an example like: ``` def foo(input): x = torch.tensor([1, 2, 3, 4]) y = [x, x] input.add_(1) return torch.cat(y) ``` we will consider x to be written to. any write to any wildcard element (an element that enters a tuple, an element that is taken from a list) will mark x as written to. This can be limiting for our ability to create a functional subset and fuse graphs - as a result, 4 of TorchVision classification models could not be functionalized. Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D23828003 Pulled By: eellison fbshipit-source-id: 9109fcb6f2ca20ca897cae71683530285da9d537 --- test/cpp/jit/test_alias_analysis.cpp | 26 ++++++++++ test/jit/test_remove_mutation.py | 41 ++++++++++++++++ torch/csrc/jit/ir/alias_analysis.cpp | 72 ++++++++++++++++++++++++---- torch/csrc/jit/ir/alias_analysis.h | 5 +- 4 files changed, 133 insertions(+), 11 deletions(-) diff --git a/test/cpp/jit/test_alias_analysis.cpp b/test/cpp/jit/test_alias_analysis.cpp index e854113a7a8..e700ee54061 100644 --- a/test/cpp/jit/test_alias_analysis.cpp +++ b/test/cpp/jit/test_alias_analysis.cpp @@ -1238,6 +1238,32 @@ TEST(AliasRegistrationTest, PureWithAnnotationsShouldError) { "Tried to register operator foo::rand11(Tensor(a) arg1) -> (Tensor(a)) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA"); } +TEST(AliasRegistrationTest, AliasMoveAtenListOp) { + auto graph = std::make_shared(); + std::unordered_map vmap; + auto graph_string = R"IR( + graph(): + %x : Tensor = prim::MakeTestTensor() + %8 : int = prim::Constant[value=0]() + %5 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=2]() + %y : Tensor[] = prim::ListConstruct(%x) + %6 : Tensor = aten::add_(%x, %4, %5) + %9 : Tensor = aten::cat(%y, %8) + return (%9))IR"; + + torch::jit::parseIR(graph_string, graph.get(), vmap); + AliasDb aliasDb(graph); + + // bc y.1 has a single used in a single non-aliasing aten op, + // x is added to y.1 contained elements instead of wildcard set + EXPECT_TRUE(!aliasDb.mayAlias(vmap["x"], vmap["9"])); + + // write to contained element should prevent move + EXPECT_TRUE(!aliasDb.moveBeforeTopologicallyValid( + vmap["y"]->node(), vmap["9"]->node())); +} + TEST(AliasRegistrationTest, PureWithAnnotationsShouldError2) { auto registry = torch::RegisterOperators().op( "foo::rand12(Tensor(a) arg1) -> Tensor(b)", diff --git a/test/jit/test_remove_mutation.py b/test/jit/test_remove_mutation.py index ef408e775c3..b747fc06bcd 100644 --- a/test/jit/test_remove_mutation.py +++ b/test/jit/test_remove_mutation.py @@ -200,3 +200,44 @@ class TestRemoveMutation(JitTestCase): # it is possible to remove the append here but don't currently have the logic for it FileCheck().check_not("append").run(graph) self.assertEqual(intermediary_use(), fn()) + + def test_common_pytorch_list_ops(self): + for op in ["cat", "stack", "vstack", "hstack", "dstack"]: + class OpMod(torch.nn.Module): + def __init__(self, op): + super(OpMod, self).__init__() + self.op = torch_op + + def forward(self): + x = torch.tensor([1, 2, 3, 4]) + x.add_(3) + y = [x, x] + return self.op(y) + 3 + + torch_op = getattr(torch, op) + mod = OpMod(torch_op) + mod_script = torch.jit.script(mod) + self.run_pass('remove_mutation', mod_script.forward.graph) + FileCheck().check_not("aten::add_").run(mod_script.forward.graph) + self.assertEqual(mod(), mod_script()) + + # test that the output doesnt alias the input + for inputs in [torch.rand(2, 2)], [torch.rand(2, 2) for _ in range(2)]: + result = torch_op(inputs) + sums = [ten.sum() for ten in result] + + for inp in inputs: + inp.fill_(10) + + self.assertEqual(sums, [ten.sum() for ten in result]) + + + @torch.jit.script + def test_multiple_uses(): + x = torch.tensor([1, 2, 3, 4]) + x.add_(3) + y = [x, x] + return torch.cat(y), y + + self.run_pass('remove_mutation', mod_script.forward.graph) + FileCheck().check("aten::add_").run(test_multiple_uses.graph) diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index 50b84d8f640..bb5872f35f4 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include @@ -298,15 +299,10 @@ void AliasDb::getReadsImpl(Node* n, MemoryLocations& ret) const { auto it = elementMap_.find(input); if (it != elementMap_.end()) { auto el = it->second; - // Add all memory locations this element may alias. - ret |= memoryDAG_->getMemoryLocations(el); - // We also consider memory locations of contained values to be "read". - for (const auto& type : input->type()->containedTypes()) { - if (auto wildcard = getWildcard(type)) { - ret |= memoryDAG_->getMemoryLocations(wildcard); - } - } + // Add all memory locations this element may alias and their contained + // elements + memoryDAG_->collectAllContainedMemoryLocations(el, ret); } } @@ -878,6 +874,44 @@ void AliasDb::analyzeConservative(Node* node) { } } +bool AliasDb::functionalNonEscapingListUse(const Use& use) const { + Node* n = use.user; + size_t offset = use.offset; + Value* container = n->inputs().at(offset); + + // only consider aten op uses of lists + if (!container->type()->cast()) { + return false; + } + + /* + in the general case, we consider any Value that enters another container as + entering the heap, and thus aliasing all other heap values of the same type. + the advantage of this approach are: + - there are many composite list/container ops that would be tricky to + schematize if we did something more complicated + - limits the size of the AliasDb, because a container of size 10 only contains + 1 memory dag element instead of 10 + - we do not need to worry about adding contained elements to the wildcard set + when a container escapes the graph. + The downside of this approach is we are unable to handle the common case of a + list constructed and passed into an aten op. Here, optimize for a set of + common ops where the output does not alias the list or the list elements + */ + + switch (use.user->kind()) { + case aten::cat: + case aten::broadcast_tensors: + case aten::stack: + case aten::vstack: + case aten::hstack: + case aten::dstack: + return true; + } + + return false; +} + // List or dict or tuple: construct: create an aliasing element for the actual // container, then mark all inputs as wildcards, since they've gone inside the // container. Then, add the wildcard sets of appropriate type to the contained @@ -895,6 +929,20 @@ void AliasDb::analyzeContainerConstruct(Node* node) { TORCH_INTERNAL_ASSERT(node->outputs().size() == 1); auto container = node->output(); + + // optimization: + // if a list is only used once in an aten op, and the op output + // doesn't alias the input, then we can add all inputs to the list's + // contained elements instead of the wildcard set. + if (container->uses().size() == 1 && + functionalNonEscapingListUse(container->uses().at(0))) { + giveFreshAlias(container, false); + for (Value* v : node->inputs()) { + addToContainedElements(v, container); + } + return; + } + giveFreshAlias(container); auto container_elem = elementMap_.at(container); for (auto input : node->inputs()) { @@ -1068,7 +1116,9 @@ void AliasDb::createValue(const Value* value) { elementMap_[value] = new_elem; } -void AliasDb::giveFreshAlias(const Value* value) { +void AliasDb::giveFreshAlias( + const Value* value, + bool add_wildcard_to_contained_elems) { auto maybe_mut_type = getMutableTypePtr(value->type()); if (!maybe_mut_type) { return; @@ -1082,7 +1132,9 @@ void AliasDb::giveFreshAlias(const Value* value) { auto new_elem = memoryDAGBuilder_->makeFreshValue(value); elementMap_[value] = new_elem; - addContainedTypesToFreshElement(new_elem, *maybe_mut_type); + if (add_wildcard_to_contained_elems) { + addContainedTypesToFreshElement(new_elem, *maybe_mut_type); + } } Element* AliasDb::getOrCreateElement(const Value* value) { diff --git a/torch/csrc/jit/ir/alias_analysis.h b/torch/csrc/jit/ir/alias_analysis.h index e3e69185891..b20654b1f6b 100644 --- a/torch/csrc/jit/ir/alias_analysis.h +++ b/torch/csrc/jit/ir/alias_analysis.h @@ -205,10 +205,13 @@ class AliasDb { const Value* element, const Value* container); void mapAliases(at::ArrayRef to, at::ArrayRef from); - void giveFreshAlias(const Value* value); + void giveFreshAlias( + const Value* value, + bool add_wildcard_to_contained_elems = true); Element* getOrCreateElement(const Value* value); c10::optional getMutableTypePtr(const TypePtr& type) const; + bool functionalNonEscapingListUse(const Use& use) const; bool isContainerType(const TypePtr& type) const;