From 5bca52a87dc3068184afc1612e8f4bc85973e632 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 11:44:37 +0100 Subject: [PATCH] rearranged imports --- tests/test_save_load.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 9941e15..aff6481 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -1,11 +1,10 @@ import os - import pytest -import copy -import numpy as np -import torch +from copy import deepcopy + +import torch as th + from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3 -from torchy_baselines.common.noise import NormalActionNoise from torchy_baselines.common.vec_env import DummyVecEnv from torchy_baselines.common.identity_env import IdentityEnvBox @@ -30,16 +29,16 @@ def test_save_load(model_class): # test action probability for given (obs, action) pair # Get dictionary of current parameters - params = copy.deepcopy(model.get_policy_parameters()) + params = deepcopy(model.get_policy_parameters()) # Modify all parameters to be random values - random_params = dict((param_name, torch.rand_like(param)) for param_name, param in params.items()) + random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) # Update model parameters with the new zeroed values model.load_parameters(random_params) # shared items new_params = model.get_policy_parameters() - shared_items = {k: params[k] for k in params if k in new_params and torch.all(torch.eq(params[k], new_params[k]))} + shared_items = {k: params[k] for k in params if k in new_params and th.all(th.eq(params[k], new_params[k]))} # Check that at least some actions are chosen different now assert not len(shared_items) == len(new_params), "Selected actions did not change " \ "after changing model parameters." @@ -53,7 +52,7 @@ def test_save_load(model_class): #check if params are still the same after load new_params = model.get_policy_parameters() - shared_items = {k: params[k] for k in params if k in new_params and torch.all(torch.eq(params[k], new_params[k]))} + shared_items = {k: params[k] for k in params if k in new_params and th.all(th.eq(params[k], new_params[k]))} # Check that at least some actions are chosen different now assert len(shared_items) == len(new_params), "Parameters not the same after save and load." os.remove("test_save.zip")