Log success rate

This commit is contained in:
Antonin Raffin 2020-02-04 13:24:09 +01:00
parent 8acac6b0f4
commit 31a862c3a9
2 changed files with 17 additions and 6 deletions

View file

@ -95,6 +95,9 @@ class BaseRLModel(ABC):
# Track the training progress (from 1 to 0)
# this is used to update the learning rate
self._current_progress = 1
# Buffers for logging
self.ep_info_buffer = None # type: deque
self.ep_success_buffer = None # type: deque
# Create and wrap the env if needed
if env is not None:
@ -194,7 +197,7 @@ class BaseRLModel(ABC):
update_learning_rate(optimizer, self.learning_rate(self._current_progress))
@staticmethod
def safe_mean(arr: Union[np.ndarray, list]) -> np.ndarray:
def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray:
"""
Compute the mean of an array if there is at least one element.
For empty array, return NaN. It is used for logging only.
@ -514,6 +517,7 @@ class BaseRLModel(ABC):
"""
self.start_time = time.time()
self.ep_info_buffer = deque(maxlen=100)
self.ep_success_buffer = deque(maxlen=100)
if self.action_noise is not None:
self.action_noise.reset()
@ -534,17 +538,22 @@ class BaseRLModel(ABC):
return episode_num, obs, callback
def _update_info_buffer(self, infos: List[Dict[str, Any]]) -> None:
def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
"""
Retrieve reward and episode length and update the buffer
if using Monitor wrapper.
:param infos: ([dict])
"""
for info in infos:
if dones is None:
dones = np.array([False] * len(infos))
for idx, info in enumerate(infos):
maybe_ep_info = info.get('episode')
maybe_is_success = info.get('is_success')
if maybe_ep_info is not None:
self.ep_info_buffer.extend([maybe_ep_info])
if maybe_is_success is not None and dones[idx]:
self.ep_success_buffer.append(maybe_is_success)
@staticmethod
def _save_to_file_zip(save_path: str, data: Dict[str, Any] = None,
@ -696,7 +705,6 @@ class OffPolicyRLModel(BaseRLModel):
self.on_policy_exploration = False
self.actor = None
self.replay_buffer = None # type: Optional[ReplayBuffer]
self.ep_info_buffer = None # type: deque
self.use_sde_at_warmup = use_sde_at_warmup
def save_replay_buffer(self):
@ -800,7 +808,7 @@ class OffPolicyRLModel(BaseRLModel):
episode_reward += reward
# Retrieve reward and episode length if using Monitor wrapper
self._update_info_buffer(infos)
self._update_info_buffer(infos, done)
# Store data in replay buffer
if replay_buffer is not None:
@ -858,6 +866,9 @@ class OffPolicyRLModel(BaseRLModel):
logger.logkv("total timesteps", self.num_timesteps)
if self.use_sde:
logger.logkv("std", (self.actor.get_std()).mean().item())
if len(self.ep_success_buffer) > 0:
logger.logkv('success rate', self.safe_mean(self.ep_success_buffer))
logger.dumpkvs()
mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0

View file

@ -18,7 +18,7 @@ class Monitor(gym.Wrapper):
def __init__(self,
env: gym.Env,
filename: Optional[str],
filename: Optional[str] = None,
allow_early_resets: bool = True,
reset_keywords=(),
info_keywords=()):