Fix reset_num_timesteps

This commit is contained in:
Antonin Raffin 2020-01-31 13:16:28 +01:00
parent 6d59bfd4a0
commit 5d4e73544c
6 changed files with 16 additions and 31 deletions

View file

@ -19,6 +19,7 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Fix loading model on CPU that were trained on GPU
- Fix `reset_num_timesteps` that was not used
Deprecations:
^^^^^^^^^^^^^

View file

@ -103,7 +103,8 @@ class CEMRL(TD3):
eval_env=None, eval_freq=-1, n_eval_episodes=5,
tb_log_name="CEMRL", eval_log_path=None, reset_num_timesteps=True):
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq, n_eval_episodes, eval_log_path)
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq,
n_eval_episodes, eval_log_path, reset_num_timesteps)
actor_steps = 0
continue_training = True

View file

@ -364,7 +364,7 @@ class BaseRLModel(ABC):
@staticmethod
def _load_from_file(load_path: str, load_data: bool = True) -> (Tuple[Optional[Dict[str, Any]],
Optional[TensorDict],
Optional[OptimizerStateDict]]):
Optional[TensorDict]]):
""" Load model data from a .zip archive
:param load_path: Where to load the model from
@ -503,7 +503,8 @@ class BaseRLModel(ABC):
callback: Union[None, Callable, List[BaseCallback], BaseCallback] = None,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
log_path: Optional[str] = None
log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> Tuple[int, np.ndarray, BaseCallback]:
"""
Initialize different variables needed for training.
@ -513,6 +514,7 @@ class BaseRLModel(ABC):
:param eval_freq: (int)
:param n_eval_episodes: (int)
:param log_path (Optional[str]):
:param reset_num_timesteps: (bool) Whether to reset or not the `num_timesteps` attribute
:return: (Tuple[int, np.ndarray, BaseCallback])
"""
self.start_time = time.time()
@ -523,6 +525,9 @@ class BaseRLModel(ABC):
timesteps_since_eval, episode_num = 0, 0
if reset_num_timesteps:
self.num_timesteps = 0
if eval_env is not None and self.seed is not None:
eval_env.seed(self.seed)
@ -828,27 +833,3 @@ class BaseRLModel(ABC):
params_to_save[name] = attr.state_dict()
self._save_to_file_zip(path, data=data, params=params_to_save, tensors=tensors)
def _eval_policy(self, eval_freq: int, eval_env: int, n_eval_episodes: int,
timesteps_since_eval: int, render: bool = False, deterministic: bool = True) -> int:
"""
Evaluate the current policy on a test environment.
:param eval_freq: Evaluate the agent every `eval_freq` timesteps (this may vary a little)
:param n_eval_episodes: Number of episode to evaluate the agent
:parma timesteps_since_eval: Number of timesteps since last evaluation
:param deterministic: Whether to use deterministic or stochastic actions
:param render: Whether to render the eval env or not
:return: Number of timesteps since last evaluation
"""
if 0 < eval_freq <= timesteps_since_eval and eval_env is not None:
timesteps_since_eval %= eval_freq
# Synchronise the normalization stats if needed
sync_envs_normalization(self.env, eval_env)
mean_reward, std_reward = evaluate_policy(self, eval_env, n_eval_episodes,
render=render, deterministic=deterministic)
if self.verbose > 0:
print(f"Eval num_timesteps={self.num_timesteps}, "
f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
print(f"FPS: {self.num_timesteps / (time.time() - self.start_time):.2f}")
return timesteps_since_eval

View file

@ -290,7 +290,8 @@ class PPO(BaseRLModel):
eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="PPO",
eval_log_path=None, reset_num_timesteps=True):
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq, n_eval_episodes, eval_log_path)
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq,
n_eval_episodes, eval_log_path, reset_num_timesteps)
iteration = 0
if self.tensorboard_log is not None and SummaryWriter is not None:

View file

@ -261,8 +261,8 @@ class SAC(BaseRLModel):
eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="SAC",
eval_log_path=None, reset_num_timesteps=True):
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq, n_eval_episodes, eval_log_path)
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq,
n_eval_episodes, eval_log_path, reset_num_timesteps)
callback.on_training_start(locals(), globals())
while self.num_timesteps < total_timesteps:

View file

@ -254,7 +254,8 @@ class TD3(BaseRLModel):
eval_env=None, eval_freq=-1, n_eval_episodes=5,
tb_log_name="TD3", eval_log_path=None, reset_num_timesteps=True):
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq, n_eval_episodes, eval_log_path)
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq,
n_eval_episodes, eval_log_path, reset_num_timesteps)
callback.on_training_start(locals(), globals())