onnxruntime/orttraining/orttraining/python/ort_trainer.py
Bowen Bao 15cb4b3023
Fix session load state & run extra_postpasses only once (#4255)
* Fix session load state & run extra_postpasses only once

* add testcase for onnx model as well
2020-06-23 11:45:26 -07:00

1026 lines
47 KiB
Python

import io
import os
import warnings
import numpy as np
import onnx
from onnx import numpy_helper
from onnx import helper
import torch
import torch.nn
import torch.onnx
import onnxruntime as ort
import onnxruntime.capi.postprocess as postprocess
from distutils.version import LooseVersion
import warnings
from .checkpointing_utils import list_checkpoint_files, get_checkpoint_name, CombineZeroCheckpoint
import onnxruntime.capi.pt_patch
DEFAULT_OPSET_VERSION = 12
class IODescription():
def __init__(self, name, shape, dtype=None, num_classes=None):
self.name_ = name
self.shape_ = shape
self.dtype_ = dtype
self.num_classes_ = num_classes
class ModelDescription():
def __init__(self, inputs, outputs):
self.inputs_ = inputs
self.outputs_ = outputs
def resolve_symbolic_dimensions(inputs, input_descs, output_descs):
import copy
output_descs_copy = copy.deepcopy(output_descs)
resolved_dims = {}
for input, input_desc in zip(inputs, input_descs):
for i, axis in enumerate(input_desc.shape_):
if isinstance(axis, str):
resolved_dims[axis] = input.size()[i]
for output_desc in output_descs_copy:
for i, axis in enumerate(output_desc.shape_):
if isinstance(axis, str):
output_desc.shape_[i] = resolved_dims[axis]
if any(isinstance(axis, str) for axis in output_desc.shape_ for output_desc in output_descs):
raise RuntimeError("Cannot run model with unknown output dimensions")
return output_descs_copy
def generate_sample(desc, device=None):
# symbolic dimensions are described with strings. set symbolic dimensions to be 1
size = [s if isinstance(s, (int)) else 1 for s in desc.shape_]
if desc.num_classes_:
return torch.randint(0, desc.num_classes_, size, dtype=desc.dtype_).to(device)
else:
return torch.randn(size, dtype=desc.dtype_).to(device)
def get_device_index(device):
if type(device) == str:
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
device = torch.device(device)
return 0 if device.index is None else device.index
def input_get_device_index(input):
if isinstance(input, (list, tuple)):
device_index = get_device_index(input[0].device)
else:
device_index = get_device_index(input.device)
return device_index
def get_all_gradients_finite_arg_name(session):
all_fp16_or_fp32_gradients_finite_node_args = [x for x in session._outputs_meta if 'all_gradients_finite' in x.name]
if len(all_fp16_or_fp32_gradients_finite_node_args) != 1:
raise RuntimeError("Failed to find a group NodeArg with name that matches 'all_gradients_finite'\
from the training session.")
return all_fp16_or_fp32_gradients_finite_node_args[0].name
def get_group_accumulated_gradients_output_node_arg_name(session):
# TODO: get the constant string via pybind.
# optimizer_graph_builder BuildGroupNode with fixed string: 'Group_Accumulated_Gradients'
accumulated_gradients_output_node_args = [x for x in session._outputs_meta if 'Group_Accumulated_Gradients' in x.name]
if len(accumulated_gradients_output_node_args) != 1:
raise RuntimeError("Failed to find a group NodeArg with name that matches 'Group_Accumulated_Gradients'\
from the training session.")
return accumulated_gradients_output_node_args[0].name
def ort_training_session_run_helper(session, iobinding, inputs, input_descs, output_descs, device, run_options=None):
for input, input_desc in zip(inputs, input_descs):
device_index = input_get_device_index(input)
iobinding.bind_input(input_desc.name_, input.device.type, device_index, dtype_torch_to_numpy(input.dtype),
list(input.size()), input.data_ptr())
output_descs_resolved = resolve_symbolic_dimensions(inputs, input_descs, output_descs)
torch_outputs = {}
for output_desc in output_descs_resolved:
torch_tensor = torch.zeros(output_desc.shape_, device=device,
dtype=output_desc.eval_dtype_ if hasattr(output_desc, 'eval_dtype_')
else output_desc.dtype_)
iobinding.bind_output(output_desc.name_, torch_tensor.device.type, get_device_index(device),
dtype_torch_to_numpy(torch_tensor.dtype),
list(torch_tensor.size()), torch_tensor.data_ptr())
torch_outputs[output_desc.name_] = torch_tensor
session.run_with_iobinding(iobinding, run_options)
return torch_outputs
def FuseSofmaxNLLToSoftmaxCE(onnx_model):
nll_count = 0
while True:
nll_count = nll_count + 1
nll_loss_node = None
nll_loss_node_index = 0
for nll_loss_node_index, node in enumerate(onnx_model.graph.node):
if node.op_type == "nll_loss" or node.op_type == "NegativeLogLikelihoodLoss":
nll_loss_node = node
break
if nll_loss_node is None:
break
softmax_node = None
softmax_node_index = 0
label_input_name = None
weight_input_name = None
for softmax_node_index, node in enumerate(onnx_model.graph.node):
if node.op_type == "LogSoftmax":
# has to be connected to nll_loss
if len(nll_loss_node.input) > 2:
weight_input_name = nll_loss_node.input[2]
if node.output[0] == nll_loss_node.input[0]:
softmax_node = node
label_input_name = nll_loss_node.input[1]
break
elif node.output[0] == nll_loss_node.input[1]:
softmax_node = node
label_input_name = nll_loss_node.input[0]
break
else:
if softmax_node is not None:
break
if softmax_node is None:
break
# delete nll_loss and LogSoftmax nodes in order
if nll_loss_node_index < softmax_node_index:
del onnx_model.graph.node[softmax_node_index]
del onnx_model.graph.node[nll_loss_node_index]
else:
del onnx_model.graph.node[nll_loss_node_index]
del onnx_model.graph.node[softmax_node_index]
probability_output_name = softmax_node.output[0]
node = onnx_model.graph.node.add()
inputs = [softmax_node.input[0], label_input_name, weight_input_name] if weight_input_name else [softmax_node.input[0], label_input_name]
node.CopyFrom(onnx.helper.make_node("SparseSoftmaxCrossEntropy", inputs,
[nll_loss_node.output[0], probability_output_name],
"nll_loss_node_" + str(nll_count)))
return onnx_model
def delete_input_with_name(input, name):
index = 0
for i in input:
if i.name == name:
del input[index]
break
index = index + 1
# reference:
# https://docs.scipy.org/doc/numpy-1.13.0/user/basics.types.html
# https://pytorch.org/docs/stable/tensors.html
# also must map to types accepted by:
# MLDataType NumpyTypeToOnnxRuntimeType(int numpy_type)
def dtype_torch_to_numpy(torch_dtype):
if torch_dtype == torch.float64 or torch_dtype == torch.double:
return np.float64
elif torch_dtype == torch.float32 or torch_dtype == torch.float:
return np.float32
elif torch_dtype == torch.float16 or torch_dtype == torch.half:
return np.float16
elif torch_dtype == torch.int64 or torch_dtype == torch.long:
return np.longlong
elif torch_dtype == torch.int32 or torch_dtype == torch.int:
return np.int32
elif torch_dtype == torch.int16 or torch_dtype == torch.short:
return np.int16
def wrap_for_input_match(model, loss_fn, input_names):
import inspect
sig = inspect.signature(model.forward)
ordered_list_keys = list(sig.parameters.keys())
if loss_fn:
sig_loss = inspect.signature(loss_fn)
if len(sig_loss.parameters) != 2:
raise RuntimeError("loss function should take two arguments - predict and label.")
# label shall be the second input to loss_fn.
ordered_list_keys = [*ordered_list_keys, list(sig_loss.parameters.keys())[1]]
class model_loss_cls(torch.nn.Module):
def __init__(self, model, loss_fn):
super(model_loss_cls, self).__init__()
self.model_ = model
self.loss_fn_ = loss_fn
def forward(self, *inputs):
# here we assume input can be unpacked into input and label
input, label = inputs[:-1], inputs[-1]
preds = self.model_(*input)
return self.loss_fn_(preds, label), preds
# name match is needed only when input_names are a subset
# of expected inputs (inputs to model and loss_fn combined).
if len(input_names) > len(ordered_list_keys):
# this is likely the case where input arguments are packed.
# TODO: to unpack the input argument.
return model_loss_cls(model, loss_fn) if loss_fn else model
elif len(input_names) == len(ordered_list_keys):
# in this case, we do not require name match.
return model_loss_cls(model, loss_fn) if loss_fn else model
if not all(x in ordered_list_keys for x in input_names):
# model desc has name(s) not matching the model signature. We cannot do anything in this case.
# better to warning the user.
return model_loss_cls(model, loss_fn) if loss_fn else model
# if input_names match ordered_list_keys, there is not need for wrapping
match = True
for i, input_name in enumerate(input_names):
if input_name != ordered_list_keys[i]:
match = False
break
if match:
return model_loss_cls(model, loss_fn) if loss_fn else model
class WrapModel(torch.nn.Module):
def __init__(self, model, loss_fn, input_names):
super(WrapModel, self).__init__()
self.model_ = model
self.loss_fn_ = loss_fn
self.input_names_ = input_names
def forward(self, *inputs):
# *inputs is given by torch trace. It is in the order of input_names.
# model_ takes input in a order (which can be obtained via inspect.signature(model.forward)) different than input_names.
sig = inspect.signature(self.model_.forward)
ordered_list_keys = list(sig.parameters.keys())
input_dict = {}
for key in sig.parameters.keys():
if key in self.input_names_:
input_dict[key] = inputs[self.input_names_.index(key)]
model_out = self.model_(**input_dict)
if self.loss_fn_ is None:
return model_out
label = inputs[-1]
preds = model_out
return self.loss_fn_(preds, label), preds
model = WrapModel(model, loss_fn, input_names)
return model
def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, opset_version=DEFAULT_OPSET_VERSION, _enable_internal_postprocess=True):
# example: {input0:{0:'batch'}, input1:{0:'batch'}}
dynamic_axes = {}
for input in model_desc.inputs_:
symbolic_axis = {}
for i, axis in enumerate(input.shape_):
if isinstance(axis, str):
symbolic_axis[i] = axis
if len(symbolic_axis):
dynamic_axes[input.name_] = symbolic_axis
for output in model_desc.outputs_:
symbolic_axis = {}
for i, axis in enumerate(output.shape_):
if isinstance(axis, str):
symbolic_axis[i] = axis
if len(symbolic_axis):
dynamic_axes[output.name_] = symbolic_axis
input_names = [input.name_ for input in model_desc.inputs_]
output_names = [output.name_ for output in model_desc.outputs_]
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
if isinstance(inputs, dict):
sample_inputs = [inputs[k.name_].to(device=device) for k in model_desc.inputs_]
elif isinstance(inputs, (list, tuple)):
sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(model_desc.inputs_)]
else:
raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.")
# pytorch onnx exporter/trace does not try to match argument names.
# e.g. for models with optional inputs, it requires all inputs be present.
# this is a problem because the model graph depends on inputs provided.
model = wrap_for_input_match(model, loss_fn, input_names)
model.eval()
with torch.no_grad():
sample_outputs = model(*sample_inputs)
if isinstance(sample_outputs, torch.Tensor):
sample_outputs = [sample_outputs]
for sample_output, output_desc in zip(sample_outputs, model_desc.outputs_):
output_desc.dtype_ = sample_output.dtype
model.train()
f = io.BytesIO()
# Other export options to use(this is for backward compatibility).
other_export_options = {}
other_export_options['training'] = True
# This option was added after 1.4 release.
if LooseVersion(torch.__version__) > LooseVersion('1.4.0'):
other_export_options['enable_onnx_checker'] = False
# This option was added after 1.6 release.
if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'):
other_export_options['training'] = torch.onnx.TrainingMode.TRAINING
torch.onnx._export(model, tuple(sample_inputs), f,
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
_retain_param_name=True,
example_outputs=tuple(sample_outputs),
do_constant_folding=False,
**other_export_options)
onnx_model = onnx.load_model_from_string(f.getvalue())
# Remove 'model_.' prefix introduced by model wrapper for initializers.
replace_name_dict = {}
for n in onnx_model.graph.initializer:
if n.name.startswith('model_.'):
replace_name_dict[n.name] = n.name[len('model_.'):]
n.name = replace_name_dict[n.name]
for n in onnx_model.graph.node:
for i, name in enumerate(n.input):
if name in replace_name_dict:
n.input[i] = replace_name_dict[name]
# onnx model initializer may contain non-trainable registered buffers that are not part
# of pytorch model named parameteres.
named_parameters = model.model_.named_parameters() if hasattr(model, 'model_') else model.named_parameters()
assert set([n for n, t in named_parameters]).issubset(
set([n.name for n in onnx_model.graph.initializer])), \
"Initializer names do not match between PyTorch model and ONNX model, " \
"please report a bug to ONNX Runtime."
if _enable_internal_postprocess:
onnx_model = postprocess.run_postprocess(onnx_model)
return onnx_model
def create_ort_training_session_with_optimizer(model, device, training_optimizer_name, lr_params_feed_name,
map_optimizer_attributes, world_rank=-1, world_size=1,
gradient_accumulation_steps=1, bind_parameters=False,
use_mixed_precision=False, allreduce_post_accumulation=False,
partition_optimizer=False,
enable_grad_norm_clip=True,
frozen_weights=[], opset_version=DEFAULT_OPSET_VERSION):
output_name = model.graph.output[0].name
ort_parameters = ort.TrainingParameters()
ort_parameters.loss_output_name = output_name
ort_parameters.use_mixed_precision = use_mixed_precision
ort_parameters.world_rank = world_rank
ort_parameters.world_size = world_size
ort_parameters.gradient_accumulation_steps = gradient_accumulation_steps
ort_parameters.use_mixed_precision = use_mixed_precision
ort_parameters.allreduce_post_accumulation = allreduce_post_accumulation
ort_parameters.partition_optimizer = partition_optimizer
ort_parameters.enable_grad_norm_clip = enable_grad_norm_clip
ort_parameters.set_gradients_as_graph_outputs = False
output_types = {}
for output in model.graph.output:
output_types[output.name] = output.type.tensor_type
# pybind does not allow to add directly to ort_parameters.weights_to_train.
# Have to work around by using a temporary weights_to_train.
torch_params = {}
optimizer_attributes_map = {}
optimizer_int_attributes_map = {}
unused_frozen_weights = [n for n in frozen_weights if n not in [i.name for i in model.graph.initializer]]
if unused_frozen_weights:
raise RuntimeError("{} in frozen_weights not found in model weights.".format(unused_frozen_weights))
weights_to_train = set()
for initializer in model.graph.initializer:
if initializer.name in frozen_weights:
continue
weights_to_train.add(initializer.name)
if map_optimizer_attributes is not None:
attributes = map_optimizer_attributes(initializer.name)
optimizer_attributes_map[initializer.name] = {}
optimizer_int_attributes_map[initializer.name] = {}
for k, v in attributes.items():
if isinstance(v, float):
optimizer_attributes_map[initializer.name][k] = v
elif isinstance(v, int):
optimizer_int_attributes_map[initializer.name][k] = v
else:
raise ValueError("Optimizer attributes must be either float or int.")
else:
optimizer_attributes_map[initializer.name] = {}
optimizer_int_attributes_map[initializer.name] = {}
if bind_parameters:
for initializer in model.graph.initializer:
torch_tensor = torch.nn.Parameter(torch.as_tensor(numpy_helper.to_array(initializer), device=device))
delete_input_with_name(model.graph.input, initializer.name)
model.graph.input.extend(
[helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims)])
torch_params[initializer.name] = torch_tensor
del model.graph.initializer[:]
ort_parameters.weights_to_train = weights_to_train
ort_parameters.training_optimizer_name = training_optimizer_name
ort_parameters.lr_params_feed_name = lr_params_feed_name
ort_parameters.optimizer_attributes_map = optimizer_attributes_map
ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map
session = ort.TrainingSession(model.SerializeToString(), ort_parameters)
train_io_binding = session.io_binding()
eval_io_binding = session.io_binding()
if bind_parameters:
for param in torch_params.keys():
torch_tensor = torch_params[param]
train_io_binding.bind_input(param, torch_tensor.device.type, get_device_index(torch_tensor.device),
dtype_torch_to_numpy(torch_params[param].dtype), list(torch_tensor.size()),
torch_tensor.data_ptr())
eval_io_binding.bind_input(param, torch_tensor.device.type, get_device_index(torch_tensor.device),
dtype_torch_to_numpy(torch_params[param].dtype), list(torch_tensor.size()),
torch_tensor.data_ptr())
return session, train_io_binding, eval_io_binding, output_name, torch_params, output_types
def save_checkpoint(model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", checkpoint_state_dict=None):
if checkpoint_state_dict==None:
checkpoint_state_dict={'model': model.state_dict()}
else:
checkpoint_state_dict.update({'model': model.state_dict()})
assert os.path.exists(checkpoint_dir), "ERROR: Checkpoint directory doesn't exist: {}".format(checkpoint_dir)
checkpoint_name = get_checkpoint_name(checkpoint_prefix, model.partition_optimizer_, model.world_rank, model.world_size)
checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name)
if os.path.exists(checkpoint_file):
warnings.warn("{} already exists, overwriting.".format(checkpoint_file))
torch.save(checkpoint_state_dict, checkpoint_file)
def _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict):
checkpoint_name = get_checkpoint_name(checkpoint_prefix, is_partitioned, model.world_rank, model.world_size)
checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name)
if is_partitioned:
assert_msg = ("Couldn't find checkpoint file {}." +
"Optimizer partitioning is enabled using ZeRO. Please make sure that the "+
"checkpoint file exists for rank {} of {}.").format(checkpoint_file,model.world_rank, model.world_size)
else:
assert_msg = "Couldn't find checkpoint file {}.".format(checkpoint_file)
assert os.path.exists(checkpoint_file), assert_msg
checkpoint_state = torch.load(checkpoint_file, map_location='cpu')
model.load_state_dict(checkpoint_state['model'], strict=strict)
del(checkpoint_state['model'])
return checkpoint_state
def _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict):
checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix)
ckpt_agg = CombineZeroCheckpoint(checkpoint_files)
aggregate_state_dict = ckpt_agg.aggregate_checkpoints()
model.load_state_dict(aggregate_state_dict, strict=strict)
# aggregate other keys in the state_dict.
# Values will be overwritten for matching keys among workers
all_checkpoint_states=dict()
for checkpoint_file in checkpoint_files:
checkpoint_state = torch.load(checkpoint_file, map_location='cpu')
del(checkpoint_state['model'])
all_checkpoint_states.update(checkpoint_state)
return all_checkpoint_states
def load_checkpoint(model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False):
checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix)
is_partitioned = False
if len(checkpoint_files) > 1:
warnings.warn(f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." +
"Attempting to load ZeRO checkpoint.")
is_partitioned = True
if (not model.partition_optimizer_) and is_partitioned:
return _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict)
else:
return _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict)
class ORTTrainer():
def __init__(self, model, loss_fn, model_desc, training_optimizer_name, map_optimizer_attributes,
learning_rate_description, device, gradient_accumulation_steps=1,
world_rank=0, world_size=1, use_mixed_precision=False, allreduce_post_accumulation=False,
global_step=0, get_lr_this_step=None, loss_scaler=None, partition_optimizer=False,
enable_grad_norm_clip=True, frozen_weights=[], _opset_version=DEFAULT_OPSET_VERSION,
_enable_internal_postprocess=True, _extra_postprocess=None):
super(ORTTrainer, self).__init__()
"""
Initialize ORTTrainer.
Args:
model: one of
- a PyTorch model (class that inherits from torch.nn.Module)
- a combined PyTorch model and loss function.
Inputs to this combined PyTorch model are a concatenation of the
model's input and the loss function's label input.
Outputs are a concatenation of the loss function's output and the
model's output.
- a combined ONNX model and loss function.
loss_fn: one of
- a PyTorch loss function if 'model' is a PyTorch model. A loss
function takes two inputs (prediction, label) and outputs a loss
tensor.
- None if model is already combined with a loss function.
model_desc: Specify input/output shapes, types, and names.
Must be consistent with the training model.
training_optimizer_name: one of
- 'SGDOptimizer'
- 'AdamOptimizer'
- 'LambOptimizer'
map_optimizer_attributes: for optimizers with weight-dependent
parameters. A callable that maps weight name to a set of optimization
parameters.
Defaults to None.
learning_rate_description: the name, shape and type of the learning
rate in form of IODescription(Learning_Rate_Name, [1,], torch.float32).
Because learning_rate is an input to the training model,
Learning_Rate_Name must be specified so that there is no name conflict
within the model.
device: device to store tensors (e.g. 'cpu', 'cuda', 'cuda:<int_idx>').
gradient_accumulation_steps: number of training steps to accumulate
gradients before averaging and applying them.
Defaults to 1.
postprocess_model: a callable to postprocess the ONNX model that is
converted from PyTorch.
Defaults to None.
world_rank: rank id used for distributed training.
Defaults to 0.
world_size: number of ranks participating in distributed training.
Defaults to 1.
use_mixed_precision: flag to enable mixed precision (aka fp16).
Defaults to False.
allreduce_post_accumulation: controls whether overlaping gradient
computation is applied with allreduce.
Defaults to False.
global_step: training step that is used as input to 'get_lr_this_step'.
Defaults to 0.
get_lr_this_step: functor used as learning rate scheduler.
It uses 'global_step' as input.
Defaults to None.
loss_scaler: updates loss scale automatically when 'use_mixed_precision'
is specified.
Defaults to None.
partition_optimizer: controls whether to partition the optimizer state.
Defaults to False.
enable_grad_norm_clip: enables gradient norm clipping.
Defaults to True.
frozen_weights: list of model parameters to be frozen (not trained).
Defaults to [].
_enable_internal_postprocess: whether to run or not the internal postprocesses.
Defaults to True
_extra_postprocess: a callable to postprocess the ONNX model that is converted from PyTorch.
Defaults to None
"""
warnings.warn('DISCLAIMER: This is an early version of an experimental training API and it is subject to change. DO NOT create production applications with it')
self.is_train = True
self.torch_model_ = None
self.onnx_model_ = None
if isinstance(model, torch.nn.Module):
self.torch_model_ = model
self.loss_fn_ = loss_fn
else:
self.onnx_model_ = model
if loss_fn is not None:
warnings.warn("loss_fn is not used when creating ORTTrainer because an ONNX model is provided.")
# TODO: accept loss_fn as an onnx model. build self.onnx_model_ with model and loss_fn
self.loss_fn_ = None
self.model_desc_ = model_desc
self.input_desc_with_lr = [*self.model_desc_.inputs_, learning_rate_description]
self.world_rank = world_rank
self.world_size = world_size
self.use_mixed_precision = use_mixed_precision
self.session = None
self.device_ = device
self.gradient_accumulation_steps = gradient_accumulation_steps
# we use self.current_step to count calls to train_step. It is used for gradient accumulation.
# gradients are being accumulated when self.current_step is not divisible by gradient_accumulation_steps.
# gradients are updated when self.current_step is divisible by gradient_accumulation_steps.
self.current_step = 0
# we use self.global_step_ to count optimizations being performed.
# it is used to calculate learning rate if self.get_lr_this_step_ is provided.
self.global_step_ = global_step
self._extra_postprocess = _extra_postprocess
self.get_lr_this_step_ = get_lr_this_step
self.loss_scaler_ = loss_scaler
if self.get_lr_this_step_ is not None or self.loss_scaler_ is not None:
warnings.warn("It is experimental to use learning rate scheduler and loss scaler inside ORTTrainer.")
self.training_optimizer_name_ = training_optimizer_name
self.learning_rate_description_ = learning_rate_description
self.map_optimizer_attributes_ = map_optimizer_attributes
self.allreduce_post_accumulation_ = allreduce_post_accumulation
self.partition_optimizer_ = partition_optimizer
self.enable_grad_norm_clip_ = enable_grad_norm_clip
self.frozen_weights_ = frozen_weights
self.opset_version_ = _opset_version
self.state_dict_ = None
self._enable_internal_postprocess = _enable_internal_postprocess
# use this special string to workaround a corner case that external loss_scale is passed into train_step as kwargs.
# see prepare_input_and_fetches for more details.
self.loss_scale_input_name = 'default_loss_scale_input_name'
self._init_session()
def _init_session(self):
if self.onnx_model_ is None:
return
if self._enable_internal_postprocess:
self._onnx_model_ = postprocess.run_postprocess(self.onnx_model_)
if self._extra_postprocess:
self._extra_postprocess(self.onnx_model_)
self._verify_fully_optimized_model(self.onnx_model_)
self.session, self.train_io_binding, self.eval_io_binding, self.output_name, _, self.output_types = \
create_ort_training_session_with_optimizer(
self.onnx_model_, self.device_,
self.training_optimizer_name_, self.learning_rate_description_.name_, self.map_optimizer_attributes_,
self.world_rank, self.world_size,
self.gradient_accumulation_steps, bind_parameters=False,
use_mixed_precision=self.use_mixed_precision, allreduce_post_accumulation=self.allreduce_post_accumulation_,
partition_optimizer=self.partition_optimizer_,
enable_grad_norm_clip=self.enable_grad_norm_clip_,
frozen_weights=self.frozen_weights_, opset_version=self.opset_version_)
self.loss_scale_input_name = self.session.loss_scale_input_name
if self.use_mixed_precision:
self.input_desc_with_lr_and_loss_scale = [
*self.input_desc_with_lr,
IODescription(self.loss_scale_input_name, [], torch.float32)]
# ORT backend has modified model output dtype from float32 to float16.
for o_desc in self.model_desc_.outputs_:
if self.use_mixed_precision and o_desc.dtype_ == torch.float32 and not self.session.is_output_fp32_node(o_desc.name_):
o_desc.eval_dtype_ = torch.float16
else:
o_desc.eval_dtype_ = o_desc.dtype_
# gradient accumulation buffers are connected to a single node with a boolean, dimension 1 tensor output.
# add a matching output to drive gradient accumulation.
if self.gradient_accumulation_steps > 1:
self.output_desc_with_group_accumulated_gradients = [
*self.model_desc_.outputs_,
IODescription(get_group_accumulated_gradients_output_node_arg_name(self.session), [1], torch.bool)]
if self.use_mixed_precision:
# when ready to use accumulated gradient with mixed precision, we need to fetch all_infinite to determine
# if the gradient is usable.
self.output_desc_with_all_fp_16_or_fp32_gradients_finite = [
*self.model_desc_.outputs_,
IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool)]
if self.state_dict_:
self.load_state_dict(self.state_dict_, self.strict_)
self.state_dict_ = None
def _init_onnx_model(self, inputs):
if self.onnx_model_ is not None:
return
if self.torch_model_ is not None:
# NOTE: pt model is moved to cpu to conserve gpu memory.
self.torch_model_.cpu()
# torch buffers created using 'register_buffer' are not meant to be trainable.
torch_buffers = list(dict(self.torch_model_.named_buffers()).keys())
self.frozen_weights_ = self.frozen_weights_ + torch_buffers
self.onnx_model_ = convert_model_loss_fn_to_onnx(
self.torch_model_, self.loss_fn_, self.model_desc_, torch.device('cpu'), inputs, opset_version=self.opset_version_, _enable_internal_postprocess=self._enable_internal_postprocess)
self._init_session()
def train(self):
self.is_train = True
def eval(self):
self.is_train = False
def _update_onnx_model_initializers(self, state_tensors):
# replace the initializers with new value
new_weights = []
replace_indices = []
for i, w in enumerate(self.onnx_model_.graph.initializer):
if w.name in state_tensors:
new_weights.append(numpy_helper.from_array(state_tensors[w.name], w.name))
replace_indices.append(i)
replace_indices.sort(reverse=True)
for w_i in replace_indices:
del self.onnx_model_.graph.initializer[w_i]
self.onnx_model_.graph.initializer.extend(new_weights)
def state_dict(self):
if not self.session:
warnings.warn("ONNXRuntime training session is not initialized yet. "
"Please run train_step or eval_step at least once before calling state_dict().")
return {}
# extract trained weights
session_state = self.session.get_state()
torch_state = {}
for name in session_state:
torch_state[name] = torch.from_numpy(session_state[name])
# extract untrained weights and buffer
for n in self.onnx_model_.graph.initializer:
if n.name not in torch_state:
torch_state[n.name] = torch.from_numpy(numpy_helper.to_array(n))
return torch_state
def load_state_dict(self, state_dict, strict=False):
# Note: It may happen ONNX model has not yet been initialized
# In this case we cache a reference to desired state and delay the restore until after initialization
# Unexpected behavior will result if the user changes the reference before initialization
if not self.session:
self.state_dict_ = state_dict
self.strict_ = strict
return
# update onnx model from loaded state dict
cur_initializers_names = [n.name for n in self.onnx_model_.graph.initializer]
new_initializers = {}
for name in state_dict:
if name in cur_initializers_names:
new_initializers[name] = state_dict[name].numpy()
elif strict:
raise RuntimeError("Checkpoint tensor: {} is not present in the model.".format(name))
self._update_onnx_model_initializers(new_initializers)
# create new session based on updated onnx model
self.state_dict_ = None
self._init_session()
# load training state
session_state = {name:state_dict[name].numpy() for name in state_dict}
self.session.load_state(session_state, strict)
def save_as_onnx(self, path):
if not self.session:
warnings.warn("ONNXRuntime training session is not initialized yet. "
"Please run train_step or eval_step at least once before calling save_as_onnx().")
return
state_tensors = self.session.get_state()
self._update_onnx_model_initializers(state_tensors)
with open(path, "wb") as f:
f.write(self.onnx_model_.SerializeToString())
def _prepare_input_and_fetches(self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs):
fetches = None
if type(args) == tuple and len(args) == 1 and type(args[0]) == list:
input = tuple(args[0])
else:
input = args
for input_desc in input_desc_with_:
if input_desc.name_ in kwargs:
input = input + (kwargs[input_desc.name_],)
if internal_learning_rate is not None:
input = input + (internal_learning_rate,)
if internal_loss_scale is not None:
input = input + (internal_loss_scale,)
elif self.use_mixed_precision:
# loss_scale input name is needed to call train_step, for example:
# kwargs[model.loss_scale_input_name] = loss_scale
# outputs = model.train_step(*args, **kwargs)
# However, when first time train_step is called model.loss_scale_input_name is not set.
# To workaround this problem, we use the special name 'default_loss_scale_input_name' to indicate
# the loss_scale.
if 'default_loss_scale_input_name' in kwargs.keys():
input = input + (kwargs['default_loss_scale_input_name'],)
fetches = None
if 'fetches' in kwargs:
fetches = kwargs['fetches']
return input, fetches
def train_step(self, *args, **kwargs):
"""
inputs: model inputs, labels, learning rate, and, if in mixed_precision mode, loss_scale.
outputs: if fetches is not provided, outputs are loss and
(if in mixed mode and is finishing gradient accumulation) all_finite.
if fetches is provided, outputs contains these requested with fetches.
fetches: names of requested outputs
"""
# inputs to the ONNX model includes inputs to the original PyTorch model
# plus learning rate and loss_scale if self.use_mixed_precision is True.
# 1. when there are internal learning_rate and loss_scale (in fp16 cases) generators,
# *args and **kwargs together contain ONLY and COMPLETE inputs to the PyTorch model.
# In this case, changes to the training script is minimized.
# 2. without internal learning rate and loss scale (in fp16 cases) generators,
# *args and **kwargs passed in from the training script shall contains
# inputs to the PyTorch model plus learning_rate and loss_scale.
# it optionally contains the fetches.
# localized arguments (*args) contains inputs to the ONNX model.
# named arguments can contain both inputs, learning_rate and loss_scale, and the fetches
learning_rate, loss_scale = None, None
if self.get_lr_this_step_ is not None:
# $args, **kwargs contains inputs to the pytorch model
lr_this_step = self.get_lr_this_step_(self.global_step_)
learning_rate = torch.tensor([lr_this_step])
if self.loss_scaler_ is not None and self.use_mixed_precision:
loss_scale = torch.tensor([self.loss_scaler_.loss_scale_])
if self.onnx_model_ is None:
sample_input, _ = self._prepare_input_and_fetches(self.model_desc_.inputs_,
None, None, *args, **kwargs)
self._init_onnx_model(sample_input)
if self.use_mixed_precision:
input, fetches = self._prepare_input_and_fetches(self.input_desc_with_lr_and_loss_scale,
learning_rate, loss_scale, *args, **kwargs)
assert len(self.input_desc_with_lr_and_loss_scale) == len(input)
input_descs = self.input_desc_with_lr_and_loss_scale
else:
input, fetches = self._prepare_input_and_fetches(self.input_desc_with_lr,
learning_rate, loss_scale, *args, **kwargs)
assert len(self.input_desc_with_lr) == len(input)
input_descs = self.input_desc_with_lr
self.current_step += 1
# handle gradient accumulation in fully optimized mode
run_options = None
has_if_all_finite = False
if fetches:
output_desc = [output for fetch in fetches for output in self.model_desc_.outputs_ if output.name_ == fetch]
elif self.current_step % self.gradient_accumulation_steps != 0:
run_options = ort.RunOptions()
run_options.only_execute_path_to_fetches = True
run_options.training_mode = True
output_desc = self.output_desc_with_group_accumulated_gradients
elif self.use_mixed_precision:
has_if_all_finite = True
output_desc = self.output_desc_with_all_fp_16_or_fp32_gradients_finite
else:
output_desc = self.model_desc_.outputs_
if not isinstance(input, (list, tuple)):
input = (input,)
session_run_results = ort_training_session_run_helper(self.session, self.train_io_binding, input,
input_descs, output_desc,
self.device_,
run_options)
if has_if_all_finite:
# After session run with all_fp32_gradients_finite, we need to clear the iobinding's output state.
# Otherwise next run with only_execute_path_to_fetches will lead to gradient all reduce
# because all_fp32_gradients_finite is still in the feed.
self.train_io_binding.clear_binding_outputs()
all_finite = session_run_results[self.output_desc_with_all_fp_16_or_fp32_gradients_finite[-1].name_]
if self.loss_scaler_ is not None:
self.loss_scaler_.update_loss_scale(all_finite)
if all_finite:
# optimization has done, increase self.global_step_
self.global_step_ = self.global_step_ + 1
elif self.current_step % self.gradient_accumulation_steps == 0:
# optimization has done, increase self.global_step_
self.global_step_ = self.global_step_ + 1
if fetches is not None:
results = [session_run_results[fetch] for fetch in fetches]
elif has_if_all_finite and self.loss_scaler_ is None:
# return descripted outputs plus the all_finite flag so that the training script can handle loss scaling.
results = [session_run_results[output_desc.name_] for output_desc in self.output_desc_with_all_fp_16_or_fp32_gradients_finite]
else:
results = [session_run_results[output_desc.name_] for output_desc in self.model_desc_.outputs_]
return results[0] if len(results) == 1 else results
def __call__(self, *args, **kwargs):
if self.is_train:
return self.train_step(*args, **kwargs)
else:
return self.eval_step(*args, **kwargs)
def eval_step(self, *args, **kwargs):
"""
inputs: model inputs and/or labels.
outputs: if 'fetches' is not provided, outputs are loss and
(if in mixed mode and is finishing gradient accumulation) all_finite.
if fetches is provided, outputs contains these requested with fetches.
fetches: names of requested outputs
"""
# with model_loss_cls, the last input is label, first output is loss
input, fetches = self._prepare_input_and_fetches(self.model_desc_.inputs_,
None, None, *args, **kwargs)
if self.onnx_model_ is None:
if self.torch_model_ is not None:
self._init_onnx_model(input)
else:
raise RuntimeError("Model is unintialized. Please ensure a valid ONNX model or PyTorch model is provided to this Trainer.")
input_desc = self.model_desc_.inputs_[0:len(input)]
if fetches is None:
output_desc = self.model_desc_.outputs_
else:
output_desc = [output for fetch in fetches for output in self.model_desc_.outputs_ if output.name_ == fetch]
if not isinstance(input, (list, tuple)):
input = (input,)
run_options = ort.RunOptions()
run_options.only_execute_path_to_fetches = True
run_options.training_mode = False
session_run_results = ort_training_session_run_helper(self.session, self.eval_io_binding, input,
input_desc,
output_desc,
self.device_,
run_options)
if len(session_run_results) == 1:
return session_run_results[list(session_run_results.keys())[0]]
else:
return [session_run_results[output_desc.name_] for output_desc in output_desc]
def _verify_fully_optimized_model(self, model):
assert(len(model.graph.output) > 0)
# model's first output must be the loss tensor
if model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().FLOAT and\
model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().FLOAT16 and\
model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().DOUBLE and\
model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().COMPLEX64 and\
model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().COMPLEX128 and\
model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().BFLOAT16:
raise RuntimeError("the first output of a model to run with fully optimized ORT backend must be float types.")
if len(model.graph.output[0].type.tensor_type.shape.dim) != 0:
raise RuntimeError(
"the first output of a model to run with fully optimized ORT backend assumed to be loss and must be a scalar.")
class LossScaler():
def __init__(self, loss_scale_input_name, is_dynamic_scale,
loss_scale=float(1 << 16),
up_scale_window=2000,
min_loss_scale=1.0, max_loss_scale=float(1 << 24)):
super(LossScaler, self).__init__()
self.loss_scale_input_name_ = loss_scale_input_name
self.is_dynamic_scale_ = is_dynamic_scale
self.initial_loss_scale_ = loss_scale
self.up_scale_window_ = up_scale_window
self.min_loss_scale_ = min_loss_scale
self.max_loss_scale_ = max_loss_scale
self.loss_scale_ = loss_scale
self.stable_steps_ = 0
def update_loss_scale(self, is_all_finite):
if not self.is_dynamic_scale_:
return
if is_all_finite:
self.stable_steps_ += 1
if self.stable_steps_ >= self.up_scale_window_:
self.loss_scale_ = min(self.max_loss_scale_, self.loss_scale_ * 2)
self.stable_steps_ = 0
else:
self.loss_scale_ = max(self.min_loss_scale_, self.loss_scale_ / 2)
self.stable_steps_ = 0
def reset(self):
self.loss_scale_ = self.initial_loss_scale_
self.stable_steps_ = 0