onnxruntime/samples/python/mnist/graph_spliter.py

121 lines
5 KiB
Python
Raw Normal View History

2020-10-16 06:26:22 +00:00
import onnx
import copy
from onnx import shape_inference
from onnxruntime.capi import _pybind_state as C
2020-11-11 10:20:38 +00:00
def print_list(name, value):
print(name + ':', ', '.join(value))
2020-10-16 06:26:22 +00:00
2020-11-11 10:20:38 +00:00
def dim_str(dim):
if dim.HasField('dim_value'):
return str(dim.dim_value)
elif dim.HasField('dim_param'):
return dim.dim_param
return 'n/a'
def print_type(name, type):
print('[' + name + ']', 'type:', type.tensor_type.elem_type, '| size:', '[' + ','.join([dim_str(d) for d in type.tensor_type.shape.dim]) + ']')
"""
2020-10-21 05:53:30 +00:00
# MNIST
2020-10-16 06:26:22 +00:00
original_model = onnx.load('mnist_original.onnx')
2020-10-21 05:53:30 +00:00
config = C.ModuleGradientGraphBuilderConfiguration()
weight_names_to_train = set()
2020-10-16 06:26:22 +00:00
for initializer in original_model.graph.initializer:
2020-10-21 05:53:30 +00:00
weight_names_to_train.add(initializer.name)
config.weight_names_to_train = weight_names_to_train
2020-10-16 06:26:22 +00:00
output_names = set()
for output in original_model.graph.output:
output_names.add(output.name)
2020-10-21 05:53:30 +00:00
config.output_names = output_names
2020-11-05 05:58:31 +00:00
models = [onnx.load_model_from_string(model_as_string) for model_as_string in C.ModuleGradientGraphBuilder().build_and_split(original_model.SerializeToString(), config)]
onnx.save(models[0], 'minst_gradient_graph.onnx')
onnx.save(models[1], 'mnist_forward.onnx')
onnx.save(models[2], 'mnist_backward.onnx')
2020-10-21 05:53:30 +00:00
2020-11-09 07:07:11 +00:00
2020-10-21 05:53:30 +00:00
#BERT
2020-11-09 07:07:11 +00:00
original_model = onnx.load('BertForSequenceClassification_full_training.onnx')
2020-10-21 05:53:30 +00:00
config = C.ModuleGradientGraphBuilderConfiguration()
weight_names_to_train = set()
for initializer in original_model.graph.initializer:
weight_names_to_train.add(initializer.name)
config.weight_names_to_train = weight_names_to_train
output_names = set()
for output in original_model.graph.output:
output_names.add(output.name)
config.output_names = output_names
2020-11-09 07:07:11 +00:00
models = [onnx.load_model_from_string(model_as_string) for model_as_string in C.ModuleGradientGraphBuilder().build_and_split(original_model.SerializeToString(), config)]
onnx.save(models[0], 'bert_gradient_graph.onnx')
onnx.save(models[1], 'bert_forward.onnx')
onnx.save(models[2], 'bert_backward.onnx')
2020-11-11 10:20:38 +00:00
"""
2020-11-09 07:07:11 +00:00
#BERT with loss
original_model = onnx.load('bert-tiny-loss.onnx')
config = C.ModuleGradientGraphBuilderConfiguration()
2020-11-11 10:20:38 +00:00
initializer_names_to_train = []
2020-11-09 07:07:11 +00:00
for initializer in original_model.graph.initializer:
if initializer.name.startswith('bert.') or initializer.name.startswith('cls.'):
2020-11-11 10:20:38 +00:00
initializer_names_to_train.append(initializer.name)
config.initializer_names_to_train = initializer_names_to_train
input_names_require_grad = []
input_names_require_grad.append('input3')
2020-11-11 02:49:26 +00:00
config.input_names_require_grad = input_names_require_grad
2020-11-09 07:07:11 +00:00
2020-11-11 10:20:38 +00:00
module_gradient_graph_builder = C.ModuleGradientGraphBuilder()
module_gradient_graph_builder.build_and_split(original_model.SerializeToString(), config)
forward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_forward_model())
backward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_backward_model())
onnx.save(onnx.load_model_from_string(module_gradient_graph_builder.get_gradient_model()), 'bert_gradient_graph.onnx')
onnx.save(forward_model, 'bert_forward.onnx')
onnx.save(backward_model, 'bert_backward.onnx')
split_graphs_info = module_gradient_graph_builder.get_split_graphs_info()
print_list('user_input_names', split_graphs_info.user_input_names)
print_list('initializer_names_to_train', split_graphs_info.initializer_names_to_train)
print_list('user_output_names', split_graphs_info.user_output_names)
print_list('backward_user_input_names', split_graphs_info.backward_user_input_names)
print_list('backward_intializer_names_as_input', split_graphs_info.backward_intializer_names_as_input)
print_list('intermediate_tensor_names', split_graphs_info.intermediate_tensor_names)
print_list('user_output_grad_names', split_graphs_info.user_output_grad_names)
print_list('backward_output_grad_names', split_graphs_info.backward_output_grad_names)
type_map = {}
for name in split_graphs_info.user_input_names:
type_map[name] = None
for name in split_graphs_info.initializer_names_to_train:
type_map[name] = None
for name in split_graphs_info.user_output_names:
type_map[name] = None
for name in split_graphs_info.backward_user_input_names:
type_map[name] = None
for name in split_graphs_info.backward_intializer_names_as_input:
type_map[name] = None
for name in split_graphs_info.intermediate_tensor_names:
type_map[name] = None
for name in split_graphs_info.user_output_grad_names:
type_map[name] = None
for name in split_graphs_info.backward_output_grad_names:
type_map[name] = None
for input in forward_model.graph.input:
if input.name in type_map and type_map[input.name] is None:
type_map[input.name] = input.type
for output in forward_model.graph.output:
if output.name in type_map and type_map[output.name] is None:
type_map[output.name] = output.type
output_grad_name = output.name + '_grad'
if output_grad_name in type_map and type_map[output_grad_name] is None:
type_map[output_grad_name] = output.type
for key, value in type_map.items():
print_type(key, value)