From 5c9d1e48242587a9b1958df2d2efea3472072f4f Mon Sep 17 00:00:00 2001 From: Xingying Cheng Date: Mon, 27 Apr 2020 10:16:59 -0700 Subject: [PATCH] Propagate module lints for mobile scripted module. (#37046) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37046 ghstack-source-id: 102669259 Creating a python api entry to generate mobile model lints which takes a scripted module as argument and returns a map of module lints. The initial version is to create placeholder which included module bundled input as the first lint instance. More lints will be added in the future. Test Plan: python test/test_optimizer.py Reviewed By: dreiss Differential Revision: D21164648 fbshipit-source-id: 9e8f4e19d74b5464a55cc73b9dc18f358c5947d6 --- test/test_mobile_optimizer.py | 56 +++++++++++++++++++++++++++++++-- torch/utils/mobile_optimizer.py | 54 +++++++++++++++++++++++++++---- 2 files changed, 102 insertions(+), 8 deletions(-) diff --git a/test/test_mobile_optimizer.py b/test/test_mobile_optimizer.py index 3e70a48d11f..e68b976a1db 100644 --- a/test/test_mobile_optimizer.py +++ b/test/test_mobile_optimizer.py @@ -1,7 +1,8 @@ import unittest import torch import torch.backends.xnnpack -from torch.utils import mobile_optimizer +import torch.utils.bundled_inputs +from torch.utils.mobile_optimizer import * from torch.nn import functional as F FileCheck = torch._C.FileCheck @@ -66,7 +67,7 @@ class TestOptimizer(unittest.TestCase): scripted_model.eval() initial_result = scripted_model(input_data) - optimized_scripted_model = mobile_optimizer.optimize_for_mobile(scripted_model) + optimized_scripted_model = optimize_for_mobile(scripted_model) optimized_result = optimized_scripted_model(input_data) FileCheck().check_not("Tensor = aten::conv2d") \ @@ -79,5 +80,56 @@ class TestOptimizer(unittest.TestCase): torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3) + def test_generate_mobile_module_lints(self): + class MyTestModule(torch.nn.Module): + def __init__(self): + super(MyTestModule, self).__init__() + self.fc = torch.nn.Linear(4, 4) + self.dropout = torch.nn.Dropout(p=0.5) + + def forward(self, inputs): + out = self.fc(inputs) + out = self.dropout(out) + return out + + class MyBNModule(torch.nn.Module): + def __init__(self): + super(MyBNModule, self).__init__() + self.bn = torch.nn.BatchNorm2d(4, affine=True) + + def forward(self, inputs): + bn = self.bn(inputs) + return bn + + class MyBundledInputModule(torch.nn.Module): + def __init__(self): + super(MyBundledInputModule, self).__init__() + + def forward(self, inputs): + return inputs + + def get_lint_count_by_type(lint_type, module_lint_List): + return len([lint_dict for lint_dict in module_lint_List if lint_dict['name'] == lint_type.name]) + + test_module = torch.jit.script(MyTestModule()) + test_module_lint_list = generate_mobile_module_lints(test_module) + self.assertEqual(len(test_module_lint_list), 4) + self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, test_module_lint_list), 1) + self.assertEqual(get_lint_count_by_type(LintCode.DROPOUT, test_module_lint_list), 1) + self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, test_module_lint_list), 2) + + bn_module = torch.jit.script(MyBNModule()) + bn_module_lint_list = generate_mobile_module_lints(bn_module) + self.assertEqual(len(bn_module_lint_list), 4) + self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, bn_module_lint_list), 1) + self.assertEqual(get_lint_count_by_type(LintCode.BATCHNORM, bn_module_lint_list), 1) + self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, bn_module_lint_list), 2) + + bi_module = torch.jit.script(MyBundledInputModule()) + torch.utils.bundled_inputs.augment_model_with_bundled_inputs( + bi_module, [(torch.tensor([1]),)], []) + bi_module_lint_list = generate_mobile_module_lints(bi_module) + self.assertEqual(len(bi_module_lint_list), 0) + if __name__ == '__main__': unittest.main() diff --git a/torch/utils/mobile_optimizer.py b/torch/utils/mobile_optimizer.py index b0132df9f8e..b6415b2d285 100644 --- a/torch/utils/mobile_optimizer.py +++ b/torch/utils/mobile_optimizer.py @@ -3,19 +3,61 @@ This module contains utility method for mobile model optimization and lint. """ import torch +from enum import Enum +class LintCode(Enum): + BUNDLED_INPUT = 1 + REQUIRES_GRAD = 2 + DROPOUT = 3 + BATCHNORM = 4 -def optimize_for_mobile(scripted_model): +def optimize_for_mobile(script_module): """ Args: - scripted_model: An instance of torch script module with type of ScriptModule + script_module: An instance of torch script module with type of ScriptModule Returns: - scripted_model: A new optimized torch script module + script_module: A new optimized torch script module """ - if not isinstance(scripted_model, torch.jit.ScriptModule): + if not isinstance(script_module, torch.jit.ScriptModule): raise TypeError( - 'Got {}, but ScriptModule is expected.'.format(type(scripted_model))) + 'Got {}, but ScriptModule is expected.'.format(type(script_module))) - optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(scripted_model._c) + optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c) return torch.jit._recursive.wrap_cpp_module(optimized_cpp_module) + + +def generate_mobile_module_lints(script_module: torch.jit.ScriptModule): + """ + Args: + script_module: An instance of torch script module with type of ScriptModule + + Returns: + lint_map: A list of dictionary that contains modules lints + """ + if not isinstance(script_module, torch.jit.ScriptModule): + raise TypeError( + 'Got {}, but ScriptModule is expected.'.format(type(script_module))) + + lint_list = [] + + if not hasattr(script_module, "_generate_bundled_inputs"): + lint_list.append({"name": LintCode.BUNDLED_INPUT.name, "message": "No bundled input, please add bundled inputs before " + "saving the module using torch.utils.bundled_inputs.augment_model_with_bundled_inputs."}) + + for name, param in script_module.named_parameters(): + if param.requires_grad: + lint_list.append({"name": LintCode.REQUIRES_GRAD.name, "message": "Param {} requires grad, " + "please set torch.no_grad() to reduce memory usage and improve computation speed during " + "inference phase.".format(name)}) + + op_names = torch.jit.export_opnames(script_module) + for op_name in op_names: + if "dropout" in op_name: + lint_list.append({"name": LintCode.DROPOUT.name, "message": "Operator {} exists, remember to call eval() before " + "saving the module.".format(op_name)}) + if "batch_norm" in op_name: + lint_list.append({"name": LintCode.BATCHNORM.name, "message": "Operator {} exists, remember to call eval() before " + "saving the module.".format(op_name)}) + + return lint_list