rearranged imports

This commit is contained in:
Noah Dormann 2019-11-21 11:44:37 +01:00
parent 4b6234a1c8
commit 5bca52a87d

View file

@ -1,11 +1,10 @@
import os
import pytest
import copy
import numpy as np
import torch
from copy import deepcopy
import torch as th
from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3
from torchy_baselines.common.noise import NormalActionNoise
from torchy_baselines.common.vec_env import DummyVecEnv
from torchy_baselines.common.identity_env import IdentityEnvBox
@ -30,16 +29,16 @@ def test_save_load(model_class):
# test action probability for given (obs, action) pair
# Get dictionary of current parameters
params = copy.deepcopy(model.get_policy_parameters())
params = deepcopy(model.get_policy_parameters())
# Modify all parameters to be random values
random_params = dict((param_name, torch.rand_like(param)) for param_name, param in params.items())
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
# Update model parameters with the new zeroed values
model.load_parameters(random_params)
# shared items
new_params = model.get_policy_parameters()
shared_items = {k: params[k] for k in params if k in new_params and torch.all(torch.eq(params[k], new_params[k]))}
shared_items = {k: params[k] for k in params if k in new_params and th.all(th.eq(params[k], new_params[k]))}
# Check that at least some actions are chosen different now
assert not len(shared_items) == len(new_params), "Selected actions did not change " \
"after changing model parameters."
@ -53,7 +52,7 @@ def test_save_load(model_class):
#check if params are still the same after load
new_params = model.get_policy_parameters()
shared_items = {k: params[k] for k in params if k in new_params and torch.all(torch.eq(params[k], new_params[k]))}
shared_items = {k: params[k] for k in params if k in new_params and th.all(th.eq(params[k], new_params[k]))}
# Check that at least some actions are chosen different now
assert len(shared_items) == len(new_params), "Parameters not the same after save and load."
os.remove("test_save.zip")