mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Propagate module lints for mobile scripted module. (#37046)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37046 ghstack-source-id: 102669259 Creating a python api entry to generate mobile model lints which takes a scripted module as argument and returns a map of module lints. The initial version is to create placeholder which included module bundled input as the first lint instance. More lints will be added in the future. Test Plan: python test/test_optimizer.py Reviewed By: dreiss Differential Revision: D21164648 fbshipit-source-id: 9e8f4e19d74b5464a55cc73b9dc18f358c5947d6
This commit is contained in:
parent
5b9f7f7b0e
commit
5c9d1e4824
2 changed files with 102 additions and 8 deletions
|
|
@ -1,7 +1,8 @@
|
|||
import unittest
|
||||
import torch
|
||||
import torch.backends.xnnpack
|
||||
from torch.utils import mobile_optimizer
|
||||
import torch.utils.bundled_inputs
|
||||
from torch.utils.mobile_optimizer import *
|
||||
from torch.nn import functional as F
|
||||
|
||||
FileCheck = torch._C.FileCheck
|
||||
|
|
@ -66,7 +67,7 @@ class TestOptimizer(unittest.TestCase):
|
|||
scripted_model.eval()
|
||||
initial_result = scripted_model(input_data)
|
||||
|
||||
optimized_scripted_model = mobile_optimizer.optimize_for_mobile(scripted_model)
|
||||
optimized_scripted_model = optimize_for_mobile(scripted_model)
|
||||
optimized_result = optimized_scripted_model(input_data)
|
||||
|
||||
FileCheck().check_not("Tensor = aten::conv2d") \
|
||||
|
|
@ -79,5 +80,56 @@ class TestOptimizer(unittest.TestCase):
|
|||
|
||||
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
def test_generate_mobile_module_lints(self):
|
||||
class MyTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyTestModule, self).__init__()
|
||||
self.fc = torch.nn.Linear(4, 4)
|
||||
self.dropout = torch.nn.Dropout(p=0.5)
|
||||
|
||||
def forward(self, inputs):
|
||||
out = self.fc(inputs)
|
||||
out = self.dropout(out)
|
||||
return out
|
||||
|
||||
class MyBNModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyBNModule, self).__init__()
|
||||
self.bn = torch.nn.BatchNorm2d(4, affine=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
bn = self.bn(inputs)
|
||||
return bn
|
||||
|
||||
class MyBundledInputModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyBundledInputModule, self).__init__()
|
||||
|
||||
def forward(self, inputs):
|
||||
return inputs
|
||||
|
||||
def get_lint_count_by_type(lint_type, module_lint_List):
|
||||
return len([lint_dict for lint_dict in module_lint_List if lint_dict['name'] == lint_type.name])
|
||||
|
||||
test_module = torch.jit.script(MyTestModule())
|
||||
test_module_lint_list = generate_mobile_module_lints(test_module)
|
||||
self.assertEqual(len(test_module_lint_list), 4)
|
||||
self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, test_module_lint_list), 1)
|
||||
self.assertEqual(get_lint_count_by_type(LintCode.DROPOUT, test_module_lint_list), 1)
|
||||
self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, test_module_lint_list), 2)
|
||||
|
||||
bn_module = torch.jit.script(MyBNModule())
|
||||
bn_module_lint_list = generate_mobile_module_lints(bn_module)
|
||||
self.assertEqual(len(bn_module_lint_list), 4)
|
||||
self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, bn_module_lint_list), 1)
|
||||
self.assertEqual(get_lint_count_by_type(LintCode.BATCHNORM, bn_module_lint_list), 1)
|
||||
self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, bn_module_lint_list), 2)
|
||||
|
||||
bi_module = torch.jit.script(MyBundledInputModule())
|
||||
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
||||
bi_module, [(torch.tensor([1]),)], [])
|
||||
bi_module_lint_list = generate_mobile_module_lints(bi_module)
|
||||
self.assertEqual(len(bi_module_lint_list), 0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -3,19 +3,61 @@ This module contains utility method for mobile model optimization and lint.
|
|||
"""
|
||||
|
||||
import torch
|
||||
from enum import Enum
|
||||
|
||||
class LintCode(Enum):
|
||||
BUNDLED_INPUT = 1
|
||||
REQUIRES_GRAD = 2
|
||||
DROPOUT = 3
|
||||
BATCHNORM = 4
|
||||
|
||||
def optimize_for_mobile(scripted_model):
|
||||
def optimize_for_mobile(script_module):
|
||||
"""
|
||||
Args:
|
||||
scripted_model: An instance of torch script module with type of ScriptModule
|
||||
script_module: An instance of torch script module with type of ScriptModule
|
||||
|
||||
Returns:
|
||||
scripted_model: A new optimized torch script module
|
||||
script_module: A new optimized torch script module
|
||||
"""
|
||||
if not isinstance(scripted_model, torch.jit.ScriptModule):
|
||||
if not isinstance(script_module, torch.jit.ScriptModule):
|
||||
raise TypeError(
|
||||
'Got {}, but ScriptModule is expected.'.format(type(scripted_model)))
|
||||
'Got {}, but ScriptModule is expected.'.format(type(script_module)))
|
||||
|
||||
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(scripted_model._c)
|
||||
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c)
|
||||
return torch.jit._recursive.wrap_cpp_module(optimized_cpp_module)
|
||||
|
||||
|
||||
def generate_mobile_module_lints(script_module: torch.jit.ScriptModule):
|
||||
"""
|
||||
Args:
|
||||
script_module: An instance of torch script module with type of ScriptModule
|
||||
|
||||
Returns:
|
||||
lint_map: A list of dictionary that contains modules lints
|
||||
"""
|
||||
if not isinstance(script_module, torch.jit.ScriptModule):
|
||||
raise TypeError(
|
||||
'Got {}, but ScriptModule is expected.'.format(type(script_module)))
|
||||
|
||||
lint_list = []
|
||||
|
||||
if not hasattr(script_module, "_generate_bundled_inputs"):
|
||||
lint_list.append({"name": LintCode.BUNDLED_INPUT.name, "message": "No bundled input, please add bundled inputs before "
|
||||
"saving the module using torch.utils.bundled_inputs.augment_model_with_bundled_inputs."})
|
||||
|
||||
for name, param in script_module.named_parameters():
|
||||
if param.requires_grad:
|
||||
lint_list.append({"name": LintCode.REQUIRES_GRAD.name, "message": "Param {} requires grad, "
|
||||
"please set torch.no_grad() to reduce memory usage and improve computation speed during "
|
||||
"inference phase.".format(name)})
|
||||
|
||||
op_names = torch.jit.export_opnames(script_module)
|
||||
for op_name in op_names:
|
||||
if "dropout" in op_name:
|
||||
lint_list.append({"name": LintCode.DROPOUT.name, "message": "Operator {} exists, remember to call eval() before "
|
||||
"saving the module.".format(op_name)})
|
||||
if "batch_norm" in op_name:
|
||||
lint_list.append({"name": LintCode.BATCHNORM.name, "message": "Operator {} exists, remember to call eval() before "
|
||||
"saving the module.".format(op_name)})
|
||||
|
||||
return lint_list
|
||||
|
|
|
|||
Loading…
Reference in a new issue