From 6474801cebfb5c660a88e3d8eebabdb54ac933dc Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Tue, 24 Mar 2020 13:34:15 -0700 Subject: [PATCH] Update ort_trainer.py with lazy onnx export (#3244) * Delay onnx export to avoid extra info * handle cases where onnx model is provided at initialization * address comments * fix rebase error --- .../python/onnxruntime_test_ort_trainer.py | 13 +- orttraining/orttraining/python/ort_trainer.py | 133 +++++++++++------- 2 files changed, 94 insertions(+), 52 deletions(-) diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py index 3a7c912459..7d7bb9fbbd 100644 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py +++ b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py @@ -19,6 +19,16 @@ def ort_trainer_learning_rate_description(): return IODescription('Learning_Rate', [1, ], torch.float32) +def remove_extra_info(model_desc): + simple_model_desc = copy.deepcopy(model_desc) + for input_desc in simple_model_desc.inputs_: + input_desc.dtype_ = None + input_desc.num_classes_ = None + for output_desc in simple_model_desc.outputs_: + output_desc.dtype_ = None + output_desc.num_classes_ = None + return simple_model_desc + def bert_model_description(): vocab_size = 30528 input_ids_desc = IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=vocab_size) @@ -49,12 +59,13 @@ def generate_sample_batch(desc, batch_size, device): def runBertTrainingTest(gradient_accumulation_steps, use_mixed_precision, allreduce_post_accumulation): model_desc = bert_model_description() + simple_model_desc = remove_extra_info(model_desc) learning_rate_description = ort_trainer_learning_rate_description() device = torch.device("cuda", 0) onnx_model = onnx.load(get_name("bert_toy_postprocessed.onnx")) - model = ORTTrainer(onnx_model, None, model_desc, "LambOptimizer", + model = ORTTrainer(onnx_model, None, simple_model_desc, "LambOptimizer", map_optimizer_attributes, learning_rate_description, device, postprocess_model=None, diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index 32f2011fde..c55c44e81a 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -10,7 +10,7 @@ import onnxruntime as ort from distutils.version import LooseVersion class IODescription(): - def __init__(self, name, shape, dtype, num_classes=None): + def __init__(self, name, shape, dtype=None, num_classes=None): self.name_ = name self.shape_ = shape self.dtype_ = dtype @@ -92,7 +92,7 @@ def ort_training_session_run_helper(session, iobinding, inputs, input_descs, out output_descs_resolved = resolve_symbolic_dimensions(inputs, input_descs, output_descs) torch_outputs = {} for output_desc in output_descs_resolved: - torch_tensor = torch.zeros(output_desc.shape_, device=device, + torch_tensor = torch.zeros(output_desc.shape_, device=device, dtype=output_desc.eval_dtype_ if hasattr(output_desc, 'eval_dtype_') else output_desc.dtype_) @@ -195,11 +195,11 @@ def dtype_torch_to_numpy(torch_dtype): return np.float32 elif torch_dtype == torch.float16 or torch_dtype == torch.half: return np.float16 - elif torch_dtype == torch.int64 or torch_dtype == torch.long: + elif torch_dtype == torch.int64 or torch_dtype == torch.long: return np.longlong - elif torch_dtype == torch.int32 or torch_dtype == torch.int: + elif torch_dtype == torch.int32 or torch_dtype == torch.int: return np.int32 - elif torch_dtype == torch.int16 or torch_dtype == torch.short: + elif torch_dtype == torch.int16 or torch_dtype == torch.short: return np.int16 def wrap_for_input_match(model, input_names): @@ -217,7 +217,7 @@ def wrap_for_input_match(model, input_names): return model if not all(x in ordered_list_keys for x in input_names): - # model desc has name(s) not matching the model signature. We cannot do anything in this case. + # model desc has name(s) not matching the model signature. We cannot do anything in this case. # better to warning the user. return model @@ -253,7 +253,7 @@ def wrap_for_input_match(model, input_names): model = WrapModel(model, input_names) return model -def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device): +def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs): # example: {input0:{0:'batch'}, input1:{0:'batch'}} dynamic_axes = {} for input in model_desc.inputs_: @@ -275,15 +275,22 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device): input_names = [input.name_ for input in model_desc.inputs_] output_names = [output.name_ for output in model_desc.outputs_] - sample_inputs = [] - for input_desc in model_desc.inputs_: - input_sample = generate_sample(input_desc, device) - sample_inputs.append(input_sample) - - sample_outputs = [] - for output_desc in model_desc.outputs_: - output_sample = generate_sample(output_desc, device) - sample_outputs.append(output_sample) + if isinstance(inputs, torch.Tensor): + inputs = [inputs] + if isinstance(inputs, dict): + sample_inputs = [inputs[k.name_].to(device=device) for k in model_desc.inputs_] + elif isinstance(inputs, (list, tuple)): + sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(model_desc.inputs_)] + else: + raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.") + model.eval() + with torch.no_grad(): + sample_outputs = model(*sample_inputs) + if isinstance(sample_outputs, torch.Tensor): + sample_outputs = [sample_outputs] + for sample_output, output_desc in zip(sample_outputs, model_desc.outputs_): + output_desc.dtype_ = sample_output.dtype + model.train() f = io.BytesIO() @@ -294,15 +301,15 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device): # e.g. for models with optional inputs, it requires all inputs be present. # this is a problem because the model graph depends on inputs provided. model = wrap_for_input_match(model, input_names) - + # Other export options to use(this is for backward compatibility). other_export_options = {} # This option was added after 1.4 release. if LooseVersion(torch.__version__) > LooseVersion('1.4.0'): - other_export_options['enable_onnx_checker'] = 'False' + other_export_options['enable_onnx_checker'] = False torch.onnx._export(model, tuple(sample_inputs), f, - input_names=input_names, + input_names=input_names, output_names=output_names, opset_version=10, dynamic_axes=dynamic_axes, @@ -490,7 +497,7 @@ def create_ort_training_session_bind_parameters(model, device, world_rank=-1, wo if device.type == 'cuda' and hasattr(device, "index") and device.index is not None: from onnxruntime.capi._pybind_state import set_cuda_device_id - set_cuda_device_id(device.index) + set_cuda_device_id(device.index) session = ort.TrainingSession(model.SerializeToString(), ort_parameters) train_io_binding = session.io_binding() @@ -523,24 +530,24 @@ class ORTTrainer(): Params: model: either: - - a PyTorch model: if 'loss_fn' is not provided, the 'model's first output must be the loss. + - a PyTorch model: if 'loss_fn' is not provided, the 'model's first output must be the loss. if 'loss_fn' is provided, 'model' and 'loss_fn' are combined as described in 'loss_fn:' - an ONNX model: the 'model's first output must be the loss. - loss_fn: a PyTorch loss function. It takes two inputs [prediction, label] and ouput a loss tensor. + loss_fn: a PyTorch loss function. It takes two inputs [prediction, label] and ouput a loss tensor. If provided, 'loss_fn' is combined with the PyTorch 'model' to form a Combined PyTorch model. Inputs to the Combined PyTorch model are concatination of the 'model's input and 'loss_fn's label input. Outputs of the Combined PyTorch model are concatination of 'loss_fn's loss output and 'model's outputs. - model_desc: model input and output description. it is used to specify input/output shapes, types, and names. - 'model_desc' must be consistent with the training model. + model_desc: model input and output description. it is used to specify input/output shapes, types, and names. + 'model_desc' must be consistent with the training model. training_optimizer_name: one of: 'SGDOptimizer', 'AdamOptimizer', 'LambOptimizer'. map_optimizer_attributes: for optimizers with weight dependent parameters, 'map_optimizer_attributes' maps weight name to a set of optimization parameters. learning_rate_description: in form of IODescription(Learning_Rate_Name, [1,], torch.float32). - Because learning_rate is an input to the training model, Learning_Rate_Name shall be set so that - there is no name conflict within the above model. + Because learning_rate is an input to the training model, Learning_Rate_Name shall be set so that + there is no name conflict within the above model. device: gradient_accumulation_steps: - postprocess_model: a callable to postprocess the ONNX model that is converted from PyTorch. + postprocess_model: a callable to postprocess the ONNX model that is converted from PyTorch. world_rank: world_size: use_mixed_precision: @@ -585,30 +592,36 @@ class ORTTrainer(): if self.get_lr_this_step_ is not None or self.loss_scaler_ is not None: print("It is experimental to use learning rate scheduler and loss scaler inside ORTTrainer.") + self.training_optimizer_name_ = training_optimizer_name + self.learning_rate_description_ = learning_rate_description + self.map_optimizer_attributes_ = map_optimizer_attributes + self.allreduce_post_accumulation_ = allreduce_post_accumulation + self.partition_optimizer_ = partition_optimizer + self.enable_grad_norm_clip_ = enable_grad_norm_clip + self.loss_scale_input_name = '' - if self.torch_model_ is not None: - self.onnx_model_ = convert_model_loss_fn_to_onnx(self.torch_model_, self.loss_fn_, self.model_desc_, torch.device('cpu')) + self._init_session() - if self.post_process_model_fn_: - self.post_process_model_fn_(self.onnx_model_) + def _init_session(self): + if self.onnx_model_ is None: + return if self.use_mixed_precision: self.loss_scale_input_name, self.scaled_loss_output_name = add_loss_scale_input(self.onnx_model_) self.input_desc_with_lr_and_loss_scale = [*self.input_desc_with_lr, IODescription(self.loss_scale_input_name, [], torch.float32)] else: self.loss_scale_input_name, self.scaled_loss_output_name = '', '' - self.enable_grad_norm_clip_ = enable_grad_norm_clip self.verify_fully_optimized_model(self.onnx_model_) self.session, self.train_io_binding, self.eval_io_binding, self.output_name, _, self.output_types = \ create_ort_training_session_with_optimizer( - self.onnx_model_, device, - training_optimizer_name, learning_rate_description.name_, map_optimizer_attributes, + self.onnx_model_, self.device_, + self.training_optimizer_name_, self.learning_rate_description_.name_, self.map_optimizer_attributes_, self.world_rank, self.world_size, self.gradient_accumulation_steps, bind_parameters=False, - use_mixed_precision=use_mixed_precision, allreduce_post_accumulation=allreduce_post_accumulation, + use_mixed_precision=self.use_mixed_precision, allreduce_post_accumulation=self.allreduce_post_accumulation_, loss_scale_input_name=self.loss_scale_input_name, scaled_loss_output_name=self.scaled_loss_output_name, - partition_optimizer=partition_optimizer, enable_grad_norm_clip=self.enable_grad_norm_clip_) + partition_optimizer=self.partition_optimizer_, enable_grad_norm_clip=self.enable_grad_norm_clip_) # ORT backend has modified model output dtype from float32 to float16. for o_desc in self.model_desc_.outputs_: @@ -625,13 +638,23 @@ class ORTTrainer(): IODescription(get_group_accumulated_gradients_output_node_arg_name(self.session), [1], torch.bool)] if self.use_mixed_precision: - # when ready to use accumulated gradient with mixed precision, we need to fetch all_infinite to determine + # when ready to use accumulated gradient with mixed precision, we need to fetch all_infinite to determine # if the gradient is usable. self.output_desc_with_all_fp_16_or_fp32_gradients_finite = [ *self.model_desc_.outputs_, IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool)] - self.device_ = device + def _init_onnx_model(self, inputs): + if self.onnx_model_ is not None: + return + + if self.torch_model_ is not None: + self.onnx_model_ = convert_model_loss_fn_to_onnx(self.torch_model_, self.loss_fn_, self.model_desc_, torch.device('cpu'), inputs) + + if self.post_process_model_fn_: + self.post_process_model_fn_(self.onnx_model_) + + self._init_session() def train(self): self.is_train = True @@ -671,7 +694,7 @@ class ORTTrainer(): with open(path, "wb") as f: f.write(self.onnx_model_.SerializeToString()) - def prepare_input_and_fetches(self, input_desc_with_, learning_rate, loss_scale, *args, **kwargs): + def prepare_input_and_fetches(self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs): fetches = None if type(args) == tuple and len(args) == 1 and type(args[0]) == list: input = tuple(args[0]) @@ -681,10 +704,10 @@ class ORTTrainer(): for input_desc in input_desc_with_: if input_desc.name_ in kwargs: input = input + (kwargs[input_desc.name_],) - if learning_rate is not None: - input = input + (learning_rate,) - if loss_scale is not None: - input = input + (loss_scale,) + if internal_learning_rate is not None: + input = input + (internal_learning_rate,) + if internal_loss_scale is not None: + input = input + (internal_loss_scale,) fetches = None if 'fetches' in kwargs: @@ -707,7 +730,7 @@ class ORTTrainer(): # *args and **kwargs together contain ONLY and COMPLETE inputs to the PyTorch model. # In this case, changes to the training script is minimized. # 2. without internal learning rate and loss scale (in fp16 cases) generators, - # *args and **kwargs passed in from the training script shall contains + # *args and **kwargs passed in from the training script shall contains # inputs to the PyTorch model plus learning_rate and loss_scale. # it optionally contains the fetches. # localized arguments (*args) contains inputs to the ONNX model. @@ -722,17 +745,19 @@ class ORTTrainer(): if self.loss_scaler_ is not None and self.use_mixed_precision: loss_scale = torch.tensor(self.loss_scaler_.loss_scale_) + if self.onnx_model_ is None: + sample_input, _ = self.prepare_input_and_fetches(self.model_desc_.inputs_, + None, None, *args, **kwargs) + self._init_onnx_model(sample_input) + if self.use_mixed_precision: input, fetches = self.prepare_input_and_fetches(self.input_desc_with_lr_and_loss_scale, learning_rate, loss_scale, *args, **kwargs) - else: - input, fetches = self.prepare_input_and_fetches(self.input_desc_with_lr, - learning_rate, loss_scale, *args, **kwargs) - - if self.use_mixed_precision: assert len(self.input_desc_with_lr_and_loss_scale) == len(input) input_descs = self.input_desc_with_lr_and_loss_scale else: + input, fetches = self.prepare_input_and_fetches(self.input_desc_with_lr, + learning_rate, loss_scale, *args, **kwargs) assert len(self.input_desc_with_lr) == len(input) input_descs = self.input_desc_with_lr @@ -796,7 +821,7 @@ class ORTTrainer(): def eval_step(self, *args, **kwargs): """ inputs: model inputs and/or labels. - outputs: if 'fetches' is not provided, outputs are loss and + outputs: if 'fetches' is not provided, outputs are loss and (if in mixed mode and is finishing gradient accumulation) all_finite. if fetches is provided, outputs contains these requested with fetches. fetches: names of requested outputs @@ -806,6 +831,12 @@ class ORTTrainer(): input, fetches = self.prepare_input_and_fetches(self.model_desc_.inputs_, None, None, *args, **kwargs) + if self.onnx_model_ is None: + if self.torch_model_ is not None: + self._init_onnx_model(input) + else: + raise RuntimeError("Model is unintialized. Please ensure a valid ONNX model or PyTorch model is provided to this Trainer.") + input_desc = self.model_desc_.inputs_[0:len(input)] if fetches is None: output_desc = self.model_desc_.outputs_ @@ -918,7 +949,7 @@ class LossScaler(): self.stable_steps_ = 0 def update_loss_scale(self, is_all_finite): - if not self.is_dynamic_scale_: + if not self.is_dynamic_scale_: return if is_all_finite: