mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Add add_relu fusion pass to optimize_for_mobile. (#40252)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40252 As title says. Test Plan: python test/test_mobile_optimizer.py Imported from OSS Differential Revision: D22126825 fbshipit-source-id: a1880587ba8db9dee0fa450bc463734e4a8693d9
This commit is contained in:
parent
75a4862f63
commit
8a79eec98a
4 changed files with 14 additions and 3 deletions
|
|
@ -58,8 +58,9 @@ class TestOptimizer(unittest.TestCase):
|
|||
o = F.conv2d(x, self.conv_weight, self.conv_bias,
|
||||
self.strides, self.paddings, self.dilations, self.groups)
|
||||
o = F.relu(o)
|
||||
o = o.permute([0, 2, 3, 1])
|
||||
o = F.linear(o, self.linear_weight, self.linear_bias)
|
||||
x = o.permute([0, 2, 3, 1])
|
||||
o = F.linear(x, self.linear_weight, self.linear_bias)
|
||||
o = o + x
|
||||
return F.relu(o)
|
||||
|
||||
class BNTestModule(torch.nn.Module):
|
||||
|
|
@ -90,6 +91,9 @@ class TestOptimizer(unittest.TestCase):
|
|||
.check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
|
||||
.check_not("prepacked::linear_clamp_prepack") \
|
||||
.check_count("prepacked::linear_clamp_run", 1, exactly=True) \
|
||||
.check_not("aten::add(") \
|
||||
.check_not("aten::relu(") \
|
||||
.check_count("aten::add_relu(", 1, exactly=True) \
|
||||
.run(optimized_scripted_model.graph)
|
||||
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include <torch/csrc/jit/passes/fold_conv_bn.h>
|
||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||
#include <torch/csrc/jit/passes/fuse_linear.h>
|
||||
#include <torch/csrc/jit/passes/fuse_relu.h>
|
||||
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
|
||||
#include <torch/csrc/jit/passes/prepack_folding.h>
|
||||
#include <torch/csrc/jit/passes/remove_dropout.h>
|
||||
|
|
@ -302,6 +303,10 @@ script::Module optimizeForMobile(
|
|||
removeDropout(cloned_module);
|
||||
}
|
||||
|
||||
if (!optimization_blacklist.count(MobileOptimizerType::FUSE_ADD_RELU)) {
|
||||
FuseAddRelu(cloned_module);
|
||||
}
|
||||
|
||||
return cloned_module;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ namespace jit {
|
|||
enum class MobileOptimizerType : int8_t {
|
||||
CONV_BN_FUSION,
|
||||
INSERT_FOLD_PREPACK_OPS,
|
||||
REMOVE_DROPOUT
|
||||
REMOVE_DROPOUT,
|
||||
FUSE_ADD_RELU,
|
||||
};
|
||||
|
||||
TORCH_API void insertPrePackedOps(std::shared_ptr<Graph>& graph);
|
||||
|
|
|
|||
|
|
@ -710,6 +710,7 @@ void initJITBindings(PyObject* module) {
|
|||
"INSERT_FOLD_PREPACK_OPS",
|
||||
MobileOptimizerType::INSERT_FOLD_PREPACK_OPS)
|
||||
.value("REMOVE_DROPOUT", MobileOptimizerType::REMOVE_DROPOUT)
|
||||
.value("FUSE_ADD_RELU", MobileOptimizerType::FUSE_ADD_RELU)
|
||||
.export_values();
|
||||
|
||||
// This allows PyTorchStreamReader to read from a Python buffer. It requires
|
||||
|
|
|
|||
Loading…
Reference in a new issue