mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-03 03:59:13 +00:00
Fix reset_num_timesteps
This commit is contained in:
parent
6d59bfd4a0
commit
5d4e73544c
6 changed files with 16 additions and 31 deletions
|
|
@ -19,6 +19,7 @@ New Features:
|
|||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fix loading model on CPU that were trained on GPU
|
||||
- Fix `reset_num_timesteps` that was not used
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -103,7 +103,8 @@ class CEMRL(TD3):
|
|||
eval_env=None, eval_freq=-1, n_eval_episodes=5,
|
||||
tb_log_name="CEMRL", eval_log_path=None, reset_num_timesteps=True):
|
||||
|
||||
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq, n_eval_episodes, eval_log_path)
|
||||
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps)
|
||||
actor_steps = 0
|
||||
continue_training = True
|
||||
|
||||
|
|
|
|||
|
|
@ -364,7 +364,7 @@ class BaseRLModel(ABC):
|
|||
@staticmethod
|
||||
def _load_from_file(load_path: str, load_data: bool = True) -> (Tuple[Optional[Dict[str, Any]],
|
||||
Optional[TensorDict],
|
||||
Optional[OptimizerStateDict]]):
|
||||
Optional[TensorDict]]):
|
||||
""" Load model data from a .zip archive
|
||||
|
||||
:param load_path: Where to load the model from
|
||||
|
|
@ -503,7 +503,8 @@ class BaseRLModel(ABC):
|
|||
callback: Union[None, Callable, List[BaseCallback], BaseCallback] = None,
|
||||
eval_freq: int = 10000,
|
||||
n_eval_episodes: int = 5,
|
||||
log_path: Optional[str] = None
|
||||
log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> Tuple[int, np.ndarray, BaseCallback]:
|
||||
"""
|
||||
Initialize different variables needed for training.
|
||||
|
|
@ -513,6 +514,7 @@ class BaseRLModel(ABC):
|
|||
:param eval_freq: (int)
|
||||
:param n_eval_episodes: (int)
|
||||
:param log_path (Optional[str]):
|
||||
:param reset_num_timesteps: (bool) Whether to reset or not the `num_timesteps` attribute
|
||||
:return: (Tuple[int, np.ndarray, BaseCallback])
|
||||
"""
|
||||
self.start_time = time.time()
|
||||
|
|
@ -523,6 +525,9 @@ class BaseRLModel(ABC):
|
|||
|
||||
timesteps_since_eval, episode_num = 0, 0
|
||||
|
||||
if reset_num_timesteps:
|
||||
self.num_timesteps = 0
|
||||
|
||||
if eval_env is not None and self.seed is not None:
|
||||
eval_env.seed(self.seed)
|
||||
|
||||
|
|
@ -828,27 +833,3 @@ class BaseRLModel(ABC):
|
|||
params_to_save[name] = attr.state_dict()
|
||||
|
||||
self._save_to_file_zip(path, data=data, params=params_to_save, tensors=tensors)
|
||||
|
||||
def _eval_policy(self, eval_freq: int, eval_env: int, n_eval_episodes: int,
|
||||
timesteps_since_eval: int, render: bool = False, deterministic: bool = True) -> int:
|
||||
"""
|
||||
Evaluate the current policy on a test environment.
|
||||
|
||||
:param eval_freq: Evaluate the agent every `eval_freq` timesteps (this may vary a little)
|
||||
:param n_eval_episodes: Number of episode to evaluate the agent
|
||||
:parma timesteps_since_eval: Number of timesteps since last evaluation
|
||||
:param deterministic: Whether to use deterministic or stochastic actions
|
||||
:param render: Whether to render the eval env or not
|
||||
:return: Number of timesteps since last evaluation
|
||||
"""
|
||||
if 0 < eval_freq <= timesteps_since_eval and eval_env is not None:
|
||||
timesteps_since_eval %= eval_freq
|
||||
# Synchronise the normalization stats if needed
|
||||
sync_envs_normalization(self.env, eval_env)
|
||||
mean_reward, std_reward = evaluate_policy(self, eval_env, n_eval_episodes,
|
||||
render=render, deterministic=deterministic)
|
||||
if self.verbose > 0:
|
||||
print(f"Eval num_timesteps={self.num_timesteps}, "
|
||||
f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
|
||||
print(f"FPS: {self.num_timesteps / (time.time() - self.start_time):.2f}")
|
||||
return timesteps_since_eval
|
||||
|
|
|
|||
|
|
@ -290,7 +290,8 @@ class PPO(BaseRLModel):
|
|||
eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="PPO",
|
||||
eval_log_path=None, reset_num_timesteps=True):
|
||||
|
||||
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq, n_eval_episodes, eval_log_path)
|
||||
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps)
|
||||
iteration = 0
|
||||
|
||||
if self.tensorboard_log is not None and SummaryWriter is not None:
|
||||
|
|
|
|||
|
|
@ -261,8 +261,8 @@ class SAC(BaseRLModel):
|
|||
eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="SAC",
|
||||
eval_log_path=None, reset_num_timesteps=True):
|
||||
|
||||
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq, n_eval_episodes, eval_log_path)
|
||||
|
||||
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps)
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
|
|
|
|||
|
|
@ -254,7 +254,8 @@ class TD3(BaseRLModel):
|
|||
eval_env=None, eval_freq=-1, n_eval_episodes=5,
|
||||
tb_log_name="TD3", eval_log_path=None, reset_num_timesteps=True):
|
||||
|
||||
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq, n_eval_episodes, eval_log_path)
|
||||
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps)
|
||||
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue