Update ort_trainer.py with lazy onnx export (#3244)

* Delay onnx export to avoid extra info

* handle cases where onnx model is provided at initialization

* address comments

* fix rebase error
This commit is contained in:
Bowen Bao 2020-03-24 13:34:15 -07:00 committed by GitHub
parent 98c28060b0
commit 6474801ceb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 94 additions and 52 deletions

View file

@ -19,6 +19,16 @@ def ort_trainer_learning_rate_description():
return IODescription('Learning_Rate', [1, ], torch.float32)
def remove_extra_info(model_desc):
simple_model_desc = copy.deepcopy(model_desc)
for input_desc in simple_model_desc.inputs_:
input_desc.dtype_ = None
input_desc.num_classes_ = None
for output_desc in simple_model_desc.outputs_:
output_desc.dtype_ = None
output_desc.num_classes_ = None
return simple_model_desc
def bert_model_description():
vocab_size = 30528
input_ids_desc = IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=vocab_size)
@ -49,12 +59,13 @@ def generate_sample_batch(desc, batch_size, device):
def runBertTrainingTest(gradient_accumulation_steps, use_mixed_precision, allreduce_post_accumulation):
model_desc = bert_model_description()
simple_model_desc = remove_extra_info(model_desc)
learning_rate_description = ort_trainer_learning_rate_description()
device = torch.device("cuda", 0)
onnx_model = onnx.load(get_name("bert_toy_postprocessed.onnx"))
model = ORTTrainer(onnx_model, None, model_desc, "LambOptimizer",
model = ORTTrainer(onnx_model, None, simple_model_desc, "LambOptimizer",
map_optimizer_attributes,
learning_rate_description,
device, postprocess_model=None,

View file

@ -10,7 +10,7 @@ import onnxruntime as ort
from distutils.version import LooseVersion
class IODescription():
def __init__(self, name, shape, dtype, num_classes=None):
def __init__(self, name, shape, dtype=None, num_classes=None):
self.name_ = name
self.shape_ = shape
self.dtype_ = dtype
@ -92,7 +92,7 @@ def ort_training_session_run_helper(session, iobinding, inputs, input_descs, out
output_descs_resolved = resolve_symbolic_dimensions(inputs, input_descs, output_descs)
torch_outputs = {}
for output_desc in output_descs_resolved:
torch_tensor = torch.zeros(output_desc.shape_, device=device,
torch_tensor = torch.zeros(output_desc.shape_, device=device,
dtype=output_desc.eval_dtype_ if hasattr(output_desc, 'eval_dtype_')
else output_desc.dtype_)
@ -195,11 +195,11 @@ def dtype_torch_to_numpy(torch_dtype):
return np.float32
elif torch_dtype == torch.float16 or torch_dtype == torch.half:
return np.float16
elif torch_dtype == torch.int64 or torch_dtype == torch.long:
elif torch_dtype == torch.int64 or torch_dtype == torch.long:
return np.longlong
elif torch_dtype == torch.int32 or torch_dtype == torch.int:
elif torch_dtype == torch.int32 or torch_dtype == torch.int:
return np.int32
elif torch_dtype == torch.int16 or torch_dtype == torch.short:
elif torch_dtype == torch.int16 or torch_dtype == torch.short:
return np.int16
def wrap_for_input_match(model, input_names):
@ -217,7 +217,7 @@ def wrap_for_input_match(model, input_names):
return model
if not all(x in ordered_list_keys for x in input_names):
# model desc has name(s) not matching the model signature. We cannot do anything in this case.
# model desc has name(s) not matching the model signature. We cannot do anything in this case.
# better to warning the user.
return model
@ -253,7 +253,7 @@ def wrap_for_input_match(model, input_names):
model = WrapModel(model, input_names)
return model
def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device):
def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs):
# example: {input0:{0:'batch'}, input1:{0:'batch'}}
dynamic_axes = {}
for input in model_desc.inputs_:
@ -275,15 +275,22 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device):
input_names = [input.name_ for input in model_desc.inputs_]
output_names = [output.name_ for output in model_desc.outputs_]
sample_inputs = []
for input_desc in model_desc.inputs_:
input_sample = generate_sample(input_desc, device)
sample_inputs.append(input_sample)
sample_outputs = []
for output_desc in model_desc.outputs_:
output_sample = generate_sample(output_desc, device)
sample_outputs.append(output_sample)
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
if isinstance(inputs, dict):
sample_inputs = [inputs[k.name_].to(device=device) for k in model_desc.inputs_]
elif isinstance(inputs, (list, tuple)):
sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(model_desc.inputs_)]
else:
raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.")
model.eval()
with torch.no_grad():
sample_outputs = model(*sample_inputs)
if isinstance(sample_outputs, torch.Tensor):
sample_outputs = [sample_outputs]
for sample_output, output_desc in zip(sample_outputs, model_desc.outputs_):
output_desc.dtype_ = sample_output.dtype
model.train()
f = io.BytesIO()
@ -294,15 +301,15 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device):
# e.g. for models with optional inputs, it requires all inputs be present.
# this is a problem because the model graph depends on inputs provided.
model = wrap_for_input_match(model, input_names)
# Other export options to use(this is for backward compatibility).
other_export_options = {}
# This option was added after 1.4 release.
if LooseVersion(torch.__version__) > LooseVersion('1.4.0'):
other_export_options['enable_onnx_checker'] = 'False'
other_export_options['enable_onnx_checker'] = False
torch.onnx._export(model, tuple(sample_inputs), f,
input_names=input_names,
input_names=input_names,
output_names=output_names,
opset_version=10,
dynamic_axes=dynamic_axes,
@ -490,7 +497,7 @@ def create_ort_training_session_bind_parameters(model, device, world_rank=-1, wo
if device.type == 'cuda' and hasattr(device, "index") and device.index is not None:
from onnxruntime.capi._pybind_state import set_cuda_device_id
set_cuda_device_id(device.index)
set_cuda_device_id(device.index)
session = ort.TrainingSession(model.SerializeToString(), ort_parameters)
train_io_binding = session.io_binding()
@ -523,24 +530,24 @@ class ORTTrainer():
Params:
model: either:
- a PyTorch model: if 'loss_fn' is not provided, the 'model's first output must be the loss.
- a PyTorch model: if 'loss_fn' is not provided, the 'model's first output must be the loss.
if 'loss_fn' is provided, 'model' and 'loss_fn' are combined as described in 'loss_fn:'
- an ONNX model: the 'model's first output must be the loss.
loss_fn: a PyTorch loss function. It takes two inputs [prediction, label] and ouput a loss tensor.
loss_fn: a PyTorch loss function. It takes two inputs [prediction, label] and ouput a loss tensor.
If provided, 'loss_fn' is combined with the PyTorch 'model' to form a Combined PyTorch model.
Inputs to the Combined PyTorch model are concatination of the 'model's input and 'loss_fn's label input.
Outputs of the Combined PyTorch model are concatination of 'loss_fn's loss output and 'model's outputs.
model_desc: model input and output description. it is used to specify input/output shapes, types, and names.
'model_desc' must be consistent with the training model.
model_desc: model input and output description. it is used to specify input/output shapes, types, and names.
'model_desc' must be consistent with the training model.
training_optimizer_name: one of: 'SGDOptimizer', 'AdamOptimizer', 'LambOptimizer'.
map_optimizer_attributes: for optimizers with weight dependent parameters,
'map_optimizer_attributes' maps weight name to a set of optimization parameters.
learning_rate_description: in form of IODescription(Learning_Rate_Name, [1,], torch.float32).
Because learning_rate is an input to the training model, Learning_Rate_Name shall be set so that
there is no name conflict within the above model.
Because learning_rate is an input to the training model, Learning_Rate_Name shall be set so that
there is no name conflict within the above model.
device:
gradient_accumulation_steps:
postprocess_model: a callable to postprocess the ONNX model that is converted from PyTorch.
postprocess_model: a callable to postprocess the ONNX model that is converted from PyTorch.
world_rank:
world_size:
use_mixed_precision:
@ -585,30 +592,36 @@ class ORTTrainer():
if self.get_lr_this_step_ is not None or self.loss_scaler_ is not None:
print("It is experimental to use learning rate scheduler and loss scaler inside ORTTrainer.")
self.training_optimizer_name_ = training_optimizer_name
self.learning_rate_description_ = learning_rate_description
self.map_optimizer_attributes_ = map_optimizer_attributes
self.allreduce_post_accumulation_ = allreduce_post_accumulation
self.partition_optimizer_ = partition_optimizer
self.enable_grad_norm_clip_ = enable_grad_norm_clip
self.loss_scale_input_name = ''
if self.torch_model_ is not None:
self.onnx_model_ = convert_model_loss_fn_to_onnx(self.torch_model_, self.loss_fn_, self.model_desc_, torch.device('cpu'))
self._init_session()
if self.post_process_model_fn_:
self.post_process_model_fn_(self.onnx_model_)
def _init_session(self):
if self.onnx_model_ is None:
return
if self.use_mixed_precision:
self.loss_scale_input_name, self.scaled_loss_output_name = add_loss_scale_input(self.onnx_model_)
self.input_desc_with_lr_and_loss_scale = [*self.input_desc_with_lr, IODescription(self.loss_scale_input_name, [], torch.float32)]
else:
self.loss_scale_input_name, self.scaled_loss_output_name = '', ''
self.enable_grad_norm_clip_ = enable_grad_norm_clip
self.verify_fully_optimized_model(self.onnx_model_)
self.session, self.train_io_binding, self.eval_io_binding, self.output_name, _, self.output_types = \
create_ort_training_session_with_optimizer(
self.onnx_model_, device,
training_optimizer_name, learning_rate_description.name_, map_optimizer_attributes,
self.onnx_model_, self.device_,
self.training_optimizer_name_, self.learning_rate_description_.name_, self.map_optimizer_attributes_,
self.world_rank, self.world_size,
self.gradient_accumulation_steps, bind_parameters=False,
use_mixed_precision=use_mixed_precision, allreduce_post_accumulation=allreduce_post_accumulation,
use_mixed_precision=self.use_mixed_precision, allreduce_post_accumulation=self.allreduce_post_accumulation_,
loss_scale_input_name=self.loss_scale_input_name, scaled_loss_output_name=self.scaled_loss_output_name,
partition_optimizer=partition_optimizer, enable_grad_norm_clip=self.enable_grad_norm_clip_)
partition_optimizer=self.partition_optimizer_, enable_grad_norm_clip=self.enable_grad_norm_clip_)
# ORT backend has modified model output dtype from float32 to float16.
for o_desc in self.model_desc_.outputs_:
@ -625,13 +638,23 @@ class ORTTrainer():
IODescription(get_group_accumulated_gradients_output_node_arg_name(self.session), [1], torch.bool)]
if self.use_mixed_precision:
# when ready to use accumulated gradient with mixed precision, we need to fetch all_infinite to determine
# when ready to use accumulated gradient with mixed precision, we need to fetch all_infinite to determine
# if the gradient is usable.
self.output_desc_with_all_fp_16_or_fp32_gradients_finite = [
*self.model_desc_.outputs_,
IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool)]
self.device_ = device
def _init_onnx_model(self, inputs):
if self.onnx_model_ is not None:
return
if self.torch_model_ is not None:
self.onnx_model_ = convert_model_loss_fn_to_onnx(self.torch_model_, self.loss_fn_, self.model_desc_, torch.device('cpu'), inputs)
if self.post_process_model_fn_:
self.post_process_model_fn_(self.onnx_model_)
self._init_session()
def train(self):
self.is_train = True
@ -671,7 +694,7 @@ class ORTTrainer():
with open(path, "wb") as f:
f.write(self.onnx_model_.SerializeToString())
def prepare_input_and_fetches(self, input_desc_with_, learning_rate, loss_scale, *args, **kwargs):
def prepare_input_and_fetches(self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs):
fetches = None
if type(args) == tuple and len(args) == 1 and type(args[0]) == list:
input = tuple(args[0])
@ -681,10 +704,10 @@ class ORTTrainer():
for input_desc in input_desc_with_:
if input_desc.name_ in kwargs:
input = input + (kwargs[input_desc.name_],)
if learning_rate is not None:
input = input + (learning_rate,)
if loss_scale is not None:
input = input + (loss_scale,)
if internal_learning_rate is not None:
input = input + (internal_learning_rate,)
if internal_loss_scale is not None:
input = input + (internal_loss_scale,)
fetches = None
if 'fetches' in kwargs:
@ -707,7 +730,7 @@ class ORTTrainer():
# *args and **kwargs together contain ONLY and COMPLETE inputs to the PyTorch model.
# In this case, changes to the training script is minimized.
# 2. without internal learning rate and loss scale (in fp16 cases) generators,
# *args and **kwargs passed in from the training script shall contains
# *args and **kwargs passed in from the training script shall contains
# inputs to the PyTorch model plus learning_rate and loss_scale.
# it optionally contains the fetches.
# localized arguments (*args) contains inputs to the ONNX model.
@ -722,17 +745,19 @@ class ORTTrainer():
if self.loss_scaler_ is not None and self.use_mixed_precision:
loss_scale = torch.tensor(self.loss_scaler_.loss_scale_)
if self.onnx_model_ is None:
sample_input, _ = self.prepare_input_and_fetches(self.model_desc_.inputs_,
None, None, *args, **kwargs)
self._init_onnx_model(sample_input)
if self.use_mixed_precision:
input, fetches = self.prepare_input_and_fetches(self.input_desc_with_lr_and_loss_scale,
learning_rate, loss_scale, *args, **kwargs)
else:
input, fetches = self.prepare_input_and_fetches(self.input_desc_with_lr,
learning_rate, loss_scale, *args, **kwargs)
if self.use_mixed_precision:
assert len(self.input_desc_with_lr_and_loss_scale) == len(input)
input_descs = self.input_desc_with_lr_and_loss_scale
else:
input, fetches = self.prepare_input_and_fetches(self.input_desc_with_lr,
learning_rate, loss_scale, *args, **kwargs)
assert len(self.input_desc_with_lr) == len(input)
input_descs = self.input_desc_with_lr
@ -796,7 +821,7 @@ class ORTTrainer():
def eval_step(self, *args, **kwargs):
"""
inputs: model inputs and/or labels.
outputs: if 'fetches' is not provided, outputs are loss and
outputs: if 'fetches' is not provided, outputs are loss and
(if in mixed mode and is finishing gradient accumulation) all_finite.
if fetches is provided, outputs contains these requested with fetches.
fetches: names of requested outputs
@ -806,6 +831,12 @@ class ORTTrainer():
input, fetches = self.prepare_input_and_fetches(self.model_desc_.inputs_,
None, None, *args, **kwargs)
if self.onnx_model_ is None:
if self.torch_model_ is not None:
self._init_onnx_model(input)
else:
raise RuntimeError("Model is unintialized. Please ensure a valid ONNX model or PyTorch model is provided to this Trainer.")
input_desc = self.model_desc_.inputs_[0:len(input)]
if fetches is None:
output_desc = self.model_desc_.outputs_
@ -918,7 +949,7 @@ class LossScaler():
self.stable_steps_ = 0
def update_loss_scale(self, is_all_finite):
if not self.is_dynamic_scale_:
if not self.is_dynamic_scale_:
return
if is_all_finite: