mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[JIT] improve alias analysis for list constructs (#39111)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39111 In our present alias analysis, we consider any Value that enter another container as entering the heap, and thus aliasing all other heap values of the same type. There are a number of advantages to this approach: - it is not to hard to maintain the aliasDb implementation - it is much easier from an op schema perspective - there are many composite list ops registered internally and externally that would be tricky to register and get right if we did something more complicated - It limits the size of the AliasDb, because a container of size 10 only contains a single memory dag element instead of 10 elements. The downside is that we have are unable to handle the simple and extremely common case of a list of tensors being used in an ATen op. In an example like: ``` def foo(input): x = torch.tensor([1, 2, 3, 4]) y = [x, x] input.add_(1) return torch.cat(y) ``` we will consider x to be written to. any write to any wildcard element (an element that enters a tuple, an element that is taken from a list) will mark x as written to. This can be limiting for our ability to create a functional subset and fuse graphs - as a result, 4 of TorchVision classification models could not be functionalized. Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D23828003 Pulled By: eellison fbshipit-source-id: 9109fcb6f2ca20ca897cae71683530285da9d537
This commit is contained in:
parent
9fc7a942f0
commit
ae286d81e0
4 changed files with 133 additions and 11 deletions
|
|
@ -1238,6 +1238,32 @@ TEST(AliasRegistrationTest, PureWithAnnotationsShouldError) {
|
|||
"Tried to register operator foo::rand11(Tensor(a) arg1) -> (Tensor(a)) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
|
||||
}
|
||||
|
||||
TEST(AliasRegistrationTest, AliasMoveAtenListOp) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
auto graph_string = R"IR(
|
||||
graph():
|
||||
%x : Tensor = prim::MakeTestTensor()
|
||||
%8 : int = prim::Constant[value=0]()
|
||||
%5 : int = prim::Constant[value=1]()
|
||||
%4 : int = prim::Constant[value=2]()
|
||||
%y : Tensor[] = prim::ListConstruct(%x)
|
||||
%6 : Tensor = aten::add_(%x, %4, %5)
|
||||
%9 : Tensor = aten::cat(%y, %8)
|
||||
return (%9))IR";
|
||||
|
||||
torch::jit::parseIR(graph_string, graph.get(), vmap);
|
||||
AliasDb aliasDb(graph);
|
||||
|
||||
// bc y.1 has a single used in a single non-aliasing aten op,
|
||||
// x is added to y.1 contained elements instead of wildcard set
|
||||
EXPECT_TRUE(!aliasDb.mayAlias(vmap["x"], vmap["9"]));
|
||||
|
||||
// write to contained element should prevent move
|
||||
EXPECT_TRUE(!aliasDb.moveBeforeTopologicallyValid(
|
||||
vmap["y"]->node(), vmap["9"]->node()));
|
||||
}
|
||||
|
||||
TEST(AliasRegistrationTest, PureWithAnnotationsShouldError2) {
|
||||
auto registry = torch::RegisterOperators().op(
|
||||
"foo::rand12(Tensor(a) arg1) -> Tensor(b)",
|
||||
|
|
|
|||
|
|
@ -200,3 +200,44 @@ class TestRemoveMutation(JitTestCase):
|
|||
# it is possible to remove the append here but don't currently have the logic for it
|
||||
FileCheck().check_not("append").run(graph)
|
||||
self.assertEqual(intermediary_use(), fn())
|
||||
|
||||
def test_common_pytorch_list_ops(self):
|
||||
for op in ["cat", "stack", "vstack", "hstack", "dstack"]:
|
||||
class OpMod(torch.nn.Module):
|
||||
def __init__(self, op):
|
||||
super(OpMod, self).__init__()
|
||||
self.op = torch_op
|
||||
|
||||
def forward(self):
|
||||
x = torch.tensor([1, 2, 3, 4])
|
||||
x.add_(3)
|
||||
y = [x, x]
|
||||
return self.op(y) + 3
|
||||
|
||||
torch_op = getattr(torch, op)
|
||||
mod = OpMod(torch_op)
|
||||
mod_script = torch.jit.script(mod)
|
||||
self.run_pass('remove_mutation', mod_script.forward.graph)
|
||||
FileCheck().check_not("aten::add_").run(mod_script.forward.graph)
|
||||
self.assertEqual(mod(), mod_script())
|
||||
|
||||
# test that the output doesnt alias the input
|
||||
for inputs in [torch.rand(2, 2)], [torch.rand(2, 2) for _ in range(2)]:
|
||||
result = torch_op(inputs)
|
||||
sums = [ten.sum() for ten in result]
|
||||
|
||||
for inp in inputs:
|
||||
inp.fill_(10)
|
||||
|
||||
self.assertEqual(sums, [ten.sum() for ten in result])
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def test_multiple_uses():
|
||||
x = torch.tensor([1, 2, 3, 4])
|
||||
x.add_(3)
|
||||
y = [x, x]
|
||||
return torch.cat(y), y
|
||||
|
||||
self.run_pass('remove_mutation', mod_script.forward.graph)
|
||||
FileCheck().check("aten::add_").run(test_multiple_uses.graph)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/utils/memory.h>
|
||||
|
||||
|
|
@ -298,15 +299,10 @@ void AliasDb::getReadsImpl(Node* n, MemoryLocations& ret) const {
|
|||
auto it = elementMap_.find(input);
|
||||
if (it != elementMap_.end()) {
|
||||
auto el = it->second;
|
||||
// Add all memory locations this element may alias.
|
||||
ret |= memoryDAG_->getMemoryLocations(el);
|
||||
|
||||
// We also consider memory locations of contained values to be "read".
|
||||
for (const auto& type : input->type()->containedTypes()) {
|
||||
if (auto wildcard = getWildcard(type)) {
|
||||
ret |= memoryDAG_->getMemoryLocations(wildcard);
|
||||
}
|
||||
}
|
||||
// Add all memory locations this element may alias and their contained
|
||||
// elements
|
||||
memoryDAG_->collectAllContainedMemoryLocations(el, ret);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -878,6 +874,44 @@ void AliasDb::analyzeConservative(Node* node) {
|
|||
}
|
||||
}
|
||||
|
||||
bool AliasDb::functionalNonEscapingListUse(const Use& use) const {
|
||||
Node* n = use.user;
|
||||
size_t offset = use.offset;
|
||||
Value* container = n->inputs().at(offset);
|
||||
|
||||
// only consider aten op uses of lists
|
||||
if (!container->type()->cast<ListType>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/*
|
||||
in the general case, we consider any Value that enters another container as
|
||||
entering the heap, and thus aliasing all other heap values of the same type.
|
||||
the advantage of this approach are:
|
||||
- there are many composite list/container ops that would be tricky to
|
||||
schematize if we did something more complicated
|
||||
- limits the size of the AliasDb, because a container of size 10 only contains
|
||||
1 memory dag element instead of 10
|
||||
- we do not need to worry about adding contained elements to the wildcard set
|
||||
when a container escapes the graph.
|
||||
The downside of this approach is we are unable to handle the common case of a
|
||||
list constructed and passed into an aten op. Here, optimize for a set of
|
||||
common ops where the output does not alias the list or the list elements
|
||||
*/
|
||||
|
||||
switch (use.user->kind()) {
|
||||
case aten::cat:
|
||||
case aten::broadcast_tensors:
|
||||
case aten::stack:
|
||||
case aten::vstack:
|
||||
case aten::hstack:
|
||||
case aten::dstack:
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// List or dict or tuple: construct: create an aliasing element for the actual
|
||||
// container, then mark all inputs as wildcards, since they've gone inside the
|
||||
// container. Then, add the wildcard sets of appropriate type to the contained
|
||||
|
|
@ -895,6 +929,20 @@ void AliasDb::analyzeContainerConstruct(Node* node) {
|
|||
|
||||
TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
|
||||
auto container = node->output();
|
||||
|
||||
// optimization:
|
||||
// if a list is only used once in an aten op, and the op output
|
||||
// doesn't alias the input, then we can add all inputs to the list's
|
||||
// contained elements instead of the wildcard set.
|
||||
if (container->uses().size() == 1 &&
|
||||
functionalNonEscapingListUse(container->uses().at(0))) {
|
||||
giveFreshAlias(container, false);
|
||||
for (Value* v : node->inputs()) {
|
||||
addToContainedElements(v, container);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
giveFreshAlias(container);
|
||||
auto container_elem = elementMap_.at(container);
|
||||
for (auto input : node->inputs()) {
|
||||
|
|
@ -1068,7 +1116,9 @@ void AliasDb::createValue(const Value* value) {
|
|||
elementMap_[value] = new_elem;
|
||||
}
|
||||
|
||||
void AliasDb::giveFreshAlias(const Value* value) {
|
||||
void AliasDb::giveFreshAlias(
|
||||
const Value* value,
|
||||
bool add_wildcard_to_contained_elems) {
|
||||
auto maybe_mut_type = getMutableTypePtr(value->type());
|
||||
if (!maybe_mut_type) {
|
||||
return;
|
||||
|
|
@ -1082,7 +1132,9 @@ void AliasDb::giveFreshAlias(const Value* value) {
|
|||
|
||||
auto new_elem = memoryDAGBuilder_->makeFreshValue(value);
|
||||
elementMap_[value] = new_elem;
|
||||
addContainedTypesToFreshElement(new_elem, *maybe_mut_type);
|
||||
if (add_wildcard_to_contained_elems) {
|
||||
addContainedTypesToFreshElement(new_elem, *maybe_mut_type);
|
||||
}
|
||||
}
|
||||
|
||||
Element* AliasDb::getOrCreateElement(const Value* value) {
|
||||
|
|
|
|||
|
|
@ -205,10 +205,13 @@ class AliasDb {
|
|||
const Value* element,
|
||||
const Value* container);
|
||||
void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from);
|
||||
void giveFreshAlias(const Value* value);
|
||||
void giveFreshAlias(
|
||||
const Value* value,
|
||||
bool add_wildcard_to_contained_elems = true);
|
||||
Element* getOrCreateElement(const Value* value);
|
||||
|
||||
c10::optional<TypePtr> getMutableTypePtr(const TypePtr& type) const;
|
||||
bool functionalNonEscapingListUse(const Use& use) const;
|
||||
|
||||
bool isContainerType(const TypePtr& type) const;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue