fix internal loss scale (#3483)

* Changed internal loss scale to 1-D

* added test

Co-authored-by: root <root@525204a066204ea794f942530b05ae7f000000.axlncovkyjne5caro2tmz3zryb.xx.internal.cloudapp.net>
This commit is contained in:
Tixxx 2020-04-10 14:13:48 -07:00 committed by GitHub
parent 20c7dd9f5c
commit f5ba9c922d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 9 deletions

View file

@ -56,7 +56,11 @@ def generate_sample_batch(desc, batch_size, device):
sample = generate_sample(desc_, device)
return sample
def runBertTrainingTest(gradient_accumulation_steps, use_mixed_precision, allreduce_post_accumulation, use_simple_model_desc=True):
def runBertTrainingTest(gradient_accumulation_steps,
use_mixed_precision,
allreduce_post_accumulation,
use_simple_model_desc=True,
use_internel_loss_scale=False):
model_desc = bert_model_description()
simple_model_desc = remove_extra_info(model_desc) if use_simple_model_desc else model_desc
learning_rate_description = ort_trainer_learning_rate_description()
@ -64,17 +68,20 @@ def runBertTrainingTest(gradient_accumulation_steps, use_mixed_precision, allred
onnx_model = onnx.load(get_name("bert_toy_postprocessed.onnx"))
loss_scaler = LossScaler("ort_test_input_loss_scalar", True) if use_internel_loss_scale else None
model = ORTTrainer(onnx_model, None, simple_model_desc, "LambOptimizer",
map_optimizer_attributes,
learning_rate_description,
device, postprocess_model=None,
gradient_accumulation_steps=gradient_accumulation_steps,
world_rank=0, world_size=1,
loss_scaler=loss_scaler,
use_mixed_precision=use_mixed_precision,
allreduce_post_accumulation=allreduce_post_accumulation,
seed=1)
loss_scaler = LossScaler(model.loss_scale_input_name, True)
if loss_scaler is None:
loss_scaler = LossScaler(model.loss_scale_input_name, True)
input_ids_batches = []
segment_ids_batches = []
@ -105,18 +112,27 @@ def runBertTrainingTest(gradient_accumulation_steps, use_mixed_precision, allred
lr = lr_batch_list[batch_count]
learning_rate = torch.tensor([lr]).to(device)
training_args = [input_ids,
segment_ids,
input_mask,
masked_lm_labels,
next_sentence_labels,
learning_rate]
if use_mixed_precision:
loss_scale = torch.tensor([loss_scaler.loss_scale_]).to(device)
actual_loss = model.train_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, learning_rate, loss_scale)
if not use_internel_loss_scale:
loss_scale = torch.tensor([loss_scaler.loss_scale_]).to(device)
training_args.append(loss_scale)
actual_loss = model.train_step(*training_args)
if isinstance(actual_loss, (list, tuple)):
assert len(actual_loss) == 2
actual_loss, actual_all_finite = actual_loss
loss_scaler.update_loss_scale(actual_all_finite.item())
actual_all_finites = [*actual_all_finites, actual_all_finite.cpu().numpy().item(0)]
if not use_internel_loss_scale:
loss_scaler.update_loss_scale(actual_all_finite.item())
actual_all_finites = [*actual_all_finites, actual_all_finite.cpu().numpy().item(0)]
actual_losses = [*actual_losses, actual_loss.cpu().numpy().item(0)]
else:
loss = model(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, learning_rate)
loss = model(*training_args)
actual_losses = [*actual_losses, loss.cpu().numpy().item(0)]
if batch_count == num_batches - 1:
@ -125,7 +141,8 @@ def runBertTrainingTest(gradient_accumulation_steps, use_mixed_precision, allred
eval_loss = model.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, fetches=['loss'])
eval_loss = eval_loss.cpu().numpy().item(0)
if use_mixed_precision:
# If using internal loss scale, all_finites are handled internally too.
if use_mixed_precision and not use_internel_loss_scale:
return actual_losses, actual_all_finites, eval_loss
else:
return actual_losses, eval_loss

View file

@ -22,6 +22,21 @@ class TestOrtTrainer(unittest.TestCase):
assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch")
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch")
def testBertTrainingMixedPrecisionInternalLossScale(self):
torch.manual_seed(1)
expected_losses = [11.078125, 11.0, 11.0390625, 11.0, 11.015625, 11.0, 10.9921875, 11.0703125]
expected_eval_loss = [11.046875]
actual_losses, actual_eval_loss = runBertTrainingTest(
gradient_accumulation_steps=1,
use_mixed_precision=True,
allreduce_post_accumulation=False,
use_simple_model_desc=False,
use_internel_loss_scale=True)
rtol = 1e-01
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch")
def testBertTrainingGradientAccumulationMixedPrecision(self):
torch.manual_seed(1)
expected_losses = [11.046875, 11.171875, 11.0234375, 11.046875, 10.8984375, 10.9921875, 11.078125, 10.96875]

View file

@ -747,7 +747,7 @@ class ORTTrainer():
lr_this_step = self.get_lr_this_step_(self.global_step_)
learning_rate = torch.tensor([lr_this_step])
if self.loss_scaler_ is not None and self.use_mixed_precision:
loss_scale = torch.tensor(self.loss_scaler_.loss_scale_)
loss_scale = torch.tensor([self.loss_scaler_.loss_scale_])
if self.onnx_model_ is None:
sample_input, _ = self.prepare_input_and_fetches(self.model_desc_.inputs_,