From 5d4e73544ca8af22310bb86129c325021065a579 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 31 Jan 2020 13:16:28 +0100 Subject: [PATCH] Fix `reset_num_timesteps` --- docs/misc/changelog.rst | 1 + torchy_baselines/cem_rl/cem_rl.py | 3 ++- torchy_baselines/common/base_class.py | 33 ++++++--------------------- torchy_baselines/ppo/ppo.py | 3 ++- torchy_baselines/sac/sac.py | 4 ++-- torchy_baselines/td3/td3.py | 3 ++- 6 files changed, 16 insertions(+), 31 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 917199c..d0d69c1 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -19,6 +19,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ - Fix loading model on CPU that were trained on GPU +- Fix `reset_num_timesteps` that was not used Deprecations: ^^^^^^^^^^^^^ diff --git a/torchy_baselines/cem_rl/cem_rl.py b/torchy_baselines/cem_rl/cem_rl.py index 2e3be4d..35c6e3c 100644 --- a/torchy_baselines/cem_rl/cem_rl.py +++ b/torchy_baselines/cem_rl/cem_rl.py @@ -103,7 +103,8 @@ class CEMRL(TD3): eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="CEMRL", eval_log_path=None, reset_num_timesteps=True): - episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq, n_eval_episodes, eval_log_path) + episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq, + n_eval_episodes, eval_log_path, reset_num_timesteps) actor_steps = 0 continue_training = True diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index e3575b0..1f711be 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -364,7 +364,7 @@ class BaseRLModel(ABC): @staticmethod def _load_from_file(load_path: str, load_data: bool = True) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], - Optional[OptimizerStateDict]]): + Optional[TensorDict]]): """ Load model data from a .zip archive :param load_path: Where to load the model from @@ -503,7 +503,8 @@ class BaseRLModel(ABC): callback: Union[None, Callable, List[BaseCallback], BaseCallback] = None, eval_freq: int = 10000, n_eval_episodes: int = 5, - log_path: Optional[str] = None + log_path: Optional[str] = None, + reset_num_timesteps: bool = True, ) -> Tuple[int, np.ndarray, BaseCallback]: """ Initialize different variables needed for training. @@ -513,6 +514,7 @@ class BaseRLModel(ABC): :param eval_freq: (int) :param n_eval_episodes: (int) :param log_path (Optional[str]): + :param reset_num_timesteps: (bool) Whether to reset or not the `num_timesteps` attribute :return: (Tuple[int, np.ndarray, BaseCallback]) """ self.start_time = time.time() @@ -523,6 +525,9 @@ class BaseRLModel(ABC): timesteps_since_eval, episode_num = 0, 0 + if reset_num_timesteps: + self.num_timesteps = 0 + if eval_env is not None and self.seed is not None: eval_env.seed(self.seed) @@ -828,27 +833,3 @@ class BaseRLModel(ABC): params_to_save[name] = attr.state_dict() self._save_to_file_zip(path, data=data, params=params_to_save, tensors=tensors) - - def _eval_policy(self, eval_freq: int, eval_env: int, n_eval_episodes: int, - timesteps_since_eval: int, render: bool = False, deterministic: bool = True) -> int: - """ - Evaluate the current policy on a test environment. - - :param eval_freq: Evaluate the agent every `eval_freq` timesteps (this may vary a little) - :param n_eval_episodes: Number of episode to evaluate the agent - :parma timesteps_since_eval: Number of timesteps since last evaluation - :param deterministic: Whether to use deterministic or stochastic actions - :param render: Whether to render the eval env or not - :return: Number of timesteps since last evaluation - """ - if 0 < eval_freq <= timesteps_since_eval and eval_env is not None: - timesteps_since_eval %= eval_freq - # Synchronise the normalization stats if needed - sync_envs_normalization(self.env, eval_env) - mean_reward, std_reward = evaluate_policy(self, eval_env, n_eval_episodes, - render=render, deterministic=deterministic) - if self.verbose > 0: - print(f"Eval num_timesteps={self.num_timesteps}, " - f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}") - print(f"FPS: {self.num_timesteps / (time.time() - self.start_time):.2f}") - return timesteps_since_eval diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 188dfca..c5916ea 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -290,7 +290,8 @@ class PPO(BaseRLModel): eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="PPO", eval_log_path=None, reset_num_timesteps=True): - episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq, n_eval_episodes, eval_log_path) + episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq, + n_eval_episodes, eval_log_path, reset_num_timesteps) iteration = 0 if self.tensorboard_log is not None and SummaryWriter is not None: diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index 03caf7e..907e921 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -261,8 +261,8 @@ class SAC(BaseRLModel): eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="SAC", eval_log_path=None, reset_num_timesteps=True): - episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq, n_eval_episodes, eval_log_path) - + episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq, + n_eval_episodes, eval_log_path, reset_num_timesteps) callback.on_training_start(locals(), globals()) while self.num_timesteps < total_timesteps: diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index c39fdb7..1c2e417 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -254,7 +254,8 @@ class TD3(BaseRLModel): eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="TD3", eval_log_path=None, reset_num_timesteps=True): - episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq, n_eval_episodes, eval_log_path) + episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq, + n_eval_episodes, eval_log_path, reset_num_timesteps) callback.on_training_start(locals(), globals())