mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
fix internal loss scale (#3483)
* Changed internal loss scale to 1-D * added test Co-authored-by: root <root@525204a066204ea794f942530b05ae7f000000.axlncovkyjne5caro2tmz3zryb.xx.internal.cloudapp.net>
This commit is contained in:
parent
20c7dd9f5c
commit
f5ba9c922d
3 changed files with 41 additions and 9 deletions
|
|
@ -56,7 +56,11 @@ def generate_sample_batch(desc, batch_size, device):
|
|||
sample = generate_sample(desc_, device)
|
||||
return sample
|
||||
|
||||
def runBertTrainingTest(gradient_accumulation_steps, use_mixed_precision, allreduce_post_accumulation, use_simple_model_desc=True):
|
||||
def runBertTrainingTest(gradient_accumulation_steps,
|
||||
use_mixed_precision,
|
||||
allreduce_post_accumulation,
|
||||
use_simple_model_desc=True,
|
||||
use_internel_loss_scale=False):
|
||||
model_desc = bert_model_description()
|
||||
simple_model_desc = remove_extra_info(model_desc) if use_simple_model_desc else model_desc
|
||||
learning_rate_description = ort_trainer_learning_rate_description()
|
||||
|
|
@ -64,17 +68,20 @@ def runBertTrainingTest(gradient_accumulation_steps, use_mixed_precision, allred
|
|||
|
||||
onnx_model = onnx.load(get_name("bert_toy_postprocessed.onnx"))
|
||||
|
||||
loss_scaler = LossScaler("ort_test_input_loss_scalar", True) if use_internel_loss_scale else None
|
||||
model = ORTTrainer(onnx_model, None, simple_model_desc, "LambOptimizer",
|
||||
map_optimizer_attributes,
|
||||
learning_rate_description,
|
||||
device, postprocess_model=None,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
world_rank=0, world_size=1,
|
||||
loss_scaler=loss_scaler,
|
||||
use_mixed_precision=use_mixed_precision,
|
||||
allreduce_post_accumulation=allreduce_post_accumulation,
|
||||
seed=1)
|
||||
|
||||
loss_scaler = LossScaler(model.loss_scale_input_name, True)
|
||||
if loss_scaler is None:
|
||||
loss_scaler = LossScaler(model.loss_scale_input_name, True)
|
||||
|
||||
input_ids_batches = []
|
||||
segment_ids_batches = []
|
||||
|
|
@ -105,18 +112,27 @@ def runBertTrainingTest(gradient_accumulation_steps, use_mixed_precision, allred
|
|||
lr = lr_batch_list[batch_count]
|
||||
|
||||
learning_rate = torch.tensor([lr]).to(device)
|
||||
training_args = [input_ids,
|
||||
segment_ids,
|
||||
input_mask,
|
||||
masked_lm_labels,
|
||||
next_sentence_labels,
|
||||
learning_rate]
|
||||
if use_mixed_precision:
|
||||
loss_scale = torch.tensor([loss_scaler.loss_scale_]).to(device)
|
||||
actual_loss = model.train_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, learning_rate, loss_scale)
|
||||
if not use_internel_loss_scale:
|
||||
loss_scale = torch.tensor([loss_scaler.loss_scale_]).to(device)
|
||||
training_args.append(loss_scale)
|
||||
actual_loss = model.train_step(*training_args)
|
||||
if isinstance(actual_loss, (list, tuple)):
|
||||
assert len(actual_loss) == 2
|
||||
actual_loss, actual_all_finite = actual_loss
|
||||
loss_scaler.update_loss_scale(actual_all_finite.item())
|
||||
actual_all_finites = [*actual_all_finites, actual_all_finite.cpu().numpy().item(0)]
|
||||
if not use_internel_loss_scale:
|
||||
loss_scaler.update_loss_scale(actual_all_finite.item())
|
||||
actual_all_finites = [*actual_all_finites, actual_all_finite.cpu().numpy().item(0)]
|
||||
|
||||
actual_losses = [*actual_losses, actual_loss.cpu().numpy().item(0)]
|
||||
else:
|
||||
loss = model(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, learning_rate)
|
||||
loss = model(*training_args)
|
||||
actual_losses = [*actual_losses, loss.cpu().numpy().item(0)]
|
||||
|
||||
if batch_count == num_batches - 1:
|
||||
|
|
@ -125,7 +141,8 @@ def runBertTrainingTest(gradient_accumulation_steps, use_mixed_precision, allred
|
|||
eval_loss = model.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, fetches=['loss'])
|
||||
eval_loss = eval_loss.cpu().numpy().item(0)
|
||||
|
||||
if use_mixed_precision:
|
||||
# If using internal loss scale, all_finites are handled internally too.
|
||||
if use_mixed_precision and not use_internel_loss_scale:
|
||||
return actual_losses, actual_all_finites, eval_loss
|
||||
else:
|
||||
return actual_losses, eval_loss
|
||||
|
|
|
|||
|
|
@ -22,6 +22,21 @@ class TestOrtTrainer(unittest.TestCase):
|
|||
assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch")
|
||||
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch")
|
||||
|
||||
def testBertTrainingMixedPrecisionInternalLossScale(self):
|
||||
torch.manual_seed(1)
|
||||
expected_losses = [11.078125, 11.0, 11.0390625, 11.0, 11.015625, 11.0, 10.9921875, 11.0703125]
|
||||
expected_eval_loss = [11.046875]
|
||||
actual_losses, actual_eval_loss = runBertTrainingTest(
|
||||
gradient_accumulation_steps=1,
|
||||
use_mixed_precision=True,
|
||||
allreduce_post_accumulation=False,
|
||||
use_simple_model_desc=False,
|
||||
use_internel_loss_scale=True)
|
||||
|
||||
rtol = 1e-01
|
||||
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
|
||||
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch")
|
||||
|
||||
def testBertTrainingGradientAccumulationMixedPrecision(self):
|
||||
torch.manual_seed(1)
|
||||
expected_losses = [11.046875, 11.171875, 11.0234375, 11.046875, 10.8984375, 10.9921875, 11.078125, 10.96875]
|
||||
|
|
|
|||
|
|
@ -747,7 +747,7 @@ class ORTTrainer():
|
|||
lr_this_step = self.get_lr_this_step_(self.global_step_)
|
||||
learning_rate = torch.tensor([lr_this_step])
|
||||
if self.loss_scaler_ is not None and self.use_mixed_precision:
|
||||
loss_scale = torch.tensor(self.loss_scaler_.loss_scale_)
|
||||
loss_scale = torch.tensor([self.loss_scaler_.loss_scale_])
|
||||
|
||||
if self.onnx_model_ is None:
|
||||
sample_input, _ = self.prepare_input_and_fetches(self.model_desc_.inputs_,
|
||||
|
|
|
|||
Loading…
Reference in a new issue