From f0555eb4379715c2a9d5ba7bbdb224233f58f14e Mon Sep 17 00:00:00 2001 From: Adam Louly Date: Fri, 13 Jan 2023 14:54:23 -0600 Subject: [PATCH] Improved test cases by using paramerters (#14246) ### Description Completing some missing parts of some test cases for python bindings ### Motivation and Context Some test cases like test_training_module_checkpoint and test_optimizer step were not completed before because we had no access to parameters to check if the parameters are changing after the optimizer step or that the checkpoint saved parameters remains the same. now that we have access to the vector or parameters by exposing get_contiguous_parameters() method. we can complete the tests. --- .../python/orttraining_test_python_bindings.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py index 5ccc6105c7..c9039a9f91 100644 --- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py @@ -137,9 +137,13 @@ def test_optimizer_step(): optimizer = Optimizer(optimizer_file_path, model) model.train() + old_flatten_params = model.get_contiguous_parameters() model(forward_inputs) + optimizer.step() - # TODO : Check if parameters changed from before and after optimizer step. + new_params = model.get_contiguous_parameters() + # Assert that the parameters are updated. + assert not np.array_equal(old_flatten_params.numpy(), new_params.numpy()) def test_get_and_set_lr(): @@ -226,10 +230,19 @@ def test_training_module_checkpoint(): checkpoint_save_path = os.path.join(temp_dir, "checkpoint_export.ckpt") model.save_checkpoint(checkpoint_save_path) + old_flatten_params = model.get_contiguous_parameters() - # TODO : Load checkpoint to a zeroed model and assert parameters are different. + # Assert the checkpoint was saved. assert os.path.exists(checkpoint_save_path) + # Assert the checkpoint parameters remain after saving. + state = CheckpointState(checkpoint_save_path) + new_model = Module(model_file_path, state) + + new_params = new_model.get_contiguous_parameters() + + assert np.array_equal(old_flatten_params.numpy(), new_params.numpy()) + def test_copy_buffer_to_parameters(): # Initialize Models