mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
### Description The PR implements FloatE4M3FN, FloatE5M2, FloatE4MEFNUZ, FloatE5M2FNUZ as described in PR https://github.com/onnx/onnx/pull/4805. It uses CUDA API to cast float/half to float8 if CUDA>=11.8, a custom implementation if CUDA<11.8. * It implements, Cast, QuantizeLinear, DequantizeLinear for all types on CPU, only for types FloatE4M3FN, FloatE5M2 on CUDA. * It extends the supported types for control flow operator, Shape, Reshape, Identity, If, Loop, Scan, Reshape * It implements Equal(19). * Cast, QuantizeLinear, DequantizeLinear operators now support a parameter `saturate` only valid for float 8 types. It is true by default. In that case, any value out of range is converted into the maximum float 8 value. If false, it is infinite. * QuantizeLinear, DequantizeLinear now supports multiple scales on CUDA (and ROCm by extension), scale = 1D tensor with one scale per channel ### Motivation and Context Supports latest onnx version. Fixes [AB#15395](https://aiinfra.visualstudio.com/6a833879-cd9b-44a4-a9de-adc2d818f13c/_workitems/edit/15395) --------- Co-authored-by: Xavier Dupre <xadupre@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net> Co-authored-by: Randy Shuai <rashuai@microsoft.com> Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> Co-authored-by: Scott McKay <Scott.McKay@microsoft.com>
1241 lines
50 KiB
Python
1241 lines
50 KiB
Python
import io
|
|
import os
|
|
import warnings
|
|
|
|
import numpy as np
|
|
import onnx
|
|
import torch
|
|
import torch.nn
|
|
import torch.onnx
|
|
from onnx import helper, numpy_helper
|
|
from packaging.version import Version as LooseVersion
|
|
|
|
import onnxruntime as ort
|
|
import onnxruntime.capi.pt_patch
|
|
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
|
|
|
from ..training import postprocess
|
|
from .checkpointing_utils import CombineZeroCheckpoint, get_checkpoint_name, list_checkpoint_files
|
|
|
|
DEFAULT_OPSET_VERSION = 14
|
|
|
|
|
|
class IODescription:
|
|
def __init__(self, name, shape, dtype=None, num_classes=None):
|
|
self.name_ = name
|
|
self.shape_ = shape
|
|
self.dtype_ = dtype
|
|
self.num_classes_ = num_classes
|
|
|
|
|
|
class ModelDescription:
|
|
def __init__(self, inputs, outputs):
|
|
self.inputs_ = inputs
|
|
self.outputs_ = outputs
|
|
|
|
|
|
def resolve_symbolic_dimensions(inputs, input_descs, output_descs):
|
|
import copy
|
|
|
|
output_descs_copy = copy.deepcopy(output_descs)
|
|
resolved_dims = {}
|
|
for input, input_desc in zip(inputs, input_descs):
|
|
for i, axis in enumerate(input_desc.shape_):
|
|
if isinstance(axis, str):
|
|
resolved_dims[axis] = input.size()[i]
|
|
|
|
for output_desc in output_descs_copy:
|
|
for i, axis in enumerate(output_desc.shape_):
|
|
if isinstance(axis, str):
|
|
output_desc.shape_[i] = resolved_dims[axis]
|
|
|
|
if any(isinstance(axis, str) for axis in output_desc.shape_ for output_desc in output_descs):
|
|
raise RuntimeError("Cannot run model with unknown output dimensions")
|
|
|
|
return output_descs_copy
|
|
|
|
|
|
def generate_sample(desc, device=None):
|
|
# symbolic dimensions are described with strings. set symbolic dimensions to be 1
|
|
size = [s if isinstance(s, (int)) else 1 for s in desc.shape_]
|
|
if desc.num_classes_:
|
|
return torch.randint(0, desc.num_classes_, size, dtype=desc.dtype_).to(device)
|
|
else:
|
|
return torch.randn(size, dtype=desc.dtype_).to(device)
|
|
|
|
|
|
def get_device_index(device):
|
|
if type(device) == str:
|
|
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
|
|
device = torch.device(device)
|
|
return 0 if device.index is None else device.index
|
|
|
|
|
|
def input_get_device_index(input):
|
|
if isinstance(input, (list, tuple)):
|
|
device_index = get_device_index(input[0].device)
|
|
else:
|
|
device_index = get_device_index(input.device)
|
|
|
|
return device_index
|
|
|
|
|
|
def get_all_gradients_finite_arg_name(session):
|
|
all_fp16_or_fp32_gradients_finite_node_args = [x for x in session._outputs_meta if "all_gradients_finite" in x.name]
|
|
if len(all_fp16_or_fp32_gradients_finite_node_args) < 1:
|
|
raise RuntimeError(
|
|
"Failed to find a group NodeArg with name that matches 'all_gradients_finite'\
|
|
from the training session."
|
|
)
|
|
|
|
return all_fp16_or_fp32_gradients_finite_node_args[0].name
|
|
|
|
|
|
def get_group_accumulated_gradients_output_node_arg_name(session):
|
|
# TODO: get the constant string via pybind.
|
|
# optimizer_graph_builder BuildGroupNode with fixed string: 'Group_Accumulated_Gradients'
|
|
accumulated_gradients_output_node_args = [
|
|
x for x in session._outputs_meta if "Group_Accumulated_Gradients" in x.name
|
|
]
|
|
if len(accumulated_gradients_output_node_args) != 1:
|
|
raise RuntimeError(
|
|
"Failed to find a group NodeArg with name that matches 'Group_Accumulated_Gradients'\
|
|
from the training session."
|
|
)
|
|
|
|
return accumulated_gradients_output_node_args[0].name
|
|
|
|
|
|
def ort_training_session_run_helper(session, iobinding, inputs, input_descs, output_descs, device, run_options=None):
|
|
for input, input_desc in zip(inputs, input_descs):
|
|
device_index = input_get_device_index(input)
|
|
iobinding.bind_input(
|
|
input_desc.name_,
|
|
input.device.type,
|
|
device_index,
|
|
dtype_torch_to_numpy(input.dtype),
|
|
list(input.size()),
|
|
input.data_ptr(),
|
|
)
|
|
|
|
output_descs_resolved = resolve_symbolic_dimensions(inputs, input_descs, output_descs)
|
|
torch_outputs = {}
|
|
for output_desc in output_descs_resolved:
|
|
torch_tensor = torch.zeros(
|
|
output_desc.shape_,
|
|
device=device,
|
|
dtype=output_desc.eval_dtype_ if hasattr(output_desc, "eval_dtype_") else output_desc.dtype_,
|
|
)
|
|
iobinding.bind_output(
|
|
output_desc.name_,
|
|
torch_tensor.device.type,
|
|
get_device_index(device),
|
|
dtype_torch_to_numpy(torch_tensor.dtype),
|
|
list(torch_tensor.size()),
|
|
torch_tensor.data_ptr(),
|
|
)
|
|
torch_outputs[output_desc.name_] = torch_tensor
|
|
|
|
session.run_with_iobinding(iobinding, run_options)
|
|
return torch_outputs
|
|
|
|
|
|
def FuseSofmaxNLLToSoftmaxCE(onnx_model): # noqa: N802
|
|
nll_count = 0
|
|
while True:
|
|
nll_count = nll_count + 1
|
|
nll_loss_node = None
|
|
nll_loss_node_index = 0
|
|
for nll_loss_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007
|
|
if node.op_type == "nll_loss" or node.op_type == "NegativeLogLikelihoodLoss":
|
|
nll_loss_node = node
|
|
break
|
|
|
|
if nll_loss_node is None:
|
|
break
|
|
|
|
softmax_node = None
|
|
softmax_node_index = 0
|
|
label_input_name = None
|
|
weight_input_name = None
|
|
for softmax_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007
|
|
if node.op_type == "LogSoftmax":
|
|
# has to be connected to nll_loss
|
|
if len(nll_loss_node.input) > 2:
|
|
weight_input_name = nll_loss_node.input[2]
|
|
if node.output[0] == nll_loss_node.input[0]:
|
|
softmax_node = node
|
|
label_input_name = nll_loss_node.input[1]
|
|
break
|
|
elif node.output[0] == nll_loss_node.input[1]:
|
|
softmax_node = node
|
|
label_input_name = nll_loss_node.input[0]
|
|
break
|
|
else:
|
|
if softmax_node is not None:
|
|
break
|
|
|
|
if softmax_node is None:
|
|
break
|
|
|
|
# delete nll_loss and LogSoftmax nodes in order
|
|
if nll_loss_node_index < softmax_node_index:
|
|
del onnx_model.graph.node[softmax_node_index]
|
|
del onnx_model.graph.node[nll_loss_node_index]
|
|
else:
|
|
del onnx_model.graph.node[nll_loss_node_index]
|
|
del onnx_model.graph.node[softmax_node_index]
|
|
|
|
probability_output_name = softmax_node.output[0]
|
|
node = onnx_model.graph.node.add()
|
|
inputs = (
|
|
[softmax_node.input[0], label_input_name, weight_input_name]
|
|
if weight_input_name
|
|
else [softmax_node.input[0], label_input_name]
|
|
)
|
|
node.CopyFrom(
|
|
onnx.helper.make_node(
|
|
"SparseSoftmaxCrossEntropy",
|
|
inputs,
|
|
[nll_loss_node.output[0], probability_output_name],
|
|
"nll_loss_node_" + str(nll_count),
|
|
)
|
|
)
|
|
|
|
return onnx_model
|
|
|
|
|
|
def delete_input_with_name(input, name):
|
|
index = 0
|
|
for i in input:
|
|
if i.name == name:
|
|
del input[index]
|
|
break
|
|
index = index + 1
|
|
|
|
|
|
# reference:
|
|
# https://docs.scipy.org/doc/numpy-1.13.0/user/basics.types.html
|
|
# https://pytorch.org/docs/stable/tensors.html
|
|
# also must map to types accepted by:
|
|
# MLDataType NumpyTypeToOnnxRuntimeType(int numpy_type)
|
|
def dtype_torch_to_numpy(torch_dtype):
|
|
if torch_dtype == torch.float64 or torch_dtype == torch.double:
|
|
return np.float64
|
|
elif torch_dtype == torch.float32 or torch_dtype == torch.float:
|
|
return np.float32
|
|
elif torch_dtype == torch.float16 or torch_dtype == torch.half:
|
|
return np.float16
|
|
elif torch_dtype == torch.int64 or torch_dtype == torch.long:
|
|
return np.longlong
|
|
elif torch_dtype == torch.int32 or torch_dtype == torch.int:
|
|
return np.int32
|
|
elif torch_dtype == torch.int16 or torch_dtype == torch.short:
|
|
return np.int16
|
|
elif torch_dtype == torch.bool:
|
|
return bool
|
|
else:
|
|
raise Exception("Torch type to numpy type mapping unavailable for: " + str(torch_dtype))
|
|
|
|
|
|
class model_loss_cls(torch.nn.Module): # noqa: N801
|
|
def __init__(self, model, loss_fn):
|
|
super().__init__()
|
|
self.model_ = model
|
|
self.loss_fn_ = loss_fn
|
|
|
|
def forward(self, *inputs):
|
|
# here we assume input can be unpacked into input and label
|
|
input, label = inputs[:-1], inputs[-1]
|
|
preds = self.model_(*input)
|
|
return self.loss_fn_(preds, label), preds
|
|
|
|
|
|
class WrapModel(torch.nn.Module):
|
|
def __init__(self, model, loss_fn, input_names):
|
|
super().__init__()
|
|
self.model_ = model
|
|
self.loss_fn_ = loss_fn
|
|
self.input_names_ = input_names
|
|
|
|
def forward(self, *inputs):
|
|
import inspect
|
|
|
|
# *inputs is given by torch trace. It is in the order of input_names.
|
|
# model_ takes input in a order (which can be obtained via inspect.signature(model.forward)) different than input_names.
|
|
sig = inspect.signature(self.model_.forward)
|
|
list(sig.parameters.keys())
|
|
|
|
input_dict = {}
|
|
for key in sig.parameters:
|
|
if key in self.input_names_:
|
|
input_dict[key] = inputs[self.input_names_.index(key)]
|
|
|
|
model_out = self.model_(**input_dict)
|
|
if self.loss_fn_ is None:
|
|
return model_out
|
|
|
|
label = inputs[-1]
|
|
preds = model_out
|
|
return self.loss_fn_(preds, label), preds
|
|
|
|
|
|
def wrap_for_input_match(model, loss_fn, input_names):
|
|
import inspect
|
|
|
|
sig = inspect.signature(model.forward)
|
|
ordered_list_keys = list(sig.parameters.keys())
|
|
if loss_fn:
|
|
sig_loss = inspect.signature(loss_fn)
|
|
if len(sig_loss.parameters) != 2:
|
|
raise RuntimeError("loss function should take two arguments - predict and label.")
|
|
|
|
# label shall be the second input to loss_fn.
|
|
ordered_list_keys = [*ordered_list_keys, list(sig_loss.parameters.keys())[1]]
|
|
|
|
# name match is needed only when input_names are a subset
|
|
# of expected inputs (inputs to model and loss_fn combined).
|
|
if len(input_names) > len(ordered_list_keys):
|
|
# this is likely the case where input arguments are packed.
|
|
# TODO: to unpack the input argument.
|
|
return model_loss_cls(model, loss_fn) if loss_fn else model
|
|
elif len(input_names) == len(ordered_list_keys):
|
|
# in this case, we do not require name match.
|
|
return model_loss_cls(model, loss_fn) if loss_fn else model
|
|
|
|
if not all(x in ordered_list_keys for x in input_names):
|
|
# model desc has name(s) not matching the model signature. We cannot do anything in this case.
|
|
# better to warning the user.
|
|
return model_loss_cls(model, loss_fn) if loss_fn else model
|
|
|
|
# if input_names match ordered_list_keys, there is not need for wrapping
|
|
match = True
|
|
for i, input_name in enumerate(input_names):
|
|
if input_name != ordered_list_keys[i]:
|
|
match = False
|
|
break
|
|
|
|
if match:
|
|
return model_loss_cls(model, loss_fn) if loss_fn else model
|
|
|
|
model = WrapModel(model, loss_fn, input_names)
|
|
|
|
return model
|
|
|
|
|
|
def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, opset_version=DEFAULT_OPSET_VERSION):
|
|
# example: {input0:{0:'batch'}, input1:{0:'batch'}}
|
|
dynamic_axes = {}
|
|
for input in model_desc.inputs_:
|
|
symbolic_axis = {}
|
|
for i, axis in enumerate(input.shape_):
|
|
if isinstance(axis, str):
|
|
symbolic_axis[i] = axis
|
|
if len(symbolic_axis):
|
|
dynamic_axes[input.name_] = symbolic_axis
|
|
|
|
for output in model_desc.outputs_:
|
|
symbolic_axis = {}
|
|
for i, axis in enumerate(output.shape_):
|
|
if isinstance(axis, str):
|
|
symbolic_axis[i] = axis
|
|
if len(symbolic_axis):
|
|
dynamic_axes[output.name_] = symbolic_axis
|
|
|
|
input_names = [input.name_ for input in model_desc.inputs_]
|
|
output_names = [output.name_ for output in model_desc.outputs_]
|
|
|
|
if isinstance(inputs, torch.Tensor):
|
|
inputs = [inputs]
|
|
if isinstance(inputs, dict):
|
|
sample_inputs = [inputs[k.name_].to(device=device) for k in model_desc.inputs_]
|
|
elif isinstance(inputs, (list, tuple)):
|
|
sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(model_desc.inputs_)]
|
|
else:
|
|
raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.")
|
|
|
|
# pytorch onnx exporter/trace does not try to match argument names.
|
|
# e.g. for models with optional inputs, it requires all inputs be present.
|
|
# this is a problem because the model graph depends on inputs provided.
|
|
model = wrap_for_input_match(model, loss_fn, input_names)
|
|
|
|
model.eval()
|
|
with torch.no_grad():
|
|
import copy
|
|
|
|
# Deepcopy inputs, since input values may change after model run.
|
|
sample_inputs_copy = copy.deepcopy(sample_inputs)
|
|
try:
|
|
# Deepcopy model, in case model is stateful and changes after model run.
|
|
model_copy = copy.deepcopy(model)
|
|
except Exception:
|
|
model_copy = model
|
|
warnings.warn(
|
|
"This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX."
|
|
" Compute will continue, but unexpected results may occur!"
|
|
)
|
|
|
|
sample_outputs = model_copy(*sample_inputs_copy)
|
|
if isinstance(sample_outputs, torch.Tensor):
|
|
sample_outputs = [sample_outputs]
|
|
for sample_output, output_desc in zip(sample_outputs, model_desc.outputs_):
|
|
output_desc.dtype_ = sample_output.dtype
|
|
model.train()
|
|
|
|
f = io.BytesIO()
|
|
|
|
# Other export options to use(this is for backward compatibility).
|
|
other_export_options = {}
|
|
other_export_options["training"] = True
|
|
|
|
# This option was added after 1.4 release.
|
|
if LooseVersion(torch.__version__) > LooseVersion("1.4.0") and LooseVersion(torch.__version__) < LooseVersion(
|
|
"1.10.0"
|
|
):
|
|
other_export_options["enable_onnx_checker"] = False
|
|
# This option was added after 1.6 release.
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
|
other_export_options["training"] = torch.onnx.TrainingMode.TRAINING
|
|
|
|
# Deepcopy inputs, since input values may change after model run.
|
|
import copy
|
|
|
|
sample_inputs_copy = copy.deepcopy(sample_inputs)
|
|
|
|
# Enable contrib ops export from PyTorch
|
|
from onnxruntime.tools import pytorch_export_contrib_ops
|
|
|
|
pytorch_export_contrib_ops.register()
|
|
|
|
torch.onnx._export(
|
|
model,
|
|
tuple(sample_inputs_copy),
|
|
f,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
opset_version=opset_version,
|
|
dynamic_axes=dynamic_axes,
|
|
do_constant_folding=False,
|
|
**other_export_options,
|
|
)
|
|
|
|
onnx_model = onnx.load_model_from_string(f.getvalue())
|
|
|
|
# Remove 'model_.' prefix introduced by model wrapper for initializers.
|
|
if isinstance(model, (WrapModel, model_loss_cls)):
|
|
replace_name_dict = {}
|
|
for n in onnx_model.graph.initializer:
|
|
if n.name.startswith("model_."):
|
|
replace_name_dict[n.name] = n.name[len("model_.") :]
|
|
n.name = replace_name_dict[n.name]
|
|
for n in onnx_model.graph.node:
|
|
for i, name in enumerate(n.input):
|
|
if name in replace_name_dict:
|
|
n.input[i] = replace_name_dict[name]
|
|
|
|
return onnx_model
|
|
|
|
|
|
def create_ort_training_session_with_optimizer(
|
|
model,
|
|
device,
|
|
training_optimizer_name,
|
|
lr_params_feed_name,
|
|
map_optimizer_attributes,
|
|
world_rank=-1,
|
|
world_size=1,
|
|
gradient_accumulation_steps=1,
|
|
bind_parameters=False,
|
|
use_mixed_precision=False,
|
|
allreduce_post_accumulation=False,
|
|
deepspeed_zero_stage=0,
|
|
enable_grad_norm_clip=True,
|
|
frozen_weights=[], # noqa: B006
|
|
opset_version=DEFAULT_OPSET_VERSION,
|
|
use_deterministic_compute=False,
|
|
use_memory_efficient_gradient=False,
|
|
enable_adasum=False,
|
|
optimized_model_filepath="",
|
|
):
|
|
output_name = model.graph.output[0].name
|
|
ort_parameters = ort.TrainingParameters()
|
|
ort_parameters.loss_output_name = output_name
|
|
ort_parameters.use_mixed_precision = use_mixed_precision
|
|
ort_parameters.world_rank = world_rank
|
|
ort_parameters.world_size = world_size
|
|
ort_parameters.gradient_accumulation_steps = gradient_accumulation_steps
|
|
ort_parameters.allreduce_post_accumulation = allreduce_post_accumulation
|
|
ort_parameters.deepspeed_zero_stage = deepspeed_zero_stage
|
|
ort_parameters.enable_grad_norm_clip = enable_grad_norm_clip
|
|
ort_parameters.set_gradients_as_graph_outputs = False
|
|
ort_parameters.use_memory_efficient_gradient = use_memory_efficient_gradient
|
|
ort_parameters.enable_adasum = enable_adasum
|
|
output_types = {}
|
|
for output in model.graph.output:
|
|
output_types[output.name] = output.type.tensor_type
|
|
|
|
# pybind does not allow to add directly to ort_parameters.weights_to_train.
|
|
# Have to work around by using a temporary weights_to_train.
|
|
torch_params = {}
|
|
optimizer_attributes_map = {}
|
|
optimizer_int_attributes_map = {}
|
|
|
|
unused_frozen_weights = [n for n in frozen_weights if n not in [i.name for i in model.graph.initializer]]
|
|
if unused_frozen_weights:
|
|
raise RuntimeError(f"{unused_frozen_weights} in frozen_weights not found in model weights.")
|
|
|
|
weights_to_train = set()
|
|
for initializer in model.graph.initializer:
|
|
if initializer.name in frozen_weights:
|
|
continue
|
|
weights_to_train.add(initializer.name)
|
|
if map_optimizer_attributes is not None:
|
|
attributes = map_optimizer_attributes(initializer.name)
|
|
optimizer_attributes_map[initializer.name] = {}
|
|
optimizer_int_attributes_map[initializer.name] = {}
|
|
for k, v in attributes.items():
|
|
if isinstance(v, float):
|
|
optimizer_attributes_map[initializer.name][k] = v
|
|
elif isinstance(v, int):
|
|
optimizer_int_attributes_map[initializer.name][k] = v
|
|
else:
|
|
raise ValueError("Optimizer attributes must be either float or int.")
|
|
else:
|
|
optimizer_attributes_map[initializer.name] = {}
|
|
optimizer_int_attributes_map[initializer.name] = {}
|
|
|
|
if bind_parameters:
|
|
for initializer in model.graph.initializer:
|
|
torch_tensor = torch.nn.Parameter(torch.as_tensor(numpy_helper.to_array(initializer), device=device))
|
|
delete_input_with_name(model.graph.input, initializer.name)
|
|
model.graph.input.extend(
|
|
[helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims)]
|
|
)
|
|
torch_params[initializer.name] = torch_tensor
|
|
|
|
del model.graph.initializer[:]
|
|
|
|
ort_parameters.weights_to_train = weights_to_train
|
|
ort_parameters.training_optimizer_name = training_optimizer_name
|
|
ort_parameters.lr_params_feed_name = lr_params_feed_name
|
|
ort_parameters.optimizer_attributes_map = optimizer_attributes_map
|
|
ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map
|
|
|
|
sessionOptions = ort.SessionOptions() # noqa: N806
|
|
sessionOptions.use_deterministic_compute = use_deterministic_compute
|
|
if len(optimized_model_filepath) > 0:
|
|
sessionOptions.optimized_model_filepath = optimized_model_filepath
|
|
session = ort.TrainingSession(model.SerializeToString(), ort_parameters, sessionOptions)
|
|
train_io_binding = session.io_binding()
|
|
eval_io_binding = session.io_binding()
|
|
|
|
if bind_parameters:
|
|
for param in torch_params:
|
|
torch_tensor = torch_params[param]
|
|
|
|
train_io_binding.bind_input(
|
|
param,
|
|
torch_tensor.device.type,
|
|
get_device_index(torch_tensor.device),
|
|
dtype_torch_to_numpy(torch_params[param].dtype),
|
|
list(torch_tensor.size()),
|
|
torch_tensor.data_ptr(),
|
|
)
|
|
eval_io_binding.bind_input(
|
|
param,
|
|
torch_tensor.device.type,
|
|
get_device_index(torch_tensor.device),
|
|
dtype_torch_to_numpy(torch_params[param].dtype),
|
|
list(torch_tensor.size()),
|
|
torch_tensor.data_ptr(),
|
|
)
|
|
|
|
return session, train_io_binding, eval_io_binding, output_name, torch_params, output_types
|
|
|
|
|
|
def save_checkpoint(
|
|
model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", checkpoint_state_dict=None, include_optimizer_state=True
|
|
):
|
|
if checkpoint_state_dict is None:
|
|
checkpoint_state_dict = {"model": model.state_dict(include_optimizer_state)}
|
|
else:
|
|
checkpoint_state_dict.update({"model": model.state_dict(include_optimizer_state)})
|
|
|
|
assert os.path.exists(checkpoint_dir), f"ERROR: Checkpoint directory doesn't exist: {checkpoint_dir}"
|
|
|
|
checkpoint_name = get_checkpoint_name(
|
|
checkpoint_prefix, model.deepspeed_zero_stage_, model.world_rank, model.world_size
|
|
)
|
|
checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name)
|
|
|
|
if os.path.exists(checkpoint_file):
|
|
warnings.warn(f"{checkpoint_file} already exists, overwriting.")
|
|
|
|
torch.save(checkpoint_state_dict, checkpoint_file)
|
|
|
|
|
|
def _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict):
|
|
checkpoint_name = get_checkpoint_name(checkpoint_prefix, is_partitioned, model.world_rank, model.world_size)
|
|
checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name)
|
|
|
|
if is_partitioned:
|
|
assert_msg = (
|
|
"Couldn't find checkpoint file {}."
|
|
"Optimizer partitioning is enabled using ZeRO. Please make sure that the "
|
|
"checkpoint file exists for rank {} of {}."
|
|
).format(checkpoint_file, model.world_rank, model.world_size)
|
|
else:
|
|
assert_msg = f"Couldn't find checkpoint file {checkpoint_file}."
|
|
|
|
assert os.path.exists(checkpoint_file), assert_msg
|
|
|
|
checkpoint_state = torch.load(checkpoint_file, map_location="cpu")
|
|
|
|
model.load_state_dict(checkpoint_state["model"], strict=strict)
|
|
del checkpoint_state["model"]
|
|
return checkpoint_state
|
|
|
|
|
|
def _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict):
|
|
checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix)
|
|
|
|
ckpt_agg = CombineZeroCheckpoint(checkpoint_files)
|
|
aggregate_state_dict = ckpt_agg.aggregate_checkpoints()
|
|
|
|
model.load_state_dict(aggregate_state_dict, strict=strict)
|
|
|
|
# aggregate other keys in the state_dict.
|
|
# Values will be overwritten for matching keys among workers
|
|
all_checkpoint_states = {}
|
|
for checkpoint_file in checkpoint_files:
|
|
checkpoint_state = torch.load(checkpoint_file, map_location="cpu")
|
|
del checkpoint_state["model"]
|
|
all_checkpoint_states.update(checkpoint_state)
|
|
return all_checkpoint_states
|
|
|
|
|
|
def load_checkpoint(model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False):
|
|
checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix)
|
|
is_partitioned = False
|
|
if len(checkpoint_files) > 1:
|
|
warnings.warn(
|
|
f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}."
|
|
"Attempting to load ZeRO checkpoint."
|
|
)
|
|
is_partitioned = True
|
|
if (not model.deepspeed_zero_stage_) and is_partitioned:
|
|
return _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict)
|
|
else:
|
|
return _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict)
|
|
|
|
|
|
class ORTTrainer:
|
|
def __init__(
|
|
self,
|
|
model,
|
|
loss_fn,
|
|
model_desc,
|
|
training_optimizer_name,
|
|
map_optimizer_attributes,
|
|
learning_rate_description,
|
|
device,
|
|
gradient_accumulation_steps=1,
|
|
world_rank=0,
|
|
world_size=1,
|
|
use_mixed_precision=False,
|
|
allreduce_post_accumulation=False,
|
|
global_step=0,
|
|
get_lr_this_step=None,
|
|
loss_scaler=None,
|
|
deepspeed_zero_stage=0,
|
|
enable_grad_norm_clip=True,
|
|
frozen_weights=[], # noqa: B006
|
|
_opset_version=DEFAULT_OPSET_VERSION,
|
|
_enable_internal_postprocess=True,
|
|
_extra_postprocess=None,
|
|
_use_deterministic_compute=False,
|
|
use_memory_efficient_gradient=False,
|
|
run_symbolic_shape_infer=False,
|
|
enable_adasum=False,
|
|
optimized_model_filepath="",
|
|
):
|
|
super().__init__()
|
|
"""
|
|
Initialize ORTTrainer.
|
|
|
|
Args:
|
|
|
|
model: one of
|
|
- a PyTorch model (class that inherits from torch.nn.Module)
|
|
- a combined PyTorch model and loss function.
|
|
Inputs to this combined PyTorch model are a concatenation of the
|
|
model's input and the loss function's label input.
|
|
Outputs are a concatenation of the loss function's output and the
|
|
model's output.
|
|
- a combined ONNX model and loss function.
|
|
loss_fn: one of
|
|
- a PyTorch loss function if 'model' is a PyTorch model. A loss
|
|
function takes two inputs (prediction, label) and outputs a loss
|
|
tensor.
|
|
- None if model is already combined with a loss function.
|
|
model_desc: Specify input/output shapes, types, and names.
|
|
Must be consistent with the training model.
|
|
training_optimizer_name: one of
|
|
- 'SGDOptimizer'
|
|
- 'AdamOptimizer'
|
|
- 'LambOptimizer'
|
|
map_optimizer_attributes: for optimizers with weight-dependent
|
|
parameters. A callable that maps weight name to a set of optimization
|
|
parameters.
|
|
Defaults to None.
|
|
learning_rate_description: the name, shape and type of the learning
|
|
rate in form of IODescription(Learning_Rate_Name, [1,], torch.float32).
|
|
Because learning_rate is an input to the training model,
|
|
Learning_Rate_Name must be specified so that there is no name conflict
|
|
within the model.
|
|
device: device to store tensors (e.g. 'cpu', 'cuda', 'cuda:<int_idx>').
|
|
gradient_accumulation_steps: number of training steps to accumulate
|
|
gradients before averaging and applying them.
|
|
Defaults to 1.
|
|
world_rank: rank id used for distributed training.
|
|
Defaults to 0.
|
|
world_size: number of ranks participating in distributed training.
|
|
Defaults to 1.
|
|
use_mixed_precision: flag to enable mixed precision (aka fp16).
|
|
Defaults to False.
|
|
allreduce_post_accumulation: controls whether overlaping gradient
|
|
computation is applied with allreduce.
|
|
Defaults to False.
|
|
global_step: training step that is used as input to 'get_lr_this_step'.
|
|
Defaults to 0.
|
|
get_lr_this_step: functor used as learning rate scheduler.
|
|
It uses 'global_step' as input.
|
|
Defaults to None.
|
|
loss_scaler: updates loss scale automatically when 'use_mixed_precision'
|
|
is specified.
|
|
Defaults to None.
|
|
deepspeed_zero_stage: controls whether to partition state using the DeepSpeed ZeRO technique. Stages 0 and 1 are supported.
|
|
Defaults to 0 (disabled).
|
|
enable_grad_norm_clip: enables gradient norm clipping.
|
|
Defaults to True.
|
|
frozen_weights: list of model parameters to be frozen (not trained).
|
|
Defaults to [].
|
|
_enable_internal_postprocess: whether to run or not the internal postprocesses.
|
|
Defaults to True
|
|
_extra_postprocess: a callable to postprocess the ONNX model that is converted from PyTorch.
|
|
Defaults to None
|
|
use_memory_efficient_gradient: use memory aware gradient builder.
|
|
Defaults to False
|
|
run_symbolic_shape_infer: run symbolic shape inference
|
|
Defaults to False
|
|
optimized_model_filepath: path to output the optimized training graph.
|
|
Defaults to "" (no output).
|
|
"""
|
|
warnings.warn(
|
|
"ORTTrainer is deprecated and will be removed in ort release 1.14. Please use ORTModule instead.",
|
|
FutureWarning,
|
|
)
|
|
warnings.warn(
|
|
"DISCLAIMER: This is an early version of an experimental training API and it is subject to change. DO NOT create production applications with it"
|
|
)
|
|
self.is_train = True
|
|
|
|
self.torch_model_ = None
|
|
self.onnx_model_ = None
|
|
self._enable_internal_postprocess = _enable_internal_postprocess
|
|
self._extra_postprocess = _extra_postprocess
|
|
|
|
if isinstance(model, torch.nn.Module):
|
|
self.torch_model_ = model
|
|
self.loss_fn_ = loss_fn
|
|
self._torch_state_dict_keys = list(model.state_dict().keys())
|
|
else:
|
|
self._torch_state_dict_keys = []
|
|
self.onnx_model_ = model
|
|
if loss_fn is not None:
|
|
warnings.warn("loss_fn is not used when creating ORTTrainer because an ONNX model is provided.")
|
|
# TODO: accept loss_fn as an onnx model. build self.onnx_model_ with model and loss_fn
|
|
self.loss_fn_ = None
|
|
|
|
if self._enable_internal_postprocess:
|
|
postprocess.run_postprocess(self.onnx_model_)
|
|
|
|
if self._extra_postprocess:
|
|
self._extra_postprocess(self.onnx_model_)
|
|
|
|
self.model_desc_ = model_desc
|
|
self.input_desc_with_lr = [*self.model_desc_.inputs_, learning_rate_description]
|
|
|
|
self.world_rank = world_rank
|
|
self.world_size = world_size
|
|
self.use_mixed_precision = use_mixed_precision
|
|
|
|
self.session = None
|
|
self.device_ = device
|
|
self.gradient_accumulation_steps = gradient_accumulation_steps
|
|
# we use self.current_step to count calls to train_step. It is used for gradient accumulation.
|
|
# gradients are being accumulated when self.current_step is not divisible by gradient_accumulation_steps.
|
|
# gradients are updated when self.current_step is divisible by gradient_accumulation_steps.
|
|
self.current_step = 0
|
|
|
|
# we use self.global_step_ to count optimizations being performed.
|
|
# it is used to calculate learning rate if self.get_lr_this_step_ is provided.
|
|
self.global_step_ = global_step
|
|
self.get_lr_this_step_ = get_lr_this_step
|
|
self.loss_scaler_ = loss_scaler
|
|
|
|
if self.get_lr_this_step_ is not None or self.loss_scaler_ is not None:
|
|
warnings.warn("It is experimental to use learning rate scheduler and loss scaler inside ORTTrainer.")
|
|
self.training_optimizer_name_ = training_optimizer_name
|
|
self.learning_rate_description_ = learning_rate_description
|
|
self.map_optimizer_attributes_ = map_optimizer_attributes
|
|
self.allreduce_post_accumulation_ = allreduce_post_accumulation
|
|
self.deepspeed_zero_stage_ = deepspeed_zero_stage
|
|
self.enable_grad_norm_clip_ = enable_grad_norm_clip
|
|
self.frozen_weights_ = frozen_weights
|
|
self.opset_version_ = _opset_version
|
|
self.state_dict_ = None
|
|
self._use_deterministic_compute = _use_deterministic_compute
|
|
self.use_memory_efficient_gradient = use_memory_efficient_gradient
|
|
self.run_symbolic_shape_infer = run_symbolic_shape_infer
|
|
self.enable_adasum = enable_adasum
|
|
self.optimized_model_filepath = optimized_model_filepath
|
|
|
|
# use this special string to workaround a corner case that external loss_scale is passed into train_step as kwargs.
|
|
# see prepare_input_and_fetches for more details.
|
|
self.loss_scale_input_name = "default_loss_scale_input_name"
|
|
|
|
self._init_session()
|
|
|
|
def _init_session(self):
|
|
if self.onnx_model_ is None:
|
|
return
|
|
|
|
self._verify_fully_optimized_model(self.onnx_model_)
|
|
|
|
if self.run_symbolic_shape_infer:
|
|
self.onnx_model_ = SymbolicShapeInference.infer_shapes(
|
|
self.onnx_model_, auto_merge=True, guess_output_rank=True
|
|
)
|
|
|
|
# old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error.
|
|
# for example, load_state_dict will be called before returing the function, and it calls _init_session again
|
|
del self.session
|
|
(
|
|
self.session,
|
|
self.train_io_binding,
|
|
self.eval_io_binding,
|
|
self.output_name,
|
|
_,
|
|
self.output_types,
|
|
) = create_ort_training_session_with_optimizer(
|
|
self.onnx_model_,
|
|
self.device_,
|
|
self.training_optimizer_name_,
|
|
self.learning_rate_description_.name_,
|
|
self.map_optimizer_attributes_,
|
|
self.world_rank,
|
|
self.world_size,
|
|
self.gradient_accumulation_steps,
|
|
bind_parameters=False,
|
|
use_mixed_precision=self.use_mixed_precision,
|
|
allreduce_post_accumulation=self.allreduce_post_accumulation_,
|
|
deepspeed_zero_stage=self.deepspeed_zero_stage_,
|
|
enable_grad_norm_clip=self.enable_grad_norm_clip_,
|
|
frozen_weights=self.frozen_weights_,
|
|
opset_version=self.opset_version_,
|
|
use_deterministic_compute=self._use_deterministic_compute,
|
|
use_memory_efficient_gradient=self.use_memory_efficient_gradient,
|
|
enable_adasum=self.enable_adasum,
|
|
optimized_model_filepath=self.optimized_model_filepath,
|
|
)
|
|
|
|
self.loss_scale_input_name = self.session.loss_scale_input_name
|
|
|
|
if self.use_mixed_precision:
|
|
self.input_desc_with_lr_and_loss_scale = [
|
|
*self.input_desc_with_lr,
|
|
IODescription(self.loss_scale_input_name, [], torch.float32),
|
|
]
|
|
|
|
# ORT backend has modified model output dtype from float32 to float16.
|
|
for o_desc in self.model_desc_.outputs_:
|
|
if (
|
|
self.use_mixed_precision
|
|
and o_desc.dtype_ == torch.float32
|
|
and not self.session.is_output_fp32_node(o_desc.name_)
|
|
):
|
|
o_desc.eval_dtype_ = torch.float16
|
|
else:
|
|
o_desc.eval_dtype_ = o_desc.dtype_
|
|
|
|
# gradient accumulation buffers are connected to a single node with a boolean, dimension 1 tensor output.
|
|
# add a matching output to drive gradient accumulation.
|
|
if self.gradient_accumulation_steps > 1:
|
|
self.output_desc_with_group_accumulated_gradients = [
|
|
*self.model_desc_.outputs_,
|
|
IODescription(get_group_accumulated_gradients_output_node_arg_name(self.session), [1], torch.bool),
|
|
]
|
|
|
|
if self.use_mixed_precision:
|
|
# when ready to use accumulated gradient with mixed precision, we need to fetch all_infinite to determine
|
|
# if the gradient is usable.
|
|
self.output_desc_with_all_fp_16_or_fp32_gradients_finite = [
|
|
*self.model_desc_.outputs_,
|
|
IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool),
|
|
]
|
|
|
|
if self.state_dict_:
|
|
self.load_state_dict(self.state_dict_, self.strict_)
|
|
self.state_dict_ = None
|
|
|
|
def _init_onnx_model(self, inputs):
|
|
if self.onnx_model_ is not None:
|
|
return
|
|
|
|
if self.torch_model_ is not None:
|
|
# NOTE: pt model is moved to cpu to conserve gpu memory.
|
|
self.torch_model_.cpu()
|
|
# torch buffers created using 'register_buffer' are not meant to be trainable.
|
|
torch_buffers = list(dict(self.torch_model_.named_buffers()).keys())
|
|
self.frozen_weights_ = self.frozen_weights_ + torch_buffers
|
|
self.onnx_model_ = convert_model_loss_fn_to_onnx(
|
|
self.torch_model_,
|
|
self.loss_fn_,
|
|
self.model_desc_,
|
|
torch.device("cpu"),
|
|
inputs,
|
|
opset_version=self.opset_version_,
|
|
)
|
|
|
|
if self._enable_internal_postprocess:
|
|
postprocess.run_postprocess(self.onnx_model_)
|
|
|
|
if self._extra_postprocess:
|
|
self._extra_postprocess(self.onnx_model_)
|
|
|
|
self._init_session()
|
|
|
|
def train(self):
|
|
self.is_train = True
|
|
|
|
def eval(self):
|
|
self.is_train = False
|
|
|
|
def _update_onnx_model_initializers(self, state_tensors):
|
|
# replace the initializers with new value
|
|
new_weights = []
|
|
replace_indices = []
|
|
for i, w in enumerate(self.onnx_model_.graph.initializer):
|
|
if w.name in state_tensors:
|
|
new_weights.append(numpy_helper.from_array(state_tensors[w.name], w.name))
|
|
replace_indices.append(i)
|
|
replace_indices.sort(reverse=True)
|
|
for w_i in replace_indices:
|
|
del self.onnx_model_.graph.initializer[w_i]
|
|
self.onnx_model_.graph.initializer.extend(new_weights)
|
|
|
|
def state_dict(self, include_optimizer_state=True):
|
|
if not self.session:
|
|
warnings.warn(
|
|
"ONNXRuntime training session is not initialized yet. "
|
|
"Please run train_step or eval_step at least once before calling state_dict()."
|
|
)
|
|
return {}
|
|
|
|
# extract trained weights
|
|
session_state = self.session.get_state()
|
|
torch_state = {}
|
|
for name in session_state:
|
|
torch_state[name] = torch.from_numpy(session_state[name])
|
|
|
|
# extract untrained weights and buffer
|
|
for n in self.onnx_model_.graph.initializer:
|
|
if n.name not in torch_state:
|
|
torch_state[n.name] = torch.from_numpy(numpy_helper.to_array(n))
|
|
|
|
# Need to remove redundant initializers and name suffices to map back to original torch state names
|
|
if not include_optimizer_state and self._torch_state_dict_keys:
|
|
return {key: torch_state[key] for key in self._torch_state_dict_keys if key in torch_state}
|
|
return torch_state
|
|
|
|
def load_state_dict(self, state_dict, strict=False):
|
|
# Note: It may happen ONNX model has not yet been initialized
|
|
# In this case we cache a reference to desired state and delay the restore until after initialization
|
|
# Unexpected behavior will result if the user changes the reference before initialization
|
|
if not self.session:
|
|
self.state_dict_ = state_dict
|
|
self.strict_ = strict
|
|
return
|
|
|
|
# update onnx model from loaded state dict
|
|
cur_initializers_names = [n.name for n in self.onnx_model_.graph.initializer]
|
|
new_initializers = {}
|
|
|
|
for name in state_dict:
|
|
if name in cur_initializers_names:
|
|
new_initializers[name] = state_dict[name].numpy()
|
|
elif strict:
|
|
raise RuntimeError(f"Checkpoint tensor: {name} is not present in the model.")
|
|
|
|
self._update_onnx_model_initializers(new_initializers)
|
|
|
|
# create new session based on updated onnx model
|
|
self.state_dict_ = None
|
|
self._init_session()
|
|
|
|
# load training state
|
|
session_state = {name: state_dict[name].numpy() for name in state_dict}
|
|
self.session.load_state(session_state, strict)
|
|
|
|
def save_as_onnx(self, path):
|
|
if not self.session:
|
|
warnings.warn(
|
|
"ONNXRuntime training session is not initialized yet. "
|
|
"Please run train_step or eval_step at least once before calling save_as_onnx()."
|
|
)
|
|
return
|
|
state_tensors = self.session.get_state()
|
|
self._update_onnx_model_initializers(state_tensors)
|
|
|
|
with open(path, "wb") as f:
|
|
f.write(self.onnx_model_.SerializeToString())
|
|
|
|
def _prepare_input_and_fetches(
|
|
self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs
|
|
):
|
|
fetches = None
|
|
if type(args) == tuple and len(args) == 1 and type(args[0]) == list:
|
|
input = tuple(args[0])
|
|
else:
|
|
input = args
|
|
|
|
for input_desc in input_desc_with_:
|
|
if input_desc.name_ in kwargs:
|
|
input = (*input, kwargs[input_desc.name_])
|
|
if internal_learning_rate is not None:
|
|
input = (*input, internal_learning_rate)
|
|
if internal_loss_scale is not None:
|
|
input = (*input, internal_loss_scale)
|
|
elif self.use_mixed_precision:
|
|
# loss_scale input name is needed to call train_step, for example:
|
|
# kwargs[model.loss_scale_input_name] = loss_scale
|
|
# outputs = model.train_step(*args, **kwargs)
|
|
# However, when first time train_step is called model.loss_scale_input_name is not set.
|
|
# To workaround this problem, we use the special name 'default_loss_scale_input_name' to indicate
|
|
# the loss_scale.
|
|
if "default_loss_scale_input_name" in kwargs:
|
|
input = (*input, kwargs["default_loss_scale_input_name"])
|
|
|
|
fetches = None
|
|
if "fetches" in kwargs:
|
|
fetches = kwargs["fetches"]
|
|
|
|
return input, fetches
|
|
|
|
def train_step(self, *args, **kwargs):
|
|
"""
|
|
inputs: model inputs, labels, learning rate, and, if in mixed_precision mode, loss_scale.
|
|
outputs: if fetches is not provided, outputs are loss and
|
|
(if in mixed mode and is finishing gradient accumulation) all_finite.
|
|
if fetches is provided, outputs contains these requested with fetches.
|
|
fetches: names of requested outputs
|
|
"""
|
|
|
|
# inputs to the ONNX model includes inputs to the original PyTorch model
|
|
# plus learning rate and loss_scale if self.use_mixed_precision is True.
|
|
# 1. when there are internal learning_rate and loss_scale (in fp16 cases) generators,
|
|
# *args and **kwargs together contain ONLY and COMPLETE inputs to the PyTorch model.
|
|
# In this case, changes to the training script is minimized.
|
|
# 2. without internal learning rate and loss scale (in fp16 cases) generators,
|
|
# *args and **kwargs passed in from the training script shall contains
|
|
# inputs to the PyTorch model plus learning_rate and loss_scale.
|
|
# it optionally contains the fetches.
|
|
# localized arguments (*args) contains inputs to the ONNX model.
|
|
# named arguments can contain both inputs, learning_rate and loss_scale, and the fetches
|
|
|
|
learning_rate, loss_scale = None, None
|
|
if self.get_lr_this_step_ is not None:
|
|
# $args, **kwargs contains inputs to the pytorch model
|
|
lr_this_step = self.get_lr_this_step_(self.global_step_)
|
|
learning_rate = torch.tensor([lr_this_step])
|
|
if self.loss_scaler_ is not None and self.use_mixed_precision:
|
|
loss_scale = torch.tensor([self.loss_scaler_.loss_scale_])
|
|
|
|
if self.onnx_model_ is None:
|
|
sample_input, _ = self._prepare_input_and_fetches(self.model_desc_.inputs_, None, None, *args, **kwargs)
|
|
self._init_onnx_model(sample_input)
|
|
|
|
if self.use_mixed_precision:
|
|
input, fetches = self._prepare_input_and_fetches(
|
|
self.input_desc_with_lr_and_loss_scale, learning_rate, loss_scale, *args, **kwargs
|
|
)
|
|
assert len(self.input_desc_with_lr_and_loss_scale) == len(input)
|
|
input_descs = self.input_desc_with_lr_and_loss_scale
|
|
else:
|
|
input, fetches = self._prepare_input_and_fetches(
|
|
self.input_desc_with_lr, learning_rate, loss_scale, *args, **kwargs
|
|
)
|
|
assert len(self.input_desc_with_lr) == len(input)
|
|
input_descs = self.input_desc_with_lr
|
|
|
|
self.current_step += 1
|
|
|
|
# handle gradient accumulation in fully optimized mode
|
|
run_options = None
|
|
has_if_all_finite = False
|
|
if fetches:
|
|
output_desc = [output for fetch in fetches for output in self.model_desc_.outputs_ if output.name_ == fetch]
|
|
elif self.current_step % self.gradient_accumulation_steps != 0:
|
|
run_options = ort.RunOptions()
|
|
run_options.only_execute_path_to_fetches = True
|
|
output_desc = self.output_desc_with_group_accumulated_gradients
|
|
elif self.use_mixed_precision:
|
|
has_if_all_finite = True
|
|
output_desc = self.output_desc_with_all_fp_16_or_fp32_gradients_finite
|
|
else:
|
|
output_desc = self.model_desc_.outputs_
|
|
|
|
if not isinstance(input, (list, tuple)):
|
|
input = (input,)
|
|
|
|
session_run_results = ort_training_session_run_helper(
|
|
self.session, self.train_io_binding, input, input_descs, output_desc, self.device_, run_options
|
|
)
|
|
|
|
if has_if_all_finite:
|
|
# After session run with all_fp32_gradients_finite, we need to clear the iobinding's output state.
|
|
# Otherwise next run with only_execute_path_to_fetches will lead to gradient all reduce
|
|
# because all_fp32_gradients_finite is still in the feed.
|
|
self.train_io_binding.clear_binding_outputs()
|
|
all_finite = session_run_results[self.output_desc_with_all_fp_16_or_fp32_gradients_finite[-1].name_]
|
|
if self.loss_scaler_ is not None:
|
|
self.loss_scaler_.update_loss_scale(all_finite)
|
|
if all_finite:
|
|
# optimization has done, increase self.global_step_
|
|
self.global_step_ = self.global_step_ + 1
|
|
elif self.current_step % self.gradient_accumulation_steps == 0:
|
|
# optimization has done, increase self.global_step_
|
|
self.global_step_ = self.global_step_ + 1
|
|
|
|
if fetches is not None:
|
|
results = [session_run_results[fetch] for fetch in fetches]
|
|
elif has_if_all_finite and self.loss_scaler_ is None:
|
|
# return descripted outputs plus the all_finite flag so that the training script can handle loss scaling.
|
|
results = [
|
|
session_run_results[output_desc.name_]
|
|
for output_desc in self.output_desc_with_all_fp_16_or_fp32_gradients_finite
|
|
]
|
|
else:
|
|
results = [session_run_results[output_desc.name_] for output_desc in self.model_desc_.outputs_]
|
|
return results[0] if len(results) == 1 else results
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
if self.is_train:
|
|
return self.train_step(*args, **kwargs)
|
|
else:
|
|
return self.eval_step(*args, **kwargs)
|
|
|
|
def eval_step(self, *args, **kwargs):
|
|
"""
|
|
inputs: model inputs and/or labels.
|
|
outputs: if 'fetches' is not provided, outputs are loss and
|
|
(if in mixed mode and is finishing gradient accumulation) all_finite.
|
|
if fetches is provided, outputs contains these requested with fetches.
|
|
fetches: names of requested outputs
|
|
"""
|
|
|
|
# with model_loss_cls, the last input is label, first output is loss
|
|
input, fetches = self._prepare_input_and_fetches(self.model_desc_.inputs_, None, None, *args, **kwargs)
|
|
|
|
if self.onnx_model_ is None:
|
|
if self.torch_model_ is not None:
|
|
self._init_onnx_model(input)
|
|
else:
|
|
raise RuntimeError(
|
|
"Model is unintialized. Please ensure a valid ONNX model or PyTorch model is provided to this Trainer."
|
|
)
|
|
|
|
input_desc = self.model_desc_.inputs_[0 : len(input)]
|
|
if fetches is None:
|
|
output_desc = self.model_desc_.outputs_
|
|
else:
|
|
output_desc = [output for fetch in fetches for output in self.model_desc_.outputs_ if output.name_ == fetch]
|
|
|
|
if not isinstance(input, (list, tuple)):
|
|
input = (input,)
|
|
|
|
run_options = ort.RunOptions()
|
|
run_options.only_execute_path_to_fetches = True
|
|
run_options.training_mode = False
|
|
|
|
session_run_results = ort_training_session_run_helper(
|
|
self.session, self.eval_io_binding, input, input_desc, output_desc, self.device_, run_options
|
|
)
|
|
|
|
if len(session_run_results) == 1:
|
|
return session_run_results[list(session_run_results.keys())[0]]
|
|
else:
|
|
return [session_run_results[output_desc.name_] for output_desc in output_desc]
|
|
|
|
def _verify_fully_optimized_model(self, model):
|
|
assert len(model.graph.output) > 0
|
|
# model's first output must be the loss tensor
|
|
if model.graph.output[0].type.tensor_type.elem_type not in {
|
|
onnx.TensorProto.FLOAT,
|
|
onnx.TensorProto.FLOAT16,
|
|
onnx.TensorProto.DOUBLE,
|
|
onnx.TensorProto.COMPLEX64,
|
|
onnx.TensorProto.COMPLEX128,
|
|
onnx.TensorProto.BFLOAT16,
|
|
onnx.TensorProto.FLOAT8E4M3FN,
|
|
onnx.TensorProto.FLOAT8E4M3FNUZ,
|
|
onnx.TensorProto.FLOAT8E5M2,
|
|
onnx.TensorProto.FLOAT8E5M2FNUZ,
|
|
}:
|
|
raise RuntimeError(
|
|
"the first output of a model to run with fully optimized ORT backend must be float types."
|
|
)
|
|
if len(model.graph.output[0].type.tensor_type.shape.dim) != 0:
|
|
raise RuntimeError(
|
|
"the first output of a model to run with fully optimized ORT backend assumed to be loss and must be a scalar."
|
|
)
|
|
|
|
|
|
class LossScaler:
|
|
def __init__(
|
|
self,
|
|
loss_scale_input_name,
|
|
is_dynamic_scale,
|
|
loss_scale=float(1 << 16), # noqa: B008
|
|
up_scale_window=2000,
|
|
min_loss_scale=1.0,
|
|
max_loss_scale=float(1 << 24), # noqa: B008
|
|
):
|
|
super().__init__()
|
|
self.loss_scale_input_name_ = loss_scale_input_name
|
|
self.is_dynamic_scale_ = is_dynamic_scale
|
|
self.initial_loss_scale_ = loss_scale
|
|
self.up_scale_window_ = up_scale_window
|
|
self.min_loss_scale_ = min_loss_scale
|
|
self.max_loss_scale_ = max_loss_scale
|
|
self.loss_scale_ = loss_scale
|
|
self.stable_steps_ = 0
|
|
|
|
def update_loss_scale(self, is_all_finite):
|
|
if not self.is_dynamic_scale_:
|
|
return
|
|
|
|
if is_all_finite:
|
|
self.stable_steps_ += 1
|
|
|
|
if self.stable_steps_ >= self.up_scale_window_:
|
|
self.loss_scale_ = min(self.max_loss_scale_, self.loss_scale_ * 2)
|
|
self.stable_steps_ = 0
|
|
else:
|
|
self.loss_scale_ = max(self.min_loss_scale_, self.loss_scale_ / 2)
|
|
self.stable_steps_ = 0
|
|
|
|
def reset(self):
|
|
self.loss_scale_ = self.initial_loss_scale_
|
|
self.stable_steps_ = 0
|