Update PyBullet example

This commit is contained in:
Antonin RAFFIN 2020-05-09 14:38:57 +02:00
parent b1f5db1bb2
commit a06c4a7859

View file

@ -299,7 +299,7 @@ PyBullet: Normalizing input features
Normalizing input features may be essential to successful training of an RL agent
(by default, images are scaled but not other types of input),
for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`_. For that, a wrapper exists and
for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`_ environments. For that, a wrapper exists and
will compute a running average and standard deviation of input features (it can do the same for rewards).
@ -311,12 +311,13 @@ will compute a running average and standard deviation of input features (it can
.. code-block:: python
import gym
import pybullet_envs
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3 import PPO
env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")])
# Automatically normalize the input features
# Automatically normalize the input features and reward
env = VecNormalize(env, norm_obs=True, norm_reward=True,
clip_obs=10.)
@ -325,8 +326,23 @@ will compute a running average and standard deviation of input features (it can
# Don't forget to save the VecNormalize statistics when saving the agent
log_dir = "/tmp/"
model.save(log_dir + "ppo_reacher")
env.save(os.path.join(log_dir, "vec_normalize.pkl"))
model.save(log_dir + "ppo_halfcheetah")
stats_path = os.path.join(log_dir, "vec_normalize.pkl")
env.save(stats_path)
# To demonstrate loading
del model, env
# Load the agent
model = PPO.load(log_dir + "ppo_halfcheetah")
# Load the saved statistics
env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")])
env = VecNormalize.load(stats_path, env)
# do not update them at test time
env.training = False
# reward normalization is not needed at test time
env.norm_reward = False
Record a Video