mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Improved test cases by using paramerters (#14246)
### Description Completing some missing parts of some test cases for python bindings ### Motivation and Context Some test cases like test_training_module_checkpoint and test_optimizer step were not completed before because we had no access to parameters to check if the parameters are changing after the optimizer step or that the checkpoint saved parameters remains the same. now that we have access to the vector or parameters by exposing get_contiguous_parameters() method. we can complete the tests.
This commit is contained in:
parent
6ac7c894bf
commit
f0555eb437
1 changed files with 15 additions and 2 deletions
|
|
@ -137,9 +137,13 @@ def test_optimizer_step():
|
|||
optimizer = Optimizer(optimizer_file_path, model)
|
||||
|
||||
model.train()
|
||||
old_flatten_params = model.get_contiguous_parameters()
|
||||
model(forward_inputs)
|
||||
|
||||
optimizer.step()
|
||||
# TODO : Check if parameters changed from before and after optimizer step.
|
||||
new_params = model.get_contiguous_parameters()
|
||||
# Assert that the parameters are updated.
|
||||
assert not np.array_equal(old_flatten_params.numpy(), new_params.numpy())
|
||||
|
||||
|
||||
def test_get_and_set_lr():
|
||||
|
|
@ -226,10 +230,19 @@ def test_training_module_checkpoint():
|
|||
checkpoint_save_path = os.path.join(temp_dir, "checkpoint_export.ckpt")
|
||||
|
||||
model.save_checkpoint(checkpoint_save_path)
|
||||
old_flatten_params = model.get_contiguous_parameters()
|
||||
|
||||
# TODO : Load checkpoint to a zeroed model and assert parameters are different.
|
||||
# Assert the checkpoint was saved.
|
||||
assert os.path.exists(checkpoint_save_path)
|
||||
|
||||
# Assert the checkpoint parameters remain after saving.
|
||||
state = CheckpointState(checkpoint_save_path)
|
||||
new_model = Module(model_file_path, state)
|
||||
|
||||
new_params = new_model.get_contiguous_parameters()
|
||||
|
||||
assert np.array_equal(old_flatten_params.numpy(), new_params.numpy())
|
||||
|
||||
|
||||
def test_copy_buffer_to_parameters():
|
||||
# Initialize Models
|
||||
|
|
|
|||
Loading…
Reference in a new issue