diff --git a/test/test_mobile_optimizer.py b/test/test_mobile_optimizer.py index 3e70a48d11f..e68b976a1db 100644 --- a/test/test_mobile_optimizer.py +++ b/test/test_mobile_optimizer.py @@ -1,7 +1,8 @@ import unittest import torch import torch.backends.xnnpack -from torch.utils import mobile_optimizer +import torch.utils.bundled_inputs +from torch.utils.mobile_optimizer import * from torch.nn import functional as F FileCheck = torch._C.FileCheck @@ -66,7 +67,7 @@ class TestOptimizer(unittest.TestCase): scripted_model.eval() initial_result = scripted_model(input_data) - optimized_scripted_model = mobile_optimizer.optimize_for_mobile(scripted_model) + optimized_scripted_model = optimize_for_mobile(scripted_model) optimized_result = optimized_scripted_model(input_data) FileCheck().check_not("Tensor = aten::conv2d") \ @@ -79,5 +80,56 @@ class TestOptimizer(unittest.TestCase): torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3) + def test_generate_mobile_module_lints(self): + class MyTestModule(torch.nn.Module): + def __init__(self): + super(MyTestModule, self).__init__() + self.fc = torch.nn.Linear(4, 4) + self.dropout = torch.nn.Dropout(p=0.5) + + def forward(self, inputs): + out = self.fc(inputs) + out = self.dropout(out) + return out + + class MyBNModule(torch.nn.Module): + def __init__(self): + super(MyBNModule, self).__init__() + self.bn = torch.nn.BatchNorm2d(4, affine=True) + + def forward(self, inputs): + bn = self.bn(inputs) + return bn + + class MyBundledInputModule(torch.nn.Module): + def __init__(self): + super(MyBundledInputModule, self).__init__() + + def forward(self, inputs): + return inputs + + def get_lint_count_by_type(lint_type, module_lint_List): + return len([lint_dict for lint_dict in module_lint_List if lint_dict['name'] == lint_type.name]) + + test_module = torch.jit.script(MyTestModule()) + test_module_lint_list = generate_mobile_module_lints(test_module) + self.assertEqual(len(test_module_lint_list), 4) + self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, test_module_lint_list), 1) + self.assertEqual(get_lint_count_by_type(LintCode.DROPOUT, test_module_lint_list), 1) + self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, test_module_lint_list), 2) + + bn_module = torch.jit.script(MyBNModule()) + bn_module_lint_list = generate_mobile_module_lints(bn_module) + self.assertEqual(len(bn_module_lint_list), 4) + self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, bn_module_lint_list), 1) + self.assertEqual(get_lint_count_by_type(LintCode.BATCHNORM, bn_module_lint_list), 1) + self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, bn_module_lint_list), 2) + + bi_module = torch.jit.script(MyBundledInputModule()) + torch.utils.bundled_inputs.augment_model_with_bundled_inputs( + bi_module, [(torch.tensor([1]),)], []) + bi_module_lint_list = generate_mobile_module_lints(bi_module) + self.assertEqual(len(bi_module_lint_list), 0) + if __name__ == '__main__': unittest.main() diff --git a/torch/utils/mobile_optimizer.py b/torch/utils/mobile_optimizer.py index b0132df9f8e..b6415b2d285 100644 --- a/torch/utils/mobile_optimizer.py +++ b/torch/utils/mobile_optimizer.py @@ -3,19 +3,61 @@ This module contains utility method for mobile model optimization and lint. """ import torch +from enum import Enum +class LintCode(Enum): + BUNDLED_INPUT = 1 + REQUIRES_GRAD = 2 + DROPOUT = 3 + BATCHNORM = 4 -def optimize_for_mobile(scripted_model): +def optimize_for_mobile(script_module): """ Args: - scripted_model: An instance of torch script module with type of ScriptModule + script_module: An instance of torch script module with type of ScriptModule Returns: - scripted_model: A new optimized torch script module + script_module: A new optimized torch script module """ - if not isinstance(scripted_model, torch.jit.ScriptModule): + if not isinstance(script_module, torch.jit.ScriptModule): raise TypeError( - 'Got {}, but ScriptModule is expected.'.format(type(scripted_model))) + 'Got {}, but ScriptModule is expected.'.format(type(script_module))) - optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(scripted_model._c) + optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c) return torch.jit._recursive.wrap_cpp_module(optimized_cpp_module) + + +def generate_mobile_module_lints(script_module: torch.jit.ScriptModule): + """ + Args: + script_module: An instance of torch script module with type of ScriptModule + + Returns: + lint_map: A list of dictionary that contains modules lints + """ + if not isinstance(script_module, torch.jit.ScriptModule): + raise TypeError( + 'Got {}, but ScriptModule is expected.'.format(type(script_module))) + + lint_list = [] + + if not hasattr(script_module, "_generate_bundled_inputs"): + lint_list.append({"name": LintCode.BUNDLED_INPUT.name, "message": "No bundled input, please add bundled inputs before " + "saving the module using torch.utils.bundled_inputs.augment_model_with_bundled_inputs."}) + + for name, param in script_module.named_parameters(): + if param.requires_grad: + lint_list.append({"name": LintCode.REQUIRES_GRAD.name, "message": "Param {} requires grad, " + "please set torch.no_grad() to reduce memory usage and improve computation speed during " + "inference phase.".format(name)}) + + op_names = torch.jit.export_opnames(script_module) + for op_name in op_names: + if "dropout" in op_name: + lint_list.append({"name": LintCode.DROPOUT.name, "message": "Operator {} exists, remember to call eval() before " + "saving the module.".format(op_name)}) + if "batch_norm" in op_name: + lint_list.append({"name": LintCode.BATCHNORM.name, "message": "Operator {} exists, remember to call eval() before " + "saving the module.".format(op_name)}) + + return lint_list