mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Allow specifying net type in predictor_exporter
Summary: predictor_exporter copies the original predict_net's op, external_input and external_output fields, but ignores the type field. This is reasonable as the train net would generally have 'dag' type and copying that for inference may not be applicable. It's good to have a way to specify the net type nevertheless to run DAGNet for inference. This diff adds a field in predictor_exporter to do that. Reviewed By: akyrola Differential Revision: D5122354 fbshipit-source-id: 0e3cc417128db903c71515135c9e3b87620ae21e
This commit is contained in:
parent
03503140fd
commit
152d439400
3 changed files with 11 additions and 3 deletions
|
|
@ -17,7 +17,8 @@ import collections
|
|||
|
||||
class PredictorExportMeta(collections.namedtuple(
|
||||
'PredictorExportMeta',
|
||||
'predict_net, parameters, inputs, outputs, shapes, name, extra_init_net')):
|
||||
'predict_net, parameters, inputs, outputs, shapes, name, \
|
||||
extra_init_net, net_type')):
|
||||
"""
|
||||
Metadata to be used for serializaing a net.
|
||||
|
||||
|
|
@ -27,6 +28,8 @@ class PredictorExportMeta(collections.namedtuple(
|
|||
|
||||
Override the named tuple to provide optional name parameter.
|
||||
name will be used to identify multiple prediction nets.
|
||||
|
||||
net_type is the type field in caffe2 NetDef - can be 'simple', 'dag', etc.
|
||||
"""
|
||||
def __new__(
|
||||
cls,
|
||||
|
|
@ -36,7 +39,8 @@ class PredictorExportMeta(collections.namedtuple(
|
|||
outputs,
|
||||
shapes=None,
|
||||
name="",
|
||||
extra_init_net=None
|
||||
extra_init_net=None,
|
||||
net_type=None,
|
||||
):
|
||||
inputs = map(str, inputs)
|
||||
outputs = map(str, outputs)
|
||||
|
|
@ -49,7 +53,7 @@ class PredictorExportMeta(collections.namedtuple(
|
|||
assert isinstance(predict_net, (caffe2_pb2.NetDef, caffe2_pb2.PlanDef))
|
||||
return super(PredictorExportMeta, cls).__new__(
|
||||
cls, predict_net, parameters, inputs, outputs, shapes, name,
|
||||
extra_init_net)
|
||||
extra_init_net, net_type)
|
||||
|
||||
def inputs_name(self):
|
||||
return utils.get_comp_name(predictor_constants.INPUTS_BLOB_TYPE,
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ class PredictorExporterTest(unittest.TestCase):
|
|||
outputs=self.predictor_export_meta.outputs,
|
||||
shapes=self.predictor_export_meta.shapes,
|
||||
extra_init_net=extra_init_net,
|
||||
net_type='dag',
|
||||
)
|
||||
|
||||
db_type = 'minidb'
|
||||
|
|
@ -110,6 +111,7 @@ class PredictorExporterTest(unittest.TestCase):
|
|||
# producing good numbers (with our custom implementation)
|
||||
workspace.FeedBlob("data", np.random.randn(2, 5).astype(np.float32))
|
||||
predict_net = pred_utils.GetNet(meta_net_def, pc.PREDICT_NET_TYPE)
|
||||
self.assertEqual(predict_net.type, 'dag')
|
||||
workspace.RunNetOnce(predict_net)
|
||||
np.testing.assert_array_almost_equal(
|
||||
workspace.FetchBlob("y"),
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ def create_predict_net(predictor_export_meta):
|
|||
net.Proto().external_input.extend(
|
||||
predictor_export_meta.inputs + predictor_export_meta.parameters)
|
||||
net.Proto().external_output.extend(predictor_export_meta.outputs)
|
||||
if predictor_export_meta.net_type is not None:
|
||||
net.Proto().type = predictor_export_meta.net_type
|
||||
return net.Proto()
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue