From 9be4c75fa0399a292e95ef8f3c79457e4b5b2338 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Mon, 22 Mar 2021 18:30:47 -0700 Subject: [PATCH] [JIT] Add Reinplacing to MKLDNN Subgraphs (#53908) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53908 This adds reinplacing to MKLDNN Subgraphs so that we replace `aten::add` with `aten::add_`. Normally you would have to prove device and dtype, but we know that already, and because we have explicit broadcast nodes for other reasons we dont have to prove that the output shape of add is the same as inputs. Ive tested correctness on resnet, I'm going to do more extensive testing as well. When I benchmarked the "unsafe" version (always inplace) I saw average speedups of ~16% for both Single threaded and Multithreaded. I dont think the "safe" version will be far beyond; when I looked at resnet for example every `add` and `relu` were reinplaced. Theres some question of reusing other alias / liveness / inplacing passes in SR. I thought about it, however I didnt want to add a cross-dependency between very different parts of the code base with a bunch of different assumptions. The logic here is also covering a simpler case and does not add much complexity IMO. Test Plan: Imported from OSS Reviewed By: Krovatkin Differential Revision: D27132969 Pulled By: eellison fbshipit-source-id: 121a38daaedf01363f6b66a814beaaa72a0ab0dc --- test/jit/test_freezing.py | 102 ++++++++++- test/test_jit.py | 2 +- .../csrc/jit/passes/frozen_ops_to_mkldnn.cpp | 173 +++++++++++++++++- 3 files changed, 272 insertions(+), 5 deletions(-) diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index 9efdaaa5d53..f74902b3a2b 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -1741,7 +1741,8 @@ class TestFrozenOptimizations(JitTestCase): scripted_mod = torch.jit.script(mod) scripted_mod = torch.jit.freeze(scripted_mod) self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph) - FileCheck().check("aten::to_mkldnn").check_not("aten::add_").check("aten::div_").run(scripted_mod.graph) + # add gets uninplaced and reinplaced + FileCheck().check("aten::to_mkldnn").check("aten::add_").check("aten::div_").run(scripted_mod.graph) inp = torch.rand([20, 20]) self.assertEqual(scripted_mod(inp), mod(inp)) self.assertEqual(scripted_mod(inp), mod(inp)) @@ -1816,3 +1817,102 @@ class TestFrozenOptimizations(JitTestCase): FileCheck().check("aten::cudnn_convolution_relu").run(frozen_mod.graph) self.assertEqual(mod_eager(inp), frozen_mod(inp)) +@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled") +class TestMKLDNNReinplacing(JitTestCase): + def setUp(self): + self.default_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.float) + + def tearDown(self): + torch.set_default_dtype(self.default_dtype) + + def getConv(self): + return nn.Conv2d(3, 32, kernel_size=3, stride=2).eval() + + def getInput(self): + return torch.rand([4, 3, 4, 4]) + + def freezeAndConvert(self, mod): + mod = torch.jit.freeze(torch.jit.script(mod.eval())) + self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) + return mod + + def checkResults(self, mod1, mod2): + inp = self.getInput() + self.assertEqual(mod1(inp), mod2(inp)) + + def test_successful(self): + # simple conv-relu + mod_eager = nn.Sequential(self.getConv(), nn.ReLU(), nn.ReLU()) + mod = self.freezeAndConvert(mod_eager) + FileCheck().check("mkldnn_convolution").check_next("aten::relu_").check_next("aten::relu_").run(mod.graph) + self.checkResults(mod_eager, mod) + + def test_merge_liveness(self): + class Mod(nn.Module): + def __init__(self, tensor): + super().__init__() + self.tensor = tensor + + def forward(self, x): + # this mul can be inplaced since x is dead after this use + temporary = x * self.tensor + # temporary livespan is the return node, + # add can not be inplaced + return temporary + temporary, temporary + + mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1]))) + mod = self.freezeAndConvert(mod_eager) + FileCheck().check("aten::mul_").check_not("aten::add_").run(mod.graph) + self.checkResults(mod_eager, mod) + + def test_always_alive_values(self): + class Mod(nn.Module): + def __init__(self, tensor): + super().__init__() + self.tensor = tensor + + def forward(self, x): + # x can't be inplaced because its a return value, + # check that the inplacing pass doesnt try to inplace + # self.tensor because its always alive + return x * self.tensor, x + + mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1]))) + mod = self.freezeAndConvert(mod_eager) + FileCheck().check_not("aten::mul_").run(mod.graph) + self.checkResults(mod_eager, mod) + + conv = self.getConv() + + class Mod(nn.Module): + def __init__(self): + super().__init__() + self.tensor = torch.rand([4, 32, 1, 1]) + self.conv = conv + + def forward(self, x): + # the shapes dont add up on this just testing a particular pattern + conv_output = self.conv(x) + return conv_output, self.conv(torch.add(x, x)) + + mod = self.freezeAndConvert(Mod()) + # x is an input to the graph, and so it should not be inplaced + # in the torch.add(x, x) call + FileCheck().check_not("aten::add_").run(mod.graph) + + def test_switch_inputs_to_inplace(self): + class Mod(nn.Module): + def __init__(self, tensor): + super().__init__() + self.tensor = tensor + + def forward(self, x): + # self.tensor cannot be inplaced, however x can, + # and bc add is commutative we can reverse inputs to add_ + return self.tensor + x + + mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1]))) + mod = self.freezeAndConvert(mod_eager) + FileCheck().check("aten::add_").run(mod.graph) + self.checkResults(mod_eager, mod) diff --git a/test/test_jit.py b/test/test_jit.py index baa6133bd12..5d77f8ae514 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -19,7 +19,7 @@ from jit.test_export_modes import TestExportModes # noqa: F401 from jit.test_class_type import TestClassType # noqa: F401 from jit.test_builtins import TestBuiltins, TestTensorBuiltins # noqa: F401 from jit.test_unsupported_ops import TestUnsupportedOps # noqa: F401 -from jit.test_freezing import TestFreezing, TestFrozenOptimizations # noqa: F401 +from jit.test_freezing import TestFreezing, TestFrozenOptimizations, TestMKLDNNReinplacing # noqa: F401 from jit.test_peephole import TestPeephole # noqa: F401 from jit.test_save_load import TestSaveLoad # noqa: F401 from jit.test_module_containers import TestModuleContainers # noqa: F401 diff --git a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp index cf47044c54e..c64845fe5fa 100644 --- a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp +++ b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp @@ -24,6 +24,8 @@ // clang-format off // moving ConvUtils include induces import cycle #include +#include +#include // clang-format on namespace torch { @@ -39,6 +41,170 @@ c10::AliasAnalysisKind aliasAnalysisFromSchema() { return AliasAnalysisKind::FROM_SCHEMA; } +using ValueSet = std::unordered_set; +using ValueSetPtr = std::shared_ptr>; + +Node* getLastUse(Value* v) { + auto last_use_node = v->node(); + for (const auto& use : v->uses()) { + if (use.user->isAfter(last_use_node)) { + last_use_node = use.user; + } + } + return last_use_node; +} + +void merge_sets( + std::unordered_map& alias_mapping, + Value* existing, + Value* new_v) { + if (alias_mapping[existing] == alias_mapping[new_v]) { + return; + } + auto existing_set = alias_mapping[existing]; + auto set_to_remove = alias_mapping[new_v]; + for (auto it = set_to_remove->begin(); it != set_to_remove->end(); it++) { + existing_set->insert(*it); + alias_mapping[*it] = existing_set; + } +} + +// no uses of tensors in container types +void assertNonTensorTypeDoesNotContainTensors(TypePtr type) { + if (type->cast()) { + return; + } + for (auto t : type->containedTypes()) { + TORCH_INTERNAL_ASSERT(!t->cast()); + } +} + +void InplaceMKLDNNSubgraph(std::shared_ptr graph) { + // This function first calculates aliasing sets, + // then calculates the last node each aliasing set is alive for. + // Then we go through each node, if it's a node which has an equivalent + // inplace node and the aliasing set for its input is dead afer this node, we + // inplace it. Then we merge the aliasing sets for the input and output of the + // node and extend the liveness of the set. To inplace a node you need to + // prove device and dtype of the input and output are the same, which we've + // already done, and prove that the output size is the same as the input size, + // which is achieved by explicit Broadcast nodes (which we inserted for other + // reasons). + // The graphs here are simple subgraphs without uses of Tensors in + // containers (Lists, GetAttrs, etc) + + // CALCULATE ALIASING SETS + + auto aliasDb = torch::make_unique(graph); + + // map from Value to its Aliasing Set + std::unordered_map alias_mapping; + ValueSet set; + ValueSetPtr input_set = std::make_shared(set); + for (Value* v : graph->inputs()) { + if (v->type()->cast()) { + input_set->insert(v); + alias_mapping[v] = input_set; + } else { + assertNonTensorTypeDoesNotContainTensors(v->type()); + } + } + + for (Node* n : graph->nodes()) { + for (Value* output : n->outputs()) { + if (!output->type()->cast()) { + assertNonTensorTypeDoesNotContainTensors(output->type()); + continue; + } + + std::unordered_set new_set = {output}; + alias_mapping[output] = std::make_shared(new_set); + for (Value* input : n->inputs()) { + if (aliasDb->mayAlias(input, output)) { + merge_sets(alias_mapping, input, output); + } + } + } + } + + // CALCULATE ALIASING SET LIVENESS + + // map from aliased set -> last use of set + std::unordered_map set_liveness; + for (auto& set : alias_mapping) { + if (set_liveness.count(set.second)) { + continue; + } + Node* last = nullptr; + for (auto it = set.second->begin(); it != set.second->end(); it++) { + Value* v = *it; + auto k = v->node()->kind(); + if (k == prim::Constant || k == prim::ConstantMKLDNNTensor || + k == prim::Param) { + last = graph->return_node(); + continue; + } + + auto last_use = getLastUse(v); + if (!last || last_use->isAfter(last)) { + last = last_use; + } + } + set_liveness[set.second] = last; + } + + // REUSING MEMORY BY REINPLACING NODES + std::vector nodes_to_inplace; + + auto add_to_inplace_set = [&](Node* node) { + // defer making the inplacing change because that would invalidate the old + // Node output Value* + nodes_to_inplace.push_back(node); + TORCH_INTERNAL_ASSERT(node->outputs().size() == 1); + auto output_liveness_end = + set_liveness[alias_mapping[node->outputs().at(0)]]; + merge_sets(alias_mapping, node->inputs().at(0), node->output()); + set_liveness[alias_mapping[node->output()]] = output_liveness_end; + }; + + for (Node* node : graph->nodes()) { + auto k = node->kind(); + if (k == aten::relu || k == aten::sigmoid || k == aten::dropout) { + if (set_liveness[alias_mapping[node->inputs().at(0)]]->isAfter(node)) { + continue; + } + add_to_inplace_set(node); + } else if (k == aten::mul || k == aten::add) { + // the binary operators (add/mul) are commutative and only take tensor + // inputs, so we can inplace either the first or second input + int64_t reusable_value_index = -1; + for (size_t i = 0; i < 2; i++) { + TORCH_INTERNAL_ASSERT(node->inputs().at(i)->type()->cast()); + if (!set_liveness[alias_mapping[node->inputs().at(i)]]->isAfter(node)) { + reusable_value_index = i; + break; + } + } + + if (reusable_value_index == -1) { + continue; + } + + if (reusable_value_index == 1) { + node->insertInput(0, node->inputs().at(1)); + node->removeInput(2); + } + add_to_inplace_set(node); + } + } + + for (Node* node : nodes_to_inplace) { + node->replaceWithNewSymbol( + Symbol::fromQualString(node->schema().name() + "_")); + node->destroy(); + } +} + Operation BroadOp(const Node* node) { return [](Stack* stack) { auto b = pop(stack).toTensor(); @@ -239,7 +405,7 @@ void moveWeightsToMKLDNN(Node* n) { } } -void computeSubgraphInMKLDNN(Node* subgraph_node) { +void ComputeSubgraphInMKLDNN(Node* subgraph_node) { auto graph = subgraph_node->owningGraph(); Value* none_value = nullptr; { @@ -467,7 +633,7 @@ class MKLDNNSubgraphSlicer { return true; } - if (n->kind() == aten::add) { + if (n->kind() == aten::add || n->kind() == aten::mul) { // mkldnn doesn't currently support Tensor-Scalar add for (size_t i = 0; i < 2; i++) { if (!n->inputs().at(i)->type()->cast()) { @@ -489,7 +655,8 @@ class MKLDNNSubgraphSlicer { while (curNode != *block_->nodes().end()) { auto nextNode = curNode->next(); if (curNode->kind() == prim::MKLDNNGroup) { - computeSubgraphInMKLDNN(curNode); + ComputeSubgraphInMKLDNN(curNode); + InplaceMKLDNNSubgraph(SubgraphUtils::getSubgraph(curNode)); SubgraphUtils::unmergeSubgraph(curNode); } curNode = nextNode;