Add add_relu fusion pass to optimize_for_mobile. (#40252)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40252

As title says.

Test Plan:
python test/test_mobile_optimizer.py

Imported from OSS

Differential Revision: D22126825

fbshipit-source-id: a1880587ba8db9dee0fa450bc463734e4a8693d9
This commit is contained in:
Kimish Patel 2020-07-10 08:08:22 -07:00 committed by Facebook GitHub Bot
parent 75a4862f63
commit 8a79eec98a
4 changed files with 14 additions and 3 deletions

View file

@ -58,8 +58,9 @@ class TestOptimizer(unittest.TestCase):
o = F.conv2d(x, self.conv_weight, self.conv_bias,
self.strides, self.paddings, self.dilations, self.groups)
o = F.relu(o)
o = o.permute([0, 2, 3, 1])
o = F.linear(o, self.linear_weight, self.linear_bias)
x = o.permute([0, 2, 3, 1])
o = F.linear(x, self.linear_weight, self.linear_bias)
o = o + x
return F.relu(o)
class BNTestModule(torch.nn.Module):
@ -90,6 +91,9 @@ class TestOptimizer(unittest.TestCase):
.check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
.check_not("prepacked::linear_clamp_prepack") \
.check_count("prepacked::linear_clamp_run", 1, exactly=True) \
.check_not("aten::add(") \
.check_not("aten::relu(") \
.check_count("aten::add_relu(", 1, exactly=True) \
.run(optimized_scripted_model.graph)
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)

View file

@ -7,6 +7,7 @@
#include <torch/csrc/jit/passes/fold_conv_bn.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/fuse_linear.h>
#include <torch/csrc/jit/passes/fuse_relu.h>
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
#include <torch/csrc/jit/passes/prepack_folding.h>
#include <torch/csrc/jit/passes/remove_dropout.h>
@ -302,6 +303,10 @@ script::Module optimizeForMobile(
removeDropout(cloned_module);
}
if (!optimization_blacklist.count(MobileOptimizerType::FUSE_ADD_RELU)) {
FuseAddRelu(cloned_module);
}
return cloned_module;
}

View file

@ -9,7 +9,8 @@ namespace jit {
enum class MobileOptimizerType : int8_t {
CONV_BN_FUSION,
INSERT_FOLD_PREPACK_OPS,
REMOVE_DROPOUT
REMOVE_DROPOUT,
FUSE_ADD_RELU,
};
TORCH_API void insertPrePackedOps(std::shared_ptr<Graph>& graph);

View file

@ -710,6 +710,7 @@ void initJITBindings(PyObject* module) {
"INSERT_FOLD_PREPACK_OPS",
MobileOptimizerType::INSERT_FOLD_PREPACK_OPS)
.value("REMOVE_DROPOUT", MobileOptimizerType::REMOVE_DROPOUT)
.value("FUSE_ADD_RELU", MobileOptimizerType::FUSE_ADD_RELU)
.export_values();
// This allows PyTorchStreamReader to read from a Python buffer. It requires