Improve VecNormalize syncing for evaluation

This commit is contained in:
Antonin Raffin 2019-11-07 11:17:26 +01:00
parent 6c7c8375a4
commit c6f90b9c3c
2 changed files with 9 additions and 4 deletions

View file

@ -180,6 +180,7 @@ class StateDependentNoiseDistribution(Distribution):
self.exploration_mat = None
self.use_expln = use_expln
if squash_output:
print("== Using TanhBijector ===")
self.bijector = TanhBijector(epsilon)
else:
self.bijector = None
@ -206,7 +207,6 @@ class StateDependentNoiseDistribution(Distribution):
return mean_actions, log_std
def proba_distribution(self, mean_actions, log_std, latent_pi, deterministic=False):
# TODO: try without detach
variance = th.mm(latent_pi.detach() ** 2, self.get_std(log_std) ** 2)
self.distribution = Normal(mean_actions, th.sqrt(variance))

View file

@ -17,7 +17,7 @@ from torchy_baselines.common.base_class import BaseRLModel
from torchy_baselines.common.evaluation import evaluate_policy
from torchy_baselines.common.buffers import RolloutBuffer
from torchy_baselines.common.utils import explained_variance, get_schedule_fn
from torchy_baselines.common.vec_env import VecNormalize
from torchy_baselines.common.vec_env import VecNormalize, VecEnvWrapper
from torchy_baselines.common import logger
from torchy_baselines.ppo.policies import PPOPolicy
@ -294,9 +294,14 @@ class PPO(BaseRLModel):
# Evaluate agent
if 0 < eval_freq <= timesteps_since_eval and eval_env is not None:
timesteps_since_eval %= eval_freq
# TODO: move that to the base class
# Sync eval env and train env when using VecNormalize
if isinstance(self.env, VecNormalize):
eval_env.obs_rms = deepcopy(self.env.obs_rms)
env_tmp, eval_env_tmp = self.env, eval_env
while isinstance(env_tmp, VecEnvWrapper):
if isinstance(env_tmp, VecNormalize):
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
env_tmp = env_tmp.venv
eval_env_tmp.venv
mean_reward, _ = evaluate_policy(self, eval_env, n_eval_episodes)
if self.tb_writer is not None:
self.tb_writer.add_scalar('Eval/reward', mean_reward, self.num_timesteps)