From c6f90b9c3c62f3f4e6869f49982d0989f15e6aa4 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 7 Nov 2019 11:17:26 +0100 Subject: [PATCH] Improve VecNormalize syncing for evaluation --- torchy_baselines/common/distributions.py | 2 +- torchy_baselines/ppo/ppo.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/torchy_baselines/common/distributions.py b/torchy_baselines/common/distributions.py index 90105ee..37f4cf8 100644 --- a/torchy_baselines/common/distributions.py +++ b/torchy_baselines/common/distributions.py @@ -180,6 +180,7 @@ class StateDependentNoiseDistribution(Distribution): self.exploration_mat = None self.use_expln = use_expln if squash_output: + print("== Using TanhBijector ===") self.bijector = TanhBijector(epsilon) else: self.bijector = None @@ -206,7 +207,6 @@ class StateDependentNoiseDistribution(Distribution): return mean_actions, log_std def proba_distribution(self, mean_actions, log_std, latent_pi, deterministic=False): - # TODO: try without detach variance = th.mm(latent_pi.detach() ** 2, self.get_std(log_std) ** 2) self.distribution = Normal(mean_actions, th.sqrt(variance)) diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index bbf880b..32e5f95 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -17,7 +17,7 @@ from torchy_baselines.common.base_class import BaseRLModel from torchy_baselines.common.evaluation import evaluate_policy from torchy_baselines.common.buffers import RolloutBuffer from torchy_baselines.common.utils import explained_variance, get_schedule_fn -from torchy_baselines.common.vec_env import VecNormalize +from torchy_baselines.common.vec_env import VecNormalize, VecEnvWrapper from torchy_baselines.common import logger from torchy_baselines.ppo.policies import PPOPolicy @@ -294,9 +294,14 @@ class PPO(BaseRLModel): # Evaluate agent if 0 < eval_freq <= timesteps_since_eval and eval_env is not None: timesteps_since_eval %= eval_freq + # TODO: move that to the base class # Sync eval env and train env when using VecNormalize - if isinstance(self.env, VecNormalize): - eval_env.obs_rms = deepcopy(self.env.obs_rms) + env_tmp, eval_env_tmp = self.env, eval_env + while isinstance(env_tmp, VecEnvWrapper): + if isinstance(env_tmp, VecNormalize): + eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) + env_tmp = env_tmp.venv + eval_env_tmp.venv mean_reward, _ = evaluate_policy(self, eval_env, n_eval_episodes) if self.tb_writer is not None: self.tb_writer.add_scalar('Eval/reward', mean_reward, self.num_timesteps)