mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-04 04:07:27 +00:00
Update PyBullet example
This commit is contained in:
parent
b1f5db1bb2
commit
a06c4a7859
1 changed files with 20 additions and 4 deletions
|
|
@ -299,7 +299,7 @@ PyBullet: Normalizing input features
|
|||
|
||||
Normalizing input features may be essential to successful training of an RL agent
|
||||
(by default, images are scaled but not other types of input),
|
||||
for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`_. For that, a wrapper exists and
|
||||
for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`_ environments. For that, a wrapper exists and
|
||||
will compute a running average and standard deviation of input features (it can do the same for rewards).
|
||||
|
||||
|
||||
|
|
@ -311,12 +311,13 @@ will compute a running average and standard deviation of input features (it can
|
|||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
import pybullet_envs
|
||||
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")])
|
||||
# Automatically normalize the input features
|
||||
# Automatically normalize the input features and reward
|
||||
env = VecNormalize(env, norm_obs=True, norm_reward=True,
|
||||
clip_obs=10.)
|
||||
|
||||
|
|
@ -325,8 +326,23 @@ will compute a running average and standard deviation of input features (it can
|
|||
|
||||
# Don't forget to save the VecNormalize statistics when saving the agent
|
||||
log_dir = "/tmp/"
|
||||
model.save(log_dir + "ppo_reacher")
|
||||
env.save(os.path.join(log_dir, "vec_normalize.pkl"))
|
||||
model.save(log_dir + "ppo_halfcheetah")
|
||||
stats_path = os.path.join(log_dir, "vec_normalize.pkl")
|
||||
env.save(stats_path)
|
||||
|
||||
# To demonstrate loading
|
||||
del model, env
|
||||
|
||||
# Load the agent
|
||||
model = PPO.load(log_dir + "ppo_halfcheetah")
|
||||
|
||||
# Load the saved statistics
|
||||
env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")])
|
||||
env = VecNormalize.load(stats_path, env)
|
||||
# do not update them at test time
|
||||
env.training = False
|
||||
# reward normalization is not needed at test time
|
||||
env.norm_reward = False
|
||||
|
||||
|
||||
Record a Video
|
||||
|
|
|
|||
Loading…
Reference in a new issue