diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py index 5ccc6105c7..c9039a9f91 100644 --- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py @@ -137,9 +137,13 @@ def test_optimizer_step(): optimizer = Optimizer(optimizer_file_path, model) model.train() + old_flatten_params = model.get_contiguous_parameters() model(forward_inputs) + optimizer.step() - # TODO : Check if parameters changed from before and after optimizer step. + new_params = model.get_contiguous_parameters() + # Assert that the parameters are updated. + assert not np.array_equal(old_flatten_params.numpy(), new_params.numpy()) def test_get_and_set_lr(): @@ -226,10 +230,19 @@ def test_training_module_checkpoint(): checkpoint_save_path = os.path.join(temp_dir, "checkpoint_export.ckpt") model.save_checkpoint(checkpoint_save_path) + old_flatten_params = model.get_contiguous_parameters() - # TODO : Load checkpoint to a zeroed model and assert parameters are different. + # Assert the checkpoint was saved. assert os.path.exists(checkpoint_save_path) + # Assert the checkpoint parameters remain after saving. + state = CheckpointState(checkpoint_save_path) + new_model = Module(model_file_path, state) + + new_params = new_model.get_contiguous_parameters() + + assert np.array_equal(old_flatten_params.numpy(), new_params.numpy()) + def test_copy_buffer_to_parameters(): # Initialize Models