Remove precise tuple construct flag (#71121)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71121

Test Plan: Imported from OSS

Reviewed By: d1jang

Differential Revision: D33515234

Pulled By: eellison

fbshipit-source-id: 57cfe171b583a6bb4d3493a34b159061e97a11b8
This commit is contained in:
Elias Ellison 2022-01-11 22:09:58 -08:00 committed by Facebook GitHub Bot
parent 47ad6628f1
commit 9bccb31306
5 changed files with 18 additions and 29 deletions

View file

@ -4,6 +4,7 @@
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
#include <torch/csrc/utils/memory.h>
@ -1474,7 +1475,7 @@ TEST(
torch::jit::parseIR(graph_string, graph.get(), vmap);
AliasDb aliasDb(
graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);
graph, /*isFrozen=*/false);
EXPECT_TRUE(!aliasDb.mayAlias(vmap["x"], vmap["y"]));
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["z"], vmap["x"]));
@ -1483,22 +1484,23 @@ TEST(
TEST(
AliasRegistrationTest,
WildcareAliasForTupleConstructWithSingleUseAsGraphOutputWithDisablePreciseTupleContainerAnalysis) {
RecursiveSubgraphTupleContainment) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
auto graph_string = R"IR(
graph():
%x : Tensor = prim::MakeTestTensor()
%y : Tensor = prim::MakeTestTensor()
%z : (Tensor) = prim::TupleConstruct(%x, %y)
%z : (Tensor, Tensor) = prim::TupleConstruct(%x, %y)
return (%z))IR";
torch::jit::parseIR(graph_string, graph.get(), vmap);
// enablePreciseTupleContainerAnalysis = false.
auto node = vmap["z"]->node();
auto subgraph = SubgraphUtils::createSingletonSubgraph(node, prim::FunctionalGraph);
AliasDb aliasDb(graph);
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["z"], vmap["x"]));
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["z"], vmap["y"]));
EXPECT_TRUE(aliasDb.mayContainAlias(subgraph->output(), vmap["x"]));
EXPECT_TRUE(aliasDb.mayContainAlias(subgraph->output(), vmap["y"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["y"]));
}
@ -1519,7 +1521,7 @@ TEST(AliasRegistrationTest, WildcardAliasForTupleConstructWithUses) {
torch::jit::parseIR(graph_string, graph.get(), vmap);
AliasDb aliasDb(
graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);
graph, /*isFrozen=*/false);
EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["y"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["z"]));
@ -1551,8 +1553,7 @@ TEST(AliasRegistrationTest, ATenSplitIntListAliasCheck) {
return (%d))IR";
torch::jit::parseIR(graph_string, graph.get(), vmap);
AliasDb aliasDb(
graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);
AliasDb aliasDb(graph, /*isFrozen=*/false);
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"]));
@ -1578,8 +1579,7 @@ TEST(AliasRegistrationTest, ATenSplitIntAliasCheck) {
return (%d))IR";
torch::jit::parseIR(graph_string, graph.get(), vmap);
AliasDb aliasDb(
graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);
AliasDb aliasDb(graph, /*isFrozen=*/false);
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"]));

View file

@ -207,13 +207,9 @@ struct AliasDb::WriteRegistry {
std::unordered_set<Node*> writesToAllWildcards_;
};
AliasDb::AliasDb(
std::shared_ptr<Graph> graph,
bool isFrozen,
bool enablePreciseTupleContainerAnalysis)
AliasDb::AliasDb(std::shared_ptr<Graph> graph, bool isFrozen)
: graph_(std::move(graph)),
isFrozen_(isFrozen),
enablePreciseTupleContainerAnalysis_(enablePreciseTupleContainerAnalysis),
memoryDAGBuilder_(std::make_unique<MemoryDAGBuilder>()),
writeRegistry_(std::make_unique<AliasDb::WriteRegistry>()) {
analyze(graph_);
@ -1146,9 +1142,6 @@ bool AliasDb::functionalNonEscapingListUse(const Use& use) const {
}
bool AliasDb::functionalNonEscapingTupleUse(const Use& use) const {
if (!enablePreciseTupleContainerAnalysis_) {
return false;
}
Node* n = use.user;
size_t offset = use.offset;
Value* container = n->inputs().at(offset);
@ -1156,7 +1149,9 @@ bool AliasDb::functionalNonEscapingTupleUse(const Use& use) const {
return false;
}
// TODO(T97387453): Cover more ops that do not let escape tuples' elements.
return use.user->kind() == prim::Return;
bool in_return_outputs = use.user->kind() == prim::Return;
bool not_in_nested_subgraph = use.user->owningBlock() == graph_->block();
return in_return_outputs && not_in_nested_subgraph;
}
// List or dict or tuple construct: create an aliasing element for the actual

View file

@ -45,8 +45,7 @@ class AliasDb {
public:
TORCH_API explicit AliasDb(
std::shared_ptr<Graph> graphi,
bool isFrozen = false,
bool enablePreciseTupleContainerAnalysis = false);
bool isFrozen = false);
TORCH_API ~AliasDb();
// There are limitations to what effects the alias analysis can track. Two
@ -248,9 +247,6 @@ class AliasDb {
// internally.
bool isFrozen_;
// Enable precise treatment of prim::TupleConstruct.
bool enablePreciseTupleContainerAnalysis_ = false;
// The points-to graph that stores aliasing relationships
std::unique_ptr<MemoryDAGBuilder> memoryDAGBuilder_;
std::unique_ptr<MemoryDAG> memoryDAG_;

View file

@ -534,8 +534,7 @@ StaticModule::StaticModule(
// Create ProcessedFunction instances first to freeze their addresses to pass
// to ProcessedNode.
AliasDb alias_db(
graph_, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);
AliasDb alias_db(graph_, /*isFrozen=*/false);
GRAPH_DEBUG("AliasDb: ", alias_db.toString());
// Construct constant and function nodes

View file

@ -744,8 +744,7 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
AliasDb alias_db(
graph,
/*isFrozen=*/false,
/*enablePreciseTupleContainerAnalysis=*/true);
/*isFrozen=*/false);
const std::vector<Value*> graph_outputs(
graph->outputs().begin(), graph->outputs().end());
auto nodes = graph->nodes();