Propagate module lints for mobile scripted module. (#37046)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37046
ghstack-source-id: 102669259

Creating a python api entry to generate mobile model lints which takes a scripted module as argument and returns a map of module lints.

The initial version is to create placeholder which included module bundled input as the first lint instance. More lints will be added in the future.

Test Plan: python test/test_optimizer.py

Reviewed By: dreiss

Differential Revision: D21164648

fbshipit-source-id: 9e8f4e19d74b5464a55cc73b9dc18f358c5947d6
This commit is contained in:
Xingying Cheng 2020-04-27 10:16:59 -07:00 committed by Facebook GitHub Bot
parent 5b9f7f7b0e
commit 5c9d1e4824
2 changed files with 102 additions and 8 deletions

View file

@ -1,7 +1,8 @@
import unittest
import torch
import torch.backends.xnnpack
from torch.utils import mobile_optimizer
import torch.utils.bundled_inputs
from torch.utils.mobile_optimizer import *
from torch.nn import functional as F
FileCheck = torch._C.FileCheck
@ -66,7 +67,7 @@ class TestOptimizer(unittest.TestCase):
scripted_model.eval()
initial_result = scripted_model(input_data)
optimized_scripted_model = mobile_optimizer.optimize_for_mobile(scripted_model)
optimized_scripted_model = optimize_for_mobile(scripted_model)
optimized_result = optimized_scripted_model(input_data)
FileCheck().check_not("Tensor = aten::conv2d") \
@ -79,5 +80,56 @@ class TestOptimizer(unittest.TestCase):
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
def test_generate_mobile_module_lints(self):
class MyTestModule(torch.nn.Module):
def __init__(self):
super(MyTestModule, self).__init__()
self.fc = torch.nn.Linear(4, 4)
self.dropout = torch.nn.Dropout(p=0.5)
def forward(self, inputs):
out = self.fc(inputs)
out = self.dropout(out)
return out
class MyBNModule(torch.nn.Module):
def __init__(self):
super(MyBNModule, self).__init__()
self.bn = torch.nn.BatchNorm2d(4, affine=True)
def forward(self, inputs):
bn = self.bn(inputs)
return bn
class MyBundledInputModule(torch.nn.Module):
def __init__(self):
super(MyBundledInputModule, self).__init__()
def forward(self, inputs):
return inputs
def get_lint_count_by_type(lint_type, module_lint_List):
return len([lint_dict for lint_dict in module_lint_List if lint_dict['name'] == lint_type.name])
test_module = torch.jit.script(MyTestModule())
test_module_lint_list = generate_mobile_module_lints(test_module)
self.assertEqual(len(test_module_lint_list), 4)
self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, test_module_lint_list), 1)
self.assertEqual(get_lint_count_by_type(LintCode.DROPOUT, test_module_lint_list), 1)
self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, test_module_lint_list), 2)
bn_module = torch.jit.script(MyBNModule())
bn_module_lint_list = generate_mobile_module_lints(bn_module)
self.assertEqual(len(bn_module_lint_list), 4)
self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, bn_module_lint_list), 1)
self.assertEqual(get_lint_count_by_type(LintCode.BATCHNORM, bn_module_lint_list), 1)
self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, bn_module_lint_list), 2)
bi_module = torch.jit.script(MyBundledInputModule())
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
bi_module, [(torch.tensor([1]),)], [])
bi_module_lint_list = generate_mobile_module_lints(bi_module)
self.assertEqual(len(bi_module_lint_list), 0)
if __name__ == '__main__':
unittest.main()

View file

@ -3,19 +3,61 @@ This module contains utility method for mobile model optimization and lint.
"""
import torch
from enum import Enum
class LintCode(Enum):
BUNDLED_INPUT = 1
REQUIRES_GRAD = 2
DROPOUT = 3
BATCHNORM = 4
def optimize_for_mobile(scripted_model):
def optimize_for_mobile(script_module):
"""
Args:
scripted_model: An instance of torch script module with type of ScriptModule
script_module: An instance of torch script module with type of ScriptModule
Returns:
scripted_model: A new optimized torch script module
script_module: A new optimized torch script module
"""
if not isinstance(scripted_model, torch.jit.ScriptModule):
if not isinstance(script_module, torch.jit.ScriptModule):
raise TypeError(
'Got {}, but ScriptModule is expected.'.format(type(scripted_model)))
'Got {}, but ScriptModule is expected.'.format(type(script_module)))
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(scripted_model._c)
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c)
return torch.jit._recursive.wrap_cpp_module(optimized_cpp_module)
def generate_mobile_module_lints(script_module: torch.jit.ScriptModule):
"""
Args:
script_module: An instance of torch script module with type of ScriptModule
Returns:
lint_map: A list of dictionary that contains modules lints
"""
if not isinstance(script_module, torch.jit.ScriptModule):
raise TypeError(
'Got {}, but ScriptModule is expected.'.format(type(script_module)))
lint_list = []
if not hasattr(script_module, "_generate_bundled_inputs"):
lint_list.append({"name": LintCode.BUNDLED_INPUT.name, "message": "No bundled input, please add bundled inputs before "
"saving the module using torch.utils.bundled_inputs.augment_model_with_bundled_inputs."})
for name, param in script_module.named_parameters():
if param.requires_grad:
lint_list.append({"name": LintCode.REQUIRES_GRAD.name, "message": "Param {} requires grad, "
"please set torch.no_grad() to reduce memory usage and improve computation speed during "
"inference phase.".format(name)})
op_names = torch.jit.export_opnames(script_module)
for op_name in op_names:
if "dropout" in op_name:
lint_list.append({"name": LintCode.DROPOUT.name, "message": "Operator {} exists, remember to call eval() before "
"saving the module.".format(op_name)})
if "batch_norm" in op_name:
lint_list.append({"name": LintCode.BATCHNORM.name, "message": "Operator {} exists, remember to call eval() before "
"saving the module.".format(op_name)})
return lint_list