mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Update ort_trainer.py with lazy onnx export (#3244)
* Delay onnx export to avoid extra info * handle cases where onnx model is provided at initialization * address comments * fix rebase error
This commit is contained in:
parent
98c28060b0
commit
6474801ceb
2 changed files with 94 additions and 52 deletions
|
|
@ -19,6 +19,16 @@ def ort_trainer_learning_rate_description():
|
|||
return IODescription('Learning_Rate', [1, ], torch.float32)
|
||||
|
||||
|
||||
def remove_extra_info(model_desc):
|
||||
simple_model_desc = copy.deepcopy(model_desc)
|
||||
for input_desc in simple_model_desc.inputs_:
|
||||
input_desc.dtype_ = None
|
||||
input_desc.num_classes_ = None
|
||||
for output_desc in simple_model_desc.outputs_:
|
||||
output_desc.dtype_ = None
|
||||
output_desc.num_classes_ = None
|
||||
return simple_model_desc
|
||||
|
||||
def bert_model_description():
|
||||
vocab_size = 30528
|
||||
input_ids_desc = IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=vocab_size)
|
||||
|
|
@ -49,12 +59,13 @@ def generate_sample_batch(desc, batch_size, device):
|
|||
|
||||
def runBertTrainingTest(gradient_accumulation_steps, use_mixed_precision, allreduce_post_accumulation):
|
||||
model_desc = bert_model_description()
|
||||
simple_model_desc = remove_extra_info(model_desc)
|
||||
learning_rate_description = ort_trainer_learning_rate_description()
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
onnx_model = onnx.load(get_name("bert_toy_postprocessed.onnx"))
|
||||
|
||||
model = ORTTrainer(onnx_model, None, model_desc, "LambOptimizer",
|
||||
model = ORTTrainer(onnx_model, None, simple_model_desc, "LambOptimizer",
|
||||
map_optimizer_attributes,
|
||||
learning_rate_description,
|
||||
device, postprocess_model=None,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import onnxruntime as ort
|
|||
from distutils.version import LooseVersion
|
||||
|
||||
class IODescription():
|
||||
def __init__(self, name, shape, dtype, num_classes=None):
|
||||
def __init__(self, name, shape, dtype=None, num_classes=None):
|
||||
self.name_ = name
|
||||
self.shape_ = shape
|
||||
self.dtype_ = dtype
|
||||
|
|
@ -92,7 +92,7 @@ def ort_training_session_run_helper(session, iobinding, inputs, input_descs, out
|
|||
output_descs_resolved = resolve_symbolic_dimensions(inputs, input_descs, output_descs)
|
||||
torch_outputs = {}
|
||||
for output_desc in output_descs_resolved:
|
||||
torch_tensor = torch.zeros(output_desc.shape_, device=device,
|
||||
torch_tensor = torch.zeros(output_desc.shape_, device=device,
|
||||
dtype=output_desc.eval_dtype_ if hasattr(output_desc, 'eval_dtype_')
|
||||
else output_desc.dtype_)
|
||||
|
||||
|
|
@ -195,11 +195,11 @@ def dtype_torch_to_numpy(torch_dtype):
|
|||
return np.float32
|
||||
elif torch_dtype == torch.float16 or torch_dtype == torch.half:
|
||||
return np.float16
|
||||
elif torch_dtype == torch.int64 or torch_dtype == torch.long:
|
||||
elif torch_dtype == torch.int64 or torch_dtype == torch.long:
|
||||
return np.longlong
|
||||
elif torch_dtype == torch.int32 or torch_dtype == torch.int:
|
||||
elif torch_dtype == torch.int32 or torch_dtype == torch.int:
|
||||
return np.int32
|
||||
elif torch_dtype == torch.int16 or torch_dtype == torch.short:
|
||||
elif torch_dtype == torch.int16 or torch_dtype == torch.short:
|
||||
return np.int16
|
||||
|
||||
def wrap_for_input_match(model, input_names):
|
||||
|
|
@ -217,7 +217,7 @@ def wrap_for_input_match(model, input_names):
|
|||
return model
|
||||
|
||||
if not all(x in ordered_list_keys for x in input_names):
|
||||
# model desc has name(s) not matching the model signature. We cannot do anything in this case.
|
||||
# model desc has name(s) not matching the model signature. We cannot do anything in this case.
|
||||
# better to warning the user.
|
||||
return model
|
||||
|
||||
|
|
@ -253,7 +253,7 @@ def wrap_for_input_match(model, input_names):
|
|||
model = WrapModel(model, input_names)
|
||||
return model
|
||||
|
||||
def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device):
|
||||
def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs):
|
||||
# example: {input0:{0:'batch'}, input1:{0:'batch'}}
|
||||
dynamic_axes = {}
|
||||
for input in model_desc.inputs_:
|
||||
|
|
@ -275,15 +275,22 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device):
|
|||
input_names = [input.name_ for input in model_desc.inputs_]
|
||||
output_names = [output.name_ for output in model_desc.outputs_]
|
||||
|
||||
sample_inputs = []
|
||||
for input_desc in model_desc.inputs_:
|
||||
input_sample = generate_sample(input_desc, device)
|
||||
sample_inputs.append(input_sample)
|
||||
|
||||
sample_outputs = []
|
||||
for output_desc in model_desc.outputs_:
|
||||
output_sample = generate_sample(output_desc, device)
|
||||
sample_outputs.append(output_sample)
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = [inputs]
|
||||
if isinstance(inputs, dict):
|
||||
sample_inputs = [inputs[k.name_].to(device=device) for k in model_desc.inputs_]
|
||||
elif isinstance(inputs, (list, tuple)):
|
||||
sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(model_desc.inputs_)]
|
||||
else:
|
||||
raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.")
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
sample_outputs = model(*sample_inputs)
|
||||
if isinstance(sample_outputs, torch.Tensor):
|
||||
sample_outputs = [sample_outputs]
|
||||
for sample_output, output_desc in zip(sample_outputs, model_desc.outputs_):
|
||||
output_desc.dtype_ = sample_output.dtype
|
||||
model.train()
|
||||
|
||||
f = io.BytesIO()
|
||||
|
||||
|
|
@ -294,15 +301,15 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device):
|
|||
# e.g. for models with optional inputs, it requires all inputs be present.
|
||||
# this is a problem because the model graph depends on inputs provided.
|
||||
model = wrap_for_input_match(model, input_names)
|
||||
|
||||
|
||||
# Other export options to use(this is for backward compatibility).
|
||||
other_export_options = {}
|
||||
# This option was added after 1.4 release.
|
||||
if LooseVersion(torch.__version__) > LooseVersion('1.4.0'):
|
||||
other_export_options['enable_onnx_checker'] = 'False'
|
||||
other_export_options['enable_onnx_checker'] = False
|
||||
|
||||
torch.onnx._export(model, tuple(sample_inputs), f,
|
||||
input_names=input_names,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
opset_version=10,
|
||||
dynamic_axes=dynamic_axes,
|
||||
|
|
@ -490,7 +497,7 @@ def create_ort_training_session_bind_parameters(model, device, world_rank=-1, wo
|
|||
|
||||
if device.type == 'cuda' and hasattr(device, "index") and device.index is not None:
|
||||
from onnxruntime.capi._pybind_state import set_cuda_device_id
|
||||
set_cuda_device_id(device.index)
|
||||
set_cuda_device_id(device.index)
|
||||
session = ort.TrainingSession(model.SerializeToString(), ort_parameters)
|
||||
|
||||
train_io_binding = session.io_binding()
|
||||
|
|
@ -523,24 +530,24 @@ class ORTTrainer():
|
|||
|
||||
Params:
|
||||
model: either:
|
||||
- a PyTorch model: if 'loss_fn' is not provided, the 'model's first output must be the loss.
|
||||
- a PyTorch model: if 'loss_fn' is not provided, the 'model's first output must be the loss.
|
||||
if 'loss_fn' is provided, 'model' and 'loss_fn' are combined as described in 'loss_fn:'
|
||||
- an ONNX model: the 'model's first output must be the loss.
|
||||
loss_fn: a PyTorch loss function. It takes two inputs [prediction, label] and ouput a loss tensor.
|
||||
loss_fn: a PyTorch loss function. It takes two inputs [prediction, label] and ouput a loss tensor.
|
||||
If provided, 'loss_fn' is combined with the PyTorch 'model' to form a Combined PyTorch model.
|
||||
Inputs to the Combined PyTorch model are concatination of the 'model's input and 'loss_fn's label input.
|
||||
Outputs of the Combined PyTorch model are concatination of 'loss_fn's loss output and 'model's outputs.
|
||||
model_desc: model input and output description. it is used to specify input/output shapes, types, and names.
|
||||
'model_desc' must be consistent with the training model.
|
||||
model_desc: model input and output description. it is used to specify input/output shapes, types, and names.
|
||||
'model_desc' must be consistent with the training model.
|
||||
training_optimizer_name: one of: 'SGDOptimizer', 'AdamOptimizer', 'LambOptimizer'.
|
||||
map_optimizer_attributes: for optimizers with weight dependent parameters,
|
||||
'map_optimizer_attributes' maps weight name to a set of optimization parameters.
|
||||
learning_rate_description: in form of IODescription(Learning_Rate_Name, [1,], torch.float32).
|
||||
Because learning_rate is an input to the training model, Learning_Rate_Name shall be set so that
|
||||
there is no name conflict within the above model.
|
||||
Because learning_rate is an input to the training model, Learning_Rate_Name shall be set so that
|
||||
there is no name conflict within the above model.
|
||||
device:
|
||||
gradient_accumulation_steps:
|
||||
postprocess_model: a callable to postprocess the ONNX model that is converted from PyTorch.
|
||||
postprocess_model: a callable to postprocess the ONNX model that is converted from PyTorch.
|
||||
world_rank:
|
||||
world_size:
|
||||
use_mixed_precision:
|
||||
|
|
@ -585,30 +592,36 @@ class ORTTrainer():
|
|||
|
||||
if self.get_lr_this_step_ is not None or self.loss_scaler_ is not None:
|
||||
print("It is experimental to use learning rate scheduler and loss scaler inside ORTTrainer.")
|
||||
self.training_optimizer_name_ = training_optimizer_name
|
||||
self.learning_rate_description_ = learning_rate_description
|
||||
self.map_optimizer_attributes_ = map_optimizer_attributes
|
||||
self.allreduce_post_accumulation_ = allreduce_post_accumulation
|
||||
self.partition_optimizer_ = partition_optimizer
|
||||
self.enable_grad_norm_clip_ = enable_grad_norm_clip
|
||||
self.loss_scale_input_name = ''
|
||||
|
||||
if self.torch_model_ is not None:
|
||||
self.onnx_model_ = convert_model_loss_fn_to_onnx(self.torch_model_, self.loss_fn_, self.model_desc_, torch.device('cpu'))
|
||||
self._init_session()
|
||||
|
||||
if self.post_process_model_fn_:
|
||||
self.post_process_model_fn_(self.onnx_model_)
|
||||
def _init_session(self):
|
||||
if self.onnx_model_ is None:
|
||||
return
|
||||
|
||||
if self.use_mixed_precision:
|
||||
self.loss_scale_input_name, self.scaled_loss_output_name = add_loss_scale_input(self.onnx_model_)
|
||||
self.input_desc_with_lr_and_loss_scale = [*self.input_desc_with_lr, IODescription(self.loss_scale_input_name, [], torch.float32)]
|
||||
else:
|
||||
self.loss_scale_input_name, self.scaled_loss_output_name = '', ''
|
||||
self.enable_grad_norm_clip_ = enable_grad_norm_clip
|
||||
|
||||
self.verify_fully_optimized_model(self.onnx_model_)
|
||||
self.session, self.train_io_binding, self.eval_io_binding, self.output_name, _, self.output_types = \
|
||||
create_ort_training_session_with_optimizer(
|
||||
self.onnx_model_, device,
|
||||
training_optimizer_name, learning_rate_description.name_, map_optimizer_attributes,
|
||||
self.onnx_model_, self.device_,
|
||||
self.training_optimizer_name_, self.learning_rate_description_.name_, self.map_optimizer_attributes_,
|
||||
self.world_rank, self.world_size,
|
||||
self.gradient_accumulation_steps, bind_parameters=False,
|
||||
use_mixed_precision=use_mixed_precision, allreduce_post_accumulation=allreduce_post_accumulation,
|
||||
use_mixed_precision=self.use_mixed_precision, allreduce_post_accumulation=self.allreduce_post_accumulation_,
|
||||
loss_scale_input_name=self.loss_scale_input_name, scaled_loss_output_name=self.scaled_loss_output_name,
|
||||
partition_optimizer=partition_optimizer, enable_grad_norm_clip=self.enable_grad_norm_clip_)
|
||||
partition_optimizer=self.partition_optimizer_, enable_grad_norm_clip=self.enable_grad_norm_clip_)
|
||||
|
||||
# ORT backend has modified model output dtype from float32 to float16.
|
||||
for o_desc in self.model_desc_.outputs_:
|
||||
|
|
@ -625,13 +638,23 @@ class ORTTrainer():
|
|||
IODescription(get_group_accumulated_gradients_output_node_arg_name(self.session), [1], torch.bool)]
|
||||
|
||||
if self.use_mixed_precision:
|
||||
# when ready to use accumulated gradient with mixed precision, we need to fetch all_infinite to determine
|
||||
# when ready to use accumulated gradient with mixed precision, we need to fetch all_infinite to determine
|
||||
# if the gradient is usable.
|
||||
self.output_desc_with_all_fp_16_or_fp32_gradients_finite = [
|
||||
*self.model_desc_.outputs_,
|
||||
IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool)]
|
||||
|
||||
self.device_ = device
|
||||
def _init_onnx_model(self, inputs):
|
||||
if self.onnx_model_ is not None:
|
||||
return
|
||||
|
||||
if self.torch_model_ is not None:
|
||||
self.onnx_model_ = convert_model_loss_fn_to_onnx(self.torch_model_, self.loss_fn_, self.model_desc_, torch.device('cpu'), inputs)
|
||||
|
||||
if self.post_process_model_fn_:
|
||||
self.post_process_model_fn_(self.onnx_model_)
|
||||
|
||||
self._init_session()
|
||||
|
||||
def train(self):
|
||||
self.is_train = True
|
||||
|
|
@ -671,7 +694,7 @@ class ORTTrainer():
|
|||
with open(path, "wb") as f:
|
||||
f.write(self.onnx_model_.SerializeToString())
|
||||
|
||||
def prepare_input_and_fetches(self, input_desc_with_, learning_rate, loss_scale, *args, **kwargs):
|
||||
def prepare_input_and_fetches(self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs):
|
||||
fetches = None
|
||||
if type(args) == tuple and len(args) == 1 and type(args[0]) == list:
|
||||
input = tuple(args[0])
|
||||
|
|
@ -681,10 +704,10 @@ class ORTTrainer():
|
|||
for input_desc in input_desc_with_:
|
||||
if input_desc.name_ in kwargs:
|
||||
input = input + (kwargs[input_desc.name_],)
|
||||
if learning_rate is not None:
|
||||
input = input + (learning_rate,)
|
||||
if loss_scale is not None:
|
||||
input = input + (loss_scale,)
|
||||
if internal_learning_rate is not None:
|
||||
input = input + (internal_learning_rate,)
|
||||
if internal_loss_scale is not None:
|
||||
input = input + (internal_loss_scale,)
|
||||
|
||||
fetches = None
|
||||
if 'fetches' in kwargs:
|
||||
|
|
@ -707,7 +730,7 @@ class ORTTrainer():
|
|||
# *args and **kwargs together contain ONLY and COMPLETE inputs to the PyTorch model.
|
||||
# In this case, changes to the training script is minimized.
|
||||
# 2. without internal learning rate and loss scale (in fp16 cases) generators,
|
||||
# *args and **kwargs passed in from the training script shall contains
|
||||
# *args and **kwargs passed in from the training script shall contains
|
||||
# inputs to the PyTorch model plus learning_rate and loss_scale.
|
||||
# it optionally contains the fetches.
|
||||
# localized arguments (*args) contains inputs to the ONNX model.
|
||||
|
|
@ -722,17 +745,19 @@ class ORTTrainer():
|
|||
if self.loss_scaler_ is not None and self.use_mixed_precision:
|
||||
loss_scale = torch.tensor(self.loss_scaler_.loss_scale_)
|
||||
|
||||
if self.onnx_model_ is None:
|
||||
sample_input, _ = self.prepare_input_and_fetches(self.model_desc_.inputs_,
|
||||
None, None, *args, **kwargs)
|
||||
self._init_onnx_model(sample_input)
|
||||
|
||||
if self.use_mixed_precision:
|
||||
input, fetches = self.prepare_input_and_fetches(self.input_desc_with_lr_and_loss_scale,
|
||||
learning_rate, loss_scale, *args, **kwargs)
|
||||
else:
|
||||
input, fetches = self.prepare_input_and_fetches(self.input_desc_with_lr,
|
||||
learning_rate, loss_scale, *args, **kwargs)
|
||||
|
||||
if self.use_mixed_precision:
|
||||
assert len(self.input_desc_with_lr_and_loss_scale) == len(input)
|
||||
input_descs = self.input_desc_with_lr_and_loss_scale
|
||||
else:
|
||||
input, fetches = self.prepare_input_and_fetches(self.input_desc_with_lr,
|
||||
learning_rate, loss_scale, *args, **kwargs)
|
||||
assert len(self.input_desc_with_lr) == len(input)
|
||||
input_descs = self.input_desc_with_lr
|
||||
|
||||
|
|
@ -796,7 +821,7 @@ class ORTTrainer():
|
|||
def eval_step(self, *args, **kwargs):
|
||||
"""
|
||||
inputs: model inputs and/or labels.
|
||||
outputs: if 'fetches' is not provided, outputs are loss and
|
||||
outputs: if 'fetches' is not provided, outputs are loss and
|
||||
(if in mixed mode and is finishing gradient accumulation) all_finite.
|
||||
if fetches is provided, outputs contains these requested with fetches.
|
||||
fetches: names of requested outputs
|
||||
|
|
@ -806,6 +831,12 @@ class ORTTrainer():
|
|||
input, fetches = self.prepare_input_and_fetches(self.model_desc_.inputs_,
|
||||
None, None, *args, **kwargs)
|
||||
|
||||
if self.onnx_model_ is None:
|
||||
if self.torch_model_ is not None:
|
||||
self._init_onnx_model(input)
|
||||
else:
|
||||
raise RuntimeError("Model is unintialized. Please ensure a valid ONNX model or PyTorch model is provided to this Trainer.")
|
||||
|
||||
input_desc = self.model_desc_.inputs_[0:len(input)]
|
||||
if fetches is None:
|
||||
output_desc = self.model_desc_.outputs_
|
||||
|
|
@ -918,7 +949,7 @@ class LossScaler():
|
|||
self.stable_steps_ = 0
|
||||
|
||||
def update_loss_scale(self, is_all_finite):
|
||||
if not self.is_dynamic_scale_:
|
||||
if not self.is_dynamic_scale_:
|
||||
return
|
||||
|
||||
if is_all_finite:
|
||||
|
|
|
|||
Loading…
Reference in a new issue