[JIT] improve alias analysis for list constructs (#39111)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39111

In our present alias analysis, we consider any Value that enter another container as entering the heap, and thus aliasing all other heap values of the same type. There are a number of advantages to this approach:
- it is not to hard to maintain the aliasDb implementation
- it is much easier from an op schema perspective - there are many composite list ops registered internally and externally that would be tricky to register and get right if we did something more complicated
- It limits the size of the AliasDb, because a container of size 10 only contains a single memory dag element instead of 10 elements.

The downside is that we have are unable to handle the simple and extremely common case of a list of tensors being used in an ATen op.

In an example like:

```
 def foo(input):
    x = torch.tensor([1, 2, 3, 4])
    y = [x, x]
    input.add_(1)
    return torch.cat(y)
```

we will consider x to be written to. any write to any wildcard element (an element that enters a tuple, an element that is taken from a list) will mark x as written to. This can be limiting for our ability to create a functional subset and fuse graphs - as a result, 4 of TorchVision classification models could not be functionalized.

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Differential Revision: D23828003

Pulled By: eellison

fbshipit-source-id: 9109fcb6f2ca20ca897cae71683530285da9d537
This commit is contained in:
Elias Ellison 2020-09-22 09:37:00 -07:00 committed by Facebook GitHub Bot
parent 9fc7a942f0
commit ae286d81e0
4 changed files with 133 additions and 11 deletions

View file

@ -1238,6 +1238,32 @@ TEST(AliasRegistrationTest, PureWithAnnotationsShouldError) {
"Tried to register operator foo::rand11(Tensor(a) arg1) -> (Tensor(a)) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
}
TEST(AliasRegistrationTest, AliasMoveAtenListOp) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
auto graph_string = R"IR(
graph():
%x : Tensor = prim::MakeTestTensor()
%8 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=1]()
%4 : int = prim::Constant[value=2]()
%y : Tensor[] = prim::ListConstruct(%x)
%6 : Tensor = aten::add_(%x, %4, %5)
%9 : Tensor = aten::cat(%y, %8)
return (%9))IR";
torch::jit::parseIR(graph_string, graph.get(), vmap);
AliasDb aliasDb(graph);
// bc y.1 has a single used in a single non-aliasing aten op,
// x is added to y.1 contained elements instead of wildcard set
EXPECT_TRUE(!aliasDb.mayAlias(vmap["x"], vmap["9"]));
// write to contained element should prevent move
EXPECT_TRUE(!aliasDb.moveBeforeTopologicallyValid(
vmap["y"]->node(), vmap["9"]->node()));
}
TEST(AliasRegistrationTest, PureWithAnnotationsShouldError2) {
auto registry = torch::RegisterOperators().op(
"foo::rand12(Tensor(a) arg1) -> Tensor(b)",

View file

@ -200,3 +200,44 @@ class TestRemoveMutation(JitTestCase):
# it is possible to remove the append here but don't currently have the logic for it
FileCheck().check_not("append").run(graph)
self.assertEqual(intermediary_use(), fn())
def test_common_pytorch_list_ops(self):
for op in ["cat", "stack", "vstack", "hstack", "dstack"]:
class OpMod(torch.nn.Module):
def __init__(self, op):
super(OpMod, self).__init__()
self.op = torch_op
def forward(self):
x = torch.tensor([1, 2, 3, 4])
x.add_(3)
y = [x, x]
return self.op(y) + 3
torch_op = getattr(torch, op)
mod = OpMod(torch_op)
mod_script = torch.jit.script(mod)
self.run_pass('remove_mutation', mod_script.forward.graph)
FileCheck().check_not("aten::add_").run(mod_script.forward.graph)
self.assertEqual(mod(), mod_script())
# test that the output doesnt alias the input
for inputs in [torch.rand(2, 2)], [torch.rand(2, 2) for _ in range(2)]:
result = torch_op(inputs)
sums = [ten.sum() for ten in result]
for inp in inputs:
inp.fill_(10)
self.assertEqual(sums, [ten.sum() for ten in result])
@torch.jit.script
def test_multiple_uses():
x = torch.tensor([1, 2, 3, 4])
x.add_(3)
y = [x, x]
return torch.cat(y), y
self.run_pass('remove_mutation', mod_script.forward.graph)
FileCheck().check("aten::add_").run(test_multiple_uses.graph)

View file

@ -1,6 +1,7 @@
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/utils/memory.h>
@ -298,15 +299,10 @@ void AliasDb::getReadsImpl(Node* n, MemoryLocations& ret) const {
auto it = elementMap_.find(input);
if (it != elementMap_.end()) {
auto el = it->second;
// Add all memory locations this element may alias.
ret |= memoryDAG_->getMemoryLocations(el);
// We also consider memory locations of contained values to be "read".
for (const auto& type : input->type()->containedTypes()) {
if (auto wildcard = getWildcard(type)) {
ret |= memoryDAG_->getMemoryLocations(wildcard);
}
}
// Add all memory locations this element may alias and their contained
// elements
memoryDAG_->collectAllContainedMemoryLocations(el, ret);
}
}
@ -878,6 +874,44 @@ void AliasDb::analyzeConservative(Node* node) {
}
}
bool AliasDb::functionalNonEscapingListUse(const Use& use) const {
Node* n = use.user;
size_t offset = use.offset;
Value* container = n->inputs().at(offset);
// only consider aten op uses of lists
if (!container->type()->cast<ListType>()) {
return false;
}
/*
in the general case, we consider any Value that enters another container as
entering the heap, and thus aliasing all other heap values of the same type.
the advantage of this approach are:
- there are many composite list/container ops that would be tricky to
schematize if we did something more complicated
- limits the size of the AliasDb, because a container of size 10 only contains
1 memory dag element instead of 10
- we do not need to worry about adding contained elements to the wildcard set
when a container escapes the graph.
The downside of this approach is we are unable to handle the common case of a
list constructed and passed into an aten op. Here, optimize for a set of
common ops where the output does not alias the list or the list elements
*/
switch (use.user->kind()) {
case aten::cat:
case aten::broadcast_tensors:
case aten::stack:
case aten::vstack:
case aten::hstack:
case aten::dstack:
return true;
}
return false;
}
// List or dict or tuple: construct: create an aliasing element for the actual
// container, then mark all inputs as wildcards, since they've gone inside the
// container. Then, add the wildcard sets of appropriate type to the contained
@ -895,6 +929,20 @@ void AliasDb::analyzeContainerConstruct(Node* node) {
TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
auto container = node->output();
// optimization:
// if a list is only used once in an aten op, and the op output
// doesn't alias the input, then we can add all inputs to the list's
// contained elements instead of the wildcard set.
if (container->uses().size() == 1 &&
functionalNonEscapingListUse(container->uses().at(0))) {
giveFreshAlias(container, false);
for (Value* v : node->inputs()) {
addToContainedElements(v, container);
}
return;
}
giveFreshAlias(container);
auto container_elem = elementMap_.at(container);
for (auto input : node->inputs()) {
@ -1068,7 +1116,9 @@ void AliasDb::createValue(const Value* value) {
elementMap_[value] = new_elem;
}
void AliasDb::giveFreshAlias(const Value* value) {
void AliasDb::giveFreshAlias(
const Value* value,
bool add_wildcard_to_contained_elems) {
auto maybe_mut_type = getMutableTypePtr(value->type());
if (!maybe_mut_type) {
return;
@ -1082,7 +1132,9 @@ void AliasDb::giveFreshAlias(const Value* value) {
auto new_elem = memoryDAGBuilder_->makeFreshValue(value);
elementMap_[value] = new_elem;
addContainedTypesToFreshElement(new_elem, *maybe_mut_type);
if (add_wildcard_to_contained_elems) {
addContainedTypesToFreshElement(new_elem, *maybe_mut_type);
}
}
Element* AliasDb::getOrCreateElement(const Value* value) {

View file

@ -205,10 +205,13 @@ class AliasDb {
const Value* element,
const Value* container);
void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from);
void giveFreshAlias(const Value* value);
void giveFreshAlias(
const Value* value,
bool add_wildcard_to_contained_elems = true);
Element* getOrCreateElement(const Value* value);
c10::optional<TypePtr> getMutableTypePtr(const TypePtr& type) const;
bool functionalNonEscapingListUse(const Use& use) const;
bool isContainerType(const TypePtr& type) const;