mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-06 00:03:28 +00:00
Log success rate
This commit is contained in:
parent
8acac6b0f4
commit
31a862c3a9
2 changed files with 17 additions and 6 deletions
|
|
@ -95,6 +95,9 @@ class BaseRLModel(ABC):
|
|||
# Track the training progress (from 1 to 0)
|
||||
# this is used to update the learning rate
|
||||
self._current_progress = 1
|
||||
# Buffers for logging
|
||||
self.ep_info_buffer = None # type: deque
|
||||
self.ep_success_buffer = None # type: deque
|
||||
|
||||
# Create and wrap the env if needed
|
||||
if env is not None:
|
||||
|
|
@ -194,7 +197,7 @@ class BaseRLModel(ABC):
|
|||
update_learning_rate(optimizer, self.learning_rate(self._current_progress))
|
||||
|
||||
@staticmethod
|
||||
def safe_mean(arr: Union[np.ndarray, list]) -> np.ndarray:
|
||||
def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray:
|
||||
"""
|
||||
Compute the mean of an array if there is at least one element.
|
||||
For empty array, return NaN. It is used for logging only.
|
||||
|
|
@ -514,6 +517,7 @@ class BaseRLModel(ABC):
|
|||
"""
|
||||
self.start_time = time.time()
|
||||
self.ep_info_buffer = deque(maxlen=100)
|
||||
self.ep_success_buffer = deque(maxlen=100)
|
||||
|
||||
if self.action_noise is not None:
|
||||
self.action_noise.reset()
|
||||
|
|
@ -534,17 +538,22 @@ class BaseRLModel(ABC):
|
|||
|
||||
return episode_num, obs, callback
|
||||
|
||||
def _update_info_buffer(self, infos: List[Dict[str, Any]]) -> None:
|
||||
def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
|
||||
"""
|
||||
Retrieve reward and episode length and update the buffer
|
||||
if using Monitor wrapper.
|
||||
|
||||
:param infos: ([dict])
|
||||
"""
|
||||
for info in infos:
|
||||
if dones is None:
|
||||
dones = np.array([False] * len(infos))
|
||||
for idx, info in enumerate(infos):
|
||||
maybe_ep_info = info.get('episode')
|
||||
maybe_is_success = info.get('is_success')
|
||||
if maybe_ep_info is not None:
|
||||
self.ep_info_buffer.extend([maybe_ep_info])
|
||||
if maybe_is_success is not None and dones[idx]:
|
||||
self.ep_success_buffer.append(maybe_is_success)
|
||||
|
||||
@staticmethod
|
||||
def _save_to_file_zip(save_path: str, data: Dict[str, Any] = None,
|
||||
|
|
@ -696,7 +705,6 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
self.on_policy_exploration = False
|
||||
self.actor = None
|
||||
self.replay_buffer = None # type: Optional[ReplayBuffer]
|
||||
self.ep_info_buffer = None # type: deque
|
||||
self.use_sde_at_warmup = use_sde_at_warmup
|
||||
|
||||
def save_replay_buffer(self):
|
||||
|
|
@ -800,7 +808,7 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
episode_reward += reward
|
||||
|
||||
# Retrieve reward and episode length if using Monitor wrapper
|
||||
self._update_info_buffer(infos)
|
||||
self._update_info_buffer(infos, done)
|
||||
|
||||
# Store data in replay buffer
|
||||
if replay_buffer is not None:
|
||||
|
|
@ -858,6 +866,9 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
logger.logkv("total timesteps", self.num_timesteps)
|
||||
if self.use_sde:
|
||||
logger.logkv("std", (self.actor.get_std()).mean().item())
|
||||
|
||||
if len(self.ep_success_buffer) > 0:
|
||||
logger.logkv('success rate', self.safe_mean(self.ep_success_buffer))
|
||||
logger.dumpkvs()
|
||||
|
||||
mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class Monitor(gym.Wrapper):
|
|||
|
||||
def __init__(self,
|
||||
env: gym.Env,
|
||||
filename: Optional[str],
|
||||
filename: Optional[str] = None,
|
||||
allow_early_resets: bool = True,
|
||||
reset_keywords=(),
|
||||
info_keywords=()):
|
||||
|
|
|
|||
Loading…
Reference in a new issue