mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Improve fuser algorithm
This commit is contained in:
parent
e66c592d10
commit
31af836412
3 changed files with 101 additions and 9 deletions
|
|
@ -77,6 +77,9 @@ std::unordered_map<NodeKind, std::string> simple_map_ops = {
|
|||
{klerp, "${0} + ${weight}*(${1} - ${0})"},
|
||||
{kclamp, "min(max(${0},${min}),${max})"},
|
||||
|
||||
// simple derivatives
|
||||
{"_sigmoid_backward"_sym, "${0} * ${1} * (1.f - ${1})"},
|
||||
{"_tanh_backward"_sym, "${0} * (1.f - ${1} * ${1})"},
|
||||
};
|
||||
|
||||
std::vector<bool> TensorDesc::findContiguous(
|
||||
|
|
@ -305,8 +308,8 @@ CompiledFusionFunction::CompiledFusionFunction(const std::string & name, Annotat
|
|||
if ((prop.major >= 6 && CUDA_VERSION < 8000) ||
|
||||
(prop.major >= 7 && CUDA_VERSION < 9000)) {
|
||||
std::stringstream err_string;
|
||||
err_string << "In CompiledFusionFunction, PyTorch compiled with insufficient CUDA version: "
|
||||
<< CUDA_VERSION << " for the current GPU device " << prop.name
|
||||
err_string << "In CompiledFusionFunction, PyTorch compiled with insufficient CUDA version: "
|
||||
<< CUDA_VERSION << " for the current GPU device " << prop.name
|
||||
<< " with device capability " << prop.major << "." << prop.minor;
|
||||
throw std::runtime_error(err_string.str());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -133,4 +133,8 @@ enum BuiltinSymbol {
|
|||
const char * symbolToString(Symbol s);
|
||||
Symbol stringToSymbol(const std::string & s);
|
||||
|
||||
inline Symbol operator "" _sym(const char * s, unsigned long) {
|
||||
return stringToSymbol(s);
|
||||
}
|
||||
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -62,6 +62,8 @@ std::unordered_set<NodeKind> simple_mappable = {
|
|||
ktanh,
|
||||
ktrunc,
|
||||
kzeros,
|
||||
"_sigmoid_backward"_sym,
|
||||
"_tanh_backward"_sym,
|
||||
};
|
||||
|
||||
bool isSimpleMap(Node *node) {
|
||||
|
|
@ -176,7 +178,64 @@ struct GraphFuser {
|
|||
JIT_ASSERT(n->kind() == kFusionGroup);
|
||||
return *n->g(kSubgraph);
|
||||
}
|
||||
|
||||
void mergeFusionGroups(Node *consumer_group, Node *producer_group) {
|
||||
// Now we have two fusion groups!
|
||||
// Revert the fusion - place all inner nodes of producer back in the outer graph.
|
||||
std::vector<Node*> temporary_nodes;
|
||||
auto producer_subgraph = &getSubgraph(producer_group);
|
||||
|
||||
// Initialize a map of inner graph values to outer graph values
|
||||
std::unordered_map<Value*, Value*> inner_to_outer;
|
||||
auto inner_inputs = producer_subgraph->inputs();
|
||||
auto outer_inputs = producer_group->inputs();
|
||||
for (std::size_t i = 0; i < inner_inputs.size(); ++i) {
|
||||
inner_to_outer[inner_inputs[i]] = outer_inputs[i];
|
||||
}
|
||||
|
||||
// Clone all nodes
|
||||
for (auto inner : *producer_subgraph) {
|
||||
Node * outer = graph->createClone(inner, [&](Value * k) -> Value* {
|
||||
return inner_to_outer.at(k);
|
||||
});
|
||||
outer->insertBefore(producer_group);
|
||||
temporary_nodes.emplace_back(outer);
|
||||
auto inner_outputs = inner->outputs();
|
||||
auto outer_outputs = outer->outputs();
|
||||
for (std::size_t i = 0; i < inner_outputs.size(); ++i)
|
||||
inner_to_outer[inner_outputs[i]] = outer_outputs[i];
|
||||
}
|
||||
|
||||
// Replace uses of producer_group outputs and destroy the producer
|
||||
auto subgraph_outputs = producer_subgraph->outputs();
|
||||
for (std::size_t i = 0; i < subgraph_outputs.size(); ++i) {
|
||||
auto outer_output = inner_to_outer.at(subgraph_outputs[i]);
|
||||
producer_group->outputs()[i]->replaceAllUsesWith(outer_output);
|
||||
}
|
||||
producer_group->destroy();
|
||||
producer_group = nullptr; // Just to get a clear error in case someone uses it
|
||||
|
||||
// Inline the temporary nodes into the first group
|
||||
auto consumer_subgraph = &getSubgraph(consumer_group);
|
||||
for (auto it = temporary_nodes.rbegin(); it != temporary_nodes.rend(); ++it) {
|
||||
Node *node = *it;
|
||||
Node *merged = mergeNodeIntoGroup(consumer_group, node);
|
||||
// If any of the outputs are still used then we need to add them
|
||||
auto outputs = node->outputs();
|
||||
for (std::size_t i = 0; i < outputs.size(); ++i) {
|
||||
auto output = outputs[i];
|
||||
if (output->uses().size() == 0) continue;
|
||||
consumer_subgraph->registerOutput(merged->outputs()[i]);
|
||||
auto new_output = consumer_group->addOutput();
|
||||
output->replaceAllUsesWith(new_output);
|
||||
new_output->setType(output->typeOption());
|
||||
}
|
||||
node->destroy();
|
||||
}
|
||||
}
|
||||
|
||||
Node * mergeNodeIntoGroup(Node* group, Node * n) {
|
||||
JIT_ASSERT(n->kind() != kFusionGroup);
|
||||
auto & subgraph = getSubgraph(group);
|
||||
// map from nodes in the surrounding graph to parameters in the fusion
|
||||
// group's subgraph that correspond to them
|
||||
|
|
@ -245,6 +304,10 @@ struct GraphFuser {
|
|||
if(group->kind() != kFusionGroup) {
|
||||
group = createSingletonFusionGroup(consumer);
|
||||
}
|
||||
if (producer->node()->kind() == kFusionGroup) {
|
||||
mergeFusionGroups(group, producer->node());
|
||||
return group;
|
||||
}
|
||||
Node * merged = mergeNodeIntoGroup(group, producer->node());
|
||||
// remaining uses of this producer can occur because we allow
|
||||
// fusion in cases where uses remain after the consumer
|
||||
|
|
@ -349,8 +412,8 @@ struct GraphFuser {
|
|||
return true;
|
||||
}
|
||||
|
||||
// returns where to continue scanning
|
||||
graph_node_list::iterator scanNode(Node * consumer) {
|
||||
// returns where to continue scanning, and whether any fusion was made
|
||||
std::pair<graph_node_list::iterator, bool> scanNode(Node * consumer) {
|
||||
auto stage_guard = graph->setStageTemporary(consumer->stage());
|
||||
if(isFusableAsExitNode(consumer)) {
|
||||
// handle inputs in reverse topological order as well...
|
||||
|
|
@ -369,17 +432,17 @@ struct GraphFuser {
|
|||
if(tryToMoveChunk(consumer,producer)) {
|
||||
// the chunk before this consumer was re-arranged to allow fusion,
|
||||
// we scan this consumer again to perform the fusion
|
||||
return consumer->reverseIterator();
|
||||
return std::make_pair(consumer->reverseIterator(), true);
|
||||
}
|
||||
if(shouldFuse(consumer, producer)) {
|
||||
auto fusion_group = fuse(consumer,producer);
|
||||
// after fusion, consumer moves into a FusionGroup, so inputs is no longer valid
|
||||
// so we rescan the new FusionGroup for more fusions...
|
||||
return fusion_group->reverseIterator();
|
||||
return std::make_pair(fusion_group->reverseIterator(), true);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ++consumer->reverseIterator();
|
||||
return std::make_pair(++consumer->reverseIterator(), false);
|
||||
}
|
||||
|
||||
void run() {
|
||||
|
|
@ -393,8 +456,30 @@ struct GraphFuser {
|
|||
}
|
||||
topological_index[graph->return_node()] = i++;
|
||||
|
||||
for(auto it = nodes.rbegin(); it != nodes.rend();) {
|
||||
it = scanNode(*it);
|
||||
// Run the pass until no changes are made.
|
||||
// This is neccessary, because the algorithm can miss out on certain fusion
|
||||
// opportunities if ran only once. Consider this graph:
|
||||
//
|
||||
// %1 = f(...)
|
||||
// %2 = g(%1)
|
||||
// %3 = h(%1)
|
||||
// %4 = l(%3)
|
||||
// return (%4, %2)
|
||||
//
|
||||
// where f, g, h, l are simple map ops.
|
||||
// The first iteration will fuse %4 and %3, and see that %1 is an input, but
|
||||
// can't be fused, because it has a different use before the fusion group
|
||||
// in our topological ordering. Then, %2 will be considered, and fused with %1.
|
||||
// If we do another iteration, the algorithm will consider the fusion of these
|
||||
// two groups and fix the situation.
|
||||
bool any_changed = true;
|
||||
while (any_changed) {
|
||||
any_changed = false;
|
||||
for (auto it = nodes.rbegin(); it != nodes.rend();) {
|
||||
bool changed;
|
||||
std::tie(it, changed) = scanNode(*it);
|
||||
any_changed |= changed;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue