mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-01 03:45:11 +00:00
Added attribute self.policy_class to prevent errors when using self.policy as class
This commit is contained in:
parent
e26564e0ec
commit
9ff59eaf3d
5 changed files with 21 additions and 15 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import pytest
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
|
||||
import torch as th
|
||||
|
||||
|
|
@ -32,6 +33,11 @@ def test_save_load(model_class):
|
|||
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
|
||||
model.learn(total_timesteps=500, eval_freq=250)
|
||||
|
||||
env.reset()
|
||||
observations = np.array([env.step(env.action_space.sample())[0] for _ in range(10)])
|
||||
observations = np.squeeze(observations)
|
||||
|
||||
|
||||
# Get dictionary of current parameters
|
||||
params = deepcopy(model.get_policy_parameters())
|
||||
opt_params = deepcopy(model.get_opt_parameters())
|
||||
|
|
@ -49,6 +55,10 @@ def test_save_load(model_class):
|
|||
|
||||
params = new_params
|
||||
|
||||
|
||||
#get selected actions
|
||||
selected_actions = [model.predict(observation, deterministic=True) for observation in observations]
|
||||
|
||||
# Check
|
||||
model.save("test_save.zip")
|
||||
del model
|
||||
|
|
@ -78,6 +88,11 @@ def test_save_load(model_class):
|
|||
else:
|
||||
assert param_value == new_opt_params[optimizer]['param_groups'][param_group_idx][param_key]
|
||||
|
||||
# check if model still selects the same actions
|
||||
new_selected_actions = [model.predict(observation, deterministic=True) for observation in observations]
|
||||
for i in range(len(selected_actions)):
|
||||
assert selected_actions[i] == new_selected_actions[i]
|
||||
|
||||
# check if learn still works
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
|
|
|
|||
|
|
@ -38,9 +38,9 @@ class BaseRLModel(object):
|
|||
verbose=0, device='auto', support_multi_env=False,
|
||||
create_eval_env=False, monitor_wrapper=True, seed=None):
|
||||
if isinstance(policy, str) and policy_base is not None:
|
||||
self.policy = get_policy_from_name(policy_base, policy)
|
||||
self.policy_class = get_policy_from_name(policy_base, policy)
|
||||
else:
|
||||
self.policy = policy
|
||||
self.policy_class = policy
|
||||
|
||||
if device == 'auto':
|
||||
device = 'cuda' if th.cuda.is_available() else 'cpu'
|
||||
|
|
@ -293,7 +293,7 @@ class BaseRLModel(object):
|
|||
model.set_env(env)
|
||||
model.load_parameters(params, opt_params)
|
||||
# resetup modul after load
|
||||
model._resetup_model()
|
||||
#model._setup_model()
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ class PPO(BaseRLModel):
|
|||
|
||||
self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device,
|
||||
gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs)
|
||||
self.policy = self.policy(self.observation_space, self.action_space,
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.learning_rate, use_sde=self.use_sde, device=self.device,
|
||||
**self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
|
|
|||
|
|
@ -123,20 +123,11 @@ class SAC(BaseRLModel):
|
|||
self.ent_coef = float(self.ent_coef)
|
||||
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
|
||||
self.policy = self.policy(self.observation_space, self.action_space,
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.learning_rate, device=self.device, **self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
self._create_aliases()
|
||||
|
||||
def _resetup_model(self):
|
||||
"""
|
||||
method used to resetup anything that was not saved
|
||||
:return:
|
||||
"""
|
||||
if self.replay_buffer is None:
|
||||
obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
|
||||
|
||||
def _create_aliases(self):
|
||||
self.actor = self.policy.actor
|
||||
self.critic = self.policy.critic
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ class TD3(BaseRLModel):
|
|||
obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
|
||||
self.set_random_seed(self.seed)
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
|
||||
self.policy = self.policy(self.observation_space, self.action_space,
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.learning_rate, device=self.device, **self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
self._create_aliases()
|
||||
|
|
|
|||
Loading…
Reference in a new issue