pytorch/caffe2/python/predictor/predictor_py_utils.py
Viswanath Sivakumar 152d439400 Allow specifying net type in predictor_exporter
Summary:
predictor_exporter copies the original predict_net's op, external_input and
external_output fields, but ignores the type field. This is reasonable as the
train net would generally have 'dag' type and copying that for inference may
not be applicable. It's good to have a way to specify the net type nevertheless
to run DAGNet for inference. This diff adds a field in predictor_exporter to do
that.

Reviewed By: akyrola

Differential Revision: D5122354

fbshipit-source-id: 0e3cc417128db903c71515135c9e3b87620ae21e
2017-05-24 11:46:27 -07:00

115 lines
3.3 KiB
Python

## @package predictor_py_utils
# Module caffe2.python.predictor.predictor_py_utils
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core
def create_predict_net(predictor_export_meta):
"""
Return the input prediction net.
"""
# Construct a new net to clear the existing settings.
net = core.Net(predictor_export_meta.predict_net.name or "predict")
net.Proto().op.extend(predictor_export_meta.predict_net.op)
net.Proto().external_input.extend(
predictor_export_meta.inputs + predictor_export_meta.parameters)
net.Proto().external_output.extend(predictor_export_meta.outputs)
if predictor_export_meta.net_type is not None:
net.Proto().type = predictor_export_meta.net_type
return net.Proto()
def create_predict_init_net(ws, predictor_export_meta):
"""
Return an initialization net that zero-fill all the input and
output blobs, using the shapes from the provided workspace. This is
necessary as there is no shape inference functionality in Caffe2.
"""
net = core.Net("predict-init")
def zero_fill(blob):
shape = predictor_export_meta.shapes.get(blob)
if shape is None:
if blob not in ws.blobs:
raise Exception(
"{} not in workspace but needed for shape: {}".format(
blob, ws.blobs))
shape = ws.blobs[blob].fetch().shape
net.ConstantFill([], blob, shape=shape, value=0.0)
external_blobs = predictor_export_meta.inputs + \
predictor_export_meta.outputs
for blob in external_blobs:
zero_fill(blob)
net.Proto().external_input.extend(external_blobs)
if predictor_export_meta.extra_init_net:
net.AppendNet(predictor_export_meta.extra_init_net)
return net.Proto()
def get_comp_name(string, name):
if name:
return string + '_' + name
return string
def _ProtoMapGet(field, key):
'''
Given the key, get the value of the repeated field.
Helper function used by protobuf since it doesn't have map construct
'''
for v in field:
if (v.key == key):
return v.value
return None
def GetPlan(meta_net_def, key):
return _ProtoMapGet(meta_net_def.plans, key)
def GetPlanOriginal(meta_net_def, key):
return _ProtoMapGet(meta_net_def.plans, key)
def GetBlobs(meta_net_def, key):
blobs = _ProtoMapGet(meta_net_def.blobs, key)
if blobs is None:
return []
return blobs
def GetNet(meta_net_def, key):
return _ProtoMapGet(meta_net_def.nets, key)
def GetNetOriginal(meta_net_def, key):
return _ProtoMapGet(meta_net_def.nets, key)
def GetApplicationSpecificInfo(meta_net_def, key):
return _ProtoMapGet(meta_net_def.applicationSpecificInfo, key)
def AddBlobs(meta_net_def, blob_name, blob_def):
blobs = _ProtoMapGet(meta_net_def.blobs, blob_name)
if blobs is None:
blobs = meta_net_def.blobs.add()
blobs.key = blob_name
blobs = blobs.value
for blob in blob_def:
blobs.append(blob)
def AddPlan(meta_net_def, plan_name, plan_def):
meta_net_def.plans.add(key=plan_name, value=plan_def)
def AddNet(meta_net_def, net_name, net_def):
meta_net_def.nets.add(key=net_name, value=net_def)