diff --git a/tests/optimization/test_optimization.py b/tests/optimization/test_optimization.py index adf3039bb..6982583d2 100644 --- a/tests/optimization/test_optimization.py +++ b/tests/optimization/test_optimization.py @@ -59,7 +59,7 @@ def unwrap_and_save_reload_schedule(scheduler, num_steps=10): file_name = os.path.join(tmpdirname, "schedule.bin") torch.save(scheduler.state_dict(), file_name) - state_dict = torch.load(file_name) + state_dict = torch.load(file_name, weights_only=False) scheduler.load_state_dict(state_dict) return lrs