mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[JIT] Add Reinplacing to MKLDNN Subgraphs (#53908)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53908 This adds reinplacing to MKLDNN Subgraphs so that we replace `aten::add` with `aten::add_`. Normally you would have to prove device and dtype, but we know that already, and because we have explicit broadcast nodes for other reasons we dont have to prove that the output shape of add is the same as inputs. Ive tested correctness on resnet, I'm going to do more extensive testing as well. When I benchmarked the "unsafe" version (always inplace) I saw average speedups of ~16% for both Single threaded and Multithreaded. I dont think the "safe" version will be far beyond; when I looked at resnet for example every `add` and `relu` were reinplaced. Theres some question of reusing other alias / liveness / inplacing passes in SR. I thought about it, however I didnt want to add a cross-dependency between very different parts of the code base with a bunch of different assumptions. The logic here is also covering a simpler case and does not add much complexity IMO. Test Plan: Imported from OSS Reviewed By: Krovatkin Differential Revision: D27132969 Pulled By: eellison fbshipit-source-id: 121a38daaedf01363f6b66a814beaaa72a0ab0dc
This commit is contained in:
parent
81c6e5fb38
commit
9be4c75fa0
3 changed files with 272 additions and 5 deletions
|
|
@ -1741,7 +1741,8 @@ class TestFrozenOptimizations(JitTestCase):
|
|||
scripted_mod = torch.jit.script(mod)
|
||||
scripted_mod = torch.jit.freeze(scripted_mod)
|
||||
self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
|
||||
FileCheck().check("aten::to_mkldnn").check_not("aten::add_").check("aten::div_").run(scripted_mod.graph)
|
||||
# add gets uninplaced and reinplaced
|
||||
FileCheck().check("aten::to_mkldnn").check("aten::add_").check("aten::div_").run(scripted_mod.graph)
|
||||
inp = torch.rand([20, 20])
|
||||
self.assertEqual(scripted_mod(inp), mod(inp))
|
||||
self.assertEqual(scripted_mod(inp), mod(inp))
|
||||
|
|
@ -1816,3 +1817,102 @@ class TestFrozenOptimizations(JitTestCase):
|
|||
FileCheck().check("aten::cudnn_convolution_relu").run(frozen_mod.graph)
|
||||
|
||||
self.assertEqual(mod_eager(inp), frozen_mod(inp))
|
||||
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
|
||||
class TestMKLDNNReinplacing(JitTestCase):
|
||||
def setUp(self):
|
||||
self.default_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
def tearDown(self):
|
||||
torch.set_default_dtype(self.default_dtype)
|
||||
|
||||
def getConv(self):
|
||||
return nn.Conv2d(3, 32, kernel_size=3, stride=2).eval()
|
||||
|
||||
def getInput(self):
|
||||
return torch.rand([4, 3, 4, 4])
|
||||
|
||||
def freezeAndConvert(self, mod):
|
||||
mod = torch.jit.freeze(torch.jit.script(mod.eval()))
|
||||
self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
|
||||
return mod
|
||||
|
||||
def checkResults(self, mod1, mod2):
|
||||
inp = self.getInput()
|
||||
self.assertEqual(mod1(inp), mod2(inp))
|
||||
|
||||
def test_successful(self):
|
||||
# simple conv-relu
|
||||
mod_eager = nn.Sequential(self.getConv(), nn.ReLU(), nn.ReLU())
|
||||
mod = self.freezeAndConvert(mod_eager)
|
||||
FileCheck().check("mkldnn_convolution").check_next("aten::relu_").check_next("aten::relu_").run(mod.graph)
|
||||
self.checkResults(mod_eager, mod)
|
||||
|
||||
def test_merge_liveness(self):
|
||||
class Mod(nn.Module):
|
||||
def __init__(self, tensor):
|
||||
super().__init__()
|
||||
self.tensor = tensor
|
||||
|
||||
def forward(self, x):
|
||||
# this mul can be inplaced since x is dead after this use
|
||||
temporary = x * self.tensor
|
||||
# temporary livespan is the return node,
|
||||
# add can not be inplaced
|
||||
return temporary + temporary, temporary
|
||||
|
||||
mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
|
||||
mod = self.freezeAndConvert(mod_eager)
|
||||
FileCheck().check("aten::mul_").check_not("aten::add_").run(mod.graph)
|
||||
self.checkResults(mod_eager, mod)
|
||||
|
||||
def test_always_alive_values(self):
|
||||
class Mod(nn.Module):
|
||||
def __init__(self, tensor):
|
||||
super().__init__()
|
||||
self.tensor = tensor
|
||||
|
||||
def forward(self, x):
|
||||
# x can't be inplaced because its a return value,
|
||||
# check that the inplacing pass doesnt try to inplace
|
||||
# self.tensor because its always alive
|
||||
return x * self.tensor, x
|
||||
|
||||
mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
|
||||
mod = self.freezeAndConvert(mod_eager)
|
||||
FileCheck().check_not("aten::mul_").run(mod.graph)
|
||||
self.checkResults(mod_eager, mod)
|
||||
|
||||
conv = self.getConv()
|
||||
|
||||
class Mod(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tensor = torch.rand([4, 32, 1, 1])
|
||||
self.conv = conv
|
||||
|
||||
def forward(self, x):
|
||||
# the shapes dont add up on this just testing a particular pattern
|
||||
conv_output = self.conv(x)
|
||||
return conv_output, self.conv(torch.add(x, x))
|
||||
|
||||
mod = self.freezeAndConvert(Mod())
|
||||
# x is an input to the graph, and so it should not be inplaced
|
||||
# in the torch.add(x, x) call
|
||||
FileCheck().check_not("aten::add_").run(mod.graph)
|
||||
|
||||
def test_switch_inputs_to_inplace(self):
|
||||
class Mod(nn.Module):
|
||||
def __init__(self, tensor):
|
||||
super().__init__()
|
||||
self.tensor = tensor
|
||||
|
||||
def forward(self, x):
|
||||
# self.tensor cannot be inplaced, however x can,
|
||||
# and bc add is commutative we can reverse inputs to add_
|
||||
return self.tensor + x
|
||||
|
||||
mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
|
||||
mod = self.freezeAndConvert(mod_eager)
|
||||
FileCheck().check("aten::add_").run(mod.graph)
|
||||
self.checkResults(mod_eager, mod)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from jit.test_export_modes import TestExportModes # noqa: F401
|
|||
from jit.test_class_type import TestClassType # noqa: F401
|
||||
from jit.test_builtins import TestBuiltins, TestTensorBuiltins # noqa: F401
|
||||
from jit.test_unsupported_ops import TestUnsupportedOps # noqa: F401
|
||||
from jit.test_freezing import TestFreezing, TestFrozenOptimizations # noqa: F401
|
||||
from jit.test_freezing import TestFreezing, TestFrozenOptimizations, TestMKLDNNReinplacing # noqa: F401
|
||||
from jit.test_peephole import TestPeephole # noqa: F401
|
||||
from jit.test_save_load import TestSaveLoad # noqa: F401
|
||||
from jit.test_module_containers import TestModuleContainers # noqa: F401
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@
|
|||
// clang-format off
|
||||
// moving ConvUtils include induces import cycle
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
// clang-format on
|
||||
|
||||
namespace torch {
|
||||
|
|
@ -39,6 +41,170 @@ c10::AliasAnalysisKind aliasAnalysisFromSchema() {
|
|||
return AliasAnalysisKind::FROM_SCHEMA;
|
||||
}
|
||||
|
||||
using ValueSet = std::unordered_set<Value*>;
|
||||
using ValueSetPtr = std::shared_ptr<std::unordered_set<Value*>>;
|
||||
|
||||
Node* getLastUse(Value* v) {
|
||||
auto last_use_node = v->node();
|
||||
for (const auto& use : v->uses()) {
|
||||
if (use.user->isAfter(last_use_node)) {
|
||||
last_use_node = use.user;
|
||||
}
|
||||
}
|
||||
return last_use_node;
|
||||
}
|
||||
|
||||
void merge_sets(
|
||||
std::unordered_map<Value*, ValueSetPtr>& alias_mapping,
|
||||
Value* existing,
|
||||
Value* new_v) {
|
||||
if (alias_mapping[existing] == alias_mapping[new_v]) {
|
||||
return;
|
||||
}
|
||||
auto existing_set = alias_mapping[existing];
|
||||
auto set_to_remove = alias_mapping[new_v];
|
||||
for (auto it = set_to_remove->begin(); it != set_to_remove->end(); it++) {
|
||||
existing_set->insert(*it);
|
||||
alias_mapping[*it] = existing_set;
|
||||
}
|
||||
}
|
||||
|
||||
// no uses of tensors in container types
|
||||
void assertNonTensorTypeDoesNotContainTensors(TypePtr type) {
|
||||
if (type->cast<TensorType>()) {
|
||||
return;
|
||||
}
|
||||
for (auto t : type->containedTypes()) {
|
||||
TORCH_INTERNAL_ASSERT(!t->cast<TensorType>());
|
||||
}
|
||||
}
|
||||
|
||||
void InplaceMKLDNNSubgraph(std::shared_ptr<Graph> graph) {
|
||||
// This function first calculates aliasing sets,
|
||||
// then calculates the last node each aliasing set is alive for.
|
||||
// Then we go through each node, if it's a node which has an equivalent
|
||||
// inplace node and the aliasing set for its input is dead afer this node, we
|
||||
// inplace it. Then we merge the aliasing sets for the input and output of the
|
||||
// node and extend the liveness of the set. To inplace a node you need to
|
||||
// prove device and dtype of the input and output are the same, which we've
|
||||
// already done, and prove that the output size is the same as the input size,
|
||||
// which is achieved by explicit Broadcast nodes (which we inserted for other
|
||||
// reasons).
|
||||
// The graphs here are simple subgraphs without uses of Tensors in
|
||||
// containers (Lists, GetAttrs, etc)
|
||||
|
||||
// CALCULATE ALIASING SETS
|
||||
|
||||
auto aliasDb = torch::make_unique<AliasDb>(graph);
|
||||
|
||||
// map from Value to its Aliasing Set
|
||||
std::unordered_map<Value*, ValueSetPtr> alias_mapping;
|
||||
ValueSet set;
|
||||
ValueSetPtr input_set = std::make_shared<ValueSet>(set);
|
||||
for (Value* v : graph->inputs()) {
|
||||
if (v->type()->cast<TensorType>()) {
|
||||
input_set->insert(v);
|
||||
alias_mapping[v] = input_set;
|
||||
} else {
|
||||
assertNonTensorTypeDoesNotContainTensors(v->type());
|
||||
}
|
||||
}
|
||||
|
||||
for (Node* n : graph->nodes()) {
|
||||
for (Value* output : n->outputs()) {
|
||||
if (!output->type()->cast<TensorType>()) {
|
||||
assertNonTensorTypeDoesNotContainTensors(output->type());
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unordered_set<Value*> new_set = {output};
|
||||
alias_mapping[output] = std::make_shared<ValueSet>(new_set);
|
||||
for (Value* input : n->inputs()) {
|
||||
if (aliasDb->mayAlias(input, output)) {
|
||||
merge_sets(alias_mapping, input, output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CALCULATE ALIASING SET LIVENESS
|
||||
|
||||
// map from aliased set -> last use of set
|
||||
std::unordered_map<ValueSetPtr, Node*> set_liveness;
|
||||
for (auto& set : alias_mapping) {
|
||||
if (set_liveness.count(set.second)) {
|
||||
continue;
|
||||
}
|
||||
Node* last = nullptr;
|
||||
for (auto it = set.second->begin(); it != set.second->end(); it++) {
|
||||
Value* v = *it;
|
||||
auto k = v->node()->kind();
|
||||
if (k == prim::Constant || k == prim::ConstantMKLDNNTensor ||
|
||||
k == prim::Param) {
|
||||
last = graph->return_node();
|
||||
continue;
|
||||
}
|
||||
|
||||
auto last_use = getLastUse(v);
|
||||
if (!last || last_use->isAfter(last)) {
|
||||
last = last_use;
|
||||
}
|
||||
}
|
||||
set_liveness[set.second] = last;
|
||||
}
|
||||
|
||||
// REUSING MEMORY BY REINPLACING NODES
|
||||
std::vector<Node*> nodes_to_inplace;
|
||||
|
||||
auto add_to_inplace_set = [&](Node* node) {
|
||||
// defer making the inplacing change because that would invalidate the old
|
||||
// Node output Value*
|
||||
nodes_to_inplace.push_back(node);
|
||||
TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
|
||||
auto output_liveness_end =
|
||||
set_liveness[alias_mapping[node->outputs().at(0)]];
|
||||
merge_sets(alias_mapping, node->inputs().at(0), node->output());
|
||||
set_liveness[alias_mapping[node->output()]] = output_liveness_end;
|
||||
};
|
||||
|
||||
for (Node* node : graph->nodes()) {
|
||||
auto k = node->kind();
|
||||
if (k == aten::relu || k == aten::sigmoid || k == aten::dropout) {
|
||||
if (set_liveness[alias_mapping[node->inputs().at(0)]]->isAfter(node)) {
|
||||
continue;
|
||||
}
|
||||
add_to_inplace_set(node);
|
||||
} else if (k == aten::mul || k == aten::add) {
|
||||
// the binary operators (add/mul) are commutative and only take tensor
|
||||
// inputs, so we can inplace either the first or second input
|
||||
int64_t reusable_value_index = -1;
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
TORCH_INTERNAL_ASSERT(node->inputs().at(i)->type()->cast<TensorType>());
|
||||
if (!set_liveness[alias_mapping[node->inputs().at(i)]]->isAfter(node)) {
|
||||
reusable_value_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (reusable_value_index == -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (reusable_value_index == 1) {
|
||||
node->insertInput(0, node->inputs().at(1));
|
||||
node->removeInput(2);
|
||||
}
|
||||
add_to_inplace_set(node);
|
||||
}
|
||||
}
|
||||
|
||||
for (Node* node : nodes_to_inplace) {
|
||||
node->replaceWithNewSymbol(
|
||||
Symbol::fromQualString(node->schema().name() + "_"));
|
||||
node->destroy();
|
||||
}
|
||||
}
|
||||
|
||||
Operation BroadOp(const Node* node) {
|
||||
return [](Stack* stack) {
|
||||
auto b = pop(stack).toTensor();
|
||||
|
|
@ -239,7 +405,7 @@ void moveWeightsToMKLDNN(Node* n) {
|
|||
}
|
||||
}
|
||||
|
||||
void computeSubgraphInMKLDNN(Node* subgraph_node) {
|
||||
void ComputeSubgraphInMKLDNN(Node* subgraph_node) {
|
||||
auto graph = subgraph_node->owningGraph();
|
||||
Value* none_value = nullptr;
|
||||
{
|
||||
|
|
@ -467,7 +633,7 @@ class MKLDNNSubgraphSlicer {
|
|||
return true;
|
||||
}
|
||||
|
||||
if (n->kind() == aten::add) {
|
||||
if (n->kind() == aten::add || n->kind() == aten::mul) {
|
||||
// mkldnn doesn't currently support Tensor-Scalar add
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
if (!n->inputs().at(i)->type()->cast<TensorType>()) {
|
||||
|
|
@ -489,7 +655,8 @@ class MKLDNNSubgraphSlicer {
|
|||
while (curNode != *block_->nodes().end()) {
|
||||
auto nextNode = curNode->next();
|
||||
if (curNode->kind() == prim::MKLDNNGroup) {
|
||||
computeSubgraphInMKLDNN(curNode);
|
||||
ComputeSubgraphInMKLDNN(curNode);
|
||||
InplaceMKLDNNSubgraph(SubgraphUtils::getSubgraph(curNode));
|
||||
SubgraphUtils::unmergeSubgraph(curNode);
|
||||
}
|
||||
curNode = nextNode;
|
||||
|
|
|
|||
Loading…
Reference in a new issue