mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Remove precise tuple construct flag (#71121)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71121 Test Plan: Imported from OSS Reviewed By: d1jang Differential Revision: D33515234 Pulled By: eellison fbshipit-source-id: 57cfe171b583a6bb4d3493a34b159061e97a11b8
This commit is contained in:
parent
47ad6628f1
commit
9bccb31306
5 changed files with 18 additions and 29 deletions
|
|
@ -4,6 +4,7 @@
|
|||
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
||||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
||||
#include <torch/csrc/utils/memory.h>
|
||||
|
|
@ -1474,7 +1475,7 @@ TEST(
|
|||
|
||||
torch::jit::parseIR(graph_string, graph.get(), vmap);
|
||||
AliasDb aliasDb(
|
||||
graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);
|
||||
graph, /*isFrozen=*/false);
|
||||
|
||||
EXPECT_TRUE(!aliasDb.mayAlias(vmap["x"], vmap["y"]));
|
||||
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["z"], vmap["x"]));
|
||||
|
|
@ -1483,22 +1484,23 @@ TEST(
|
|||
|
||||
TEST(
|
||||
AliasRegistrationTest,
|
||||
WildcareAliasForTupleConstructWithSingleUseAsGraphOutputWithDisablePreciseTupleContainerAnalysis) {
|
||||
RecursiveSubgraphTupleContainment) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
auto graph_string = R"IR(
|
||||
graph():
|
||||
%x : Tensor = prim::MakeTestTensor()
|
||||
%y : Tensor = prim::MakeTestTensor()
|
||||
%z : (Tensor) = prim::TupleConstruct(%x, %y)
|
||||
%z : (Tensor, Tensor) = prim::TupleConstruct(%x, %y)
|
||||
return (%z))IR";
|
||||
|
||||
torch::jit::parseIR(graph_string, graph.get(), vmap);
|
||||
// enablePreciseTupleContainerAnalysis = false.
|
||||
auto node = vmap["z"]->node();
|
||||
auto subgraph = SubgraphUtils::createSingletonSubgraph(node, prim::FunctionalGraph);
|
||||
AliasDb aliasDb(graph);
|
||||
|
||||
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["z"], vmap["x"]));
|
||||
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["z"], vmap["y"]));
|
||||
EXPECT_TRUE(aliasDb.mayContainAlias(subgraph->output(), vmap["x"]));
|
||||
EXPECT_TRUE(aliasDb.mayContainAlias(subgraph->output(), vmap["y"]));
|
||||
EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["y"]));
|
||||
}
|
||||
|
||||
|
|
@ -1519,7 +1521,7 @@ TEST(AliasRegistrationTest, WildcardAliasForTupleConstructWithUses) {
|
|||
|
||||
torch::jit::parseIR(graph_string, graph.get(), vmap);
|
||||
AliasDb aliasDb(
|
||||
graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);
|
||||
graph, /*isFrozen=*/false);
|
||||
|
||||
EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["y"]));
|
||||
EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["z"]));
|
||||
|
|
@ -1551,8 +1553,7 @@ TEST(AliasRegistrationTest, ATenSplitIntListAliasCheck) {
|
|||
return (%d))IR";
|
||||
|
||||
torch::jit::parseIR(graph_string, graph.get(), vmap);
|
||||
AliasDb aliasDb(
|
||||
graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);
|
||||
AliasDb aliasDb(graph, /*isFrozen=*/false);
|
||||
|
||||
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"]));
|
||||
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"]));
|
||||
|
|
@ -1578,8 +1579,7 @@ TEST(AliasRegistrationTest, ATenSplitIntAliasCheck) {
|
|||
return (%d))IR";
|
||||
|
||||
torch::jit::parseIR(graph_string, graph.get(), vmap);
|
||||
AliasDb aliasDb(
|
||||
graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);
|
||||
AliasDb aliasDb(graph, /*isFrozen=*/false);
|
||||
|
||||
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"]));
|
||||
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"]));
|
||||
|
|
|
|||
|
|
@ -207,13 +207,9 @@ struct AliasDb::WriteRegistry {
|
|||
std::unordered_set<Node*> writesToAllWildcards_;
|
||||
};
|
||||
|
||||
AliasDb::AliasDb(
|
||||
std::shared_ptr<Graph> graph,
|
||||
bool isFrozen,
|
||||
bool enablePreciseTupleContainerAnalysis)
|
||||
AliasDb::AliasDb(std::shared_ptr<Graph> graph, bool isFrozen)
|
||||
: graph_(std::move(graph)),
|
||||
isFrozen_(isFrozen),
|
||||
enablePreciseTupleContainerAnalysis_(enablePreciseTupleContainerAnalysis),
|
||||
memoryDAGBuilder_(std::make_unique<MemoryDAGBuilder>()),
|
||||
writeRegistry_(std::make_unique<AliasDb::WriteRegistry>()) {
|
||||
analyze(graph_);
|
||||
|
|
@ -1146,9 +1142,6 @@ bool AliasDb::functionalNonEscapingListUse(const Use& use) const {
|
|||
}
|
||||
|
||||
bool AliasDb::functionalNonEscapingTupleUse(const Use& use) const {
|
||||
if (!enablePreciseTupleContainerAnalysis_) {
|
||||
return false;
|
||||
}
|
||||
Node* n = use.user;
|
||||
size_t offset = use.offset;
|
||||
Value* container = n->inputs().at(offset);
|
||||
|
|
@ -1156,7 +1149,9 @@ bool AliasDb::functionalNonEscapingTupleUse(const Use& use) const {
|
|||
return false;
|
||||
}
|
||||
// TODO(T97387453): Cover more ops that do not let escape tuples' elements.
|
||||
return use.user->kind() == prim::Return;
|
||||
bool in_return_outputs = use.user->kind() == prim::Return;
|
||||
bool not_in_nested_subgraph = use.user->owningBlock() == graph_->block();
|
||||
return in_return_outputs && not_in_nested_subgraph;
|
||||
}
|
||||
|
||||
// List or dict or tuple construct: create an aliasing element for the actual
|
||||
|
|
|
|||
|
|
@ -45,8 +45,7 @@ class AliasDb {
|
|||
public:
|
||||
TORCH_API explicit AliasDb(
|
||||
std::shared_ptr<Graph> graphi,
|
||||
bool isFrozen = false,
|
||||
bool enablePreciseTupleContainerAnalysis = false);
|
||||
bool isFrozen = false);
|
||||
TORCH_API ~AliasDb();
|
||||
|
||||
// There are limitations to what effects the alias analysis can track. Two
|
||||
|
|
@ -248,9 +247,6 @@ class AliasDb {
|
|||
// internally.
|
||||
bool isFrozen_;
|
||||
|
||||
// Enable precise treatment of prim::TupleConstruct.
|
||||
bool enablePreciseTupleContainerAnalysis_ = false;
|
||||
|
||||
// The points-to graph that stores aliasing relationships
|
||||
std::unique_ptr<MemoryDAGBuilder> memoryDAGBuilder_;
|
||||
std::unique_ptr<MemoryDAG> memoryDAG_;
|
||||
|
|
|
|||
|
|
@ -534,8 +534,7 @@ StaticModule::StaticModule(
|
|||
|
||||
// Create ProcessedFunction instances first to freeze their addresses to pass
|
||||
// to ProcessedNode.
|
||||
AliasDb alias_db(
|
||||
graph_, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);
|
||||
AliasDb alias_db(graph_, /*isFrozen=*/false);
|
||||
GRAPH_DEBUG("AliasDb: ", alias_db.toString());
|
||||
|
||||
// Construct constant and function nodes
|
||||
|
|
|
|||
|
|
@ -744,8 +744,7 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
|
|||
|
||||
AliasDb alias_db(
|
||||
graph,
|
||||
/*isFrozen=*/false,
|
||||
/*enablePreciseTupleContainerAnalysis=*/true);
|
||||
/*isFrozen=*/false);
|
||||
const std::vector<Value*> graph_outputs(
|
||||
graph->outputs().begin(), graph->outputs().end());
|
||||
auto nodes = graph->nodes();
|
||||
|
|
|
|||
Loading…
Reference in a new issue