Improve fuser algorithm

This commit is contained in:
Adam Paszke 2017-11-26 15:26:45 -05:00
parent e66c592d10
commit 31af836412
3 changed files with 101 additions and 9 deletions

View file

@ -77,6 +77,9 @@ std::unordered_map<NodeKind, std::string> simple_map_ops = {
{klerp, "${0} + ${weight}*(${1} - ${0})"},
{kclamp, "min(max(${0},${min}),${max})"},
// simple derivatives
{"_sigmoid_backward"_sym, "${0} * ${1} * (1.f - ${1})"},
{"_tanh_backward"_sym, "${0} * (1.f - ${1} * ${1})"},
};
std::vector<bool> TensorDesc::findContiguous(
@ -305,8 +308,8 @@ CompiledFusionFunction::CompiledFusionFunction(const std::string & name, Annotat
if ((prop.major >= 6 && CUDA_VERSION < 8000) ||
(prop.major >= 7 && CUDA_VERSION < 9000)) {
std::stringstream err_string;
err_string << "In CompiledFusionFunction, PyTorch compiled with insufficient CUDA version: "
<< CUDA_VERSION << " for the current GPU device " << prop.name
err_string << "In CompiledFusionFunction, PyTorch compiled with insufficient CUDA version: "
<< CUDA_VERSION << " for the current GPU device " << prop.name
<< " with device capability " << prop.major << "." << prop.minor;
throw std::runtime_error(err_string.str());
}

View file

@ -133,4 +133,8 @@ enum BuiltinSymbol {
const char * symbolToString(Symbol s);
Symbol stringToSymbol(const std::string & s);
inline Symbol operator "" _sym(const char * s, unsigned long) {
return stringToSymbol(s);
}
}}

View file

@ -62,6 +62,8 @@ std::unordered_set<NodeKind> simple_mappable = {
ktanh,
ktrunc,
kzeros,
"_sigmoid_backward"_sym,
"_tanh_backward"_sym,
};
bool isSimpleMap(Node *node) {
@ -176,7 +178,64 @@ struct GraphFuser {
JIT_ASSERT(n->kind() == kFusionGroup);
return *n->g(kSubgraph);
}
void mergeFusionGroups(Node *consumer_group, Node *producer_group) {
// Now we have two fusion groups!
// Revert the fusion - place all inner nodes of producer back in the outer graph.
std::vector<Node*> temporary_nodes;
auto producer_subgraph = &getSubgraph(producer_group);
// Initialize a map of inner graph values to outer graph values
std::unordered_map<Value*, Value*> inner_to_outer;
auto inner_inputs = producer_subgraph->inputs();
auto outer_inputs = producer_group->inputs();
for (std::size_t i = 0; i < inner_inputs.size(); ++i) {
inner_to_outer[inner_inputs[i]] = outer_inputs[i];
}
// Clone all nodes
for (auto inner : *producer_subgraph) {
Node * outer = graph->createClone(inner, [&](Value * k) -> Value* {
return inner_to_outer.at(k);
});
outer->insertBefore(producer_group);
temporary_nodes.emplace_back(outer);
auto inner_outputs = inner->outputs();
auto outer_outputs = outer->outputs();
for (std::size_t i = 0; i < inner_outputs.size(); ++i)
inner_to_outer[inner_outputs[i]] = outer_outputs[i];
}
// Replace uses of producer_group outputs and destroy the producer
auto subgraph_outputs = producer_subgraph->outputs();
for (std::size_t i = 0; i < subgraph_outputs.size(); ++i) {
auto outer_output = inner_to_outer.at(subgraph_outputs[i]);
producer_group->outputs()[i]->replaceAllUsesWith(outer_output);
}
producer_group->destroy();
producer_group = nullptr; // Just to get a clear error in case someone uses it
// Inline the temporary nodes into the first group
auto consumer_subgraph = &getSubgraph(consumer_group);
for (auto it = temporary_nodes.rbegin(); it != temporary_nodes.rend(); ++it) {
Node *node = *it;
Node *merged = mergeNodeIntoGroup(consumer_group, node);
// If any of the outputs are still used then we need to add them
auto outputs = node->outputs();
for (std::size_t i = 0; i < outputs.size(); ++i) {
auto output = outputs[i];
if (output->uses().size() == 0) continue;
consumer_subgraph->registerOutput(merged->outputs()[i]);
auto new_output = consumer_group->addOutput();
output->replaceAllUsesWith(new_output);
new_output->setType(output->typeOption());
}
node->destroy();
}
}
Node * mergeNodeIntoGroup(Node* group, Node * n) {
JIT_ASSERT(n->kind() != kFusionGroup);
auto & subgraph = getSubgraph(group);
// map from nodes in the surrounding graph to parameters in the fusion
// group's subgraph that correspond to them
@ -245,6 +304,10 @@ struct GraphFuser {
if(group->kind() != kFusionGroup) {
group = createSingletonFusionGroup(consumer);
}
if (producer->node()->kind() == kFusionGroup) {
mergeFusionGroups(group, producer->node());
return group;
}
Node * merged = mergeNodeIntoGroup(group, producer->node());
// remaining uses of this producer can occur because we allow
// fusion in cases where uses remain after the consumer
@ -349,8 +412,8 @@ struct GraphFuser {
return true;
}
// returns where to continue scanning
graph_node_list::iterator scanNode(Node * consumer) {
// returns where to continue scanning, and whether any fusion was made
std::pair<graph_node_list::iterator, bool> scanNode(Node * consumer) {
auto stage_guard = graph->setStageTemporary(consumer->stage());
if(isFusableAsExitNode(consumer)) {
// handle inputs in reverse topological order as well...
@ -369,17 +432,17 @@ struct GraphFuser {
if(tryToMoveChunk(consumer,producer)) {
// the chunk before this consumer was re-arranged to allow fusion,
// we scan this consumer again to perform the fusion
return consumer->reverseIterator();
return std::make_pair(consumer->reverseIterator(), true);
}
if(shouldFuse(consumer, producer)) {
auto fusion_group = fuse(consumer,producer);
// after fusion, consumer moves into a FusionGroup, so inputs is no longer valid
// so we rescan the new FusionGroup for more fusions...
return fusion_group->reverseIterator();
return std::make_pair(fusion_group->reverseIterator(), true);
}
}
}
return ++consumer->reverseIterator();
return std::make_pair(++consumer->reverseIterator(), false);
}
void run() {
@ -393,8 +456,30 @@ struct GraphFuser {
}
topological_index[graph->return_node()] = i++;
for(auto it = nodes.rbegin(); it != nodes.rend();) {
it = scanNode(*it);
// Run the pass until no changes are made.
// This is neccessary, because the algorithm can miss out on certain fusion
// opportunities if ran only once. Consider this graph:
//
// %1 = f(...)
// %2 = g(%1)
// %3 = h(%1)
// %4 = l(%3)
// return (%4, %2)
//
// where f, g, h, l are simple map ops.
// The first iteration will fuse %4 and %3, and see that %1 is an input, but
// can't be fused, because it has a different use before the fusion group
// in our topological ordering. Then, %2 will be considered, and fused with %1.
// If we do another iteration, the algorithm will consider the fusion of these
// two groups and fix the situation.
bool any_changed = true;
while (any_changed) {
any_changed = false;
for (auto it = nodes.rbegin(); it != nodes.rend();) {
bool changed;
std::tie(it, changed) = scanNode(*it);
any_changed |= changed;
}
}
}
};