diff --git a/torch/csrc/jit/passes/batch_mm.cpp b/torch/csrc/jit/passes/batch_mm.cpp index 593b301f9f6..5e6f3a379e9 100644 --- a/torch/csrc/jit/passes/batch_mm.cpp +++ b/torch/csrc/jit/passes/batch_mm.cpp @@ -140,16 +140,17 @@ struct TreeToken { } }; -void BatchMM(std::shared_ptr& graph) { +void BatchMMBlock(Block* block) { enum class Side { LHS, RHS }; static const Symbol mm_kind = "mm"_sym; static const Symbol add_kind = "add"_sym; static const Symbol cat_kind = "cat"_sym; static const Symbol dim_sym = "dim"_sym; + auto graph = block->owningGraph(); - // Look for trees in the graph + // Look for trees in the block std::unordered_map tokens; - for (auto node : graph->nodes()) { + for (auto node : block->nodes()) { if (node->kind() == mm_kind) { tokens[node] = TreeToken::fromMM(node); } else if (node->kind() == add_kind) { @@ -170,6 +171,10 @@ void BatchMM(std::shared_ptr& graph) { if (auto token = TreeToken::unify(node, lhs_it->second, rhs_it->second)) tokens[node] = token; } + } else { + for (auto block : node->blocks()) { + BatchMMBlock(block); + } } } @@ -202,7 +207,11 @@ void BatchMM(std::shared_ptr& graph) { root.node->output()->replaceAllUsesWith(batch_mm->output()); // NB: don't bother with cleaning up after yourself. We'll use DCE for that. } - EliminateDeadCode(graph); + EliminateDeadCode(block); +} + +void BatchMM(std::shared_ptr& graph) { + BatchMMBlock(graph->block()); } }} diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.cpp b/torch/csrc/jit/passes/common_subexpression_elimination.cpp index c829042d225..9f1f14943ad 100644 --- a/torch/csrc/jit/passes/common_subexpression_elimination.cpp +++ b/torch/csrc/jit/passes/common_subexpression_elimination.cpp @@ -109,9 +109,9 @@ struct EqualNodeCSE { // The function implements common subexpression elimination. // Since the nodes are visited in topological order, one pass is enough. -void EliminateCommonSubexpression(std::shared_ptr& graph) { +void EliminateCommonSubexpression(Block * block) { std::unordered_set subexprs; - for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++ it) { + for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) { auto node = *it; if (node->kind() == kPythonOp || node->kind() == kCppOp @@ -136,4 +136,8 @@ void EliminateCommonSubexpression(std::shared_ptr& graph) { } } +void EliminateCommonSubexpression(std::shared_ptr& graph) { + EliminateCommonSubexpression(graph->block()); +} + }} diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp index a09433a0584..9b1759946b9 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp @@ -16,12 +16,13 @@ namespace { // right before nodes[0] (i.e. it will not create cycles and all uses of // new node will be after this position). // prereq: nodes are in topological order -void mergeNodes(Graph & g, Symbol group_node_kind, ArrayRef nodes) { +void mergeNodes(Block * block, Symbol group_node_kind, ArrayRef nodes) { JIT_ASSERT(nodes.size() > 0); std::unordered_map value_map; + Graph * graph = block->owningGraph(); auto new_graph = std::make_shared(); - Node * group_node = g.create(group_node_kind, 0); + Node * group_node = graph->create(group_node_kind, 0); group_node->g_(kSubgraph, new_graph); auto getOrCreateInput = [&](Value * v) { @@ -71,8 +72,7 @@ void mergeNodes(Graph & g, Symbol group_node_kind, ArrayRef nodes) { } -void CreateAutodiffSubgraphs(Graph & g, size_t threshold) { - +void CreateAutodiffSubgraphs(Block * block, size_t threshold) { // This implementation is not optimal, but it is simple. // It just scans through the list in order looking for runs of // differentiable ops, and then grouping them together when @@ -88,22 +88,29 @@ void CreateAutodiffSubgraphs(Graph & g, size_t threshold) { // and group maximal groups std::vector groupable; - for(auto n : g.nodes()) { // Note: nodes() iterator stays valid since it is + for(Node * node : block->nodes()) { // Note: nodes() iterator stays valid since it is // always pointing _after_ the nodes that mergeNodes // mutates. - if(isDifferentiable(n)) { - groupable.push_back(n); + if(isDifferentiable(node)) { + groupable.push_back(node); } else { if(groupable.size() >= threshold) { - mergeNodes(g, kGraphExecutor, groupable); + mergeNodes(block, kGraphExecutor, groupable); } groupable.clear(); + for (Block * sub_block : node->blocks()) { + CreateAutodiffSubgraphs(sub_block, threshold); + } } } if(groupable.size() >= threshold) { - mergeNodes(g, kGraphExecutor, groupable); + mergeNodes(block, kGraphExecutor, groupable); } } +void CreateAutodiffSubgraphs(Graph & graph, size_t threshold) { + CreateAutodiffSubgraphs(graph.block(), threshold); +} + }} diff --git a/torch/csrc/jit/passes/dead_code_elimination.cpp b/torch/csrc/jit/passes/dead_code_elimination.cpp index 968d93fc8e4..4d5d817d314 100644 --- a/torch/csrc/jit/passes/dead_code_elimination.cpp +++ b/torch/csrc/jit/passes/dead_code_elimination.cpp @@ -3,10 +3,16 @@ namespace torch { namespace jit { void EliminateDeadCode(std::shared_ptr& graph) { - auto nodes = graph->nodes().reverse(); + EliminateDeadCode(graph->block()); +} + +void EliminateDeadCode(Block *block) { + auto nodes = block->nodes().reverse(); for (auto it = nodes.begin(); it != nodes.end(); it++) { auto node = *it; - if(!node->hasUses()) + for (Block * block : node->blocks()) + EliminateDeadCode(block); + if (!node->hasUses()) it.destroyCurrent(); } } diff --git a/torch/csrc/jit/passes/dead_code_elimination.h b/torch/csrc/jit/passes/dead_code_elimination.h index ab83951d050..438e33d7ebc 100644 --- a/torch/csrc/jit/passes/dead_code_elimination.h +++ b/torch/csrc/jit/passes/dead_code_elimination.h @@ -5,5 +5,6 @@ namespace torch { namespace jit { void EliminateDeadCode(std::shared_ptr& graph); +void EliminateDeadCode(Block *block); }} diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index a0fa6d97c7f..731afcaf028 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -78,7 +78,7 @@ bool isSimpleMap(Node *node) { } struct GraphFuser { - std::shared_ptr& graph; + Block * block; // Used to order nodes so we always consider producer-consumer fusions // in reverse topological order. @@ -87,8 +87,8 @@ struct GraphFuser { // Newly generated nodes will copy the location where they are inserted. std::unordered_map topological_index; - GraphFuser(std::shared_ptr& graph) - : graph(graph) {} + GraphFuser(Block * block) + : block(block) {} int getDevice(Node * node) { if(node->kind() == kFusionGroup) { @@ -122,6 +122,7 @@ struct GraphFuser { return true; } bool isFusable(Node * node) { + if (node->owningBlock() != block) return false; if (node->kind() == kFusionGroup) return true; return isSimpleMap(node) && allFloatIO(node); } @@ -210,7 +211,7 @@ struct GraphFuser { // Clone all nodes for (auto inner : producer_subgraph->nodes()) { - Node * outer = graph->createClone(inner, [&](Value * k) -> Value* { + Node * outer = block->owningGraph()->createClone(inner, [&](Value * k) -> Value* { return inner_to_outer.at(k); }); outer->insertBefore(producer_group); @@ -291,7 +292,7 @@ struct GraphFuser { // turn consumer node n into a fusion group with just n inside // to prepare for fusion and replace uses of n with the new group Node * createSingletonFusionGroup(Node * n) { - auto group = graph->createFusionGroup(getDevice(n)); + auto group = block->owningGraph()->createFusionGroup(getDevice(n)); // propogate position information for the new node so we can always // have a valid mapping topological_index[group] = topological_index[n]; @@ -397,7 +398,7 @@ struct GraphFuser { // TODO: Perhaps we should use cloneFrom now, as it seems unlikely // to copy select nodes now that we have refactored to have a Value // distinct from Node. - Node * input_chunk = graph->create(chunk->kind(), 0); + Node * input_chunk = block->owningGraph()->create(chunk->kind(), 0); input_chunk->copyAttributes(*chunk); input_chunk->addInput(input); insertAt(&insertion_point, input_chunk); @@ -416,7 +417,7 @@ struct GraphFuser { // apply the op to each chunk of the chunked operands, // and then rewrite the graph to use them! for (auto chunk_sel : chunk->outputs()) { - Node * chunked_op = graph->create(producer_for_chunk_node->kind()); + Node * chunked_op = block->owningGraph()->create(producer_for_chunk_node->kind()); chunked_op->copyAttributes(*producer_for_chunk_node); // Invariant: mappable operators always produce contiguous output chunked_op->output()->setType(chunk_sel->type()->cast()->contiguous()); @@ -433,7 +434,7 @@ struct GraphFuser { // returns where to continue scanning, and whether any fusion was made std::pair scanNode(Node * consumer) { - auto stage_guard = graph->setStageTemporary(consumer->stage()); + auto stage_guard = block->owningGraph()->setStageTemporary(consumer->stage()); if(isFusableAsExitNode(consumer)) { // handle inputs in reverse topological order as well... // otherwise in f(a,a+b) it will appear a is used twice if we consider @@ -465,15 +466,14 @@ struct GraphFuser { } void run() { - for(auto p : graph->inputs()) { + for(auto p : block->inputs()) { topological_index[p->node()] = 0; } size_t i = 1; - auto nodes = graph->nodes(); - for(auto consumer : nodes) { + for(auto consumer : block->nodes()) { topological_index[consumer] = i++; } - topological_index[graph->return_node()] = i++; + topological_index[block->return_node()] = i++; // Run the pass until no changes are made. // This is neccessary, because the algorithm can miss out on certain fusion @@ -494,19 +494,24 @@ struct GraphFuser { bool any_changed = true; while (any_changed) { any_changed = false; - for (auto it = nodes.rbegin(); it != nodes.rend();) { + for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) { bool changed; std::tie(it, changed) = scanNode(*it); any_changed |= changed; } } + for (Node * node : block->nodes()) { + for (Block * sub_block : node->blocks()) { + GraphFuser(sub_block).run(); + } + } } }; } // anonymous namespace void FuseGraph(std::shared_ptr& graph) { - GraphFuser(graph).run(); + GraphFuser(graph->block()).run(); } }} diff --git a/torch/csrc/jit/passes/inplace_check.cpp b/torch/csrc/jit/passes/inplace_check.cpp index 2852509e86e..18a78529f6b 100644 --- a/torch/csrc/jit/passes/inplace_check.cpp +++ b/torch/csrc/jit/passes/inplace_check.cpp @@ -2,8 +2,8 @@ namespace torch { namespace jit { -void CheckInplace(std::shared_ptr& graph) { - for (auto node : graph->nodes()) { +void CheckInplace(Block * block) { + for (auto node : block->nodes()) { if (node->kind() == kPythonOp && node->hasAttribute(kinplace)) { if (node->i(kinplace)) { throw std::runtime_error(std::string("inplace ") + @@ -14,4 +14,8 @@ void CheckInplace(std::shared_ptr& graph) { } } +void CheckInplace(std::shared_ptr& graph) { + CheckInplace(graph->block()); +} + }} // namespace torch::jit diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index 26789307272..051b7968960 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -10,10 +10,12 @@ namespace torch { namespace jit { // - Simply x.t().t() to x // // TODO: Decide what kind of fixed point strategy we will have -void PeepholeOptimize(std::shared_ptr& graph) { - for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) { +void PeepholeOptimize(Block * block) { + for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { auto* n = *it; + // XXX: remember that if you want to simplify an expression by combining multiple nodes + // into a different one, then you need to check that they all belong to the given block switch (n->kind()) { case kexpand: // Eliminate redundant expand @@ -40,7 +42,15 @@ void PeepholeOptimize(std::shared_ptr& graph) { } break; } + + for (Block * sub_block : n->blocks()) { + PeepholeOptimize(sub_block); + } } } +void PeepholeOptimize(std::shared_ptr& graph) { + PeepholeOptimize(graph->block()); +} + }} diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index f2d404a85bc..d6f8e9993d3 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -151,6 +151,15 @@ void PropagateShapeOnNode(Node * node) { } +void PropagateShapeOnBlock(Block * block) { + for (Node * node : block->nodes()) { + PropagateShapeOnNode(node); + for (Block * sub_block : node->blocks()) { + PropagateShapeOnBlock(sub_block); + } + } +} + } void PropagateInputShapes(Graph & graph, const ArgumentSpec & spec) { @@ -158,9 +167,7 @@ void PropagateInputShapes(Graph & graph, const ArgumentSpec & spec) { for(size_t i = 0; i < spec.size(); ++i) { graph.inputs()[i]->setType(spec.tensorInfo(i)); } - for(auto n : graph.nodes()) { - PropagateShapeOnNode(n); - } + PropagateShapeOnBlock(graph.block()); } }} diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index af37b390dc6..1f8cf306e20 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -12,7 +12,6 @@ PY2 = sys.version_info[0] == 2 _reserved_prefix = '__jit' _identifier_chars = set(string.ascii_lowercase + string.ascii_uppercase + string.digits) -# TODO: populate those pretty_node_names = { ast.For: "for loops", ast.Delete: "del statements",