diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 9d642b4..ce633e9 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -299,7 +299,7 @@ PyBullet: Normalizing input features Normalizing input features may be essential to successful training of an RL agent (by default, images are scaled but not other types of input), -for instance when training on `PyBullet `_. For that, a wrapper exists and +for instance when training on `PyBullet `_ environments. For that, a wrapper exists and will compute a running average and standard deviation of input features (it can do the same for rewards). @@ -311,12 +311,13 @@ will compute a running average and standard deviation of input features (it can .. code-block:: python import gym + import pybullet_envs from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize from stable_baselines3 import PPO env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")]) - # Automatically normalize the input features + # Automatically normalize the input features and reward env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.) @@ -325,8 +326,23 @@ will compute a running average and standard deviation of input features (it can # Don't forget to save the VecNormalize statistics when saving the agent log_dir = "/tmp/" - model.save(log_dir + "ppo_reacher") - env.save(os.path.join(log_dir, "vec_normalize.pkl")) + model.save(log_dir + "ppo_halfcheetah") + stats_path = os.path.join(log_dir, "vec_normalize.pkl") + env.save(stats_path) + + # To demonstrate loading + del model, env + + # Load the agent + model = PPO.load(log_dir + "ppo_halfcheetah") + + # Load the saved statistics + env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")]) + env = VecNormalize.load(stats_path, env) + # do not update them at test time + env.training = False + # reward normalization is not needed at test time + env.norm_reward = False Record a Video