mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Add support to BERT fine tuning (MVP 3)
Additional changes include major refactoring to use new backend API
This commit is contained in:
parent
78831d009b
commit
ff79e8743f
5 changed files with 182 additions and 308 deletions
|
|
@ -6,8 +6,9 @@ import onnxruntime
|
|||
import os
|
||||
import torch
|
||||
import warnings
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from inspect import signature
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from . import _utils
|
||||
|
||||
|
||||
|
|
@ -22,29 +23,16 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
# User module is wrapped to use its initializers and save computed gradients
|
||||
self._original_module = module
|
||||
self._original_module_grad_output_len = -1
|
||||
self._original_module_forward_input_grads = []
|
||||
self._onnx_training = None
|
||||
self._onnx_training_inputs_desc = []
|
||||
self._onnx_training_outputs_desc = []
|
||||
self._onnx_gradient = None
|
||||
self._grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
|
||||
|
||||
# Forward pass
|
||||
self._onnx_forward = None
|
||||
self._forward_session = None
|
||||
self._onnx_forward_initializers_desc = []
|
||||
self._onnx_forward_inputs_desc = []
|
||||
self._onnx_forward_outputs_desc = []
|
||||
self._onnx_forward_intermediate_outputs_desc = []
|
||||
|
||||
# Backward pass
|
||||
self._onnx_backward = None
|
||||
self._backward_session = None
|
||||
self._onnx_backward_initializers_desc = []
|
||||
self._onnx_backward_inputs_desc = []
|
||||
self._onnx_backward_gradient_inputs_desc = []
|
||||
self._onnx_backward_outputs_desc = []
|
||||
|
||||
# Log level
|
||||
self._loglevel = getattr(logging, 'WARNING')
|
||||
|
|
@ -60,14 +48,13 @@ class ORTModule(torch.nn.Module):
|
|||
ONNX model is exported the first time this method is executed.
|
||||
Next, a full training graph is splitted in forward and backward graph which are used
|
||||
to instantiate ONNX Runtime InferenceSession`s
|
||||
|
||||
TODO: #ImproveGraphSplitting
|
||||
Additionally to that, several descriptor lists are generated to help identify
|
||||
model input, output, initializer, intermediate and gradient tensors.
|
||||
'''
|
||||
if not self._onnx_forward:
|
||||
self._onnx_training = ORTModule._get_forward_graph(self._original_module, *inputs, **kwargs)
|
||||
self._onnx_gradient, self._onnx_forward, self._onnx_backward = ORTModule._build_gradient_graph(self._onnx_training, self._grad_builder_config)
|
||||
grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
|
||||
self._onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config)
|
||||
# TODO: PyTorch exporter bug: changes the initializer order
|
||||
self._onnx_graphs_info.initializer_grad_names_to_train = [ p[0]+'_grad' for p in self._original_module.named_parameters()]
|
||||
|
||||
if self._save_onnx:
|
||||
onnx.save(self._onnx_training, self._save_onnx_prefix + '_full_training.onnx')
|
||||
|
|
@ -75,44 +62,13 @@ class ORTModule(torch.nn.Module):
|
|||
onnx.save(self._onnx_forward, self._save_onnx_prefix + '_forward.onnx')
|
||||
onnx.save(self._onnx_backward, self._save_onnx_prefix + '_backward.onnx')
|
||||
|
||||
# TODO: Consider moving this to the backend. We don't want to append '_grad' to get correct tensor names
|
||||
self._onnx_graphs_types = ORTModule._get_io_info_from_onnx_graph(self._onnx_forward, self._onnx_graphs_info)
|
||||
|
||||
# TODO: hard-coding to CPU only
|
||||
self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString(), providers=['CPUExecutionProvider'])
|
||||
self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString(), providers=['CPUExecutionProvider'])
|
||||
|
||||
# Forward I/O description
|
||||
if not self._onnx_training_inputs_desc:
|
||||
self._onnx_training_inputs_desc = self._get_input_from_graph(self._onnx_training)
|
||||
logging.debug(f'Training inputs:\n\t {self._onnx_training_inputs_desc}')
|
||||
if not self._onnx_training_outputs_desc:
|
||||
self._onnx_training_outputs_desc = self._get_output_from_graph(self._onnx_training)
|
||||
logging.debug(f'Training outputs:\n\t {self._onnx_training_outputs_desc}')
|
||||
if not self._onnx_forward_initializers_desc:
|
||||
self._onnx_forward_initializers_desc = self._get_initializer_from_graph(self._onnx_forward)
|
||||
logging.debug(f'Forward initializers:\n\t {self._onnx_forward_initializers_desc}')
|
||||
if not self._onnx_forward_inputs_desc:
|
||||
self._onnx_forward_inputs_desc = self._get_input_from_graph(self._onnx_forward)
|
||||
logging.debug(f'Forward inputs:\n\t {self._onnx_forward_inputs_desc}')
|
||||
if not self._onnx_forward_outputs_desc:
|
||||
self._onnx_forward_outputs_desc = self._get_output_from_graph(self._onnx_forward)
|
||||
logging.debug(f'Forward outputs:\n\t {self._onnx_forward_outputs_desc}')
|
||||
if not self._onnx_forward_intermediate_outputs_desc:
|
||||
self._onnx_forward_intermediate_outputs_desc = self._get_intermediate_from_forward_graph(self._onnx_forward)
|
||||
logging.debug(f'Forward intermediate outputs:\n\t {self._onnx_forward_intermediate_outputs_desc}')
|
||||
|
||||
# Backward I/O description
|
||||
if not self._onnx_backward_initializers_desc:
|
||||
self._onnx_backward_initializers_desc = self._get_input_from_graph(self._onnx_backward, True)
|
||||
logging.debug(f'Backward initializers: {self._onnx_backward_initializers_desc}')
|
||||
if not self._onnx_backward_inputs_desc:
|
||||
self._onnx_backward_inputs_desc = self._get_input_from_graph(self._onnx_backward, False, self._onnx_backward_initializers_desc)
|
||||
logging.debug(f'Backward inputs: {self._onnx_backward_inputs_desc}')
|
||||
if not self._onnx_backward_gradient_inputs_desc:
|
||||
self._onnx_backward_gradient_inputs_desc = self._get_gradient_input_from_graph(self._onnx_backward, self._onnx_forward_inputs_desc, self._onnx_forward_initializers_desc, self._onnx_forward_intermediate_outputs_desc)
|
||||
logging.debug(f'Backward gradient inputs: {self._onnx_backward_gradient_inputs_desc}')
|
||||
if not self._onnx_backward_outputs_desc:
|
||||
self._onnx_backward_outputs_desc = self._get_output_from_graph(self._onnx_backward)
|
||||
logging.debug(f'Backward outputs: {self._onnx_backward_outputs_desc}')
|
||||
|
||||
# Use a custom torch.autograd.Function to associate self.backward_graph as the
|
||||
# gradient implementation for self.forward_graph.
|
||||
class _ORTModuleFunction(torch.autograd.Function):
|
||||
|
|
@ -127,9 +83,6 @@ class ORTModule(torch.nn.Module):
|
|||
* (Partial) user input
|
||||
* (Partial) Initializers
|
||||
* Intermediate tensors
|
||||
|
||||
TODO: #ImproveGraphSplitting
|
||||
String matching to separate user input from initializer
|
||||
'''
|
||||
|
||||
# Convert input to dict of torch tensors
|
||||
|
|
@ -143,10 +96,12 @@ class ORTModule(torch.nn.Module):
|
|||
outputs = tuple(torch.from_numpy(item) for item in outputs)
|
||||
|
||||
# Save input, initializers and intermediate tensors to be used during backward
|
||||
initializer_names = [item['name'] for item in self._onnx_backward_initializers_desc]
|
||||
input_names = [item['name'] for item in self._onnx_backward_inputs_desc if item['name'] not in initializer_names]
|
||||
ctx_input = tuple(v for k,v in data_dict.items() if k in input_names)
|
||||
ctx_initializer = tuple(v for k,v in data_dict.items() if k in initializer_names)
|
||||
user_input = self._onnx_graphs_info.user_input_names
|
||||
backward_user_input = self._onnx_graphs_info.backward_user_input_names
|
||||
ctx_input = tuple(data_dict[name] for name in user_input if name in backward_user_input)
|
||||
forward_initializer = self._onnx_graphs_info.initializer_names_to_train
|
||||
backward_intializer = self._onnx_graphs_info.backward_intializer_names_as_input
|
||||
ctx_initializer = tuple(data_dict[name] for name in forward_initializer if name in backward_intializer)
|
||||
intermediate = tuple(torch.from_numpy(item) for item in intermediate)
|
||||
ctx.save_for_backward(*[*ctx_input, *ctx_initializer, *intermediate])
|
||||
|
||||
|
|
@ -163,23 +118,17 @@ class ORTModule(torch.nn.Module):
|
|||
* Tensor stashed (in a particular order) during forward:
|
||||
* (partial) user input, (partial) initializers and intermediate tensors
|
||||
|
||||
TODO: #ImproveGraphSplitting
|
||||
Length of `*grad_output` is needed to detect intermediate tensors during backward pass
|
||||
|
||||
TODO: Input gradient is hard-coded to torch.tensor([1.])
|
||||
'''
|
||||
saved_tensors = ctx.saved_tensors
|
||||
# Used to create backward input
|
||||
if self._original_module_grad_output_len == -1:
|
||||
self._original_module_grad_output_len = len(grad_output)
|
||||
|
||||
grad_weights = self._run_backward_graph(*[*saved_tensors, *grad_output])
|
||||
|
||||
result = [torch.tensor([1])]* len(self._onnx_training_inputs_desc)
|
||||
result = [torch.tensor([1])]* len(self._onnx_graphs_info.user_input_names)
|
||||
result += [torch.from_numpy(grad) for grad in grad_weights]
|
||||
return tuple(result)
|
||||
|
||||
return _ORTModuleFunction.apply(*self._convert_forward_input_to_list(*inputs, **kwargs))
|
||||
proc_inputs = [data for data in inputs if data is not None]
|
||||
return _ORTModuleFunction.apply(*self._convert_forward_input_to_list(*proc_inputs, **kwargs))
|
||||
|
||||
def _convert_forward_input_to_list(self, *inputs, **kwargs):
|
||||
'''Creates forward `*inputs` list from user input and PyTorch initializers
|
||||
|
|
@ -189,8 +138,10 @@ class ORTModule(torch.nn.Module):
|
|||
ONNX Runtime forward requires an order list of:
|
||||
* User input: computed from forward InferenceSession
|
||||
* Initializers: computed from original PyTorch model parameters
|
||||
'''
|
||||
|
||||
This codes assumes the exported model's inputs and initializers
|
||||
are the same as the original PyTorch model
|
||||
'''
|
||||
# List containing both user inputs and initializers, in this order
|
||||
result = []
|
||||
|
||||
|
|
@ -219,12 +170,8 @@ class ORTModule(torch.nn.Module):
|
|||
def _convert_forward_input_list_to_dict(self, *inputs):
|
||||
'''Convert forward `*inputs` list to dict
|
||||
|
||||
TODO: #ImproveGraphSplitting
|
||||
Additionally, a list of gradient names of initializers are created to be used by backprop
|
||||
|
||||
TODO: Input gradient is being ignored for MVP
|
||||
'''
|
||||
|
||||
# Dictionary containing both inputs and initializers
|
||||
result = {}
|
||||
|
||||
|
|
@ -237,15 +184,6 @@ class ORTModule(torch.nn.Module):
|
|||
# Initializers
|
||||
for param in self._original_module.named_parameters():
|
||||
result.update({param[0]: inputs[result_len]})
|
||||
# TODO: Create order list of input grads to use during backward.
|
||||
# (for scenarios where gradients of input is required - not covered on MVP)
|
||||
# if len(self._original_module_forward_input_grads) < len(self._onnx_training_inputs_desc):
|
||||
# self._original_module_forward_input_grads.append(param[0]+'_grad')
|
||||
|
||||
# TODO: Create order list of initializer grads to use during backward.
|
||||
# if len(self._original_module_forward_input_grads) < len(self._onnx_backward_outputs_desc) + len(self._onnx_training_inputs_desc):
|
||||
if len(self._original_module_forward_input_grads) < len(self._onnx_backward_outputs_desc):
|
||||
self._original_module_forward_input_grads.append(param[0]+'_grad')
|
||||
result_len += 1
|
||||
|
||||
return result
|
||||
|
|
@ -253,46 +191,45 @@ class ORTModule(torch.nn.Module):
|
|||
def _convert_backward_input_list_to_dict(self, *inputs):
|
||||
'''Convert backward `*inputs` list to dict
|
||||
|
||||
ONNX Runtime backend requires dict as input, which is composed of:
|
||||
ONNX Runtime backward requires dict as input, which is composed of:
|
||||
* User input
|
||||
Although not necessary, all user inputs are used for simplicity
|
||||
* (Partial) Initializers
|
||||
init_begin = len(user_input)
|
||||
init_count = len(Pre-computed list of initializer)
|
||||
* Intermediate tensors TODO: #ImproveGraphSplitting
|
||||
Intermediate tensors are inferred from input position:
|
||||
interm_begin = len(user_input) + len(initializer)
|
||||
interm_count = len(all_inputs) - len(user_input) - len(initializer) - len(grad_output)
|
||||
* Gradient wrt outputs TODO: #ImproveGraphSplitting
|
||||
Gradient tensors are inferred from input position:
|
||||
grads_begin = len(user_input) + len(initializer) + len(intermediate)
|
||||
grads_count = len(all_inputs) - len(user_input) - len(initializer) - len(intermediate)
|
||||
* Intermediate tensors
|
||||
* Gradient wrt outputs
|
||||
'''
|
||||
|
||||
# Dictionary containing both inputs and initializers
|
||||
result = {}
|
||||
|
||||
backward_user_input = self._onnx_graphs_info.backward_user_input_names
|
||||
backward_intializer = self._onnx_graphs_info.backward_intializer_names_as_input
|
||||
intermediate = self._onnx_graphs_info.intermediate_tensor_names
|
||||
backward_output_grad_names = self._onnx_graphs_info.backward_output_grad_names
|
||||
|
||||
# Extract info about stashed input and grad output
|
||||
# Inputs
|
||||
result_len = 0
|
||||
for idx, input_data in enumerate(self._forward_session.get_inputs()):
|
||||
result.update({ input_data.name : inputs[idx]})
|
||||
result_len += 1
|
||||
inputs_pos = 0
|
||||
for idx, name in enumerate(backward_user_input):
|
||||
result.update({ name : inputs[idx]})
|
||||
inputs_pos += 1
|
||||
|
||||
# Initializers
|
||||
for initializer in self._onnx_backward_initializers_desc:
|
||||
result.update({initializer['name']: inputs[result_len]})
|
||||
result_len += 1
|
||||
for idx, name in enumerate(backward_intializer, inputs_pos):
|
||||
result.update({name: inputs[idx]})
|
||||
inputs_pos += 1
|
||||
|
||||
# Intermediate
|
||||
intermediate_len = len(inputs) - result_len - self._original_module_grad_output_len
|
||||
for idx in range(intermediate_len):
|
||||
result.update({self._onnx_forward_intermediate_outputs_desc[idx]['name']: inputs[result_len]})
|
||||
result_len += 1
|
||||
for idx, name in enumerate(intermediate, inputs_pos):
|
||||
result.update({name: inputs[idx]})
|
||||
inputs_pos += 1
|
||||
|
||||
# Grad outputs
|
||||
for idx in range(len(inputs)-result_len):
|
||||
result.update({self._onnx_backward_gradient_inputs_desc[idx]['name']: inputs[result_len]})
|
||||
result_len += 1
|
||||
for idx, name in enumerate(backward_output_grad_names, inputs_pos):
|
||||
result.update({name: inputs[idx]})
|
||||
inputs_pos += 1
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -303,10 +240,10 @@ class ORTModule(torch.nn.Module):
|
|||
to distinguish intermediate from output tensors
|
||||
'''
|
||||
|
||||
output_names = [out['name'] for out in self._onnx_forward_outputs_desc]
|
||||
forward_output = self._forward_session.run(output_names, inputs)
|
||||
output = forward_output[:len(self._onnx_training_outputs_desc)]
|
||||
intermediates = forward_output[len(self._onnx_training_outputs_desc):]
|
||||
forward_output = self._forward_session.run([*self._onnx_graphs_info.user_output_names,
|
||||
*self._onnx_graphs_info.intermediate_tensor_names], inputs)
|
||||
output = forward_output[:len(self._onnx_graphs_info.user_output_names)]
|
||||
intermediates = forward_output[len(self._onnx_graphs_info.user_output_names):]
|
||||
return output, intermediates
|
||||
|
||||
def _run_backward_graph(self, *inputs, **kwargs):
|
||||
|
|
@ -323,7 +260,7 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
# Convert dict of torch tensors to dict of numpy arrays (ORT BE requirement)
|
||||
data = self._convert_dict_torch_to_numpy(data)
|
||||
return self._backward_session.run(self._original_module_forward_input_grads, data)
|
||||
return self._backward_session.run(self._onnx_graphs_info.initializer_grad_names_to_train, data)
|
||||
|
||||
@staticmethod
|
||||
def _get_forward_graph(module, *inputs, **kwargs):
|
||||
|
|
@ -338,6 +275,11 @@ class ORTModule(torch.nn.Module):
|
|||
# Deepcopy inputs, since input values may change after model run.
|
||||
sample_inputs_copy = copy.deepcopy(inputs)
|
||||
|
||||
# Ignore optional *inputs explicitly specified as None
|
||||
sig = signature(module.forward)
|
||||
all_input_names = sig.parameters.keys()
|
||||
input_names = [name for idx, name in enumerate(all_input_names) if inputs[idx] is not None]
|
||||
|
||||
# TODO: Support contrib OPs support? user model has no hint
|
||||
# from onnxruntime.training import register_custom_ops_pytorch_exporter
|
||||
# register_custom_ops_pytorch_exporter.register_custom_op()
|
||||
|
|
@ -346,173 +288,68 @@ class ORTModule(torch.nn.Module):
|
|||
torch.onnx.export(module,
|
||||
tuple(sample_inputs_copy),
|
||||
f,
|
||||
input_names=input_names,
|
||||
opset_version=ONNX_OPSET_VERSION,
|
||||
do_constant_folding=False,
|
||||
training=torch.onnx.TrainingMode.TRAINING)
|
||||
|
||||
return onnx.load_model_from_string(f.getvalue())
|
||||
|
||||
def _get_initializer_from_graph(self, graph):
|
||||
'''Returns a descriptor list of initializers for `graph`
|
||||
|
||||
The list descriptor has the following format:
|
||||
[{ 'name': name, 'shape':[int1,...,intN], 'dtype': <onnx.dtype> ]}]
|
||||
|
||||
For ONNX types, refer to https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L461
|
||||
'''
|
||||
|
||||
# TODO: There is a tradeoff between memory footprint and total model export time
|
||||
# Ideally we want to export the model using torch.onnx.export(.., export_params=False, keep_initializers_as_inputs=True)
|
||||
# to obtain an ONNX model with minimal size and initializers as input.
|
||||
# However, this results in (guessing) assuming only initializer's name end with '.weight' and '.bias'.
|
||||
# Otherwise, it is not possible to separate input from initializer after the model is exported
|
||||
# Options are:
|
||||
# 1) If memory footprint is more important, we can export ONNX twice, varying keep_initializers_as_inputs flag
|
||||
# ONNX model is small (400 bytes vs 1.6MB for MNIST), but export takes twice the time
|
||||
# 2) If total export time is more important, we can export ONNX once, using export_params=True
|
||||
# ONNX model is bigger, but export takes half the time
|
||||
|
||||
# As performance is not the main goal in this first deliverable, using approach 2) for simplicity
|
||||
initializers = []
|
||||
for initializer in graph.graph.initializer:
|
||||
name = initializer.name
|
||||
shape = initializer.dims
|
||||
dtype = _utils.dtype_onnx_to_torch(initializer.data_type)
|
||||
initializers.append({'name': name, 'shape': shape, 'dtype': dtype})
|
||||
return initializers
|
||||
|
||||
def _get_input_from_graph(self, graph, initializers_only=False, append_initializers=[]):
|
||||
'''Returns a descriptor list of input tensors for an ONNX `graph`
|
||||
|
||||
When `initializers_only=True`, only input initializers are returned. Otherwise, both
|
||||
user input and initializers are considered.
|
||||
This is being used to get backward initializer list TODO: #ImproveGraphSplitting
|
||||
|
||||
When `append_initializers` is not empty, this list is appended to the end of the result list
|
||||
This is being used to get backward input list TODO: #ImproveGraphSplitting
|
||||
|
||||
The list descriptor has the following format:
|
||||
[{ 'name': name, 'shape':[int1,...,intN], 'dtype': <onnx.dtype> ]}]
|
||||
|
||||
For ONNX types, refer to https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L461
|
||||
'''
|
||||
|
||||
inputs = []
|
||||
for elem in graph.graph.input:
|
||||
for initializer in self._onnx_forward_initializers_desc:
|
||||
if elem.name == initializer['name']:
|
||||
if initializers_only:
|
||||
name = elem.name
|
||||
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
|
||||
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
|
||||
inputs.append({'name': name, 'shape': shape, 'dtype': dtype})
|
||||
break
|
||||
else:
|
||||
if not initializers_only:
|
||||
name = elem.name
|
||||
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
|
||||
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
|
||||
inputs.append({'name': name, 'shape': shape, 'dtype': dtype})
|
||||
if append_initializers:
|
||||
inputs.extend(append_initializers)
|
||||
return inputs
|
||||
|
||||
def _get_gradient_input_from_graph(self, backward_graph, forward_input, forward_initializer, forward_intermediate):
|
||||
'''Returns a descriptor list of gradient output for `backward_graph`
|
||||
|
||||
Gradient output tensors are found through an elimination process, that cross reference
|
||||
inputs from the backward graph to the forward input, initializer and intermediate tensors.
|
||||
|
||||
The list descriptor has the following format:
|
||||
[{ 'name': name, 'shape':[int1,...,intN], 'dtype': <onnx.dtype> ]}]
|
||||
|
||||
For ONNX types, refer to https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L461
|
||||
|
||||
TODO: #ImproveGraphSplitting
|
||||
'''
|
||||
grads = []
|
||||
found = False
|
||||
for elem in backward_graph.graph.input:
|
||||
for item in forward_input:
|
||||
if elem.name == item['name']:
|
||||
# skip output
|
||||
break
|
||||
else:
|
||||
for item in forward_initializer:
|
||||
if elem.name == item['name']:
|
||||
# skip output
|
||||
break
|
||||
else:
|
||||
for item in forward_intermediate:
|
||||
if elem.name == item['name']:
|
||||
# skip output
|
||||
break
|
||||
else:
|
||||
name = elem.name
|
||||
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
|
||||
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
|
||||
grads.append({'name': name, 'shape': shape, 'dtype': dtype})
|
||||
return grads
|
||||
|
||||
def _get_output_from_graph(self, graph):
|
||||
'''Returns a descriptor list of output tensors for an ONNX `graph`
|
||||
|
||||
The list descriptor has the following format:
|
||||
[{ 'name': name, 'shape':[int1,...,intN], 'dtype': <onnx.dtype> ]}]
|
||||
|
||||
For ONNX types, refer to https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L461
|
||||
'''
|
||||
outputs = []
|
||||
for elem in graph.graph.output:
|
||||
for initializer in self._onnx_forward_initializers_desc:
|
||||
if elem.name == initializer['name']:
|
||||
# skip initializers
|
||||
break
|
||||
else:
|
||||
name = elem.name
|
||||
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
|
||||
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
|
||||
outputs.append({'name': name, 'shape': shape, 'dtype': dtype})
|
||||
return outputs
|
||||
|
||||
def _get_intermediate_from_forward_graph(self, forward_graph):
|
||||
'''Returns a descriptor list with all intermediate tensors for `forward_graph`
|
||||
|
||||
Intermediate tensors are found through an elimination process, that cross reference
|
||||
outputs from the forward graph to the original model (exported to ONNX)
|
||||
|
||||
The list descriptor has the following format:
|
||||
[{ 'name': name, 'shape':[int1,...,intN], 'dtype': <onnx.dtype> ]}]
|
||||
|
||||
TODO: #ImproveGraphSplitting
|
||||
'''
|
||||
intermediates = []
|
||||
for elem in forward_graph.graph.output:
|
||||
for output in self._onnx_training_outputs_desc:
|
||||
if elem.name == output['name']:
|
||||
# skip output
|
||||
break
|
||||
else:
|
||||
name = elem.name
|
||||
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
|
||||
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
|
||||
intermediates.append({'name': name, 'shape': shape, 'dtype': dtype})
|
||||
return intermediates
|
||||
|
||||
@staticmethod
|
||||
def _build_gradient_graph(forward_graph, config):
|
||||
'''Adds gradient nodes on top of an existing ONNX graph (with training flag)
|
||||
|
||||
TODO: #SplittingGraphAtFrontend
|
||||
'''
|
||||
if not config.weight_names_to_train:
|
||||
weight_names_to_train = set()
|
||||
def _build_fw_bw_grad_graphs(forward_graph, config):
|
||||
'''Adds gradient nodes on top of an existing ONNX graph (with training flag)'''
|
||||
if not config.initializer_names_to_train:
|
||||
initializer_names_to_train = []
|
||||
for initializer in forward_graph.graph.initializer:
|
||||
weight_names_to_train.add(initializer.name)
|
||||
config.weight_names_to_train = weight_names_to_train
|
||||
output_names = set()
|
||||
for output in forward_graph.graph.output:
|
||||
output_names.add(output.name)
|
||||
config.output_names = output_names
|
||||
models = [onnx.load_model_from_string(model_as_string)
|
||||
for model_as_string in C.ModuleGradientGraphBuilder().build_and_split(forward_graph.SerializeToString(), config)]
|
||||
return models[0], models[1], models[2]
|
||||
initializer_names_to_train.append(initializer.name)
|
||||
config.initializer_names_to_train = initializer_names_to_train
|
||||
|
||||
# TODO: Add support to input with grad required
|
||||
config.input_names_require_grad = []
|
||||
# input_names_require_grad = []
|
||||
# input_names_require_grad.append('input.1')
|
||||
# config.input_names_require_grad = input_names_require_grad
|
||||
|
||||
module_gradient_graph_builder = C.ModuleGradientGraphBuilder()
|
||||
module_gradient_graph_builder.build_and_split(forward_graph.SerializeToString(), config)
|
||||
forward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_forward_model())
|
||||
backward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_backward_model())
|
||||
gradient_model = onnx.load_model_from_string(module_gradient_graph_builder.get_gradient_model())
|
||||
split_graphs_info = module_gradient_graph_builder.get_split_graphs_info()
|
||||
|
||||
return gradient_model, forward_model, backward_model, split_graphs_info
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_io_info_from_onnx_graph(model, graphs_info):
|
||||
type_map = {}
|
||||
for name in graphs_info.user_input_names:
|
||||
type_map[name] = None
|
||||
for name in graphs_info.initializer_names_to_train:
|
||||
type_map[name] = None
|
||||
for name in graphs_info.user_output_names:
|
||||
type_map[name] = None
|
||||
for name in graphs_info.backward_user_input_names:
|
||||
type_map[name] = None
|
||||
for name in graphs_info.backward_intializer_names_as_input:
|
||||
type_map[name] = None
|
||||
for name in graphs_info.intermediate_tensor_names:
|
||||
type_map[name] = None
|
||||
for name in graphs_info.user_output_grad_names:
|
||||
type_map[name] = None
|
||||
for name in graphs_info.backward_output_grad_names:
|
||||
type_map[name] = None
|
||||
|
||||
for input in model.graph.input:
|
||||
if input.name in type_map and type_map[input.name] is None:
|
||||
type_map[input.name] = input.type
|
||||
|
||||
for output in model.graph.output:
|
||||
if output.name in type_map and type_map[output.name] is None:
|
||||
type_map[output.name] = output.type
|
||||
output_grad_name = output.name + '_grad'
|
||||
if output_grad_name in type_map and type_map[output_grad_name] is None:
|
||||
type_map[output_grad_name] = output.type
|
||||
|
||||
return type_map
|
||||
|
|
@ -143,7 +143,7 @@ def main():
|
|||
# Set log level
|
||||
numeric_level = getattr(logging, args.log_level.upper(), None)
|
||||
if not isinstance(numeric_level, int):
|
||||
raise ValueError('Invalid log level: %s' % loglevel)
|
||||
raise ValueError('Invalid log level: %s' % args.log_level)
|
||||
logging.basicConfig(level=numeric_level)
|
||||
else:
|
||||
print('Training MNIST on vanilla PyTorch....')
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
|
||||
import pdb
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import torch
|
||||
import wget
|
||||
import os
|
||||
import pandas as pd
|
||||
import zipfile
|
||||
from transformers import BertTokenizer
|
||||
from transformers import BertTokenizer, AutoConfig
|
||||
from keras.preprocessing.sequence import pad_sequences
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
||||
|
|
@ -50,13 +48,10 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
|
|||
if step == args.train_steps:
|
||||
break
|
||||
|
||||
# Progress update every 40 batches.
|
||||
if step % args.log_interval == 0 and not step == 0:
|
||||
# Calculate elapsed time in minutes.
|
||||
elapsed = format_time(time.time() - t0)
|
||||
|
||||
# Report progress.
|
||||
print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))
|
||||
# TODO: Dynamic axis is not supported yet
|
||||
if batch[0].shape[0] != args.batch_size:
|
||||
logging.warning(f'Dynamic axis is not supported yet {len(batch)}/{args.batch_size}')
|
||||
continue
|
||||
|
||||
# Unpack this training batch from our dataloader.
|
||||
#
|
||||
|
|
@ -78,12 +73,23 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
|
|||
model.zero_grad()
|
||||
|
||||
# Perform a forward pass (evaluate the model on this training batch).
|
||||
# This will return the loss (rather than the model output) because we
|
||||
# have provided the `labels`.
|
||||
# This will return the loss (rather than the model output) because we have provided the `labels`.
|
||||
# The documentation for this `model` function is here:
|
||||
# https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
|
||||
# https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
|
||||
|
||||
# TODO: explicitly setting (optional) inputs to workaround *input, **kwargs limitation on ORTModule
|
||||
outputs = model(b_input_ids, b_input_mask, None, None, None, None, b_labels)
|
||||
# outputs = model(b_input_ids,
|
||||
# token_type_ids = None,
|
||||
# attention_mask = b_input_mask,
|
||||
# labels = b_labels)
|
||||
outputs = model(b_input_ids,
|
||||
b_input_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
b_labels)
|
||||
|
||||
if args.view_graphs:
|
||||
import torchviz
|
||||
pytorch_backward_graph = torchviz.make_dot(outputs[0], params=dict(list(model.named_parameters())))
|
||||
|
|
@ -91,15 +97,21 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
|
|||
|
||||
# The call to `model` always returns a tuple, so we need to pull the
|
||||
# loss value out of the tuple.
|
||||
# pdb.set_trace()
|
||||
loss = outputs[0]
|
||||
|
||||
# Progress update every 40 batches.
|
||||
if step % args.log_interval == 0 and not step == 0:
|
||||
# Calculate elapsed time in minutes.
|
||||
elapsed = format_time(time.time() - t0)
|
||||
|
||||
# Report progress.
|
||||
print(f'Batch {step} of {len(train_dataloader)}. Elapsed: {elapsed}. Loss: {loss.item()}')
|
||||
|
||||
# Accumulate the training loss over all of the batches so that we can
|
||||
# calculate the average loss at the end. `loss` is a Tensor containing a
|
||||
# single value; the `.item()` function just returns the Python value
|
||||
# from the tensor.
|
||||
total_loss += loss.item()
|
||||
# total_loss += loss
|
||||
|
||||
# Perform a backward pass to calculate the gradients.
|
||||
loss.backward()
|
||||
|
|
@ -122,7 +134,7 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
|
|||
print("\n Average training loss: {0:.2f}".format(avg_train_loss))
|
||||
print(" Training epoch took: {:}".format(format_time(time.time() - t0)))
|
||||
|
||||
def test(model, validation_dataloader, device):
|
||||
def test(model, validation_dataloader, device, args):
|
||||
# ========================================
|
||||
# Validation
|
||||
# ========================================
|
||||
|
|
@ -143,12 +155,16 @@ def test(model, validation_dataloader, device):
|
|||
# Evaluate data for one epoch
|
||||
for batch in validation_dataloader:
|
||||
|
||||
# TODO: Dynamic axis is not supported yet
|
||||
if batch[0].shape[0] != args.test_batch_size:
|
||||
logging.warning(f'Dynamic axis is not supported yet {len(batch)}/{args.batch_size}')
|
||||
continue
|
||||
|
||||
# Add batch to GPU
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
|
||||
# Unpack the inputs from our dataloader
|
||||
b_input_ids, b_input_mask, b_labels = batch
|
||||
|
||||
# Telling the model not to compute or store gradients, saving memory and
|
||||
# speeding up validation
|
||||
with torch.no_grad():
|
||||
|
|
@ -160,18 +176,22 @@ def test(model, validation_dataloader, device):
|
|||
# differentiates sentence 1 and 2 in 2-sentence tasks.
|
||||
# The documentation for this `model` function is here:
|
||||
# https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
|
||||
|
||||
# TODO: explicitly setting (optional) inputs to workaround *input, **kwargs limitation on ORTModule
|
||||
# TODO: original sample had the last argument equal to None, but b_labels is because model was
|
||||
# exported using 3 inputs for training, so validation must follow.
|
||||
# Another approach would be checkpoint the trained model, re-export the model for validation with the checkpoint
|
||||
outputs = model(b_input_ids,
|
||||
b_input_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None)
|
||||
b_labels)
|
||||
|
||||
# Get the "logits" output by the model. The "logits" are the output
|
||||
# values prior to applying an activation function like the softmax.
|
||||
logits = outputs[0]
|
||||
logits = outputs[1]
|
||||
|
||||
# Move logits and labels to CPU
|
||||
logits = logits.detach().cpu().numpy()
|
||||
|
|
@ -190,7 +210,7 @@ def test(model, validation_dataloader, device):
|
|||
print(" Accuracy: {0:.2f}".format(eval_accuracy/nb_eval_steps))
|
||||
print(" Validation took: {:}".format(format_time(time.time() - t0)))
|
||||
|
||||
def load_dataset():
|
||||
def load_dataset(args):
|
||||
# 2. Loading CoLA Dataset
|
||||
print('Downloading dataset...')
|
||||
|
||||
|
|
@ -276,18 +296,15 @@ def load_dataset():
|
|||
train_masks = torch.tensor(train_masks)
|
||||
validation_masks = torch.tensor(validation_masks)
|
||||
|
||||
# The DataLoader needs to know our batch size for training, so we specify it
|
||||
batch_size = 32
|
||||
|
||||
# Create the DataLoader for our training set.
|
||||
train_data = TensorDataset(train_inputs, train_masks, train_labels)
|
||||
train_sampler = RandomSampler(train_data)
|
||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
|
||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size)
|
||||
|
||||
# Create the DataLoader for our validation set.
|
||||
validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels)
|
||||
validation_sampler = SequentialSampler(validation_data)
|
||||
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)
|
||||
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=args.test_batch_size)
|
||||
|
||||
return train_dataloader, validation_dataloader
|
||||
|
||||
|
|
@ -310,6 +327,10 @@ def main():
|
|||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
parser.add_argument('--pytorch-only', action='store_true', default=False,
|
||||
help='disables ONNX Runtime training')
|
||||
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
|
||||
help='input batch size for training (default: 32)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=32, metavar='N',
|
||||
help='input batch size for testing (default: 32)')
|
||||
parser.add_argument('--view-graphs', action='store_true', default=False,
|
||||
help='views forward and backward graphs')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
|
|
@ -322,6 +343,11 @@ def main():
|
|||
help='how many batches to wait before logging training status (default: 40)')
|
||||
parser.add_argument('--train-steps', type=int, default=-1, metavar='N',
|
||||
help='number of steps to train. Set -1 to run through whole dataset (default: -1)')
|
||||
parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING',
|
||||
help='Log level (default: WARNING)')
|
||||
parser.add_argument('--num-hidden-layers', type=int, default=1, metavar='H',
|
||||
help='Number of hidden layers for the BERT model. A vanila BERT has 12 hidden layers (default: 1)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Device (CPU vs CUDA)
|
||||
|
|
@ -333,17 +359,28 @@ def main():
|
|||
print('No GPU available, using the CPU instead.')
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Set log level
|
||||
numeric_level = getattr(logging, args.log_level.upper(), None)
|
||||
if not isinstance(numeric_level, int):
|
||||
raise ValueError('Invalid log level: %s' % args.log_level)
|
||||
logging.basicConfig(level=numeric_level)
|
||||
|
||||
# 2. Dataloader
|
||||
train_dataloader, validation_dataloader = load_dataset()
|
||||
train_dataloader, validation_dataloader = load_dataset(args)
|
||||
|
||||
# 3. Modeling
|
||||
# Load BertForSequenceClassification, the pretrained BERT model with a single
|
||||
# linear classification layer on top.
|
||||
config = AutoConfig.from_pretrained(
|
||||
"bert-base-uncased",
|
||||
num_labels=2,
|
||||
num_hidden_layers=args.num_hidden_layers,
|
||||
output_attentions = False, # Whether the model returns attentions weights.
|
||||
output_hidden_states = False, # Whether the model returns all hidden-states.
|
||||
)
|
||||
model = BertForSequenceClassification.from_pretrained(
|
||||
"bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
|
||||
num_labels = 2, # The number of output labels--2 for binary classification.
|
||||
output_attentions = False, # Whether the model returns attentions weights.
|
||||
output_hidden_states = False, # Whether the model returns all hidden-states.
|
||||
config=config,
|
||||
)
|
||||
|
||||
if not args.pytorch_only:
|
||||
|
|
@ -382,7 +419,7 @@ def main():
|
|||
# 4. Train loop (fine-tune)
|
||||
for epoch_i in range(0, args.epochs):
|
||||
train(model, optimizer, scheduler, train_dataloader, epoch_i, device, args)
|
||||
test(model, validation_dataloader, device)
|
||||
test(model, validation_dataloader, device, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -15,4 +15,4 @@ echo "Copying PyTorch frontend source-code to build folder"
|
|||
cp -Rf ../../../orttraining/orttraining/python/training/* ../../../build/Linux/RelWithDebInfo/onnxruntime/training/
|
||||
|
||||
echo "Running Flexible API (ORTModule)"
|
||||
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py
|
||||
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py --no-cuda --epochs 4 --log-interval 20 --log-level=DEBUG
|
||||
|
|
|
|||
|
|
@ -15,4 +15,4 @@ echo "Copying PyTorch frontend source-code to build folder"
|
|||
cp -Rf ../../../orttraining/orttraining/python/training/* ../../../build/Linux/RelWithDebInfo/onnxruntime/training/
|
||||
|
||||
echo "Running Flexible API (ORTModule)"
|
||||
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py
|
||||
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py --epochs 10 --log-interval 100 --log-level=DEBUG
|
||||
|
|
|
|||
Loading…
Reference in a new issue