mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49980 From ``` ./python/libcst/libcst codemod remove_unused_imports.RemoveUnusedImportsWithGlean --no-format caffe2/ ``` Test Plan: Standard sandcastle tests Reviewed By: xush6528 Differential Revision: D25727359 fbshipit-source-id: c4f60005b10546423dc093d31d46deb418352286
88 lines
2.8 KiB
Python
88 lines
2.8 KiB
Python
## @package onnx
|
|
# Module caffe2.python.onnx.bin.conversion
|
|
|
|
|
|
|
|
|
|
|
|
import json
|
|
|
|
from caffe2.proto import caffe2_pb2
|
|
import click
|
|
from onnx import ModelProto
|
|
|
|
from caffe2.python.onnx.backend import Caffe2Backend as c2
|
|
import caffe2.python.onnx.frontend as c2_onnx
|
|
|
|
|
|
@click.command(
|
|
help='convert caffe2 net to onnx model',
|
|
context_settings={
|
|
'help_option_names': ['-h', '--help']
|
|
}
|
|
)
|
|
@click.argument('caffe2_net', type=click.File('rb'))
|
|
@click.option('--caffe2-net-name',
|
|
type=str,
|
|
help="Name of the caffe2 net")
|
|
@click.option('--caffe2-init-net',
|
|
type=click.File('rb'),
|
|
help="Path of the caffe2 init net pb file")
|
|
@click.option('--value-info',
|
|
type=str,
|
|
help='A json string providing the '
|
|
'type and shape information of the inputs')
|
|
@click.option('-o', '--output', required=True,
|
|
type=click.File('wb'),
|
|
help='Output path for the onnx model pb file')
|
|
def caffe2_to_onnx(caffe2_net,
|
|
caffe2_net_name,
|
|
caffe2_init_net,
|
|
value_info,
|
|
output):
|
|
c2_net_proto = caffe2_pb2.NetDef()
|
|
c2_net_proto.ParseFromString(caffe2_net.read())
|
|
if not c2_net_proto.name and not caffe2_net_name:
|
|
raise click.BadParameter(
|
|
'The input caffe2 net does not have name, '
|
|
'--caffe2-net-name must be provided')
|
|
c2_net_proto.name = caffe2_net_name or c2_net_proto.name
|
|
if caffe2_init_net:
|
|
c2_init_net_proto = caffe2_pb2.NetDef()
|
|
c2_init_net_proto.ParseFromString(caffe2_init_net.read())
|
|
c2_init_net_proto.name = '{}_init'.format(caffe2_net_name)
|
|
else:
|
|
c2_init_net_proto = None
|
|
|
|
if value_info:
|
|
value_info = json.loads(value_info)
|
|
|
|
onnx_model = c2_onnx.caffe2_net_to_onnx_model(
|
|
predict_net=c2_net_proto,
|
|
init_net=c2_init_net_proto,
|
|
value_info=value_info)
|
|
|
|
output.write(onnx_model.SerializeToString())
|
|
|
|
|
|
@click.command(
|
|
help='convert onnx model to caffe2 net',
|
|
context_settings={
|
|
'help_option_names': ['-h', '--help']
|
|
}
|
|
)
|
|
@click.argument('onnx_model', type=click.File('rb'))
|
|
@click.option('-o', '--output', required=True,
|
|
type=click.File('wb'),
|
|
help='Output path for the caffe2 net file')
|
|
@click.option('--init-net-output',
|
|
required=True,
|
|
type=click.File('wb'),
|
|
help='Output path for the caffe2 init net file')
|
|
def onnx_to_caffe2(onnx_model, output, init_net_output):
|
|
onnx_model_proto = ModelProto()
|
|
onnx_model_proto.ParseFromString(onnx_model.read())
|
|
|
|
init_net, predict_net = c2.onnx_graph_to_caffe2_net(onnx_model_proto)
|
|
init_net_output.write(init_net.SerializeToString())
|
|
output.write(predict_net.SerializeToString())
|