mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: predictor_exporter copies the original predict_net's op, external_input and external_output fields, but ignores the type field. This is reasonable as the train net would generally have 'dag' type and copying that for inference may not be applicable. It's good to have a way to specify the net type nevertheless to run DAGNet for inference. This diff adds a field in predictor_exporter to do that. Reviewed By: akyrola Differential Revision: D5122354 fbshipit-source-id: 0e3cc417128db903c71515135c9e3b87620ae21e
115 lines
3.3 KiB
Python
115 lines
3.3 KiB
Python
## @package predictor_py_utils
|
|
# Module caffe2.python.predictor.predictor_py_utils
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import core
|
|
|
|
|
|
def create_predict_net(predictor_export_meta):
|
|
"""
|
|
Return the input prediction net.
|
|
"""
|
|
# Construct a new net to clear the existing settings.
|
|
net = core.Net(predictor_export_meta.predict_net.name or "predict")
|
|
net.Proto().op.extend(predictor_export_meta.predict_net.op)
|
|
net.Proto().external_input.extend(
|
|
predictor_export_meta.inputs + predictor_export_meta.parameters)
|
|
net.Proto().external_output.extend(predictor_export_meta.outputs)
|
|
if predictor_export_meta.net_type is not None:
|
|
net.Proto().type = predictor_export_meta.net_type
|
|
return net.Proto()
|
|
|
|
|
|
def create_predict_init_net(ws, predictor_export_meta):
|
|
"""
|
|
Return an initialization net that zero-fill all the input and
|
|
output blobs, using the shapes from the provided workspace. This is
|
|
necessary as there is no shape inference functionality in Caffe2.
|
|
"""
|
|
net = core.Net("predict-init")
|
|
|
|
def zero_fill(blob):
|
|
shape = predictor_export_meta.shapes.get(blob)
|
|
if shape is None:
|
|
if blob not in ws.blobs:
|
|
raise Exception(
|
|
"{} not in workspace but needed for shape: {}".format(
|
|
blob, ws.blobs))
|
|
|
|
shape = ws.blobs[blob].fetch().shape
|
|
net.ConstantFill([], blob, shape=shape, value=0.0)
|
|
|
|
external_blobs = predictor_export_meta.inputs + \
|
|
predictor_export_meta.outputs
|
|
for blob in external_blobs:
|
|
zero_fill(blob)
|
|
|
|
net.Proto().external_input.extend(external_blobs)
|
|
if predictor_export_meta.extra_init_net:
|
|
net.AppendNet(predictor_export_meta.extra_init_net)
|
|
return net.Proto()
|
|
|
|
|
|
def get_comp_name(string, name):
|
|
if name:
|
|
return string + '_' + name
|
|
return string
|
|
|
|
|
|
def _ProtoMapGet(field, key):
|
|
'''
|
|
Given the key, get the value of the repeated field.
|
|
Helper function used by protobuf since it doesn't have map construct
|
|
'''
|
|
for v in field:
|
|
if (v.key == key):
|
|
return v.value
|
|
return None
|
|
|
|
|
|
def GetPlan(meta_net_def, key):
|
|
return _ProtoMapGet(meta_net_def.plans, key)
|
|
|
|
|
|
def GetPlanOriginal(meta_net_def, key):
|
|
return _ProtoMapGet(meta_net_def.plans, key)
|
|
|
|
|
|
def GetBlobs(meta_net_def, key):
|
|
blobs = _ProtoMapGet(meta_net_def.blobs, key)
|
|
if blobs is None:
|
|
return []
|
|
return blobs
|
|
|
|
|
|
def GetNet(meta_net_def, key):
|
|
return _ProtoMapGet(meta_net_def.nets, key)
|
|
|
|
|
|
def GetNetOriginal(meta_net_def, key):
|
|
return _ProtoMapGet(meta_net_def.nets, key)
|
|
|
|
|
|
def GetApplicationSpecificInfo(meta_net_def, key):
|
|
return _ProtoMapGet(meta_net_def.applicationSpecificInfo, key)
|
|
|
|
|
|
def AddBlobs(meta_net_def, blob_name, blob_def):
|
|
blobs = _ProtoMapGet(meta_net_def.blobs, blob_name)
|
|
if blobs is None:
|
|
blobs = meta_net_def.blobs.add()
|
|
blobs.key = blob_name
|
|
blobs = blobs.value
|
|
for blob in blob_def:
|
|
blobs.append(blob)
|
|
|
|
|
|
def AddPlan(meta_net_def, plan_name, plan_def):
|
|
meta_net_def.plans.add(key=plan_name, value=plan_def)
|
|
|
|
|
|
def AddNet(meta_net_def, net_name, net_def):
|
|
meta_net_def.nets.add(key=net_name, value=net_def)
|