mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-28 22:56:53 +00:00
Fixed `reset_num_timesteps` behavior
This commit is contained in:
parent
08a22c4834
commit
0e44cdce44
7 changed files with 66 additions and 79 deletions
|
|
@ -35,6 +35,24 @@ Others:
|
|||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
Pre-Release 0.5.0a0 (WIP)
|
||||
------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed ``reset_num_timesteps`` behavior, so ``env.reset()`` is not called if ``reset_num_timesteps=True``
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Cleanup rollout return
|
||||
|
||||
|
||||
|
||||
Pre-Release 0.3.0 (2020-02-14)
|
||||
------------------------------
|
||||
|
|
@ -57,9 +75,6 @@ Bug Fixes:
|
|||
- Fixed colors in ``results_plotter``
|
||||
- Fix entropy computation (now summed over action dim)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- SAC with SDE now sample only one matrix
|
||||
|
|
@ -106,9 +121,6 @@ Bug Fixes:
|
|||
- Fix entropy computation for squashed Gaussian (approximate it now)
|
||||
- Fix seeding when using multiple environments (different seed per env)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Add type check
|
||||
|
|
@ -125,25 +137,11 @@ Pre-Release 0.1.0 (2020-01-20)
|
|||
------------------------------
|
||||
**First Release: base algorithms and state-dependent exploration**
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Initial release of A2C, CEM-RL, PPO, SAC and TD3, working only with ``Box`` input space
|
||||
- State-Dependent Exploration (SDE) for A2C, PPO, SAC and TD3
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Maintainers
|
||||
|
|
|
|||
|
|
@ -93,7 +93,11 @@ class BaseRLModel(ABC):
|
|||
self.start_time = None
|
||||
self.policy = None
|
||||
self.learning_rate = learning_rate
|
||||
self.lr_schedule = None # type: Optional[Callable]
|
||||
self.lr_schedule = None # type: Optional[Callable]
|
||||
self._last_obs = None # type: Optional[np.ndarray]
|
||||
# When using VecNormalize:
|
||||
self._last_original_obs = None # type: Optional[np.ndarray]
|
||||
self._episode_num = 0
|
||||
# Used for SDE only
|
||||
self.use_sde = use_sde
|
||||
self.sde_sample_freq = sde_sample_freq
|
||||
|
|
@ -486,7 +490,7 @@ class BaseRLModel(ABC):
|
|||
n_eval_episodes: int = 5,
|
||||
log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> Tuple[int, np.ndarray, BaseCallback]:
|
||||
) -> 'BaseCallback':
|
||||
"""
|
||||
Initialize different variables needed for training.
|
||||
|
||||
|
|
@ -496,7 +500,7 @@ class BaseRLModel(ABC):
|
|||
:param n_eval_episodes: (int)
|
||||
:param log_path (Optional[str]): Path to a log folder
|
||||
:param reset_num_timesteps: (bool) Whether to reset or not the `num_timesteps` attribute
|
||||
:return: (Tuple[int, np.ndarray, BaseCallback])
|
||||
:return: (BaseCallback)
|
||||
"""
|
||||
self.start_time = time.time()
|
||||
self.ep_info_buffer = deque(maxlen=100)
|
||||
|
|
@ -505,21 +509,26 @@ class BaseRLModel(ABC):
|
|||
if self.action_noise is not None:
|
||||
self.action_noise.reset()
|
||||
|
||||
timesteps_since_eval, episode_num = 0, 0
|
||||
|
||||
if reset_num_timesteps:
|
||||
self.num_timesteps = 0
|
||||
self._episode_num = 0
|
||||
|
||||
# Avoid resetting the environment when calling `.learn()` consecutive times
|
||||
if reset_num_timesteps or self._last_obs is None:
|
||||
self._last_obs = self.env.reset()
|
||||
# Retrieve unnormalized observation for saving into the buffer
|
||||
if self._vec_normalize_env is not None:
|
||||
self._last_original_obs = self._vec_normalize_env.get_original_obs()
|
||||
|
||||
if eval_env is not None and self.seed is not None:
|
||||
eval_env.seed(self.seed)
|
||||
|
||||
eval_env = self._get_eval_env(eval_env)
|
||||
obs = self.env.reset()
|
||||
|
||||
# Create eval callback if needed
|
||||
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path)
|
||||
|
||||
return episode_num, obs, callback
|
||||
return callback
|
||||
|
||||
def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
|
||||
"""
|
||||
|
|
@ -744,8 +753,6 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
action_noise: Optional[ActionNoise] = None,
|
||||
learning_starts: int = 0,
|
||||
replay_buffer: Optional[ReplayBuffer] = None,
|
||||
obs: Optional[np.ndarray] = None,
|
||||
episode_num: int = 0,
|
||||
log_interval: Optional[int] = None) -> RolloutReturn:
|
||||
"""
|
||||
Collect rollout using the current policy (and possibly fill the replay buffer)
|
||||
|
|
@ -762,8 +769,6 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
(and at the beginning and end of the rollout)
|
||||
:param learning_starts: (int) Number of steps before learning for the warm-up phase.
|
||||
:param replay_buffer: (ReplayBuffer)
|
||||
:param obs: (np.ndarray) Last observation from the environment
|
||||
:param episode_num: (int) Episode index
|
||||
:param log_interval: (int) Log data every `log_interval` episodes
|
||||
:return: (RolloutReturn)
|
||||
"""
|
||||
|
|
@ -773,10 +778,6 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
assert isinstance(env, VecEnv), "You must pass a VecEnv"
|
||||
assert env.num_envs == 1, "OffPolicyRLModel only support single environment"
|
||||
|
||||
# Retrieve unnormalized observation for saving into the buffer
|
||||
if self._vec_normalize_env is not None:
|
||||
obs_ = self._vec_normalize_env.get_original_obs()
|
||||
|
||||
self.rollout_data = None
|
||||
if self.use_sde:
|
||||
self.actor.reset_noise()
|
||||
|
|
@ -804,7 +805,7 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
else:
|
||||
# Note: we assume that the policy uses tanh to scale the action
|
||||
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
|
||||
unscaled_action, _ = self.predict(obs, deterministic=False)
|
||||
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
|
||||
|
||||
# Rescale the action from [low, high] to [-1, 1]
|
||||
scaled_action = self.policy.scale_action(unscaled_action)
|
||||
|
|
@ -827,7 +828,7 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
|
||||
# Only stop training if return value is False, not when it is None.
|
||||
if callback.on_step() is False:
|
||||
return RolloutReturn(0.0, total_steps, total_episodes, None, continue_training=False)
|
||||
return RolloutReturn(0.0, total_steps, total_episodes, continue_training=False)
|
||||
|
||||
episode_reward += reward
|
||||
|
||||
|
|
@ -842,25 +843,23 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
reward_ = self._vec_normalize_env.get_original_reward()
|
||||
else:
|
||||
# Avoid changing the original ones
|
||||
obs_, new_obs_, reward_ = obs, new_obs, reward
|
||||
self._last_original_obs, new_obs_, reward_ = self._last_obs, new_obs, reward
|
||||
|
||||
replay_buffer.add(obs_, new_obs_, clipped_action, reward_, done)
|
||||
replay_buffer.add(self._last_original_obs, new_obs_, clipped_action, reward_, done)
|
||||
|
||||
if self.rollout_data is not None:
|
||||
# Assume only one env
|
||||
self.rollout_data['observations'].append(obs[0].copy())
|
||||
self.rollout_data['observations'].append(self._last_obs[0].copy())
|
||||
self.rollout_data['actions'].append(scaled_action[0].copy())
|
||||
self.rollout_data['rewards'].append(reward[0].copy())
|
||||
self.rollout_data['dones'].append(done[0].copy())
|
||||
obs_tensor = th.FloatTensor(obs).to(self.device)
|
||||
obs_tensor = th.FloatTensor(self._last_obs).to(self.device)
|
||||
self.rollout_data['values'].append(self.vf_net(obs_tensor)[0].cpu().detach().numpy())
|
||||
|
||||
obs = new_obs
|
||||
# Save the true unnormalized observation
|
||||
# otherwise obs_ = self._vec_normalize_env.unnormalize_obs(obs)
|
||||
# is a good approximation
|
||||
self._last_obs = new_obs
|
||||
# Save the unnormalized observation
|
||||
if self._vec_normalize_env is not None:
|
||||
obs_ = new_obs_
|
||||
self._last_original_obs = new_obs_
|
||||
|
||||
self.num_timesteps += 1
|
||||
episode_timesteps += 1
|
||||
|
|
@ -870,16 +869,16 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
|
||||
if done:
|
||||
total_episodes += 1
|
||||
self._episode_num += 1
|
||||
episode_rewards.append(episode_reward)
|
||||
total_timesteps.append(episode_timesteps)
|
||||
if action_noise is not None:
|
||||
action_noise.reset()
|
||||
|
||||
# Display training infos
|
||||
if self.verbose >= 1 and log_interval is not None and (
|
||||
episode_num + total_episodes) % log_interval == 0:
|
||||
if self.verbose >= 1 and log_interval is not None and (self._episode_num) % log_interval == 0:
|
||||
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
||||
logger.logkv("episodes", episode_num + total_episodes)
|
||||
logger.logkv("episodes", self._episode_num)
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
logger.logkv('ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
|
||||
logger.logkv('ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
|
||||
|
|
@ -909,7 +908,7 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
for step in reversed(range(len(self.rollout_data['rewards']))):
|
||||
if step == len(self.rollout_data['rewards']) - 1:
|
||||
next_non_terminal = 1.0 - done[0]
|
||||
next_value = self.vf_net(th.FloatTensor(obs).to(self.device))[0].detach()
|
||||
next_value = self.vf_net(th.FloatTensor(self._last_obs).to(self.device))[0].detach()
|
||||
last_return = self.rollout_data['rewards'][step] + next_non_terminal * next_value
|
||||
else:
|
||||
next_non_terminal = 1.0 - self.rollout_data['dones'][step + 1]
|
||||
|
|
@ -919,4 +918,4 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return RolloutReturn(mean_reward, total_steps, total_episodes, obs, continue_training)
|
||||
return RolloutReturn(mean_reward, total_steps, total_episodes, continue_training)
|
||||
|
|
|
|||
|
|
@ -38,5 +38,4 @@ class RolloutReturn(NamedTuple):
|
|||
episode_reward: float
|
||||
episode_timesteps: int
|
||||
n_episodes: int
|
||||
obs: Optional[np.ndarray]
|
||||
continue_training: bool
|
||||
|
|
|
|||
|
|
@ -141,12 +141,10 @@ class PPO(BaseRLModel):
|
|||
env: VecEnv,
|
||||
callback: BaseCallback,
|
||||
rollout_buffer: RolloutBuffer,
|
||||
n_rollout_steps: int = 256,
|
||||
obs: Optional[np.ndarray] = None) -> Tuple[Optional[np.ndarray], bool]:
|
||||
n_rollout_steps: int = 256) -> bool:
|
||||
|
||||
assert obs is not None, "No previous observation was provided"
|
||||
assert self._last_obs is not None, "No previous observation was provided"
|
||||
n_steps = 0
|
||||
continue_training = True
|
||||
rollout_buffer.reset()
|
||||
# Sample new weights for the state dependent exploration
|
||||
if self.use_sde:
|
||||
|
|
@ -162,7 +160,7 @@ class PPO(BaseRLModel):
|
|||
|
||||
with th.no_grad():
|
||||
# Convert to pytorch tensor
|
||||
obs_tensor = th.as_tensor(obs).to(self.device)
|
||||
obs_tensor = th.as_tensor(self._last_obs).to(self.device)
|
||||
actions, values, log_probs = self.policy.forward(obs_tensor)
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
|
|
@ -171,11 +169,11 @@ class PPO(BaseRLModel):
|
|||
# Clip the actions to avoid out of bound error
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
|
||||
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
||||
|
||||
if callback.on_step() is False:
|
||||
continue_training = False
|
||||
return None, continue_training
|
||||
return False
|
||||
|
||||
self._update_info_buffer(infos)
|
||||
n_steps += 1
|
||||
|
|
@ -184,14 +182,14 @@ class PPO(BaseRLModel):
|
|||
if isinstance(self.action_space, gym.spaces.Discrete):
|
||||
# Reshape in case of discrete action
|
||||
actions = actions.reshape(-1, 1)
|
||||
rollout_buffer.add(obs, actions, rewards, dones, values, log_probs)
|
||||
obs = new_obs
|
||||
rollout_buffer.add(self._last_obs, actions, rewards, dones, values, log_probs)
|
||||
self._last_obs = new_obs
|
||||
|
||||
rollout_buffer.compute_returns_and_advantage(values, dones=dones)
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return obs, continue_training
|
||||
return True
|
||||
|
||||
def train(self, n_epochs: int, batch_size: int = 64) -> None:
|
||||
# Update optimizer learning rate
|
||||
|
|
@ -307,9 +305,9 @@ class PPO(BaseRLModel):
|
|||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True) -> 'PPO':
|
||||
|
||||
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps)
|
||||
iteration = 0
|
||||
callback = self._setup_learn(eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps)
|
||||
|
||||
# if self.tensorboard_log is not None and SummaryWriter is not None:
|
||||
# self.tb_writer = SummaryWriter(log_dir=os.path.join(self.tensorboard_log, tb_log_name))
|
||||
|
|
@ -318,10 +316,9 @@ class PPO(BaseRLModel):
|
|||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
|
||||
obs, continue_training = self.collect_rollouts(self.env, callback,
|
||||
self.rollout_buffer,
|
||||
n_rollout_steps=self.n_steps,
|
||||
obs=obs)
|
||||
continue_training = self.collect_rollouts(self.env, callback,
|
||||
self.rollout_buffer,
|
||||
n_rollout_steps=self.n_steps)
|
||||
|
||||
if continue_training is False:
|
||||
break
|
||||
|
|
|
|||
|
|
@ -256,8 +256,8 @@ class SAC(OffPolicyRLModel):
|
|||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True) -> OffPolicyRLModel:
|
||||
|
||||
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps)
|
||||
callback = self._setup_learn(eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps)
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
|
|
@ -266,14 +266,11 @@ class SAC(OffPolicyRLModel):
|
|||
callback=callback,
|
||||
learning_starts=self.learning_starts,
|
||||
replay_buffer=self.replay_buffer,
|
||||
obs=obs, episode_num=episode_num,
|
||||
log_interval=log_interval)
|
||||
|
||||
if rollout.continue_training is False:
|
||||
break
|
||||
|
||||
obs = rollout.obs
|
||||
episode_num += rollout.n_episodes
|
||||
self._update_current_progress(self.num_timesteps, total_timesteps)
|
||||
|
||||
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
|
||||
|
|
|
|||
|
|
@ -235,8 +235,8 @@ class TD3(OffPolicyRLModel):
|
|||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True) -> OffPolicyRLModel:
|
||||
|
||||
episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps)
|
||||
callback = self._setup_learn(eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps)
|
||||
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
|
|
@ -247,14 +247,11 @@ class TD3(OffPolicyRLModel):
|
|||
callback=callback,
|
||||
learning_starts=self.learning_starts,
|
||||
replay_buffer=self.replay_buffer,
|
||||
obs=obs, episode_num=episode_num,
|
||||
log_interval=log_interval)
|
||||
|
||||
if rollout.continue_training is False:
|
||||
break
|
||||
|
||||
obs = rollout.obs
|
||||
episode_num += rollout.n_episodes
|
||||
self._update_current_progress(self.num_timesteps, total_timesteps)
|
||||
|
||||
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
0.4.0
|
||||
0.5.0a0
|
||||
|
|
|
|||
Loading…
Reference in a new issue