[JIT] Add Reinplacing to MKLDNN Subgraphs (#53908)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53908

This adds reinplacing to MKLDNN Subgraphs so that we replace `aten::add` with `aten::add_`. Normally you would have to prove device and dtype, but we know that already, and because we have explicit broadcast nodes for other reasons we dont have to prove that the output shape of add is the same as inputs.

Ive tested correctness on resnet, I'm going to do more extensive testing as well. When I benchmarked the "unsafe" version (always inplace) I saw average speedups of ~16% for both Single threaded and Multithreaded. I dont think the "safe" version will be far beyond; when I looked at resnet for example every `add` and `relu` were reinplaced.

Theres some question of reusing other alias / liveness / inplacing passes in SR. I thought about it, however I didnt want to add a cross-dependency between very different parts of the code base with a bunch of different assumptions. The logic here is also covering a simpler case and does not add much complexity IMO.

Test Plan: Imported from OSS

Reviewed By: Krovatkin

Differential Revision: D27132969

Pulled By: eellison

fbshipit-source-id: 121a38daaedf01363f6b66a814beaaa72a0ab0dc
This commit is contained in:
Elias Ellison 2021-03-22 18:30:47 -07:00 committed by Facebook GitHub Bot
parent 81c6e5fb38
commit 9be4c75fa0
3 changed files with 272 additions and 5 deletions

View file

@ -1741,7 +1741,8 @@ class TestFrozenOptimizations(JitTestCase):
scripted_mod = torch.jit.script(mod)
scripted_mod = torch.jit.freeze(scripted_mod)
self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
FileCheck().check("aten::to_mkldnn").check_not("aten::add_").check("aten::div_").run(scripted_mod.graph)
# add gets uninplaced and reinplaced
FileCheck().check("aten::to_mkldnn").check("aten::add_").check("aten::div_").run(scripted_mod.graph)
inp = torch.rand([20, 20])
self.assertEqual(scripted_mod(inp), mod(inp))
self.assertEqual(scripted_mod(inp), mod(inp))
@ -1816,3 +1817,102 @@ class TestFrozenOptimizations(JitTestCase):
FileCheck().check("aten::cudnn_convolution_relu").run(frozen_mod.graph)
self.assertEqual(mod_eager(inp), frozen_mod(inp))
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
class TestMKLDNNReinplacing(JitTestCase):
def setUp(self):
self.default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float)
def tearDown(self):
torch.set_default_dtype(self.default_dtype)
def getConv(self):
return nn.Conv2d(3, 32, kernel_size=3, stride=2).eval()
def getInput(self):
return torch.rand([4, 3, 4, 4])
def freezeAndConvert(self, mod):
mod = torch.jit.freeze(torch.jit.script(mod.eval()))
self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
return mod
def checkResults(self, mod1, mod2):
inp = self.getInput()
self.assertEqual(mod1(inp), mod2(inp))
def test_successful(self):
# simple conv-relu
mod_eager = nn.Sequential(self.getConv(), nn.ReLU(), nn.ReLU())
mod = self.freezeAndConvert(mod_eager)
FileCheck().check("mkldnn_convolution").check_next("aten::relu_").check_next("aten::relu_").run(mod.graph)
self.checkResults(mod_eager, mod)
def test_merge_liveness(self):
class Mod(nn.Module):
def __init__(self, tensor):
super().__init__()
self.tensor = tensor
def forward(self, x):
# this mul can be inplaced since x is dead after this use
temporary = x * self.tensor
# temporary livespan is the return node,
# add can not be inplaced
return temporary + temporary, temporary
mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
mod = self.freezeAndConvert(mod_eager)
FileCheck().check("aten::mul_").check_not("aten::add_").run(mod.graph)
self.checkResults(mod_eager, mod)
def test_always_alive_values(self):
class Mod(nn.Module):
def __init__(self, tensor):
super().__init__()
self.tensor = tensor
def forward(self, x):
# x can't be inplaced because its a return value,
# check that the inplacing pass doesnt try to inplace
# self.tensor because its always alive
return x * self.tensor, x
mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
mod = self.freezeAndConvert(mod_eager)
FileCheck().check_not("aten::mul_").run(mod.graph)
self.checkResults(mod_eager, mod)
conv = self.getConv()
class Mod(nn.Module):
def __init__(self):
super().__init__()
self.tensor = torch.rand([4, 32, 1, 1])
self.conv = conv
def forward(self, x):
# the shapes dont add up on this just testing a particular pattern
conv_output = self.conv(x)
return conv_output, self.conv(torch.add(x, x))
mod = self.freezeAndConvert(Mod())
# x is an input to the graph, and so it should not be inplaced
# in the torch.add(x, x) call
FileCheck().check_not("aten::add_").run(mod.graph)
def test_switch_inputs_to_inplace(self):
class Mod(nn.Module):
def __init__(self, tensor):
super().__init__()
self.tensor = tensor
def forward(self, x):
# self.tensor cannot be inplaced, however x can,
# and bc add is commutative we can reverse inputs to add_
return self.tensor + x
mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
mod = self.freezeAndConvert(mod_eager)
FileCheck().check("aten::add_").run(mod.graph)
self.checkResults(mod_eager, mod)

