Added attribute self.policy_class to prevent errors when using self.policy as class

This commit is contained in:
Noah Dormann 2019-11-28 15:25:01 +01:00
parent e26564e0ec
commit 9ff59eaf3d
5 changed files with 21 additions and 15 deletions

View file

@ -1,6 +1,7 @@
import os
import pytest
from copy import deepcopy
import numpy as np
import torch as th
@ -32,6 +33,11 @@ def test_save_load(model_class):
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
model.learn(total_timesteps=500, eval_freq=250)
env.reset()
observations = np.array([env.step(env.action_space.sample())[0] for _ in range(10)])
observations = np.squeeze(observations)
# Get dictionary of current parameters
params = deepcopy(model.get_policy_parameters())
opt_params = deepcopy(model.get_opt_parameters())
@ -49,6 +55,10 @@ def test_save_load(model_class):
params = new_params
#get selected actions
selected_actions = [model.predict(observation, deterministic=True) for observation in observations]
# Check
model.save("test_save.zip")
del model
@ -78,6 +88,11 @@ def test_save_load(model_class):
else:
assert param_value == new_opt_params[optimizer]['param_groups'][param_group_idx][param_key]
# check if model still selects the same actions
new_selected_actions = [model.predict(observation, deterministic=True) for observation in observations]
for i in range(len(selected_actions)):
assert selected_actions[i] == new_selected_actions[i]
# check if learn still works
model.learn(total_timesteps=1000, eval_freq=500)

View file

@ -38,9 +38,9 @@ class BaseRLModel(object):
verbose=0, device='auto', support_multi_env=False,
create_eval_env=False, monitor_wrapper=True, seed=None):
if isinstance(policy, str) and policy_base is not None:
self.policy = get_policy_from_name(policy_base, policy)
self.policy_class = get_policy_from_name(policy_base, policy)
else:
self.policy = policy
self.policy_class = policy
if device == 'auto':
device = 'cuda' if th.cuda.is_available() else 'cpu'
@ -293,7 +293,7 @@ class BaseRLModel(object):
model.set_env(env)
model.load_parameters(params, opt_params)
# resetup modul after load
model._resetup_model()
#model._setup_model()
return model
@staticmethod

View file

@ -119,7 +119,7 @@ class PPO(BaseRLModel):
self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device,
gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs)
self.policy = self.policy(self.observation_space, self.action_space,
self.policy = self.policy_class(self.observation_space, self.action_space,
self.learning_rate, use_sde=self.use_sde, device=self.device,
**self.policy_kwargs)
self.policy = self.policy.to(self.device)

View file

@ -123,20 +123,11 @@ class SAC(BaseRLModel):
self.ent_coef = float(self.ent_coef)
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
self.policy = self.policy(self.observation_space, self.action_space,
self.policy = self.policy_class(self.observation_space, self.action_space,
self.learning_rate, device=self.device, **self.policy_kwargs)
self.policy = self.policy.to(self.device)
self._create_aliases()
def _resetup_model(self):
"""
method used to resetup anything that was not saved
:return:
"""
if self.replay_buffer is None:
obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
def _create_aliases(self):
self.actor = self.policy.actor
self.critic = self.policy.critic

View file

@ -80,7 +80,7 @@ class TD3(BaseRLModel):
obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
self.set_random_seed(self.seed)
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
self.policy = self.policy(self.observation_space, self.action_space,
self.policy = self.policy_class(self.observation_space, self.action_space,
self.learning_rate, device=self.device, **self.policy_kwargs)
self.policy = self.policy.to(self.device)
self._create_aliases()