diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py index 909cd6fce4..dc5279c426 100644 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py +++ b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py @@ -56,7 +56,11 @@ def generate_sample_batch(desc, batch_size, device): sample = generate_sample(desc_, device) return sample -def runBertTrainingTest(gradient_accumulation_steps, use_mixed_precision, allreduce_post_accumulation, use_simple_model_desc=True): +def runBertTrainingTest(gradient_accumulation_steps, + use_mixed_precision, + allreduce_post_accumulation, + use_simple_model_desc=True, + use_internel_loss_scale=False): model_desc = bert_model_description() simple_model_desc = remove_extra_info(model_desc) if use_simple_model_desc else model_desc learning_rate_description = ort_trainer_learning_rate_description() @@ -64,17 +68,20 @@ def runBertTrainingTest(gradient_accumulation_steps, use_mixed_precision, allred onnx_model = onnx.load(get_name("bert_toy_postprocessed.onnx")) + loss_scaler = LossScaler("ort_test_input_loss_scalar", True) if use_internel_loss_scale else None model = ORTTrainer(onnx_model, None, simple_model_desc, "LambOptimizer", map_optimizer_attributes, learning_rate_description, device, postprocess_model=None, gradient_accumulation_steps=gradient_accumulation_steps, world_rank=0, world_size=1, + loss_scaler=loss_scaler, use_mixed_precision=use_mixed_precision, allreduce_post_accumulation=allreduce_post_accumulation, seed=1) - loss_scaler = LossScaler(model.loss_scale_input_name, True) + if loss_scaler is None: + loss_scaler = LossScaler(model.loss_scale_input_name, True) input_ids_batches = [] segment_ids_batches = [] @@ -105,18 +112,27 @@ def runBertTrainingTest(gradient_accumulation_steps, use_mixed_precision, allred lr = lr_batch_list[batch_count] learning_rate = torch.tensor([lr]).to(device) + training_args = [input_ids, + segment_ids, + input_mask, + masked_lm_labels, + next_sentence_labels, + learning_rate] if use_mixed_precision: - loss_scale = torch.tensor([loss_scaler.loss_scale_]).to(device) - actual_loss = model.train_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, learning_rate, loss_scale) + if not use_internel_loss_scale: + loss_scale = torch.tensor([loss_scaler.loss_scale_]).to(device) + training_args.append(loss_scale) + actual_loss = model.train_step(*training_args) if isinstance(actual_loss, (list, tuple)): assert len(actual_loss) == 2 actual_loss, actual_all_finite = actual_loss - loss_scaler.update_loss_scale(actual_all_finite.item()) - actual_all_finites = [*actual_all_finites, actual_all_finite.cpu().numpy().item(0)] + if not use_internel_loss_scale: + loss_scaler.update_loss_scale(actual_all_finite.item()) + actual_all_finites = [*actual_all_finites, actual_all_finite.cpu().numpy().item(0)] actual_losses = [*actual_losses, actual_loss.cpu().numpy().item(0)] else: - loss = model(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, learning_rate) + loss = model(*training_args) actual_losses = [*actual_losses, loss.cpu().numpy().item(0)] if batch_count == num_batches - 1: @@ -125,7 +141,8 @@ def runBertTrainingTest(gradient_accumulation_steps, use_mixed_precision, allred eval_loss = model.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, fetches=['loss']) eval_loss = eval_loss.cpu().numpy().item(0) - if use_mixed_precision: + # If using internal loss scale, all_finites are handled internally too. + if use_mixed_precision and not use_internel_loss_scale: return actual_losses, actual_all_finites, eval_loss else: return actual_losses, eval_loss diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py index 6bd74e21bc..6aa6fcfa28 100644 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py +++ b/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py @@ -22,6 +22,21 @@ class TestOrtTrainer(unittest.TestCase): assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch") assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch") + def testBertTrainingMixedPrecisionInternalLossScale(self): + torch.manual_seed(1) + expected_losses = [11.078125, 11.0, 11.0390625, 11.0, 11.015625, 11.0, 10.9921875, 11.0703125] + expected_eval_loss = [11.046875] + actual_losses, actual_eval_loss = runBertTrainingTest( + gradient_accumulation_steps=1, + use_mixed_precision=True, + allreduce_post_accumulation=False, + use_simple_model_desc=False, + use_internel_loss_scale=True) + + rtol = 1e-01 + assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") + assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch") + def testBertTrainingGradientAccumulationMixedPrecision(self): torch.manual_seed(1) expected_losses = [11.046875, 11.171875, 11.0234375, 11.046875, 10.8984375, 10.9921875, 11.078125, 10.96875] diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index 46df724b6a..a27a560154 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -747,7 +747,7 @@ class ORTTrainer(): lr_this_step = self.get_lr_this_step_(self.global_step_) learning_rate = torch.tensor([lr_this_step]) if self.loss_scaler_ is not None and self.use_mixed_precision: - loss_scale = torch.tensor(self.loss_scaler_.loss_scale_) + loss_scale = torch.tensor([self.loss_scaler_.loss_scale_]) if self.onnx_model_ is None: sample_input, _ = self.prepare_input_and_fetches(self.model_desc_.inputs_,