From 152d43940048c3c366a34df830097302bbbea0ff Mon Sep 17 00:00:00 2001 From: Viswanath Sivakumar Date: Wed, 24 May 2017 11:39:49 -0700 Subject: [PATCH] Allow specifying net type in predictor_exporter Summary: predictor_exporter copies the original predict_net's op, external_input and external_output fields, but ignores the type field. This is reasonable as the train net would generally have 'dag' type and copying that for inference may not be applicable. It's good to have a way to specify the net type nevertheless to run DAGNet for inference. This diff adds a field in predictor_exporter to do that. Reviewed By: akyrola Differential Revision: D5122354 fbshipit-source-id: 0e3cc417128db903c71515135c9e3b87620ae21e --- caffe2/python/predictor/predictor_exporter.py | 10 +++++++--- caffe2/python/predictor/predictor_exporter_test.py | 2 ++ caffe2/python/predictor/predictor_py_utils.py | 2 ++ 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/caffe2/python/predictor/predictor_exporter.py b/caffe2/python/predictor/predictor_exporter.py index 8c1a486cfec..dbcae6cd3c3 100644 --- a/caffe2/python/predictor/predictor_exporter.py +++ b/caffe2/python/predictor/predictor_exporter.py @@ -17,7 +17,8 @@ import collections class PredictorExportMeta(collections.namedtuple( 'PredictorExportMeta', - 'predict_net, parameters, inputs, outputs, shapes, name, extra_init_net')): + 'predict_net, parameters, inputs, outputs, shapes, name, \ + extra_init_net, net_type')): """ Metadata to be used for serializaing a net. @@ -27,6 +28,8 @@ class PredictorExportMeta(collections.namedtuple( Override the named tuple to provide optional name parameter. name will be used to identify multiple prediction nets. + + net_type is the type field in caffe2 NetDef - can be 'simple', 'dag', etc. """ def __new__( cls, @@ -36,7 +39,8 @@ class PredictorExportMeta(collections.namedtuple( outputs, shapes=None, name="", - extra_init_net=None + extra_init_net=None, + net_type=None, ): inputs = map(str, inputs) outputs = map(str, outputs) @@ -49,7 +53,7 @@ class PredictorExportMeta(collections.namedtuple( assert isinstance(predict_net, (caffe2_pb2.NetDef, caffe2_pb2.PlanDef)) return super(PredictorExportMeta, cls).__new__( cls, predict_net, parameters, inputs, outputs, shapes, name, - extra_init_net) + extra_init_net, net_type) def inputs_name(self): return utils.get_comp_name(predictor_constants.INPUTS_BLOB_TYPE, diff --git a/caffe2/python/predictor/predictor_exporter_test.py b/caffe2/python/predictor/predictor_exporter_test.py index 61e3fb0a5ad..ffa94e7498b 100644 --- a/caffe2/python/predictor/predictor_exporter_test.py +++ b/caffe2/python/predictor/predictor_exporter_test.py @@ -67,6 +67,7 @@ class PredictorExporterTest(unittest.TestCase): outputs=self.predictor_export_meta.outputs, shapes=self.predictor_export_meta.shapes, extra_init_net=extra_init_net, + net_type='dag', ) db_type = 'minidb' @@ -110,6 +111,7 @@ class PredictorExporterTest(unittest.TestCase): # producing good numbers (with our custom implementation) workspace.FeedBlob("data", np.random.randn(2, 5).astype(np.float32)) predict_net = pred_utils.GetNet(meta_net_def, pc.PREDICT_NET_TYPE) + self.assertEqual(predict_net.type, 'dag') workspace.RunNetOnce(predict_net) np.testing.assert_array_almost_equal( workspace.FetchBlob("y"), diff --git a/caffe2/python/predictor/predictor_py_utils.py b/caffe2/python/predictor/predictor_py_utils.py index 202f7bd9db7..bf0e3f0bd4e 100644 --- a/caffe2/python/predictor/predictor_py_utils.py +++ b/caffe2/python/predictor/predictor_py_utils.py @@ -18,6 +18,8 @@ def create_predict_net(predictor_export_meta): net.Proto().external_input.extend( predictor_export_meta.inputs + predictor_export_meta.parameters) net.Proto().external_output.extend(predictor_export_meta.outputs) + if predictor_export_meta.net_type is not None: + net.Proto().type = predictor_export_meta.net_type return net.Proto()