pytorch/benchmarks/framework_overhead_benchmark/pt_wrapper_module.py
Xiang Gao 20ac736200 Remove py2 compatible future imports (#44735)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44735

Reviewed By: mruberry

Differential Revision: D23731306

Pulled By: ezyang

fbshipit-source-id: 0ba009a99e475ddbe22981be8ac636f8a1c8b02f
2020-09-16 12:55:57 -07:00

43 lines
1.9 KiB
Python

import torch
class WrapperModule(object):
""" Wraps the instance of wrapped_type.
For graph_mode traces the instance of wrapped_type.
Randomaly initializes num_params tensors with single float element.
Args:
wrapped_type:
- Object type to be wrapped.
Expects the wrapped_type to:
- be constructed with pt_fn specified in module_config.
- provide forward method that takes module_config.num_params args.
module_config:
- Specified pt_fn to construct wrapped_type with, whether graph_mode
is enabled, and number of parameters wrapped_type's forward method
takes.
debug:
- Whether debug mode is enabled.
save:
- In graph mode, whether graph is to be saved.
"""
def __init__(self, wrapped_type, module_config, debug, save=False):
pt_fn = module_config.pt_fn
self.module = wrapped_type(pt_fn)
self.tensor_inputs = []
self.module_name = wrapped_type.__name__
for _ in range(module_config.num_params):
self.tensor_inputs.append(torch.randn(1))
if module_config.graph_mode:
self.module = torch.jit.trace(self.module, self.tensor_inputs)
if save:
file_name = self.module_name + "_" + pt_fn.__name__ + ".pt"
torch.jit.save(self.module, file_name)
print("Generated graph is saved in {}".format(file_name))
print("Benchmarking module {} with fn {}: Graph mode:{}".format(self.module_name, pt_fn.__name__, module_config.graph_mode))
if (debug and isinstance(self.module, torch.jit.ScriptModule)):
print(self.module.graph)
print(self.module.code)
def forward(self, niters):
with torch.no_grad():
for _ in range(niters):
self.module.forward(*self.tensor_inputs)