diff --git a/test/cpp/jit/test_alias_analysis.cpp b/test/cpp/jit/test_alias_analysis.cpp index e854113a7a8..e700ee54061 100644 --- a/test/cpp/jit/test_alias_analysis.cpp +++ b/test/cpp/jit/test_alias_analysis.cpp @@ -1238,6 +1238,32 @@ TEST(AliasRegistrationTest, PureWithAnnotationsShouldError) { "Tried to register operator foo::rand11(Tensor(a) arg1) -> (Tensor(a)) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA"); } +TEST(AliasRegistrationTest, AliasMoveAtenListOp) { + auto graph = std::make_shared(); + std::unordered_map vmap; + auto graph_string = R"IR( + graph(): + %x : Tensor = prim::MakeTestTensor() + %8 : int = prim::Constant[value=0]() + %5 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=2]() + %y : Tensor[] = prim::ListConstruct(%x) + %6 : Tensor = aten::add_(%x, %4, %5) + %9 : Tensor = aten::cat(%y, %8) + return (%9))IR"; + + torch::jit::parseIR(graph_string, graph.get(), vmap); + AliasDb aliasDb(graph); + + // bc y.1 has a single used in a single non-aliasing aten op, + // x is added to y.1 contained elements instead of wildcard set + EXPECT_TRUE(!aliasDb.mayAlias(vmap["x"], vmap["9"])); + + // write to contained element should prevent move + EXPECT_TRUE(!aliasDb.moveBeforeTopologicallyValid( + vmap["y"]->node(), vmap["9"]->node())); +} + TEST(AliasRegistrationTest, PureWithAnnotationsShouldError2) { auto registry = torch::RegisterOperators().op( "foo::rand12(Tensor(a) arg1) -> Tensor(b)", diff --git a/test/jit/test_remove_mutation.py b/test/jit/test_remove_mutation.py index ef408e775c3..b747fc06bcd 100644 --- a/test/jit/test_remove_mutation.py +++ b/test/jit/test_remove_mutation.py @@ -200,3 +200,44 @@ class TestRemoveMutation(JitTestCase): # it is possible to remove the append here but don't currently have the logic for it FileCheck().check_not("append").run(graph) self.assertEqual(intermediary_use(), fn()) + + def test_common_pytorch_list_ops(self): + for op in ["cat", "stack", "vstack", "hstack", "dstack"]: + class OpMod(torch.nn.Module): + def __init__(self, op): + super(OpMod, self).__init__() + self.op = torch_op + + def forward(self): + x = torch.tensor([1, 2, 3, 4]) + x.add_(3) + y = [x, x] + return self.op(y) + 3 + + torch_op = getattr(torch, op) + mod = OpMod(torch_op) + mod_script = torch.jit.script(mod) + self.run_pass('remove_mutation', mod_script.forward.graph) + FileCheck().check_not("aten::add_").run(mod_script.forward.graph) + self.assertEqual(mod(), mod_script()) + + # test that the output doesnt alias the input + for inputs in [torch.rand(2, 2)], [torch.rand(2, 2) for _ in range(2)]: + result = torch_op(inputs) + sums = [ten.sum() for ten in result] + + for inp in inputs: + inp.fill_(10) + + self.assertEqual(sums, [ten.sum() for ten in result]) + + + @torch.jit.script + def test_multiple_uses(): + x = torch.tensor([1, 2, 3, 4]) + x.add_(3) + y = [x, x] + return torch.cat(y), y + + self.run_pass('remove_mutation', mod_script.forward.graph) + FileCheck().check("aten::add_").run(test_multiple_uses.graph) diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index 50b84d8f640..bb5872f35f4 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include @@ -298,15 +299,10 @@ void AliasDb::getReadsImpl(Node* n, MemoryLocations& ret) const { auto it = elementMap_.find(input); if (it != elementMap_.end()) { auto el = it->second; - // Add all memory locations this element may alias. - ret |= memoryDAG_->getMemoryLocations(el); - // We also consider memory locations of contained values to be "read". - for (const auto& type : input->type()->containedTypes()) { - if (auto wildcard = getWildcard(type)) { - ret |= memoryDAG_->getMemoryLocations(wildcard); - } - } + // Add all memory locations this element may alias and their contained + // elements + memoryDAG_->collectAllContainedMemoryLocations(el, ret); } } @@ -878,6 +874,44 @@ void AliasDb::analyzeConservative(Node* node) { } } +bool AliasDb::functionalNonEscapingListUse(const Use& use) const { + Node* n = use.user; + size_t offset = use.offset; + Value* container = n->inputs().at(offset); + + // only consider aten op uses of lists + if (!container->type()->cast()) { + return false; + } + + /* + in the general case, we consider any Value that enters another container as + entering the heap, and thus aliasing all other heap values of the same type. + the advantage of this approach are: + - there are many composite list/container ops that would be tricky to + schematize if we did something more complicated + - limits the size of the AliasDb, because a container of size 10 only contains + 1 memory dag element instead of 10 + - we do not need to worry about adding contained elements to the wildcard set + when a container escapes the graph. + The downside of this approach is we are unable to handle the common case of a + list constructed and passed into an aten op. Here, optimize for a set of + common ops where the output does not alias the list or the list elements + */ + + switch (use.user->kind()) { + case aten::cat: + case aten::broadcast_tensors: + case aten::stack: + case aten::vstack: + case aten::hstack: + case aten::dstack: + return true; + } + + return false; +} + // List or dict or tuple: construct: create an aliasing element for the actual // container, then mark all inputs as wildcards, since they've gone inside the // container. Then, add the wildcard sets of appropriate type to the contained @@ -895,6 +929,20 @@ void AliasDb::analyzeContainerConstruct(Node* node) { TORCH_INTERNAL_ASSERT(node->outputs().size() == 1); auto container = node->output(); + + // optimization: + // if a list is only used once in an aten op, and the op output + // doesn't alias the input, then we can add all inputs to the list's + // contained elements instead of the wildcard set. + if (container->uses().size() == 1 && + functionalNonEscapingListUse(container->uses().at(0))) { + giveFreshAlias(container, false); + for (Value* v : node->inputs()) { + addToContainedElements(v, container); + } + return; + } + giveFreshAlias(container); auto container_elem = elementMap_.at(container); for (auto input : node->inputs()) { @@ -1068,7 +1116,9 @@ void AliasDb::createValue(const Value* value) { elementMap_[value] = new_elem; } -void AliasDb::giveFreshAlias(const Value* value) { +void AliasDb::giveFreshAlias( + const Value* value, + bool add_wildcard_to_contained_elems) { auto maybe_mut_type = getMutableTypePtr(value->type()); if (!maybe_mut_type) { return; @@ -1082,7 +1132,9 @@ void AliasDb::giveFreshAlias(const Value* value) { auto new_elem = memoryDAGBuilder_->makeFreshValue(value); elementMap_[value] = new_elem; - addContainedTypesToFreshElement(new_elem, *maybe_mut_type); + if (add_wildcard_to_contained_elems) { + addContainedTypesToFreshElement(new_elem, *maybe_mut_type); + } } Element* AliasDb::getOrCreateElement(const Value* value) { diff --git a/torch/csrc/jit/ir/alias_analysis.h b/torch/csrc/jit/ir/alias_analysis.h index e3e69185891..b20654b1f6b 100644 --- a/torch/csrc/jit/ir/alias_analysis.h +++ b/torch/csrc/jit/ir/alias_analysis.h @@ -205,10 +205,13 @@ class AliasDb { const Value* element, const Value* container); void mapAliases(at::ArrayRef to, at::ArrayRef from); - void giveFreshAlias(const Value* value); + void giveFreshAlias( + const Value* value, + bool add_wildcard_to_contained_elems = true); Element* getOrCreateElement(const Value* value); c10::optional getMutableTypePtr(const TypePtr& type) const; + bool functionalNonEscapingListUse(const Use& use) const; bool isContainerType(const TypePtr& type) const;