mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
This PR relands the changes introduced in PR https://github.com/pytorch/pytorch/pull/100849. The old PR turnd nnc_* functions into static. We now add declarations for them and hope that inter builds will pass. Pull Request resolved: https://github.com/pytorch/pytorch/pull/102228 Approved by: https://github.com/albanD
49 lines
1.3 KiB
C++
49 lines
1.3 KiB
C++
#include <torch/csrc/jit/passes/constant_pooling.h>
|
|
#include <torch/csrc/jit/passes/constant_propagation.h>
|
|
#include <torch/csrc/jit/passes/remove_exceptions.h>
|
|
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
static bool certainlyThrows(Block* block) {
|
|
for (Node* n : block->nodes()) {
|
|
if (n->kind() == prim::RaiseException) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static void EliminateExceptions(Block* block) {
|
|
auto graph = block->owningGraph();
|
|
Value* false_const = graph->insertConstant(IValue(false));
|
|
Value* true_const = graph->insertConstant(IValue(true));
|
|
for (Node* n : block->nodes()) {
|
|
if (n->kind() == prim::If) {
|
|
Block* true_block = n->blocks()[0];
|
|
Block* false_block = n->blocks()[1];
|
|
if (certainlyThrows(true_block)) {
|
|
n->input(0)->replaceAllUsesWith(false_const);
|
|
} else if (certainlyThrows(false_block)) {
|
|
n->input(0)->replaceAllUsesWith(true_const);
|
|
}
|
|
}
|
|
|
|
for (Block* subblock : n->blocks()) {
|
|
EliminateExceptions(subblock);
|
|
}
|
|
}
|
|
}
|
|
|
|
void EliminateExceptions(std::shared_ptr<Graph>& graph) {
|
|
GRAPH_DUMP("Before EliminateExceptions: ", graph);
|
|
EliminateExceptions(graph->block());
|
|
ConstantPropagation(graph);
|
|
ConstantPooling(graph);
|
|
GRAPH_DUMP("After EliminateExceptions: ", graph);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|