mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-29 03:31:08 +00:00
rearranged imports
This commit is contained in:
parent
4b6234a1c8
commit
5bca52a87d
1 changed files with 8 additions and 9 deletions
|
|
@ -1,11 +1,10 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import copy
|
||||
import numpy as np
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
|
||||
import torch as th
|
||||
|
||||
from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3
|
||||
from torchy_baselines.common.noise import NormalActionNoise
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv
|
||||
from torchy_baselines.common.identity_env import IdentityEnvBox
|
||||
|
||||
|
|
@ -30,16 +29,16 @@ def test_save_load(model_class):
|
|||
# test action probability for given (obs, action) pair
|
||||
|
||||
# Get dictionary of current parameters
|
||||
params = copy.deepcopy(model.get_policy_parameters())
|
||||
params = deepcopy(model.get_policy_parameters())
|
||||
|
||||
# Modify all parameters to be random values
|
||||
random_params = dict((param_name, torch.rand_like(param)) for param_name, param in params.items())
|
||||
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
|
||||
# Update model parameters with the new zeroed values
|
||||
model.load_parameters(random_params)
|
||||
|
||||
# shared items
|
||||
new_params = model.get_policy_parameters()
|
||||
shared_items = {k: params[k] for k in params if k in new_params and torch.all(torch.eq(params[k], new_params[k]))}
|
||||
shared_items = {k: params[k] for k in params if k in new_params and th.all(th.eq(params[k], new_params[k]))}
|
||||
# Check that at least some actions are chosen different now
|
||||
assert not len(shared_items) == len(new_params), "Selected actions did not change " \
|
||||
"after changing model parameters."
|
||||
|
|
@ -53,7 +52,7 @@ def test_save_load(model_class):
|
|||
|
||||
#check if params are still the same after load
|
||||
new_params = model.get_policy_parameters()
|
||||
shared_items = {k: params[k] for k in params if k in new_params and torch.all(torch.eq(params[k], new_params[k]))}
|
||||
shared_items = {k: params[k] for k in params if k in new_params and th.all(th.eq(params[k], new_params[k]))}
|
||||
# Check that at least some actions are chosen different now
|
||||
assert len(shared_items) == len(new_params), "Parameters not the same after save and load."
|
||||
os.remove("test_save.zip")
|
||||
|
|
|
|||
Loading…
Reference in a new issue