Add support to BERT fine tuning (MVP 3)

Additional changes include major refactoring to use new backend API
This commit is contained in:
Thiago Crepaldi 2020-11-13 09:41:59 -08:00
parent 78831d009b
commit ff79e8743f
5 changed files with 182 additions and 308 deletions

View file

@ -6,8 +6,9 @@ import onnxruntime
import os
import torch
import warnings
from onnxruntime.capi import _pybind_state as C
from inspect import signature
from onnxruntime.capi import _pybind_state as C
from . import _utils
@ -22,29 +23,16 @@ class ORTModule(torch.nn.Module):
# User module is wrapped to use its initializers and save computed gradients
self._original_module = module
self._original_module_grad_output_len = -1
self._original_module_forward_input_grads = []
self._onnx_training = None
self._onnx_training_inputs_desc = []
self._onnx_training_outputs_desc = []
self._onnx_gradient = None
self._grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
# Forward pass
self._onnx_forward = None
self._forward_session = None
self._onnx_forward_initializers_desc = []
self._onnx_forward_inputs_desc = []
self._onnx_forward_outputs_desc = []
self._onnx_forward_intermediate_outputs_desc = []
# Backward pass
self._onnx_backward = None
self._backward_session = None
self._onnx_backward_initializers_desc = []
self._onnx_backward_inputs_desc = []
self._onnx_backward_gradient_inputs_desc = []
self._onnx_backward_outputs_desc = []
# Log level
self._loglevel = getattr(logging, 'WARNING')
@ -60,14 +48,13 @@ class ORTModule(torch.nn.Module):
ONNX model is exported the first time this method is executed.
Next, a full training graph is splitted in forward and backward graph which are used
to instantiate ONNX Runtime InferenceSession`s
TODO: #ImproveGraphSplitting
Additionally to that, several descriptor lists are generated to help identify
model input, output, initializer, intermediate and gradient tensors.
'''
if not self._onnx_forward:
self._onnx_training = ORTModule._get_forward_graph(self._original_module, *inputs, **kwargs)
self._onnx_gradient, self._onnx_forward, self._onnx_backward = ORTModule._build_gradient_graph(self._onnx_training, self._grad_builder_config)
grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
self._onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config)
# TODO: PyTorch exporter bug: changes the initializer order
self._onnx_graphs_info.initializer_grad_names_to_train = [ p[0]+'_grad' for p in self._original_module.named_parameters()]
if self._save_onnx:
onnx.save(self._onnx_training, self._save_onnx_prefix + '_full_training.onnx')
@ -75,44 +62,13 @@ class ORTModule(torch.nn.Module):
onnx.save(self._onnx_forward, self._save_onnx_prefix + '_forward.onnx')
onnx.save(self._onnx_backward, self._save_onnx_prefix + '_backward.onnx')
# TODO: Consider moving this to the backend. We don't want to append '_grad' to get correct tensor names
self._onnx_graphs_types = ORTModule._get_io_info_from_onnx_graph(self._onnx_forward, self._onnx_graphs_info)
# TODO: hard-coding to CPU only
self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString(), providers=['CPUExecutionProvider'])
self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString(), providers=['CPUExecutionProvider'])
# Forward I/O description
if not self._onnx_training_inputs_desc:
self._onnx_training_inputs_desc = self._get_input_from_graph(self._onnx_training)
logging.debug(f'Training inputs:\n\t {self._onnx_training_inputs_desc}')
if not self._onnx_training_outputs_desc:
self._onnx_training_outputs_desc = self._get_output_from_graph(self._onnx_training)
logging.debug(f'Training outputs:\n\t {self._onnx_training_outputs_desc}')
if not self._onnx_forward_initializers_desc:
self._onnx_forward_initializers_desc = self._get_initializer_from_graph(self._onnx_forward)
logging.debug(f'Forward initializers:\n\t {self._onnx_forward_initializers_desc}')
if not self._onnx_forward_inputs_desc:
self._onnx_forward_inputs_desc = self._get_input_from_graph(self._onnx_forward)
logging.debug(f'Forward inputs:\n\t {self._onnx_forward_inputs_desc}')
if not self._onnx_forward_outputs_desc:
self._onnx_forward_outputs_desc = self._get_output_from_graph(self._onnx_forward)
logging.debug(f'Forward outputs:\n\t {self._onnx_forward_outputs_desc}')
if not self._onnx_forward_intermediate_outputs_desc:
self._onnx_forward_intermediate_outputs_desc = self._get_intermediate_from_forward_graph(self._onnx_forward)
logging.debug(f'Forward intermediate outputs:\n\t {self._onnx_forward_intermediate_outputs_desc}')
# Backward I/O description
if not self._onnx_backward_initializers_desc:
self._onnx_backward_initializers_desc = self._get_input_from_graph(self._onnx_backward, True)
logging.debug(f'Backward initializers: {self._onnx_backward_initializers_desc}')
if not self._onnx_backward_inputs_desc:
self._onnx_backward_inputs_desc = self._get_input_from_graph(self._onnx_backward, False, self._onnx_backward_initializers_desc)
logging.debug(f'Backward inputs: {self._onnx_backward_inputs_desc}')
if not self._onnx_backward_gradient_inputs_desc:
self._onnx_backward_gradient_inputs_desc = self._get_gradient_input_from_graph(self._onnx_backward, self._onnx_forward_inputs_desc, self._onnx_forward_initializers_desc, self._onnx_forward_intermediate_outputs_desc)
logging.debug(f'Backward gradient inputs: {self._onnx_backward_gradient_inputs_desc}')
if not self._onnx_backward_outputs_desc:
self._onnx_backward_outputs_desc = self._get_output_from_graph(self._onnx_backward)
logging.debug(f'Backward outputs: {self._onnx_backward_outputs_desc}')
# Use a custom torch.autograd.Function to associate self.backward_graph as the
# gradient implementation for self.forward_graph.
class _ORTModuleFunction(torch.autograd.Function):
@ -127,9 +83,6 @@ class ORTModule(torch.nn.Module):
* (Partial) user input
* (Partial) Initializers
* Intermediate tensors
TODO: #ImproveGraphSplitting
String matching to separate user input from initializer
'''
# Convert input to dict of torch tensors
@ -143,10 +96,12 @@ class ORTModule(torch.nn.Module):
outputs = tuple(torch.from_numpy(item) for item in outputs)
# Save input, initializers and intermediate tensors to be used during backward
initializer_names = [item['name'] for item in self._onnx_backward_initializers_desc]
input_names = [item['name'] for item in self._onnx_backward_inputs_desc if item['name'] not in initializer_names]
ctx_input = tuple(v for k,v in data_dict.items() if k in input_names)
ctx_initializer = tuple(v for k,v in data_dict.items() if k in initializer_names)
user_input = self._onnx_graphs_info.user_input_names
backward_user_input = self._onnx_graphs_info.backward_user_input_names
ctx_input = tuple(data_dict[name] for name in user_input if name in backward_user_input)
forward_initializer = self._onnx_graphs_info.initializer_names_to_train
backward_intializer = self._onnx_graphs_info.backward_intializer_names_as_input
ctx_initializer = tuple(data_dict[name] for name in forward_initializer if name in backward_intializer)
intermediate = tuple(torch.from_numpy(item) for item in intermediate)
ctx.save_for_backward(*[*ctx_input, *ctx_initializer, *intermediate])
@ -163,23 +118,17 @@ class ORTModule(torch.nn.Module):
* Tensor stashed (in a particular order) during forward:
* (partial) user input, (partial) initializers and intermediate tensors
TODO: #ImproveGraphSplitting
Length of `*grad_output` is needed to detect intermediate tensors during backward pass
TODO: Input gradient is hard-coded to torch.tensor([1.])
'''
saved_tensors = ctx.saved_tensors
# Used to create backward input
if self._original_module_grad_output_len == -1:
self._original_module_grad_output_len = len(grad_output)
grad_weights = self._run_backward_graph(*[*saved_tensors, *grad_output])
result = [torch.tensor([1])]* len(self._onnx_training_inputs_desc)
result = [torch.tensor([1])]* len(self._onnx_graphs_info.user_input_names)
result += [torch.from_numpy(grad) for grad in grad_weights]
return tuple(result)
return _ORTModuleFunction.apply(*self._convert_forward_input_to_list(*inputs, **kwargs))
proc_inputs = [data for data in inputs if data is not None]
return _ORTModuleFunction.apply(*self._convert_forward_input_to_list(*proc_inputs, **kwargs))
def _convert_forward_input_to_list(self, *inputs, **kwargs):
'''Creates forward `*inputs` list from user input and PyTorch initializers
@ -189,8 +138,10 @@ class ORTModule(torch.nn.Module):
ONNX Runtime forward requires an order list of:
* User input: computed from forward InferenceSession
* Initializers: computed from original PyTorch model parameters
'''
This codes assumes the exported model's inputs and initializers
are the same as the original PyTorch model
'''
# List containing both user inputs and initializers, in this order
result = []
@ -219,12 +170,8 @@ class ORTModule(torch.nn.Module):
def _convert_forward_input_list_to_dict(self, *inputs):
'''Convert forward `*inputs` list to dict
TODO: #ImproveGraphSplitting
Additionally, a list of gradient names of initializers are created to be used by backprop
TODO: Input gradient is being ignored for MVP
'''
# Dictionary containing both inputs and initializers
result = {}
@ -237,15 +184,6 @@ class ORTModule(torch.nn.Module):
# Initializers
for param in self._original_module.named_parameters():
result.update({param[0]: inputs[result_len]})
# TODO: Create order list of input grads to use during backward.
# (for scenarios where gradients of input is required - not covered on MVP)
# if len(self._original_module_forward_input_grads) < len(self._onnx_training_inputs_desc):
# self._original_module_forward_input_grads.append(param[0]+'_grad')
# TODO: Create order list of initializer grads to use during backward.
# if len(self._original_module_forward_input_grads) < len(self._onnx_backward_outputs_desc) + len(self._onnx_training_inputs_desc):
if len(self._original_module_forward_input_grads) < len(self._onnx_backward_outputs_desc):
self._original_module_forward_input_grads.append(param[0]+'_grad')
result_len += 1
return result
@ -253,46 +191,45 @@ class ORTModule(torch.nn.Module):
def _convert_backward_input_list_to_dict(self, *inputs):
'''Convert backward `*inputs` list to dict
ONNX Runtime backend requires dict as input, which is composed of:
ONNX Runtime backward requires dict as input, which is composed of:
* User input
Although not necessary, all user inputs are used for simplicity
* (Partial) Initializers
init_begin = len(user_input)
init_count = len(Pre-computed list of initializer)
* Intermediate tensors TODO: #ImproveGraphSplitting
Intermediate tensors are inferred from input position:
interm_begin = len(user_input) + len(initializer)
interm_count = len(all_inputs) - len(user_input) - len(initializer) - len(grad_output)
* Gradient wrt outputs TODO: #ImproveGraphSplitting
Gradient tensors are inferred from input position:
grads_begin = len(user_input) + len(initializer) + len(intermediate)
grads_count = len(all_inputs) - len(user_input) - len(initializer) - len(intermediate)
* Intermediate tensors
* Gradient wrt outputs
'''
# Dictionary containing both inputs and initializers
result = {}
backward_user_input = self._onnx_graphs_info.backward_user_input_names
backward_intializer = self._onnx_graphs_info.backward_intializer_names_as_input
intermediate = self._onnx_graphs_info.intermediate_tensor_names
backward_output_grad_names = self._onnx_graphs_info.backward_output_grad_names
# Extract info about stashed input and grad output
# Inputs
result_len = 0
for idx, input_data in enumerate(self._forward_session.get_inputs()):
result.update({ input_data.name : inputs[idx]})
result_len += 1
inputs_pos = 0
for idx, name in enumerate(backward_user_input):
result.update({ name : inputs[idx]})
inputs_pos += 1
# Initializers
for initializer in self._onnx_backward_initializers_desc:
result.update({initializer['name']: inputs[result_len]})
result_len += 1
for idx, name in enumerate(backward_intializer, inputs_pos):
result.update({name: inputs[idx]})
inputs_pos += 1
# Intermediate
intermediate_len = len(inputs) - result_len - self._original_module_grad_output_len
for idx in range(intermediate_len):
result.update({self._onnx_forward_intermediate_outputs_desc[idx]['name']: inputs[result_len]})
result_len += 1
for idx, name in enumerate(intermediate, inputs_pos):
result.update({name: inputs[idx]})
inputs_pos += 1
# Grad outputs
for idx in range(len(inputs)-result_len):
result.update({self._onnx_backward_gradient_inputs_desc[idx]['name']: inputs[result_len]})
result_len += 1
for idx, name in enumerate(backward_output_grad_names, inputs_pos):
result.update({name: inputs[idx]})
inputs_pos += 1
return result
@ -303,10 +240,10 @@ class ORTModule(torch.nn.Module):
to distinguish intermediate from output tensors
'''
output_names = [out['name'] for out in self._onnx_forward_outputs_desc]
forward_output = self._forward_session.run(output_names, inputs)
output = forward_output[:len(self._onnx_training_outputs_desc)]
intermediates = forward_output[len(self._onnx_training_outputs_desc):]
forward_output = self._forward_session.run([*self._onnx_graphs_info.user_output_names,
*self._onnx_graphs_info.intermediate_tensor_names], inputs)
output = forward_output[:len(self._onnx_graphs_info.user_output_names)]
intermediates = forward_output[len(self._onnx_graphs_info.user_output_names):]
return output, intermediates
def _run_backward_graph(self, *inputs, **kwargs):
@ -323,7 +260,7 @@ class ORTModule(torch.nn.Module):
# Convert dict of torch tensors to dict of numpy arrays (ORT BE requirement)
data = self._convert_dict_torch_to_numpy(data)
return self._backward_session.run(self._original_module_forward_input_grads, data)
return self._backward_session.run(self._onnx_graphs_info.initializer_grad_names_to_train, data)
@staticmethod
def _get_forward_graph(module, *inputs, **kwargs):
@ -338,6 +275,11 @@ class ORTModule(torch.nn.Module):
# Deepcopy inputs, since input values may change after model run.
sample_inputs_copy = copy.deepcopy(inputs)
# Ignore optional *inputs explicitly specified as None
sig = signature(module.forward)
all_input_names = sig.parameters.keys()
input_names = [name for idx, name in enumerate(all_input_names) if inputs[idx] is not None]
# TODO: Support contrib OPs support? user model has no hint
# from onnxruntime.training import register_custom_ops_pytorch_exporter
# register_custom_ops_pytorch_exporter.register_custom_op()
@ -346,173 +288,68 @@ class ORTModule(torch.nn.Module):
torch.onnx.export(module,
tuple(sample_inputs_copy),
f,
input_names=input_names,
opset_version=ONNX_OPSET_VERSION,
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING)
return onnx.load_model_from_string(f.getvalue())
def _get_initializer_from_graph(self, graph):
'''Returns a descriptor list of initializers for `graph`
The list descriptor has the following format:
[{ 'name': name, 'shape':[int1,...,intN], 'dtype': <onnx.dtype> ]}]
For ONNX types, refer to https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L461
'''
# TODO: There is a tradeoff between memory footprint and total model export time
# Ideally we want to export the model using torch.onnx.export(.., export_params=False, keep_initializers_as_inputs=True)
# to obtain an ONNX model with minimal size and initializers as input.
# However, this results in (guessing) assuming only initializer's name end with '.weight' and '.bias'.
# Otherwise, it is not possible to separate input from initializer after the model is exported
# Options are:
# 1) If memory footprint is more important, we can export ONNX twice, varying keep_initializers_as_inputs flag
# ONNX model is small (400 bytes vs 1.6MB for MNIST), but export takes twice the time
# 2) If total export time is more important, we can export ONNX once, using export_params=True
# ONNX model is bigger, but export takes half the time
# As performance is not the main goal in this first deliverable, using approach 2) for simplicity
initializers = []
for initializer in graph.graph.initializer:
name = initializer.name
shape = initializer.dims
dtype = _utils.dtype_onnx_to_torch(initializer.data_type)
initializers.append({'name': name, 'shape': shape, 'dtype': dtype})
return initializers
def _get_input_from_graph(self, graph, initializers_only=False, append_initializers=[]):
'''Returns a descriptor list of input tensors for an ONNX `graph`
When `initializers_only=True`, only input initializers are returned. Otherwise, both
user input and initializers are considered.
This is being used to get backward initializer list TODO: #ImproveGraphSplitting
When `append_initializers` is not empty, this list is appended to the end of the result list
This is being used to get backward input list TODO: #ImproveGraphSplitting
The list descriptor has the following format:
[{ 'name': name, 'shape':[int1,...,intN], 'dtype': <onnx.dtype> ]}]
For ONNX types, refer to https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L461
'''
inputs = []
for elem in graph.graph.input:
for initializer in self._onnx_forward_initializers_desc:
if elem.name == initializer['name']:
if initializers_only:
name = elem.name
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
inputs.append({'name': name, 'shape': shape, 'dtype': dtype})
break
else:
if not initializers_only:
name = elem.name
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
inputs.append({'name': name, 'shape': shape, 'dtype': dtype})
if append_initializers:
inputs.extend(append_initializers)
return inputs
def _get_gradient_input_from_graph(self, backward_graph, forward_input, forward_initializer, forward_intermediate):
'''Returns a descriptor list of gradient output for `backward_graph`
Gradient output tensors are found through an elimination process, that cross reference
inputs from the backward graph to the forward input, initializer and intermediate tensors.
The list descriptor has the following format:
[{ 'name': name, 'shape':[int1,...,intN], 'dtype': <onnx.dtype> ]}]
For ONNX types, refer to https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L461
TODO: #ImproveGraphSplitting
'''
grads = []
found = False
for elem in backward_graph.graph.input:
for item in forward_input:
if elem.name == item['name']:
# skip output
break
else:
for item in forward_initializer:
if elem.name == item['name']:
# skip output
break
else:
for item in forward_intermediate:
if elem.name == item['name']:
# skip output
break
else:
name = elem.name
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
grads.append({'name': name, 'shape': shape, 'dtype': dtype})
return grads
def _get_output_from_graph(self, graph):
'''Returns a descriptor list of output tensors for an ONNX `graph`
The list descriptor has the following format:
[{ 'name': name, 'shape':[int1,...,intN], 'dtype': <onnx.dtype> ]}]
For ONNX types, refer to https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L461
'''
outputs = []
for elem in graph.graph.output:
for initializer in self._onnx_forward_initializers_desc:
if elem.name == initializer['name']:
# skip initializers
break
else:
name = elem.name
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
outputs.append({'name': name, 'shape': shape, 'dtype': dtype})
return outputs
def _get_intermediate_from_forward_graph(self, forward_graph):
'''Returns a descriptor list with all intermediate tensors for `forward_graph`
Intermediate tensors are found through an elimination process, that cross reference
outputs from the forward graph to the original model (exported to ONNX)
The list descriptor has the following format:
[{ 'name': name, 'shape':[int1,...,intN], 'dtype': <onnx.dtype> ]}]
TODO: #ImproveGraphSplitting
'''
intermediates = []
for elem in forward_graph.graph.output:
for output in self._onnx_training_outputs_desc:
if elem.name == output['name']:
# skip output
break
else:
name = elem.name
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
intermediates.append({'name': name, 'shape': shape, 'dtype': dtype})
return intermediates
@staticmethod
def _build_gradient_graph(forward_graph, config):
'''Adds gradient nodes on top of an existing ONNX graph (with training flag)
TODO: #SplittingGraphAtFrontend
'''
if not config.weight_names_to_train:
weight_names_to_train = set()
def _build_fw_bw_grad_graphs(forward_graph, config):
'''Adds gradient nodes on top of an existing ONNX graph (with training flag)'''
if not config.initializer_names_to_train:
initializer_names_to_train = []
for initializer in forward_graph.graph.initializer:
weight_names_to_train.add(initializer.name)
config.weight_names_to_train = weight_names_to_train
output_names = set()
for output in forward_graph.graph.output:
output_names.add(output.name)
config.output_names = output_names
models = [onnx.load_model_from_string(model_as_string)
for model_as_string in C.ModuleGradientGraphBuilder().build_and_split(forward_graph.SerializeToString(), config)]
return models[0], models[1], models[2]
initializer_names_to_train.append(initializer.name)
config.initializer_names_to_train = initializer_names_to_train
# TODO: Add support to input with grad required
config.input_names_require_grad = []
# input_names_require_grad = []
# input_names_require_grad.append('input.1')
# config.input_names_require_grad = input_names_require_grad
module_gradient_graph_builder = C.ModuleGradientGraphBuilder()
module_gradient_graph_builder.build_and_split(forward_graph.SerializeToString(), config)
forward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_forward_model())
backward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_backward_model())
gradient_model = onnx.load_model_from_string(module_gradient_graph_builder.get_gradient_model())
split_graphs_info = module_gradient_graph_builder.get_split_graphs_info()
return gradient_model, forward_model, backward_model, split_graphs_info
@staticmethod
def _get_io_info_from_onnx_graph(model, graphs_info):
type_map = {}
for name in graphs_info.user_input_names:
type_map[name] = None
for name in graphs_info.initializer_names_to_train:
type_map[name] = None
for name in graphs_info.user_output_names:
type_map[name] = None
for name in graphs_info.backward_user_input_names:
type_map[name] = None
for name in graphs_info.backward_intializer_names_as_input:
type_map[name] = None
for name in graphs_info.intermediate_tensor_names:
type_map[name] = None
for name in graphs_info.user_output_grad_names:
type_map[name] = None
for name in graphs_info.backward_output_grad_names:
type_map[name] = None
for input in model.graph.input:
if input.name in type_map and type_map[input.name] is None:
type_map[input.name] = input.type
for output in model.graph.output:
if output.name in type_map and type_map[output.name] is None:
type_map[output.name] = output.type
output_grad_name = output.name + '_grad'
if output_grad_name in type_map and type_map[output_grad_name] is None:
type_map[output_grad_name] = output.type
return type_map

View file

@ -143,7 +143,7 @@ def main():
# Set log level
numeric_level = getattr(logging, args.log_level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError('Invalid log level: %s' % loglevel)
raise ValueError('Invalid log level: %s' % args.log_level)
logging.basicConfig(level=numeric_level)
else:
print('Training MNIST on vanilla PyTorch....')

View file

@ -1,13 +1,11 @@
import pdb
import logging
import argparse
import torch
import wget
import os
import pandas as pd
import zipfile
from transformers import BertTokenizer
from transformers import BertTokenizer, AutoConfig
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
@ -50,13 +48,10 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
if step == args.train_steps:
break
# Progress update every 40 batches.
if step % args.log_interval == 0 and not step == 0:
# Calculate elapsed time in minutes.
elapsed = format_time(time.time() - t0)
# Report progress.
print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))
# TODO: Dynamic axis is not supported yet
if batch[0].shape[0] != args.batch_size:
logging.warning(f'Dynamic axis is not supported yet {len(batch)}/{args.batch_size}')
continue
# Unpack this training batch from our dataloader.
#
@ -78,12 +73,23 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
model.zero_grad()
# Perform a forward pass (evaluate the model on this training batch).
# This will return the loss (rather than the model output) because we
# have provided the `labels`.
# This will return the loss (rather than the model output) because we have provided the `labels`.
# The documentation for this `model` function is here:
# https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
# https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
# TODO: explicitly setting (optional) inputs to workaround *input, **kwargs limitation on ORTModule
outputs = model(b_input_ids, b_input_mask, None, None, None, None, b_labels)
# outputs = model(b_input_ids,
# token_type_ids = None,
# attention_mask = b_input_mask,
# labels = b_labels)
outputs = model(b_input_ids,
b_input_mask,
None,
None,
None,
None,
b_labels)
if args.view_graphs:
import torchviz
pytorch_backward_graph = torchviz.make_dot(outputs[0], params=dict(list(model.named_parameters())))
@ -91,15 +97,21 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
# The call to `model` always returns a tuple, so we need to pull the
# loss value out of the tuple.
# pdb.set_trace()
loss = outputs[0]
# Progress update every 40 batches.
if step % args.log_interval == 0 and not step == 0:
# Calculate elapsed time in minutes.
elapsed = format_time(time.time() - t0)
# Report progress.
print(f'Batch {step} of {len(train_dataloader)}. Elapsed: {elapsed}. Loss: {loss.item()}')
# Accumulate the training loss over all of the batches so that we can
# calculate the average loss at the end. `loss` is a Tensor containing a
# single value; the `.item()` function just returns the Python value
# from the tensor.
total_loss += loss.item()
# total_loss += loss
# Perform a backward pass to calculate the gradients.
loss.backward()
@ -122,7 +134,7 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
print("\n Average training loss: {0:.2f}".format(avg_train_loss))
print(" Training epoch took: {:}".format(format_time(time.time() - t0)))
def test(model, validation_dataloader, device):
def test(model, validation_dataloader, device, args):
# ========================================
# Validation
# ========================================
@ -143,12 +155,16 @@ def test(model, validation_dataloader, device):
# Evaluate data for one epoch
for batch in validation_dataloader:
# TODO: Dynamic axis is not supported yet
if batch[0].shape[0] != args.test_batch_size:
logging.warning(f'Dynamic axis is not supported yet {len(batch)}/{args.batch_size}')
continue
# Add batch to GPU
batch = tuple(t.to(device) for t in batch)
# Unpack the inputs from our dataloader
b_input_ids, b_input_mask, b_labels = batch
# Telling the model not to compute or store gradients, saving memory and
# speeding up validation
with torch.no_grad():
@ -160,18 +176,22 @@ def test(model, validation_dataloader, device):
# differentiates sentence 1 and 2 in 2-sentence tasks.
# The documentation for this `model` function is here:
# https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
# TODO: explicitly setting (optional) inputs to workaround *input, **kwargs limitation on ORTModule
# TODO: original sample had the last argument equal to None, but b_labels is because model was
# exported using 3 inputs for training, so validation must follow.
# Another approach would be checkpoint the trained model, re-export the model for validation with the checkpoint
outputs = model(b_input_ids,
b_input_mask,
None,
None,
None,
None,
None)
b_labels)
# Get the "logits" output by the model. The "logits" are the output
# values prior to applying an activation function like the softmax.
logits = outputs[0]
logits = outputs[1]
# Move logits and labels to CPU
logits = logits.detach().cpu().numpy()
@ -190,7 +210,7 @@ def test(model, validation_dataloader, device):
print(" Accuracy: {0:.2f}".format(eval_accuracy/nb_eval_steps))
print(" Validation took: {:}".format(format_time(time.time() - t0)))
def load_dataset():
def load_dataset(args):
# 2. Loading CoLA Dataset
print('Downloading dataset...')
@ -276,18 +296,15 @@ def load_dataset():
train_masks = torch.tensor(train_masks)
validation_masks = torch.tensor(validation_masks)
# The DataLoader needs to know our batch size for training, so we specify it
batch_size = 32
# Create the DataLoader for our training set.
train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size)
# Create the DataLoader for our validation set.
validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels)
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=args.test_batch_size)
return train_dataloader, validation_dataloader
@ -310,6 +327,10 @@ def main():
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--pytorch-only', action='store_true', default=False,
help='disables ONNX Runtime training')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('--test-batch-size', type=int, default=32, metavar='N',
help='input batch size for testing (default: 32)')
parser.add_argument('--view-graphs', action='store_true', default=False,
help='views forward and backward graphs')
parser.add_argument('--no-cuda', action='store_true', default=False,
@ -322,6 +343,11 @@ def main():
help='how many batches to wait before logging training status (default: 40)')
parser.add_argument('--train-steps', type=int, default=-1, metavar='N',
help='number of steps to train. Set -1 to run through whole dataset (default: -1)')
parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING',
help='Log level (default: WARNING)')
parser.add_argument('--num-hidden-layers', type=int, default=1, metavar='H',
help='Number of hidden layers for the BERT model. A vanila BERT has 12 hidden layers (default: 1)')
args = parser.parse_args()
# Device (CPU vs CUDA)
@ -333,17 +359,28 @@ def main():
print('No GPU available, using the CPU instead.')
device = torch.device("cpu")
# Set log level
numeric_level = getattr(logging, args.log_level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError('Invalid log level: %s' % args.log_level)
logging.basicConfig(level=numeric_level)
# 2. Dataloader
train_dataloader, validation_dataloader = load_dataset()
train_dataloader, validation_dataloader = load_dataset(args)
# 3. Modeling
# Load BertForSequenceClassification, the pretrained BERT model with a single
# linear classification layer on top.
config = AutoConfig.from_pretrained(
"bert-base-uncased",
num_labels=2,
num_hidden_layers=args.num_hidden_layers,
output_attentions = False, # Whether the model returns attentions weights.
output_hidden_states = False, # Whether the model returns all hidden-states.
)
model = BertForSequenceClassification.from_pretrained(
"bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
num_labels = 2, # The number of output labels--2 for binary classification.
output_attentions = False, # Whether the model returns attentions weights.
output_hidden_states = False, # Whether the model returns all hidden-states.
config=config,
)
if not args.pytorch_only:
@ -382,7 +419,7 @@ def main():
# 4. Train loop (fine-tune)
for epoch_i in range(0, args.epochs):
train(model, optimizer, scheduler, train_dataloader, epoch_i, device, args)
test(model, validation_dataloader, device)
test(model, validation_dataloader, device, args)
if __name__ == '__main__':
main()

View file

@ -15,4 +15,4 @@ echo "Copying PyTorch frontend source-code to build folder"
cp -Rf ../../../orttraining/orttraining/python/training/* ../../../build/Linux/RelWithDebInfo/onnxruntime/training/
echo "Running Flexible API (ORTModule)"
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py --no-cuda --epochs 4 --log-interval 20 --log-level=DEBUG

View file

@ -15,4 +15,4 @@ echo "Copying PyTorch frontend source-code to build folder"
cp -Rf ../../../orttraining/orttraining/python/training/* ../../../build/Linux/RelWithDebInfo/onnxruntime/training/
echo "Running Flexible API (ORTModule)"
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py --epochs 10 --log-interval 100 --log-level=DEBUG