diff --git a/test/test_mobile_optimizer.py b/test/test_mobile_optimizer.py index 528a62a53ef..b6e9764c379 100644 --- a/test/test_mobile_optimizer.py +++ b/test/test_mobile_optimizer.py @@ -58,8 +58,9 @@ class TestOptimizer(unittest.TestCase): o = F.conv2d(x, self.conv_weight, self.conv_bias, self.strides, self.paddings, self.dilations, self.groups) o = F.relu(o) - o = o.permute([0, 2, 3, 1]) - o = F.linear(o, self.linear_weight, self.linear_bias) + x = o.permute([0, 2, 3, 1]) + o = F.linear(x, self.linear_weight, self.linear_bias) + o = o + x return F.relu(o) class BNTestModule(torch.nn.Module): @@ -90,6 +91,9 @@ class TestOptimizer(unittest.TestCase): .check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \ .check_not("prepacked::linear_clamp_prepack") \ .check_count("prepacked::linear_clamp_run", 1, exactly=True) \ + .check_not("aten::add(") \ + .check_not("aten::relu(") \ + .check_count("aten::add_relu(", 1, exactly=True) \ .run(optimized_scripted_model.graph) torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3) diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp index d450d543c2b..8308167e1da 100644 --- a/torch/csrc/jit/passes/xnnpack_rewrite.cpp +++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -302,6 +303,10 @@ script::Module optimizeForMobile( removeDropout(cloned_module); } + if (!optimization_blacklist.count(MobileOptimizerType::FUSE_ADD_RELU)) { + FuseAddRelu(cloned_module); + } + return cloned_module; } diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.h b/torch/csrc/jit/passes/xnnpack_rewrite.h index 2d61ab6fa3a..19bfcdd35ed 100644 --- a/torch/csrc/jit/passes/xnnpack_rewrite.h +++ b/torch/csrc/jit/passes/xnnpack_rewrite.h @@ -9,7 +9,8 @@ namespace jit { enum class MobileOptimizerType : int8_t { CONV_BN_FUSION, INSERT_FOLD_PREPACK_OPS, - REMOVE_DROPOUT + REMOVE_DROPOUT, + FUSE_ADD_RELU, }; TORCH_API void insertPrePackedOps(std::shared_ptr& graph); diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index d170edcce41..9c0f39c6267 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -710,6 +710,7 @@ void initJITBindings(PyObject* module) { "INSERT_FOLD_PREPACK_OPS", MobileOptimizerType::INSERT_FOLD_PREPACK_OPS) .value("REMOVE_DROPOUT", MobileOptimizerType::REMOVE_DROPOUT) + .value("FUSE_ADD_RELU", MobileOptimizerType::FUSE_ADD_RELU) .export_values(); // This allows PyTorchStreamReader to read from a Python buffer. It requires