diff --git a/test/cpp/jit/test_memory_dag.cpp b/test/cpp/jit/test_memory_dag.cpp index f3ff28f941a..f5d9fbf7c22 100644 --- a/test/cpp/jit/test_memory_dag.cpp +++ b/test/cpp/jit/test_memory_dag.cpp @@ -36,7 +36,7 @@ TEST(MemoryDAGTest, Basic) { t->makePointerTo(e, a); t->makePointerTo(e, f); - auto dag = std::make_unique(std::move(t)); + auto dag = std::move(*t).createMemoryDAG(); /** * Test mayAlias() @@ -69,7 +69,7 @@ TEST(MemoryDAGTest, Basic) { auto c = t->makeFreshValue(cValue); t->addToContainedElements(a, c); - auto dag = std::make_unique(std::move(t)); + auto dag = std::move(*t).createMemoryDAG(); EXPECT_TRUE(dag->mayContainAlias(a, b)); EXPECT_TRUE(dag->mayContainAlias(b, a)); @@ -99,7 +99,7 @@ TEST(MemoryDAGTest, Basic) { auto d = t->makeFreshValue(dValue); t->addToContainedElements(b, d); - auto dag = std::make_unique(std::move(t)); + auto dag = std::move(*t).createMemoryDAG(); EXPECT_TRUE(dag->mayContainAlias(b, d)); EXPECT_TRUE(dag->mayContainAlias(d, b)); @@ -126,7 +126,7 @@ TEST(MemoryDAGTest, Basic) { t->addToContainedElements(f, e); - auto dag = std::make_unique(std::move(t)); + auto dag = std::move(*t).createMemoryDAG(); for (auto elem : {a, b, c, d}) { EXPECT_FALSE(dag->mayContainAlias(f, elem)); EXPECT_FALSE(dag->mayContainAlias(e, elem)); diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index fe8bb67f364..29953ecd19a 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -228,7 +228,7 @@ AliasDb::AliasDb( writeRegistry_(std::make_unique()) { analyze(graph_); - memoryDAG_ = std::make_unique(std::move(memoryDAGBuilder_)); + memoryDAG_ = std::move(*memoryDAGBuilder_).createMemoryDAG(); memoryDAGBuilder_ = nullptr; // to make further access a hard error memoryDAG_->setWildcards( diff --git a/torch/csrc/jit/passes/utils/memory_dag.h b/torch/csrc/jit/passes/utils/memory_dag.h index c455d3413e7..f3068588dae 100644 --- a/torch/csrc/jit/passes/utils/memory_dag.h +++ b/torch/csrc/jit/passes/utils/memory_dag.h @@ -19,40 +19,52 @@ typedef c10::SparseBitVector<256> MemoryLocations; namespace torch { namespace jit { -struct Element; struct Value; -class MemoryDAG; using AliasTypeSet = std::vector; -/** - * Helper to build up the points-to graph. - * - * We separate the "building" into a different class because it allows us to - * cache internally to MemoryDAG without worrying about how the DAG structure - * is mutated. - */ -class TORCH_API MemoryDAGBuilder { - public: - MemoryDAGBuilder() = default; - MemoryDAGBuilder(const MemoryDAGBuilder&) = delete; - MemoryDAGBuilder& operator=(const MemoryDAGBuilder&) = delete; +// `Element` represents a vertex in the points-to graph. It represents +// anything that could have an aliasing relationship--mostly IR +// `Value`s, but also wildcards or the type inside a container (e.g. `T` +// in `List[T]`) +struct Element { + Element(const Value* value_, unsigned index_); + // wildcard constructor + explicit Element(unsigned index_); + // Index into the owning DAG's bit vector that represents this element. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + unsigned index; + + // All elements that this element *may* point to. It's possible to have + // multiple elements that you might point to due to control flow/complex ops + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + MemoryLocations pointsTo; + // Backreference for points-to. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + MemoryLocations pointedFrom; + + // Elements can contain other elements (e.g. List[Tensor]) + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + MemoryLocations containedElements; + + // The values that this element corresponds to. May be empty if this element + // doesn't represent a first-class value. + // This is for debug information only. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::unordered_set values; + + private: // Make `from` point at `to`. void makePointerTo(Element* from, Element* to); - void addToContainedElements(Element* contained, Element* container); + friend class MemoryDAG; + // We memoize the results of `getMemoryLocations` to speed up queries. + // A nullopt means that this cache is not yet populated. Since `MemoryDAG` is + // immutable, this cache should never need to be invalidated. + mutable c10::optional cachedMemoryLocations_; - // Make a fresh Element (i.e. an Element that doesn't point to anything) and - // return it. - Element* makeFreshValue(const Value* v); - - friend MemoryDAG; - - private: - // `MemoryDAGBuilder` builds up `indexToElementMap_`, then uses - // the map to construct the `MemoryDAG` - std::vector> indexToElementMap_; + mutable c10::optional cachedAllContainedMemoryLocations_; }; // class MemoryDAG @@ -72,8 +84,8 @@ class TORCH_API MemoryDAGBuilder { // which memory locations an element may point to. class TORCH_API MemoryDAG { public: - explicit MemoryDAG(std::unique_ptr builder) - : indexToElementMap_(std::move(builder->indexToElementMap_)) {} + explicit MemoryDAG(std::vector> indexToElementMap) + : indexToElementMap_(std::move(indexToElementMap)) {} // explicitly delete copy constructor because otherwise windows build is // confused for an exported class see // https://stackoverflow.com/a/51033485/105137 @@ -127,49 +139,38 @@ class TORCH_API MemoryDAG { std::vector> indexToElementMap_; }; -// `Element` represents a vertex in the points-to graph. It represents -// anything that could have an aliasing relationship--mostly IR -// `Value`s, but also wildcards or the type inside a container (e.g. `T` -// in `List[T]`) -struct Element { - Element(const Value* value_, unsigned index_); - // wildcard constructor - explicit Element(unsigned index_); +/** + * Helper to build up the points-to graph. + * + * We separate the "building" into a different class because it allows us to + * cache internally to MemoryDAG without worrying about how the DAG structure + * is mutated. + */ +class TORCH_API MemoryDAGBuilder { + public: + MemoryDAGBuilder() = default; + MemoryDAGBuilder(const MemoryDAGBuilder&) = delete; + MemoryDAGBuilder& operator=(const MemoryDAGBuilder&) = delete; - // Index into the owning DAG's bit vector that represents this element. - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - unsigned index; - - // All elements that this element *may* point to. It's possible to have - // multiple elements that you might point to due to control flow/complex ops - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - MemoryLocations pointsTo; - // Backreference for points-to. - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - MemoryLocations pointedFrom; - - // Elements can contain other elements (e.g. List[Tensor]) - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - MemoryLocations containedElements; - - // The values that this element corresponds to. May be empty if this element - // doesn't represent a first-class value. - // This is for debug information only. - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - std::unordered_set values; - - private: // Make `from` point at `to`. void makePointerTo(Element* from, Element* to); - friend class MemoryDAG; - // We memoize the results of `getMemoryLocations` to speed up queries. - // A nullopt means that this cache is not yet populated. Since `MemoryDAG` is - // immutable, this cache should never need to be invalidated. - mutable c10::optional cachedMemoryLocations_; + void addToContainedElements(Element* contained, Element* container); - mutable c10::optional cachedAllContainedMemoryLocations_; + std::unique_ptr createMemoryDAG() && { + return std::make_unique(std::move(indexToElementMap_)); + } + + // Make a fresh Element (i.e. an Element that doesn't point to anything) and + // return it. + Element* makeFreshValue(const Value* v); + + friend MemoryDAG; + + private: + // `MemoryDAGBuilder` builds up `indexToElementMap_`, then uses + // the map to construct the `MemoryDAG` + std::vector> indexToElementMap_; }; - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index 56c30a7c33f..5cf330875fb 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -240,7 +240,6 @@ class TORCH_API StaticRuntimeMetadata : public torch::CustomClassHolder { /// class MemoryPlanner; class StaticNodeInfo; -class ProcessedFunction; class ProcessedNode; class StaticRuntime; @@ -259,6 +258,42 @@ struct TORCH_API SROperatorObserver { }; #endif +class TORCH_API ProcessedFunction { + public: + ProcessedFunction( + Node* node, + bool enable_out_variant, + bool check_memory_overlap); + + enum class Kind : uint8_t { + kOutVariant, + kNativeFunction, + kInterpreterFallback, + }; + + void run(ProcessedNode* pnode) const { + return f_(pnode); + } + + Kind kind() const { + return kind_; + } + + bool checkMemoryOverlap() const { + return check_memory_overlap_; + } + + size_t num_outputs() const { + return num_outputs_; + } + + private: + SROperator f_; + Kind kind_{ProcessedFunction::Kind::kOutVariant}; + bool check_memory_overlap_{false}; + size_t num_outputs_{0}; +}; + // A `BlockInfo` instance stores all of the shared state that each // `BlockRunner` will need to access. Most of this information is // read-only and shared between threads. @@ -778,42 +813,6 @@ class TORCH_API BlockRunner { std::vector nodes_; }; -class TORCH_API ProcessedFunction { - public: - ProcessedFunction( - Node* node, - bool enable_out_variant, - bool check_memory_overlap); - - enum class Kind : uint8_t { - kOutVariant, - kNativeFunction, - kInterpreterFallback, - }; - - void run(ProcessedNode* pnode) const { - return f_(pnode); - } - - Kind kind() const { - return kind_; - } - - bool checkMemoryOverlap() const { - return check_memory_overlap_; - } - - size_t num_outputs() const { - return num_outputs_; - } - - private: - SROperator f_; - Kind kind_{ProcessedFunction::Kind::kOutVariant}; - bool check_memory_overlap_{false}; - size_t num_outputs_{0}; -}; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) class TORCH_API StaticNodeInfo { public: