Traverse sub-blocks in JIT passes (#5329)

* Traverse sub-blocks in JIT passes

* Add an extra check to prevent cross-block fusion
This commit is contained in:
Adam Paszke 2018-02-23 02:32:31 +01:00 committed by Zachary DeVito
parent b6854ee012
commit c2a3d85a07
10 changed files with 91 additions and 39 deletions

View file

@ -140,16 +140,17 @@ struct TreeToken {
}
};
void BatchMM(std::shared_ptr<Graph>& graph) {
void BatchMMBlock(Block* block) {
enum class Side { LHS, RHS };
static const Symbol mm_kind = "mm"_sym;
static const Symbol add_kind = "add"_sym;
static const Symbol cat_kind = "cat"_sym;
static const Symbol dim_sym = "dim"_sym;
auto graph = block->owningGraph();
// Look for trees in the graph
// Look for trees in the block
std::unordered_map<Node*, TreeToken> tokens;
for (auto node : graph->nodes()) {
for (auto node : block->nodes()) {
if (node->kind() == mm_kind) {
tokens[node] = TreeToken::fromMM(node);
} else if (node->kind() == add_kind) {
@ -170,6 +171,10 @@ void BatchMM(std::shared_ptr<Graph>& graph) {
if (auto token = TreeToken::unify(node, lhs_it->second, rhs_it->second))
tokens[node] = token;
}
} else {
for (auto block : node->blocks()) {
BatchMMBlock(block);
}
}
}
@ -202,7 +207,11 @@ void BatchMM(std::shared_ptr<Graph>& graph) {
root.node->output()->replaceAllUsesWith(batch_mm->output());
// NB: don't bother with cleaning up after yourself. We'll use DCE for that.
}
EliminateDeadCode(graph);
EliminateDeadCode(block);
}
void BatchMM(std::shared_ptr<Graph>& graph) {
BatchMMBlock(graph->block());
}
}}

View file

@ -109,9 +109,9 @@ struct EqualNodeCSE {
// The function implements common subexpression elimination.
// Since the nodes are visited in topological order, one pass is enough.
void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) {
void EliminateCommonSubexpression(Block * block) {
std::unordered_set<Node*, HashNodeCSE, EqualNodeCSE> subexprs;
for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++ it) {
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) {
auto node = *it;
if (node->kind() == kPythonOp
|| node->kind() == kCppOp
@ -136,4 +136,8 @@ void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) {
}
}
void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) {
EliminateCommonSubexpression(graph->block());
}
}}

View file

@ -16,12 +16,13 @@ namespace {
// right before nodes[0] (i.e. it will not create cycles and all uses of
// new node will be after this position).
// prereq: nodes are in topological order
void mergeNodes(Graph & g, Symbol group_node_kind, ArrayRef<Node*> nodes) {
void mergeNodes(Block * block, Symbol group_node_kind, ArrayRef<Node*> nodes) {
JIT_ASSERT(nodes.size() > 0);
std::unordered_map<Value*, Value*> value_map;
Graph * graph = block->owningGraph();
auto new_graph = std::make_shared<Graph>();
Node * group_node = g.create(group_node_kind, 0);
Node * group_node = graph->create(group_node_kind, 0);
group_node->g_(kSubgraph, new_graph);
auto getOrCreateInput = [&](Value * v) {
@ -71,8 +72,7 @@ void mergeNodes(Graph & g, Symbol group_node_kind, ArrayRef<Node*> nodes) {
}
void CreateAutodiffSubgraphs(Graph & g, size_t threshold) {
void CreateAutodiffSubgraphs(Block * block, size_t threshold) {
// This implementation is not optimal, but it is simple.
// It just scans through the list in order looking for runs of
// differentiable ops, and then grouping them together when
@ -88,22 +88,29 @@ void CreateAutodiffSubgraphs(Graph & g, size_t threshold) {
// and group maximal groups
std::vector<Node*> groupable;
for(auto n : g.nodes()) { // Note: nodes() iterator stays valid since it is
for(Node * node : block->nodes()) { // Note: nodes() iterator stays valid since it is
// always pointing _after_ the nodes that mergeNodes
// mutates.
if(isDifferentiable(n)) {
groupable.push_back(n);
if(isDifferentiable(node)) {
groupable.push_back(node);
} else {
if(groupable.size() >= threshold) {
mergeNodes(g, kGraphExecutor, groupable);
mergeNodes(block, kGraphExecutor, groupable);
}
groupable.clear();
for (Block * sub_block : node->blocks()) {
CreateAutodiffSubgraphs(sub_block, threshold);
}
}
}
if(groupable.size() >= threshold) {
mergeNodes(g, kGraphExecutor, groupable);
mergeNodes(block, kGraphExecutor, groupable);
}
}
void CreateAutodiffSubgraphs(Graph & graph, size_t threshold) {
CreateAutodiffSubgraphs(graph.block(), threshold);
}
}}

View file

@ -3,10 +3,16 @@
namespace torch { namespace jit {
void EliminateDeadCode(std::shared_ptr<Graph>& graph) {
auto nodes = graph->nodes().reverse();
EliminateDeadCode(graph->block());
}
void EliminateDeadCode(Block *block) {
auto nodes = block->nodes().reverse();
for (auto it = nodes.begin(); it != nodes.end(); it++) {
auto node = *it;
if(!node->hasUses())
for (Block * block : node->blocks())
EliminateDeadCode(block);
if (!node->hasUses())
it.destroyCurrent();
}
}

View file

@ -5,5 +5,6 @@
namespace torch { namespace jit {
void EliminateDeadCode(std::shared_ptr<Graph>& graph);
void EliminateDeadCode(Block *block);
}}

View file

@ -78,7 +78,7 @@ bool isSimpleMap(Node *node) {
}
struct GraphFuser {
std::shared_ptr<Graph>& graph;
Block * block;
// Used to order nodes so we always consider producer-consumer fusions
// in reverse topological order.
@ -87,8 +87,8 @@ struct GraphFuser {
// Newly generated nodes will copy the location where they are inserted.
std::unordered_map<Node*,size_t> topological_index;
GraphFuser(std::shared_ptr<Graph>& graph)
: graph(graph) {}
GraphFuser(Block * block)
: block(block) {}
int getDevice(Node * node) {
if(node->kind() == kFusionGroup) {
@ -122,6 +122,7 @@ struct GraphFuser {
return true;
}
bool isFusable(Node * node) {
if (node->owningBlock() != block) return false;
if (node->kind() == kFusionGroup) return true;
return isSimpleMap(node) && allFloatIO(node);
}
@ -210,7 +211,7 @@ struct GraphFuser {
// Clone all nodes
for (auto inner : producer_subgraph->nodes()) {
Node * outer = graph->createClone(inner, [&](Value * k) -> Value* {
Node * outer = block->owningGraph()->createClone(inner, [&](Value * k) -> Value* {
return inner_to_outer.at(k);
});
outer->insertBefore(producer_group);
@ -291,7 +292,7 @@ struct GraphFuser {
// turn consumer node n into a fusion group with just n inside
// to prepare for fusion and replace uses of n with the new group
Node * createSingletonFusionGroup(Node * n) {
auto group = graph->createFusionGroup(getDevice(n));
auto group = block->owningGraph()->createFusionGroup(getDevice(n));
// propogate position information for the new node so we can always
// have a valid mapping
topological_index[group] = topological_index[n];
@ -397,7 +398,7 @@ struct GraphFuser {
// TODO: Perhaps we should use cloneFrom now, as it seems unlikely
// to copy select nodes now that we have refactored to have a Value
// distinct from Node.
Node * input_chunk = graph->create(chunk->kind(), 0);
Node * input_chunk = block->owningGraph()->create(chunk->kind(), 0);
input_chunk->copyAttributes(*chunk);
input_chunk->addInput(input);
insertAt(&insertion_point, input_chunk);
@ -416,7 +417,7 @@ struct GraphFuser {
// apply the op to each chunk of the chunked operands,
// and then rewrite the graph to use them!
for (auto chunk_sel : chunk->outputs()) {
Node * chunked_op = graph->create(producer_for_chunk_node->kind());
Node * chunked_op = block->owningGraph()->create(producer_for_chunk_node->kind());
chunked_op->copyAttributes(*producer_for_chunk_node);
// Invariant: mappable operators always produce contiguous output
chunked_op->output()->setType(chunk_sel->type()->cast<TensorType>()->contiguous());
@ -433,7 +434,7 @@ struct GraphFuser {
// returns where to continue scanning, and whether any fusion was made
std::pair<graph_node_list::iterator, bool> scanNode(Node * consumer) {
auto stage_guard = graph->setStageTemporary(consumer->stage());
auto stage_guard = block->owningGraph()->setStageTemporary(consumer->stage());
if(isFusableAsExitNode(consumer)) {
// handle inputs in reverse topological order as well...
// otherwise in f(a,a+b) it will appear a is used twice if we consider
@ -465,15 +466,14 @@ struct GraphFuser {
}
void run() {
for(auto p : graph->inputs()) {
for(auto p : block->inputs()) {
topological_index[p->node()] = 0;
}
size_t i = 1;
auto nodes = graph->nodes();
for(auto consumer : nodes) {
for(auto consumer : block->nodes()) {
topological_index[consumer] = i++;
}
topological_index[graph->return_node()] = i++;
topological_index[block->return_node()] = i++;
// Run the pass until no changes are made.
// This is neccessary, because the algorithm can miss out on certain fusion
@ -494,19 +494,24 @@ struct GraphFuser {
bool any_changed = true;
while (any_changed) {
any_changed = false;
for (auto it = nodes.rbegin(); it != nodes.rend();) {
for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) {
bool changed;
std::tie(it, changed) = scanNode(*it);
any_changed |= changed;
}
}
for (Node * node : block->nodes()) {
for (Block * sub_block : node->blocks()) {
GraphFuser(sub_block).run();
}
}
}
};
} // anonymous namespace
void FuseGraph(std::shared_ptr<Graph>& graph) {
GraphFuser(graph).run();
GraphFuser(graph->block()).run();
}
}}

View file

@ -2,8 +2,8 @@
namespace torch { namespace jit {
void CheckInplace(std::shared_ptr<Graph>& graph) {
for (auto node : graph->nodes()) {
void CheckInplace(Block * block) {
for (auto node : block->nodes()) {
if (node->kind() == kPythonOp && node->hasAttribute(kinplace)) {
if (node->i(kinplace)) {
throw std::runtime_error(std::string("inplace ") +
@ -14,4 +14,8 @@ void CheckInplace(std::shared_ptr<Graph>& graph) {
}
}
void CheckInplace(std::shared_ptr<Graph>& graph) {
CheckInplace(graph->block());
}
}} // namespace torch::jit

View file

@ -10,10 +10,12 @@ namespace torch { namespace jit {
// - Simply x.t().t() to x
//
// TODO: Decide what kind of fixed point strategy we will have
void PeepholeOptimize(std::shared_ptr<Graph>& graph) {
for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
void PeepholeOptimize(Block * block) {
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
auto* n = *it;
// XXX: remember that if you want to simplify an expression by combining multiple nodes
// into a different one, then you need to check that they all belong to the given block
switch (n->kind()) {
case kexpand:
// Eliminate redundant expand
@ -40,7 +42,15 @@ void PeepholeOptimize(std::shared_ptr<Graph>& graph) {
}
break;
}
for (Block * sub_block : n->blocks()) {
PeepholeOptimize(sub_block);
}
}
}
void PeepholeOptimize(std::shared_ptr<Graph>& graph) {
PeepholeOptimize(graph->block());
}
}}

View file

@ -151,6 +151,15 @@ void PropagateShapeOnNode(Node * node) {
}
void PropagateShapeOnBlock(Block * block) {
for (Node * node : block->nodes()) {
PropagateShapeOnNode(node);
for (Block * sub_block : node->blocks()) {
PropagateShapeOnBlock(sub_block);
}
}
}
}
void PropagateInputShapes(Graph & graph, const ArgumentSpec & spec) {
@ -158,9 +167,7 @@ void PropagateInputShapes(Graph & graph, const ArgumentSpec & spec) {
for(size_t i = 0; i < spec.size(); ++i) {
graph.inputs()[i]->setType(spec.tensorInfo(i));
}
for(auto n : graph.nodes()) {
PropagateShapeOnNode(n);
}
PropagateShapeOnBlock(graph.block());
}
}}

View file

@ -12,7 +12,6 @@ PY2 = sys.version_info[0] == 2
_reserved_prefix = '__jit'
_identifier_chars = set(string.ascii_lowercase + string.ascii_uppercase + string.digits)
# TODO: populate those
pretty_node_names = {
ast.For: "for loops",
ast.Delete: "del statements",