View file

@ -19,7 +19,7 @@ from jit.test_export_modes import TestExportModes # noqa: F401
from jit.test_class_type import TestClassType # noqa: F401
from jit.test_builtins import TestBuiltins, TestTensorBuiltins # noqa: F401
from jit.test_unsupported_ops import TestUnsupportedOps # noqa: F401
from jit.test_freezing import TestFreezing, TestFrozenOptimizations # noqa: F401
from jit.test_freezing import TestFreezing, TestFrozenOptimizations, TestMKLDNNReinplacing # noqa: F401
from jit.test_peephole import TestPeephole # noqa: F401
from jit.test_save_load import TestSaveLoad # noqa: F401
from jit.test_module_containers import TestModuleContainers # noqa: F401

View file

@ -24,6 +24,8 @@
// clang-format off
// moving ConvUtils include induces import cycle
#include <ATen/native/ConvUtils.h>
#include <algorithm>
#include <memory>
// clang-format on
namespace torch {
@ -39,6 +41,170 @@ c10::AliasAnalysisKind aliasAnalysisFromSchema() {
return AliasAnalysisKind::FROM_SCHEMA;
}
using ValueSet = std::unordered_set<Value*>;
using ValueSetPtr = std::shared_ptr<std::unordered_set<Value*>>;
Node* getLastUse(Value* v) {
auto last_use_node = v->node();
for (const auto& use : v->uses()) {
if (use.user->isAfter(last_use_node)) {
last_use_node = use.user;
}
}
return last_use_node;
}
void merge_sets(
std::unordered_map<Value*, ValueSetPtr>& alias_mapping,
Value* existing,
Value* new_v) {
if (alias_mapping[existing] == alias_mapping[new_v]) {
return;
}
auto existing_set = alias_mapping[existing];
auto set_to_remove = alias_mapping[new_v];
for (auto it = set_to_remove->begin(); it != set_to_remove->end(); it++) {
existing_set->insert(*it);
alias_mapping[*it] = existing_set;
}
}
// no uses of tensors in container types
void assertNonTensorTypeDoesNotContainTensors(TypePtr type) {
if (type->cast<TensorType>()) {
return;
}
for (auto t : type->containedTypes()) {
TORCH_INTERNAL_ASSERT(!t->cast<TensorType>());
}
}
void InplaceMKLDNNSubgraph(std::shared_ptr<Graph> graph) {
// This function first calculates aliasing sets,
// then calculates the last node each aliasing set is alive for.
// Then we go through each node, if it's a node which has an equivalent
// inplace node and the aliasing set for its input is dead afer this node, we
// inplace it. Then we merge the aliasing sets for the input and output of the
// node and extend the liveness of the set. To inplace a node you need to
// prove device and dtype of the input and output are the same, which we've
// already done, and prove that the output size is the same as the input size,
// which is achieved by explicit Broadcast nodes (which we inserted for other
// reasons).
// The graphs here are simple subgraphs without uses of Tensors in
// containers (Lists, GetAttrs, etc)
// CALCULATE ALIASING SETS
auto aliasDb = torch::make_unique<AliasDb>(graph);
// map from Value to its Aliasing Set
std::unordered_map<Value*, ValueSetPtr> alias_mapping;
ValueSet set;
ValueSetPtr input_set = std::make_shared<ValueSet>(set);
for (Value* v : graph->inputs()) {
if (v->type()->cast<TensorType>()) {
input_set->insert(v);
alias_mapping[v] = input_set;
} else {
assertNonTensorTypeDoesNotContainTensors(v->type());
}
}
for (Node* n : graph->nodes()) {
for (Value* output : n->outputs()) {
if (!output->type()->cast<TensorType>()) {
assertNonTensorTypeDoesNotContainTensors(output->type());
continue;
}
std::unordered_set<Value*> new_set = {output};
alias_mapping[output] = std::make_shared<ValueSet>(new_set);
for (Value* input : n->inputs()) {
if (aliasDb->mayAlias(input, output)) {
merge_sets(alias_mapping, input, output);
}
}
}
}
// CALCULATE ALIASING SET LIVENESS
// map from aliased set -> last use of set
std::unordered_map<ValueSetPtr, Node*> set_liveness;
for (auto& set : alias_mapping) {
if (set_liveness.count(set.second)) {
continue;
}
Node* last = nullptr;
for (auto it = set.second->begin(); it != set.second->end(); it++) {
Value* v = *it;
auto k = v->node()->kind();
if (k == prim::Constant || k == prim::ConstantMKLDNNTensor ||
k == prim::Param) {
last = graph->return_node();
continue;
}
auto last_use = getLastUse(v);
if (!last || last_use->isAfter(last)) {
last = last_use;
}
}
set_liveness[set.second] = last;
}
// REUSING MEMORY BY REINPLACING NODES
std::vector<Node*> nodes_to_inplace;
auto add_to_inplace_set = [&](Node* node) {
// defer making the inplacing change because that would invalidate the old
// Node output Value*
nodes_to_inplace.push_back(node);
TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
auto output_liveness_end =
set_liveness[alias_mapping[node->outputs().at(0)]];
merge_sets(alias_mapping, node->inputs().at(0), node->output());
set_liveness[alias_mapping[node->output()]] = output_liveness_end;
};
for (Node* node : graph->nodes()) {
auto k = node->kind();
if (k == aten::relu || k == aten::sigmoid || k == aten::dropout) {
if (set_liveness[alias_mapping[node->inputs().at(0)]]->isAfter(node)) {
continue;
}
add_to_inplace_set(node);
} else if (k == aten::mul || k == aten::add) {
// the binary operators (add/mul) are commutative and only take tensor
// inputs, so we can inplace either the first or second input
int64_t reusable_value_index = -1;
for (size_t i = 0; i < 2; i++) {
TORCH_INTERNAL_ASSERT(node->inputs().at(i)->type()->cast<TensorType>());
if (!set_liveness[alias_mapping[node->inputs().at(i)]]->isAfter(node)) {
reusable_value_index = i;
break;
}
}
if (reusable_value_index == -1) {
continue;
}
if (reusable_value_index == 1) {
node->insertInput(0, node->inputs().at(1));
node->removeInput(2);
}
add_to_inplace_set(node);
}
}
for (Node* node : nodes_to_inplace) {
node->replaceWithNewSymbol(
Symbol::fromQualString(node->schema().name() + "_"));
node->destroy();
}
}
Operation BroadOp(const Node* node) {
return [](Stack* stack) {
auto b = pop(stack).toTensor();
@ -239,7 +405,7 @@ void moveWeightsToMKLDNN(Node* n) {
}
}
void computeSubgraphInMKLDNN(Node* subgraph_node) {
void ComputeSubgraphInMKLDNN(Node* subgraph_node) {
auto graph = subgraph_node->owningGraph();
Value* none_value = nullptr;
{
@ -467,7 +633,7 @@ class MKLDNNSubgraphSlicer {
return true;
}
if (n->kind() == aten::add) {
if (n->kind() == aten::add || n->kind() == aten::mul) {
// mkldnn doesn't currently support Tensor-Scalar add
for (size_t i = 0; i < 2; i++) {
if (!n->inputs().at(i)->type()->cast<TensorType>()) {
@ -489,7 +655,8 @@ class MKLDNNSubgraphSlicer {
while (curNode != *block_->nodes().end()) {
auto nextNode = curNode->next();
if (curNode->kind() == prim::MKLDNNGroup) {
computeSubgraphInMKLDNN(curNode);
ComputeSubgraphInMKLDNN(curNode);
InplaceMKLDNNSubgraph(SubgraphUtils::getSubgraph(curNode));
SubgraphUtils::unmergeSubgraph(curNode);
}
curNode = nextNode;