Allow specifying net type in predictor_exporter

Summary:
predictor_exporter copies the original predict_net's op, external_input and
external_output fields, but ignores the type field. This is reasonable as the
train net would generally have 'dag' type and copying that for inference may
not be applicable. It's good to have a way to specify the net type nevertheless
to run DAGNet for inference. This diff adds a field in predictor_exporter to do
that.

Reviewed By: akyrola

Differential Revision: D5122354

fbshipit-source-id: 0e3cc417128db903c71515135c9e3b87620ae21e
This commit is contained in:
Viswanath Sivakumar 2017-05-24 11:39:49 -07:00 committed by Facebook Github Bot
parent 03503140fd
commit 152d439400
3 changed files with 11 additions and 3 deletions

View file

@ -17,7 +17,8 @@ import collections
class PredictorExportMeta(collections.namedtuple(
'PredictorExportMeta',
'predict_net, parameters, inputs, outputs, shapes, name, extra_init_net')):
'predict_net, parameters, inputs, outputs, shapes, name, \
extra_init_net, net_type')):
"""
Metadata to be used for serializaing a net.
@ -27,6 +28,8 @@ class PredictorExportMeta(collections.namedtuple(
Override the named tuple to provide optional name parameter.
name will be used to identify multiple prediction nets.
net_type is the type field in caffe2 NetDef - can be 'simple', 'dag', etc.
"""
def __new__(
cls,
@ -36,7 +39,8 @@ class PredictorExportMeta(collections.namedtuple(
outputs,
shapes=None,
name="",
extra_init_net=None
extra_init_net=None,
net_type=None,
):
inputs = map(str, inputs)
outputs = map(str, outputs)
@ -49,7 +53,7 @@ class PredictorExportMeta(collections.namedtuple(
assert isinstance(predict_net, (caffe2_pb2.NetDef, caffe2_pb2.PlanDef))
return super(PredictorExportMeta, cls).__new__(
cls, predict_net, parameters, inputs, outputs, shapes, name,
extra_init_net)
extra_init_net, net_type)
def inputs_name(self):
return utils.get_comp_name(predictor_constants.INPUTS_BLOB_TYPE,

View file

@ -67,6 +67,7 @@ class PredictorExporterTest(unittest.TestCase):
outputs=self.predictor_export_meta.outputs,
shapes=self.predictor_export_meta.shapes,
extra_init_net=extra_init_net,
net_type='dag',
)
db_type = 'minidb'
@ -110,6 +111,7 @@ class PredictorExporterTest(unittest.TestCase):
# producing good numbers (with our custom implementation)
workspace.FeedBlob("data", np.random.randn(2, 5).astype(np.float32))
predict_net = pred_utils.GetNet(meta_net_def, pc.PREDICT_NET_TYPE)
self.assertEqual(predict_net.type, 'dag')
workspace.RunNetOnce(predict_net)
np.testing.assert_array_almost_equal(
workspace.FetchBlob("y"),

View file

@ -18,6 +18,8 @@ def create_predict_net(predictor_export_meta):
net.Proto().external_input.extend(
predictor_export_meta.inputs + predictor_export_meta.parameters)
net.Proto().external_output.extend(predictor_export_meta.outputs)
if predictor_export_meta.net_type is not None:
net.Proto().type = predictor_export_meta.net_type
return net.Proto()