mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-03 03:59:13 +00:00
Improve VecNormalize syncing for evaluation
This commit is contained in:
parent
6c7c8375a4
commit
c6f90b9c3c
2 changed files with 9 additions and 4 deletions
|
|
@ -180,6 +180,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
self.exploration_mat = None
|
||||
self.use_expln = use_expln
|
||||
if squash_output:
|
||||
print("== Using TanhBijector ===")
|
||||
self.bijector = TanhBijector(epsilon)
|
||||
else:
|
||||
self.bijector = None
|
||||
|
|
@ -206,7 +207,6 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
return mean_actions, log_std
|
||||
|
||||
def proba_distribution(self, mean_actions, log_std, latent_pi, deterministic=False):
|
||||
# TODO: try without detach
|
||||
variance = th.mm(latent_pi.detach() ** 2, self.get_std(log_std) ** 2)
|
||||
self.distribution = Normal(mean_actions, th.sqrt(variance))
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from torchy_baselines.common.base_class import BaseRLModel
|
|||
from torchy_baselines.common.evaluation import evaluate_policy
|
||||
from torchy_baselines.common.buffers import RolloutBuffer
|
||||
from torchy_baselines.common.utils import explained_variance, get_schedule_fn
|
||||
from torchy_baselines.common.vec_env import VecNormalize
|
||||
from torchy_baselines.common.vec_env import VecNormalize, VecEnvWrapper
|
||||
from torchy_baselines.common import logger
|
||||
from torchy_baselines.ppo.policies import PPOPolicy
|
||||
|
||||
|
|
@ -294,9 +294,14 @@ class PPO(BaseRLModel):
|
|||
# Evaluate agent
|
||||
if 0 < eval_freq <= timesteps_since_eval and eval_env is not None:
|
||||
timesteps_since_eval %= eval_freq
|
||||
# TODO: move that to the base class
|
||||
# Sync eval env and train env when using VecNormalize
|
||||
if isinstance(self.env, VecNormalize):
|
||||
eval_env.obs_rms = deepcopy(self.env.obs_rms)
|
||||
env_tmp, eval_env_tmp = self.env, eval_env
|
||||
while isinstance(env_tmp, VecEnvWrapper):
|
||||
if isinstance(env_tmp, VecNormalize):
|
||||
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
|
||||
env_tmp = env_tmp.venv
|
||||
eval_env_tmp.venv
|
||||
mean_reward, _ = evaluate_policy(self, eval_env, n_eval_episodes)
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.add_scalar('Eval/reward', mean_reward, self.num_timesteps)
|
||||
|
|
|
|||
Loading…
Reference in a new issue