diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 4d7861e..dec98b4 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -1,6 +1,7 @@ import os import pytest from copy import deepcopy +import numpy as np import torch as th @@ -32,6 +33,11 @@ def test_save_load(model_class): model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) model.learn(total_timesteps=500, eval_freq=250) + env.reset() + observations = np.array([env.step(env.action_space.sample())[0] for _ in range(10)]) + observations = np.squeeze(observations) + + # Get dictionary of current parameters params = deepcopy(model.get_policy_parameters()) opt_params = deepcopy(model.get_opt_parameters()) @@ -49,6 +55,10 @@ def test_save_load(model_class): params = new_params + + #get selected actions + selected_actions = [model.predict(observation, deterministic=True) for observation in observations] + # Check model.save("test_save.zip") del model @@ -78,6 +88,11 @@ def test_save_load(model_class): else: assert param_value == new_opt_params[optimizer]['param_groups'][param_group_idx][param_key] + # check if model still selects the same actions + new_selected_actions = [model.predict(observation, deterministic=True) for observation in observations] + for i in range(len(selected_actions)): + assert selected_actions[i] == new_selected_actions[i] + # check if learn still works model.learn(total_timesteps=1000, eval_freq=500) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 970ea5a..cbdb50b 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -38,9 +38,9 @@ class BaseRLModel(object): verbose=0, device='auto', support_multi_env=False, create_eval_env=False, monitor_wrapper=True, seed=None): if isinstance(policy, str) and policy_base is not None: - self.policy = get_policy_from_name(policy_base, policy) + self.policy_class = get_policy_from_name(policy_base, policy) else: - self.policy = policy + self.policy_class = policy if device == 'auto': device = 'cuda' if th.cuda.is_available() else 'cpu' @@ -293,7 +293,7 @@ class BaseRLModel(object): model.set_env(env) model.load_parameters(params, opt_params) # resetup modul after load - model._resetup_model() + #model._setup_model() return model @staticmethod diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index dd52e8b..a0c2231 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -119,7 +119,7 @@ class PPO(BaseRLModel): self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device, gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs) - self.policy = self.policy(self.observation_space, self.action_space, + self.policy = self.policy_class(self.observation_space, self.action_space, self.learning_rate, use_sde=self.use_sde, device=self.device, **self.policy_kwargs) self.policy = self.policy.to(self.device) diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index 42771bb..cc539d5 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -123,20 +123,11 @@ class SAC(BaseRLModel): self.ent_coef = float(self.ent_coef) self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) - self.policy = self.policy(self.observation_space, self.action_space, + self.policy = self.policy_class(self.observation_space, self.action_space, self.learning_rate, device=self.device, **self.policy_kwargs) self.policy = self.policy.to(self.device) self._create_aliases() - def _resetup_model(self): - """ - method used to resetup anything that was not saved - :return: - """ - if self.replay_buffer is None: - obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] - self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) - def _create_aliases(self): self.actor = self.policy.actor self.critic = self.policy.critic diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 07036cf..1c3639e 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -80,7 +80,7 @@ class TD3(BaseRLModel): obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] self.set_random_seed(self.seed) self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) - self.policy = self.policy(self.observation_space, self.action_space, + self.policy = self.policy_class(self.observation_space, self.action_space, self.learning_rate, device=self.device, **self.policy_kwargs) self.policy = self.policy.to(self.device) self._create_aliases()