diff --git a/test/cpp/jit/test_alias_analysis.cpp b/test/cpp/jit/test_alias_analysis.cpp index 681dea3a387..6e33f7dca47 100644 --- a/test/cpp/jit/test_alias_analysis.cpp +++ b/test/cpp/jit/test_alias_analysis.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -1474,7 +1475,7 @@ TEST( torch::jit::parseIR(graph_string, graph.get(), vmap); AliasDb aliasDb( - graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true); + graph, /*isFrozen=*/false); EXPECT_TRUE(!aliasDb.mayAlias(vmap["x"], vmap["y"])); EXPECT_TRUE(aliasDb.mayContainAlias(vmap["z"], vmap["x"])); @@ -1483,22 +1484,23 @@ TEST( TEST( AliasRegistrationTest, - WildcareAliasForTupleConstructWithSingleUseAsGraphOutputWithDisablePreciseTupleContainerAnalysis) { + RecursiveSubgraphTupleContainment) { auto graph = std::make_shared(); std::unordered_map vmap; auto graph_string = R"IR( graph(): %x : Tensor = prim::MakeTestTensor() %y : Tensor = prim::MakeTestTensor() - %z : (Tensor) = prim::TupleConstruct(%x, %y) + %z : (Tensor, Tensor) = prim::TupleConstruct(%x, %y) return (%z))IR"; torch::jit::parseIR(graph_string, graph.get(), vmap); - // enablePreciseTupleContainerAnalysis = false. + auto node = vmap["z"]->node(); + auto subgraph = SubgraphUtils::createSingletonSubgraph(node, prim::FunctionalGraph); AliasDb aliasDb(graph); - EXPECT_TRUE(aliasDb.mayContainAlias(vmap["z"], vmap["x"])); - EXPECT_TRUE(aliasDb.mayContainAlias(vmap["z"], vmap["y"])); + EXPECT_TRUE(aliasDb.mayContainAlias(subgraph->output(), vmap["x"])); + EXPECT_TRUE(aliasDb.mayContainAlias(subgraph->output(), vmap["y"])); EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["y"])); } @@ -1519,7 +1521,7 @@ TEST(AliasRegistrationTest, WildcardAliasForTupleConstructWithUses) { torch::jit::parseIR(graph_string, graph.get(), vmap); AliasDb aliasDb( - graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true); + graph, /*isFrozen=*/false); EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["y"])); EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["z"])); @@ -1551,8 +1553,7 @@ TEST(AliasRegistrationTest, ATenSplitIntListAliasCheck) { return (%d))IR"; torch::jit::parseIR(graph_string, graph.get(), vmap); - AliasDb aliasDb( - graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true); + AliasDb aliasDb(graph, /*isFrozen=*/false); EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"])); EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"])); @@ -1578,8 +1579,7 @@ TEST(AliasRegistrationTest, ATenSplitIntAliasCheck) { return (%d))IR"; torch::jit::parseIR(graph_string, graph.get(), vmap); - AliasDb aliasDb( - graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true); + AliasDb aliasDb(graph, /*isFrozen=*/false); EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"])); EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"])); diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index 58064f2c14a..57a74824d3e 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -207,13 +207,9 @@ struct AliasDb::WriteRegistry { std::unordered_set writesToAllWildcards_; }; -AliasDb::AliasDb( - std::shared_ptr graph, - bool isFrozen, - bool enablePreciseTupleContainerAnalysis) +AliasDb::AliasDb(std::shared_ptr graph, bool isFrozen) : graph_(std::move(graph)), isFrozen_(isFrozen), - enablePreciseTupleContainerAnalysis_(enablePreciseTupleContainerAnalysis), memoryDAGBuilder_(std::make_unique()), writeRegistry_(std::make_unique()) { analyze(graph_); @@ -1146,9 +1142,6 @@ bool AliasDb::functionalNonEscapingListUse(const Use& use) const { } bool AliasDb::functionalNonEscapingTupleUse(const Use& use) const { - if (!enablePreciseTupleContainerAnalysis_) { - return false; - } Node* n = use.user; size_t offset = use.offset; Value* container = n->inputs().at(offset); @@ -1156,7 +1149,9 @@ bool AliasDb::functionalNonEscapingTupleUse(const Use& use) const { return false; } // TODO(T97387453): Cover more ops that do not let escape tuples' elements. - return use.user->kind() == prim::Return; + bool in_return_outputs = use.user->kind() == prim::Return; + bool not_in_nested_subgraph = use.user->owningBlock() == graph_->block(); + return in_return_outputs && not_in_nested_subgraph; } // List or dict or tuple construct: create an aliasing element for the actual diff --git a/torch/csrc/jit/ir/alias_analysis.h b/torch/csrc/jit/ir/alias_analysis.h index c2211a09ec5..86fafbee4cb 100644 --- a/torch/csrc/jit/ir/alias_analysis.h +++ b/torch/csrc/jit/ir/alias_analysis.h @@ -45,8 +45,7 @@ class AliasDb { public: TORCH_API explicit AliasDb( std::shared_ptr graphi, - bool isFrozen = false, - bool enablePreciseTupleContainerAnalysis = false); + bool isFrozen = false); TORCH_API ~AliasDb(); // There are limitations to what effects the alias analysis can track. Two @@ -248,9 +247,6 @@ class AliasDb { // internally. bool isFrozen_; - // Enable precise treatment of prim::TupleConstruct. - bool enablePreciseTupleContainerAnalysis_ = false; - // The points-to graph that stores aliasing relationships std::unique_ptr memoryDAGBuilder_; std::unique_ptr memoryDAG_; diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index f912fc8d32b..02319f7c1e2 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -534,8 +534,7 @@ StaticModule::StaticModule( // Create ProcessedFunction instances first to freeze their addresses to pass // to ProcessedNode. - AliasDb alias_db( - graph_, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true); + AliasDb alias_db(graph_, /*isFrozen=*/false); GRAPH_DEBUG("AliasDb: ", alias_db.toString()); // Construct constant and function nodes diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp index 4d3f8ae0f8d..3a4366d3b32 100644 --- a/torch/csrc/jit/runtime/static/passes.cpp +++ b/torch/csrc/jit/runtime/static/passes.cpp @@ -744,8 +744,7 @@ void FuseListUnpack(std::shared_ptr& graph) { AliasDb alias_db( graph, - /*isFrozen=*/false, - /*enablePreciseTupleContainerAnalysis=*/true); + /*isFrozen=*/false); const std::vector graph_outputs( graph->outputs().begin(), graph->outputs().end()); auto nodes = graph->nodes();