From 31a862c3a99ac325577af6dfc4e4a188af77077f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 4 Feb 2020 13:24:09 +0100 Subject: [PATCH] Log success rate --- torchy_baselines/common/base_class.py | 21 ++++++++++++++++----- torchy_baselines/common/monitor.py | 2 +- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 5cb75d7..ee6ec7f 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -95,6 +95,9 @@ class BaseRLModel(ABC): # Track the training progress (from 1 to 0) # this is used to update the learning rate self._current_progress = 1 + # Buffers for logging + self.ep_info_buffer = None # type: deque + self.ep_success_buffer = None # type: deque # Create and wrap the env if needed if env is not None: @@ -194,7 +197,7 @@ class BaseRLModel(ABC): update_learning_rate(optimizer, self.learning_rate(self._current_progress)) @staticmethod - def safe_mean(arr: Union[np.ndarray, list]) -> np.ndarray: + def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray: """ Compute the mean of an array if there is at least one element. For empty array, return NaN. It is used for logging only. @@ -514,6 +517,7 @@ class BaseRLModel(ABC): """ self.start_time = time.time() self.ep_info_buffer = deque(maxlen=100) + self.ep_success_buffer = deque(maxlen=100) if self.action_noise is not None: self.action_noise.reset() @@ -534,17 +538,22 @@ class BaseRLModel(ABC): return episode_num, obs, callback - def _update_info_buffer(self, infos: List[Dict[str, Any]]) -> None: + def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None: """ Retrieve reward and episode length and update the buffer if using Monitor wrapper. :param infos: ([dict]) """ - for info in infos: + if dones is None: + dones = np.array([False] * len(infos)) + for idx, info in enumerate(infos): maybe_ep_info = info.get('episode') + maybe_is_success = info.get('is_success') if maybe_ep_info is not None: self.ep_info_buffer.extend([maybe_ep_info]) + if maybe_is_success is not None and dones[idx]: + self.ep_success_buffer.append(maybe_is_success) @staticmethod def _save_to_file_zip(save_path: str, data: Dict[str, Any] = None, @@ -696,7 +705,6 @@ class OffPolicyRLModel(BaseRLModel): self.on_policy_exploration = False self.actor = None self.replay_buffer = None # type: Optional[ReplayBuffer] - self.ep_info_buffer = None # type: deque self.use_sde_at_warmup = use_sde_at_warmup def save_replay_buffer(self): @@ -800,7 +808,7 @@ class OffPolicyRLModel(BaseRLModel): episode_reward += reward # Retrieve reward and episode length if using Monitor wrapper - self._update_info_buffer(infos) + self._update_info_buffer(infos, done) # Store data in replay buffer if replay_buffer is not None: @@ -858,6 +866,9 @@ class OffPolicyRLModel(BaseRLModel): logger.logkv("total timesteps", self.num_timesteps) if self.use_sde: logger.logkv("std", (self.actor.get_std()).mean().item()) + + if len(self.ep_success_buffer) > 0: + logger.logkv('success rate', self.safe_mean(self.ep_success_buffer)) logger.dumpkvs() mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0 diff --git a/torchy_baselines/common/monitor.py b/torchy_baselines/common/monitor.py index 8d460d2..88bea45 100644 --- a/torchy_baselines/common/monitor.py +++ b/torchy_baselines/common/monitor.py @@ -18,7 +18,7 @@ class Monitor(gym.Wrapper): def __init__(self, env: gym.Env, - filename: Optional[str], + filename: Optional[str] = None, allow_early_resets: bool = True, reset_keywords=(), info_keywords=()):