mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-30 23:18:20 +00:00
Add ONNX BERT Frozen Weights and Save as ONNX Tests (#4859)
Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>
This commit is contained in:
parent
25cc6158a8
commit
7589445e6e
2 changed files with 100 additions and 3 deletions
|
|
@ -267,8 +267,9 @@ class ORTTrainer(object):
|
|||
ValueError: raised when `path` is not valid path
|
||||
"""
|
||||
if not self._training_session:
|
||||
raise RuntimeWarning("Training session is not initialized yet. "
|
||||
warnings.warn("Training session is not initialized yet. "
|
||||
"'train_step' or 'eval_step' methods must be executed at least once before calling 'save_as_onnx()'.")
|
||||
return
|
||||
state_tensors = self._training_session.get_state()
|
||||
self._update_onnx_model_initializers(state_tensors)
|
||||
|
||||
|
|
@ -276,7 +277,8 @@ class ORTTrainer(object):
|
|||
dir_name = os.path.dirname(path)
|
||||
file_name = os.path.basename(path)
|
||||
if not dir_name or not os.path.exists(dir_name) or not file_name:
|
||||
raise ValueError("'path' is not valid. It must contain an existing folder + filename")
|
||||
warnings.warn("'path' is not valid. It must contain an existing folder + filename")
|
||||
return
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(self._onnx_model.SerializeToString())
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ def optimizer_parameters(model):
|
|||
if any(key in initializer.name for key in no_decay_keys):
|
||||
no_decay_param_group.append(initializer.name)
|
||||
params = [{'params': no_decay_param_group, "alpha": 0.9, "beta": 0.999, "lambda_coef": 0.0, "epsilon": 1e-6}]
|
||||
|
||||
return params
|
||||
|
||||
|
||||
|
|
@ -593,6 +594,100 @@ def testToyBertCheckpointFrozenWeights():
|
|||
loaded_state_dict = checkpoint.experimental_state_dict(trainer2)
|
||||
assert state_dict.keys() == loaded_state_dict.keys()
|
||||
|
||||
@pytest.mark.parametrize("model_params", [
|
||||
(['bert.embeddings.LayerNorm.bias']),
|
||||
(['bert.embeddings.LayerNorm.bias',
|
||||
'bert.embeddings.LayerNorm.weight',
|
||||
'bert.encoder.layer.0.attention.output.LayerNorm.bias']),
|
||||
])
|
||||
def testORTTrainerFrozenWeights(model_params):
|
||||
device = 'cuda'
|
||||
total_steps = 10
|
||||
seed = 1
|
||||
|
||||
# EXPERIMENTAL API
|
||||
model_desc = bert_model_description()
|
||||
model = load_bert_onnx_model()
|
||||
|
||||
optim_config = optim.LambConfig()
|
||||
# Setup ORTTrainer WITHOUT frozen weights
|
||||
opts_dict = {
|
||||
'debug' : {
|
||||
'deterministic_compute': True
|
||||
},
|
||||
'device': {
|
||||
'id': device,
|
||||
},
|
||||
}
|
||||
opts = orttrainer.ORTTrainerOptions(opts_dict)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
set_seed(seed)
|
||||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
|
||||
|
||||
for i in range(total_steps):
|
||||
sample_input = generate_random_input_from_model_desc(model_desc, i)
|
||||
trainer.train_step(*sample_input)
|
||||
|
||||
# All model_params must be in the session state
|
||||
assert trainer._onnx_model is not None
|
||||
session_state = trainer._training_session.get_state()
|
||||
assert all([param in session_state for param in model_params])
|
||||
|
||||
# Setup ORTTrainer WITH frozen weights
|
||||
opts_dict.update({'utils' : {'frozen_weights' : model_params}})
|
||||
opts = orttrainer.ORTTrainerOptions(opts_dict)
|
||||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
|
||||
|
||||
for i in range(total_steps):
|
||||
sample_input = generate_random_input_from_model_desc(model_desc, i)
|
||||
trainer.train_step(*sample_input)
|
||||
|
||||
# All model_params CANNOT be in the session state
|
||||
assert trainer._onnx_model is not None
|
||||
session_state = trainer._training_session.get_state()
|
||||
assert not any([param in session_state for param in model_params])
|
||||
|
||||
def testToyBERTSaveAsONNX():
|
||||
device = 'cuda'
|
||||
onnx_file_name = os.path.join('..','..','..','temp_toy_bert_onnx_model.onnx')
|
||||
if os.path.exists(onnx_file_name):
|
||||
os.remove(onnx_file_name)
|
||||
assert not os.path.exists(onnx_file_name)
|
||||
|
||||
# Load trainer
|
||||
model_desc = bert_model_description()
|
||||
model = load_bert_onnx_model()
|
||||
|
||||
optim_config = optim.LambConfig()
|
||||
opts = orttrainer.ORTTrainerOptions({
|
||||
'debug' : {
|
||||
'deterministic_compute': True
|
||||
},
|
||||
'device': {
|
||||
'id': device,
|
||||
},
|
||||
})
|
||||
|
||||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config)#, options=opts)
|
||||
|
||||
trainer.save_as_onnx(onnx_file_name)
|
||||
assert os.path.exists(onnx_file_name)
|
||||
|
||||
with open(onnx_file_name, "rb") as f:
|
||||
bin_str = f.read()
|
||||
reload_onnx_model = onnx.load_model_from_string(bin_str)
|
||||
os.remove(onnx_file_name)
|
||||
|
||||
# Create a new trainer from persisted ONNX model and compare with original ONNX model
|
||||
trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config)#, options=opts)
|
||||
assert trainer_from_onnx._onnx_model is not None
|
||||
assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model))
|
||||
for initializer, loaded_initializer in zip(trainer._onnx_model.graph.initializer, trainer_from_onnx._onnx_model.graph.initializer):
|
||||
assert initializer.name == loaded_initializer.name
|
||||
assert (onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph(trainer._onnx_model.graph))
|
||||
_test_helpers.assert_onnx_weights(trainer, trainer_from_onnx)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############
|
||||
|
|
@ -682,7 +777,7 @@ def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, lega
|
|||
for i in range(total_steps):
|
||||
sample_input = generate_random_input_from_model_desc(model_desc, i)
|
||||
experimental_losses.append(trainer.train_step(*sample_input).cpu().item())
|
||||
assert trainer.options.lr_scheduler.get_last_lr()[0] == legacy_lr_scheduler(i)
|
||||
assert_allclose(trainer.options.lr_scheduler.get_last_lr()[0], legacy_lr_scheduler(i))
|
||||
|
||||
# LEGACY IMPLEMENTATION
|
||||
torch.manual_seed(seed)
|
||||
|
|
|
|||
Loading…
Reference in a new issue