mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
120 lines
5 KiB
Python
120 lines
5 KiB
Python
import onnx
|
|
import copy
|
|
from onnx import shape_inference
|
|
from onnxruntime.capi import _pybind_state as C
|
|
|
|
|
|
def print_list(name, value):
|
|
print(name + ':', ', '.join(value))
|
|
|
|
|
|
def dim_str(dim):
|
|
if dim.HasField('dim_value'):
|
|
return str(dim.dim_value)
|
|
elif dim.HasField('dim_param'):
|
|
return dim.dim_param
|
|
return 'n/a'
|
|
|
|
def print_type(name, type):
|
|
print('[' + name + ']', 'type:', type.tensor_type.elem_type, '| size:', '[' + ','.join([dim_str(d) for d in type.tensor_type.shape.dim]) + ']')
|
|
|
|
|
|
"""
|
|
# MNIST
|
|
original_model = onnx.load('mnist_original.onnx')
|
|
config = C.ModuleGradientGraphBuilderConfiguration()
|
|
weight_names_to_train = set()
|
|
for initializer in original_model.graph.initializer:
|
|
weight_names_to_train.add(initializer.name)
|
|
config.weight_names_to_train = weight_names_to_train
|
|
output_names = set()
|
|
for output in original_model.graph.output:
|
|
output_names.add(output.name)
|
|
config.output_names = output_names
|
|
|
|
models = [onnx.load_model_from_string(model_as_string) for model_as_string in C.ModuleGradientGraphBuilder().build_and_split(original_model.SerializeToString(), config)]
|
|
onnx.save(models[0], 'minst_gradient_graph.onnx')
|
|
onnx.save(models[1], 'mnist_forward.onnx')
|
|
onnx.save(models[2], 'mnist_backward.onnx')
|
|
|
|
|
|
#BERT
|
|
original_model = onnx.load('BertForSequenceClassification_full_training.onnx')
|
|
config = C.ModuleGradientGraphBuilderConfiguration()
|
|
weight_names_to_train = set()
|
|
for initializer in original_model.graph.initializer:
|
|
weight_names_to_train.add(initializer.name)
|
|
config.weight_names_to_train = weight_names_to_train
|
|
output_names = set()
|
|
for output in original_model.graph.output:
|
|
output_names.add(output.name)
|
|
config.output_names = output_names
|
|
|
|
models = [onnx.load_model_from_string(model_as_string) for model_as_string in C.ModuleGradientGraphBuilder().build_and_split(original_model.SerializeToString(), config)]
|
|
onnx.save(models[0], 'bert_gradient_graph.onnx')
|
|
onnx.save(models[1], 'bert_forward.onnx')
|
|
onnx.save(models[2], 'bert_backward.onnx')
|
|
"""
|
|
|
|
#BERT with loss
|
|
original_model = onnx.load('bert-tiny-loss.onnx')
|
|
config = C.ModuleGradientGraphBuilderConfiguration()
|
|
initializer_names_to_train = []
|
|
for initializer in original_model.graph.initializer:
|
|
if initializer.name.startswith('bert.') or initializer.name.startswith('cls.'):
|
|
initializer_names_to_train.append(initializer.name)
|
|
config.initializer_names_to_train = initializer_names_to_train
|
|
input_names_require_grad = []
|
|
input_names_require_grad.append('input3')
|
|
config.input_names_require_grad = input_names_require_grad
|
|
|
|
module_gradient_graph_builder = C.ModuleGradientGraphBuilder()
|
|
module_gradient_graph_builder.build_and_split(original_model.SerializeToString(), config)
|
|
|
|
forward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_forward_model())
|
|
backward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_backward_model())
|
|
onnx.save(onnx.load_model_from_string(module_gradient_graph_builder.get_gradient_model()), 'bert_gradient_graph.onnx')
|
|
onnx.save(forward_model, 'bert_forward.onnx')
|
|
onnx.save(backward_model, 'bert_backward.onnx')
|
|
|
|
split_graphs_info = module_gradient_graph_builder.get_split_graphs_info()
|
|
print_list('user_input_names', split_graphs_info.user_input_names)
|
|
print_list('initializer_names_to_train', split_graphs_info.initializer_names_to_train)
|
|
print_list('user_output_names', split_graphs_info.user_output_names)
|
|
print_list('backward_user_input_names', split_graphs_info.backward_user_input_names)
|
|
print_list('backward_intializer_names_as_input', split_graphs_info.backward_intializer_names_as_input)
|
|
print_list('intermediate_tensor_names', split_graphs_info.intermediate_tensor_names)
|
|
print_list('user_output_grad_names', split_graphs_info.user_output_grad_names)
|
|
print_list('backward_output_grad_names', split_graphs_info.backward_output_grad_names)
|
|
|
|
type_map = {}
|
|
for name in split_graphs_info.user_input_names:
|
|
type_map[name] = None
|
|
for name in split_graphs_info.initializer_names_to_train:
|
|
type_map[name] = None
|
|
for name in split_graphs_info.user_output_names:
|
|
type_map[name] = None
|
|
for name in split_graphs_info.backward_user_input_names:
|
|
type_map[name] = None
|
|
for name in split_graphs_info.backward_intializer_names_as_input:
|
|
type_map[name] = None
|
|
for name in split_graphs_info.intermediate_tensor_names:
|
|
type_map[name] = None
|
|
for name in split_graphs_info.user_output_grad_names:
|
|
type_map[name] = None
|
|
for name in split_graphs_info.backward_output_grad_names:
|
|
type_map[name] = None
|
|
|
|
for input in forward_model.graph.input:
|
|
if input.name in type_map and type_map[input.name] is None:
|
|
type_map[input.name] = input.type
|
|
|
|
for output in forward_model.graph.output:
|
|
if output.name in type_map and type_map[output.name] is None:
|
|
type_map[output.name] = output.type
|
|
output_grad_name = output.name + '_grad'
|
|
if output_grad_name in type_map and type_map[output_grad_name] is None:
|
|
type_map[output_grad_name] = output.type
|
|
|
|
for key, value in type_map.items():
|
|
print_type(key, value)
|