mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Traverse sub-blocks in JIT passes (#5329)
* Traverse sub-blocks in JIT passes * Add an extra check to prevent cross-block fusion
This commit is contained in:
parent
b6854ee012
commit
c2a3d85a07
10 changed files with 91 additions and 39 deletions
|
|
@ -140,16 +140,17 @@ struct TreeToken {
|
|||
}
|
||||
};
|
||||
|
||||
void BatchMM(std::shared_ptr<Graph>& graph) {
|
||||
void BatchMMBlock(Block* block) {
|
||||
enum class Side { LHS, RHS };
|
||||
static const Symbol mm_kind = "mm"_sym;
|
||||
static const Symbol add_kind = "add"_sym;
|
||||
static const Symbol cat_kind = "cat"_sym;
|
||||
static const Symbol dim_sym = "dim"_sym;
|
||||
auto graph = block->owningGraph();
|
||||
|
||||
// Look for trees in the graph
|
||||
// Look for trees in the block
|
||||
std::unordered_map<Node*, TreeToken> tokens;
|
||||
for (auto node : graph->nodes()) {
|
||||
for (auto node : block->nodes()) {
|
||||
if (node->kind() == mm_kind) {
|
||||
tokens[node] = TreeToken::fromMM(node);
|
||||
} else if (node->kind() == add_kind) {
|
||||
|
|
@ -170,6 +171,10 @@ void BatchMM(std::shared_ptr<Graph>& graph) {
|
|||
if (auto token = TreeToken::unify(node, lhs_it->second, rhs_it->second))
|
||||
tokens[node] = token;
|
||||
}
|
||||
} else {
|
||||
for (auto block : node->blocks()) {
|
||||
BatchMMBlock(block);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -202,7 +207,11 @@ void BatchMM(std::shared_ptr<Graph>& graph) {
|
|||
root.node->output()->replaceAllUsesWith(batch_mm->output());
|
||||
// NB: don't bother with cleaning up after yourself. We'll use DCE for that.
|
||||
}
|
||||
EliminateDeadCode(graph);
|
||||
EliminateDeadCode(block);
|
||||
}
|
||||
|
||||
void BatchMM(std::shared_ptr<Graph>& graph) {
|
||||
BatchMMBlock(graph->block());
|
||||
}
|
||||
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -109,9 +109,9 @@ struct EqualNodeCSE {
|
|||
|
||||
// The function implements common subexpression elimination.
|
||||
// Since the nodes are visited in topological order, one pass is enough.
|
||||
void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) {
|
||||
void EliminateCommonSubexpression(Block * block) {
|
||||
std::unordered_set<Node*, HashNodeCSE, EqualNodeCSE> subexprs;
|
||||
for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++ it) {
|
||||
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) {
|
||||
auto node = *it;
|
||||
if (node->kind() == kPythonOp
|
||||
|| node->kind() == kCppOp
|
||||
|
|
@ -136,4 +136,8 @@ void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) {
|
|||
}
|
||||
}
|
||||
|
||||
void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) {
|
||||
EliminateCommonSubexpression(graph->block());
|
||||
}
|
||||
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -16,12 +16,13 @@ namespace {
|
|||
// right before nodes[0] (i.e. it will not create cycles and all uses of
|
||||
// new node will be after this position).
|
||||
// prereq: nodes are in topological order
|
||||
void mergeNodes(Graph & g, Symbol group_node_kind, ArrayRef<Node*> nodes) {
|
||||
void mergeNodes(Block * block, Symbol group_node_kind, ArrayRef<Node*> nodes) {
|
||||
JIT_ASSERT(nodes.size() > 0);
|
||||
std::unordered_map<Value*, Value*> value_map;
|
||||
Graph * graph = block->owningGraph();
|
||||
|
||||
auto new_graph = std::make_shared<Graph>();
|
||||
Node * group_node = g.create(group_node_kind, 0);
|
||||
Node * group_node = graph->create(group_node_kind, 0);
|
||||
group_node->g_(kSubgraph, new_graph);
|
||||
|
||||
auto getOrCreateInput = [&](Value * v) {
|
||||
|
|
@ -71,8 +72,7 @@ void mergeNodes(Graph & g, Symbol group_node_kind, ArrayRef<Node*> nodes) {
|
|||
|
||||
}
|
||||
|
||||
void CreateAutodiffSubgraphs(Graph & g, size_t threshold) {
|
||||
|
||||
void CreateAutodiffSubgraphs(Block * block, size_t threshold) {
|
||||
// This implementation is not optimal, but it is simple.
|
||||
// It just scans through the list in order looking for runs of
|
||||
// differentiable ops, and then grouping them together when
|
||||
|
|
@ -88,22 +88,29 @@ void CreateAutodiffSubgraphs(Graph & g, size_t threshold) {
|
|||
// and group maximal groups
|
||||
|
||||
std::vector<Node*> groupable;
|
||||
for(auto n : g.nodes()) { // Note: nodes() iterator stays valid since it is
|
||||
for(Node * node : block->nodes()) { // Note: nodes() iterator stays valid since it is
|
||||
// always pointing _after_ the nodes that mergeNodes
|
||||
// mutates.
|
||||
if(isDifferentiable(n)) {
|
||||
groupable.push_back(n);
|
||||
if(isDifferentiable(node)) {
|
||||
groupable.push_back(node);
|
||||
} else {
|
||||
if(groupable.size() >= threshold) {
|
||||
mergeNodes(g, kGraphExecutor, groupable);
|
||||
mergeNodes(block, kGraphExecutor, groupable);
|
||||
}
|
||||
groupable.clear();
|
||||
for (Block * sub_block : node->blocks()) {
|
||||
CreateAutodiffSubgraphs(sub_block, threshold);
|
||||
}
|
||||
}
|
||||
}
|
||||
if(groupable.size() >= threshold) {
|
||||
mergeNodes(g, kGraphExecutor, groupable);
|
||||
mergeNodes(block, kGraphExecutor, groupable);
|
||||
}
|
||||
}
|
||||
|
||||
void CreateAutodiffSubgraphs(Graph & graph, size_t threshold) {
|
||||
CreateAutodiffSubgraphs(graph.block(), threshold);
|
||||
}
|
||||
|
||||
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -3,10 +3,16 @@
|
|||
namespace torch { namespace jit {
|
||||
|
||||
void EliminateDeadCode(std::shared_ptr<Graph>& graph) {
|
||||
auto nodes = graph->nodes().reverse();
|
||||
EliminateDeadCode(graph->block());
|
||||
}
|
||||
|
||||
void EliminateDeadCode(Block *block) {
|
||||
auto nodes = block->nodes().reverse();
|
||||
for (auto it = nodes.begin(); it != nodes.end(); it++) {
|
||||
auto node = *it;
|
||||
if(!node->hasUses())
|
||||
for (Block * block : node->blocks())
|
||||
EliminateDeadCode(block);
|
||||
if (!node->hasUses())
|
||||
it.destroyCurrent();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,5 +5,6 @@
|
|||
namespace torch { namespace jit {
|
||||
|
||||
void EliminateDeadCode(std::shared_ptr<Graph>& graph);
|
||||
void EliminateDeadCode(Block *block);
|
||||
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ bool isSimpleMap(Node *node) {
|
|||
}
|
||||
|
||||
struct GraphFuser {
|
||||
std::shared_ptr<Graph>& graph;
|
||||
Block * block;
|
||||
|
||||
// Used to order nodes so we always consider producer-consumer fusions
|
||||
// in reverse topological order.
|
||||
|
|
@ -87,8 +87,8 @@ struct GraphFuser {
|
|||
// Newly generated nodes will copy the location where they are inserted.
|
||||
std::unordered_map<Node*,size_t> topological_index;
|
||||
|
||||
GraphFuser(std::shared_ptr<Graph>& graph)
|
||||
: graph(graph) {}
|
||||
GraphFuser(Block * block)
|
||||
: block(block) {}
|
||||
|
||||
int getDevice(Node * node) {
|
||||
if(node->kind() == kFusionGroup) {
|
||||
|
|
@ -122,6 +122,7 @@ struct GraphFuser {
|
|||
return true;
|
||||
}
|
||||
bool isFusable(Node * node) {
|
||||
if (node->owningBlock() != block) return false;
|
||||
if (node->kind() == kFusionGroup) return true;
|
||||
return isSimpleMap(node) && allFloatIO(node);
|
||||
}
|
||||
|
|
@ -210,7 +211,7 @@ struct GraphFuser {
|
|||
|
||||
// Clone all nodes
|
||||
for (auto inner : producer_subgraph->nodes()) {
|
||||
Node * outer = graph->createClone(inner, [&](Value * k) -> Value* {
|
||||
Node * outer = block->owningGraph()->createClone(inner, [&](Value * k) -> Value* {
|
||||
return inner_to_outer.at(k);
|
||||
});
|
||||
outer->insertBefore(producer_group);
|
||||
|
|
@ -291,7 +292,7 @@ struct GraphFuser {
|
|||
// turn consumer node n into a fusion group with just n inside
|
||||
// to prepare for fusion and replace uses of n with the new group
|
||||
Node * createSingletonFusionGroup(Node * n) {
|
||||
auto group = graph->createFusionGroup(getDevice(n));
|
||||
auto group = block->owningGraph()->createFusionGroup(getDevice(n));
|
||||
// propogate position information for the new node so we can always
|
||||
// have a valid mapping
|
||||
topological_index[group] = topological_index[n];
|
||||
|
|
@ -397,7 +398,7 @@ struct GraphFuser {
|
|||
// TODO: Perhaps we should use cloneFrom now, as it seems unlikely
|
||||
// to copy select nodes now that we have refactored to have a Value
|
||||
// distinct from Node.
|
||||
Node * input_chunk = graph->create(chunk->kind(), 0);
|
||||
Node * input_chunk = block->owningGraph()->create(chunk->kind(), 0);
|
||||
input_chunk->copyAttributes(*chunk);
|
||||
input_chunk->addInput(input);
|
||||
insertAt(&insertion_point, input_chunk);
|
||||
|
|
@ -416,7 +417,7 @@ struct GraphFuser {
|
|||
// apply the op to each chunk of the chunked operands,
|
||||
// and then rewrite the graph to use them!
|
||||
for (auto chunk_sel : chunk->outputs()) {
|
||||
Node * chunked_op = graph->create(producer_for_chunk_node->kind());
|
||||
Node * chunked_op = block->owningGraph()->create(producer_for_chunk_node->kind());
|
||||
chunked_op->copyAttributes(*producer_for_chunk_node);
|
||||
// Invariant: mappable operators always produce contiguous output
|
||||
chunked_op->output()->setType(chunk_sel->type()->cast<TensorType>()->contiguous());
|
||||
|
|
@ -433,7 +434,7 @@ struct GraphFuser {
|
|||
|
||||
// returns where to continue scanning, and whether any fusion was made
|
||||
std::pair<graph_node_list::iterator, bool> scanNode(Node * consumer) {
|
||||
auto stage_guard = graph->setStageTemporary(consumer->stage());
|
||||
auto stage_guard = block->owningGraph()->setStageTemporary(consumer->stage());
|
||||
if(isFusableAsExitNode(consumer)) {
|
||||
// handle inputs in reverse topological order as well...
|
||||
// otherwise in f(a,a+b) it will appear a is used twice if we consider
|
||||
|
|
@ -465,15 +466,14 @@ struct GraphFuser {
|
|||
}
|
||||
|
||||
void run() {
|
||||
for(auto p : graph->inputs()) {
|
||||
for(auto p : block->inputs()) {
|
||||
topological_index[p->node()] = 0;
|
||||
}
|
||||
size_t i = 1;
|
||||
auto nodes = graph->nodes();
|
||||
for(auto consumer : nodes) {
|
||||
for(auto consumer : block->nodes()) {
|
||||
topological_index[consumer] = i++;
|
||||
}
|
||||
topological_index[graph->return_node()] = i++;
|
||||
topological_index[block->return_node()] = i++;
|
||||
|
||||
// Run the pass until no changes are made.
|
||||
// This is neccessary, because the algorithm can miss out on certain fusion
|
||||
|
|
@ -494,19 +494,24 @@ struct GraphFuser {
|
|||
bool any_changed = true;
|
||||
while (any_changed) {
|
||||
any_changed = false;
|
||||
for (auto it = nodes.rbegin(); it != nodes.rend();) {
|
||||
for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) {
|
||||
bool changed;
|
||||
std::tie(it, changed) = scanNode(*it);
|
||||
any_changed |= changed;
|
||||
}
|
||||
}
|
||||
for (Node * node : block->nodes()) {
|
||||
for (Block * sub_block : node->blocks()) {
|
||||
GraphFuser(sub_block).run();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
void FuseGraph(std::shared_ptr<Graph>& graph) {
|
||||
GraphFuser(graph).run();
|
||||
GraphFuser(graph->block()).run();
|
||||
}
|
||||
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
void CheckInplace(std::shared_ptr<Graph>& graph) {
|
||||
for (auto node : graph->nodes()) {
|
||||
void CheckInplace(Block * block) {
|
||||
for (auto node : block->nodes()) {
|
||||
if (node->kind() == kPythonOp && node->hasAttribute(kinplace)) {
|
||||
if (node->i(kinplace)) {
|
||||
throw std::runtime_error(std::string("inplace ") +
|
||||
|
|
@ -14,4 +14,8 @@ void CheckInplace(std::shared_ptr<Graph>& graph) {
|
|||
}
|
||||
}
|
||||
|
||||
void CheckInplace(std::shared_ptr<Graph>& graph) {
|
||||
CheckInplace(graph->block());
|
||||
}
|
||||
|
||||
}} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -10,10 +10,12 @@ namespace torch { namespace jit {
|
|||
// - Simply x.t().t() to x
|
||||
//
|
||||
// TODO: Decide what kind of fixed point strategy we will have
|
||||
void PeepholeOptimize(std::shared_ptr<Graph>& graph) {
|
||||
for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
|
||||
void PeepholeOptimize(Block * block) {
|
||||
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
|
||||
auto* n = *it;
|
||||
|
||||
// XXX: remember that if you want to simplify an expression by combining multiple nodes
|
||||
// into a different one, then you need to check that they all belong to the given block
|
||||
switch (n->kind()) {
|
||||
case kexpand:
|
||||
// Eliminate redundant expand
|
||||
|
|
@ -40,7 +42,15 @@ void PeepholeOptimize(std::shared_ptr<Graph>& graph) {
|
|||
}
|
||||
break;
|
||||
}
|
||||
|
||||
for (Block * sub_block : n->blocks()) {
|
||||
PeepholeOptimize(sub_block);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PeepholeOptimize(std::shared_ptr<Graph>& graph) {
|
||||
PeepholeOptimize(graph->block());
|
||||
}
|
||||
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -151,6 +151,15 @@ void PropagateShapeOnNode(Node * node) {
|
|||
|
||||
}
|
||||
|
||||
void PropagateShapeOnBlock(Block * block) {
|
||||
for (Node * node : block->nodes()) {
|
||||
PropagateShapeOnNode(node);
|
||||
for (Block * sub_block : node->blocks()) {
|
||||
PropagateShapeOnBlock(sub_block);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void PropagateInputShapes(Graph & graph, const ArgumentSpec & spec) {
|
||||
|
|
@ -158,9 +167,7 @@ void PropagateInputShapes(Graph & graph, const ArgumentSpec & spec) {
|
|||
for(size_t i = 0; i < spec.size(); ++i) {
|
||||
graph.inputs()[i]->setType(spec.tensorInfo(i));
|
||||
}
|
||||
for(auto n : graph.nodes()) {
|
||||
PropagateShapeOnNode(n);
|
||||
}
|
||||
PropagateShapeOnBlock(graph.block());
|
||||
}
|
||||
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ PY2 = sys.version_info[0] == 2
|
|||
_reserved_prefix = '__jit'
|
||||
_identifier_chars = set(string.ascii_lowercase + string.ascii_uppercase + string.digits)
|
||||
|
||||
# TODO: populate those
|
||||
pretty_node_names = {
|
||||
ast.For: "for loops",
|
||||
ast.Delete: "del statements",
|
||||
|
|
|
|||
Loading…
Reference in a new issue