Add ONNX BERT Frozen Weights and Save as ONNX Tests (#4859)

Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>
This commit is contained in:
Rayan-Krishnan 2020-08-19 21:31:38 -07:00 committed by GitHub
parent 25cc6158a8
commit 7589445e6e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 100 additions and 3 deletions

View file

@ -267,8 +267,9 @@ class ORTTrainer(object):
ValueError: raised when `path` is not valid path
"""
if not self._training_session:
raise RuntimeWarning("Training session is not initialized yet. "
warnings.warn("Training session is not initialized yet. "
"'train_step' or 'eval_step' methods must be executed at least once before calling 'save_as_onnx()'.")
return
state_tensors = self._training_session.get_state()
self._update_onnx_model_initializers(state_tensors)
@ -276,7 +277,8 @@ class ORTTrainer(object):
dir_name = os.path.dirname(path)
file_name = os.path.basename(path)
if not dir_name or not os.path.exists(dir_name) or not file_name:
raise ValueError("'path' is not valid. It must contain an existing folder + filename")
warnings.warn("'path' is not valid. It must contain an existing folder + filename")
return
with open(path, "wb") as f:
f.write(self._onnx_model.SerializeToString())

View file

@ -79,6 +79,7 @@ def optimizer_parameters(model):
if any(key in initializer.name for key in no_decay_keys):
no_decay_param_group.append(initializer.name)
params = [{'params': no_decay_param_group, "alpha": 0.9, "beta": 0.999, "lambda_coef": 0.0, "epsilon": 1e-6}]
return params
@ -593,6 +594,100 @@ def testToyBertCheckpointFrozenWeights():
loaded_state_dict = checkpoint.experimental_state_dict(trainer2)
assert state_dict.keys() == loaded_state_dict.keys()
@pytest.mark.parametrize("model_params", [
(['bert.embeddings.LayerNorm.bias']),
(['bert.embeddings.LayerNorm.bias',
'bert.embeddings.LayerNorm.weight',
'bert.encoder.layer.0.attention.output.LayerNorm.bias']),
])
def testORTTrainerFrozenWeights(model_params):
device = 'cuda'
total_steps = 10
seed = 1
# EXPERIMENTAL API
model_desc = bert_model_description()
model = load_bert_onnx_model()
optim_config = optim.LambConfig()
# Setup ORTTrainer WITHOUT frozen weights
opts_dict = {
'debug' : {
'deterministic_compute': True
},
'device': {
'id': device,
},
}
opts = orttrainer.ORTTrainerOptions(opts_dict)
torch.manual_seed(seed)
set_seed(seed)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
for i in range(total_steps):
sample_input = generate_random_input_from_model_desc(model_desc, i)
trainer.train_step(*sample_input)
# All model_params must be in the session state
assert trainer._onnx_model is not None
session_state = trainer._training_session.get_state()
assert all([param in session_state for param in model_params])
# Setup ORTTrainer WITH frozen weights
opts_dict.update({'utils' : {'frozen_weights' : model_params}})
opts = orttrainer.ORTTrainerOptions(opts_dict)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
for i in range(total_steps):
sample_input = generate_random_input_from_model_desc(model_desc, i)
trainer.train_step(*sample_input)
# All model_params CANNOT be in the session state
assert trainer._onnx_model is not None
session_state = trainer._training_session.get_state()
assert not any([param in session_state for param in model_params])
def testToyBERTSaveAsONNX():
device = 'cuda'
onnx_file_name = os.path.join('..','..','..','temp_toy_bert_onnx_model.onnx')
if os.path.exists(onnx_file_name):
os.remove(onnx_file_name)
assert not os.path.exists(onnx_file_name)
# Load trainer
model_desc = bert_model_description()
model = load_bert_onnx_model()
optim_config = optim.LambConfig()
opts = orttrainer.ORTTrainerOptions({
'debug' : {
'deterministic_compute': True
},
'device': {
'id': device,
},
})
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config)#, options=opts)
trainer.save_as_onnx(onnx_file_name)
assert os.path.exists(onnx_file_name)
with open(onnx_file_name, "rb") as f:
bin_str = f.read()
reload_onnx_model = onnx.load_model_from_string(bin_str)
os.remove(onnx_file_name)
# Create a new trainer from persisted ONNX model and compare with original ONNX model
trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config)#, options=opts)
assert trainer_from_onnx._onnx_model is not None
assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model))
for initializer, loaded_initializer in zip(trainer._onnx_model.graph.initializer, trainer_from_onnx._onnx_model.graph.initializer):
assert initializer.name == loaded_initializer.name
assert (onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph(trainer._onnx_model.graph))
_test_helpers.assert_onnx_weights(trainer, trainer_from_onnx)
###############################################################################
# Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############
@ -682,7 +777,7 @@ def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, lega
for i in range(total_steps):
sample_input = generate_random_input_from_model_desc(model_desc, i)
experimental_losses.append(trainer.train_step(*sample_input).cpu().item())
assert trainer.options.lr_scheduler.get_last_lr()[0] == legacy_lr_scheduler(i)
assert_allclose(trainer.options.lr_scheduler.get_last_lr()[0], legacy_lr_scheduler(i))
# LEGACY IMPLEMENTATION
torch.manual_seed(seed)