mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Standardized the use of `"` for string representation (#1086)
* Replace ``'`` by ``" `` in python code * Update changelog * Rm whitespace
This commit is contained in:
parent
d3eb0e3ed6
commit
a697401e03
8 changed files with 54 additions and 53 deletions
|
|
@ -216,13 +216,13 @@ It will save the best model if ``best_model_save_path`` folder is specified and
|
|||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
|
||||
# Separate evaluation env
|
||||
eval_env = gym.make('Pendulum-v1')
|
||||
eval_env = gym.make("Pendulum-v1")
|
||||
# Use deterministic actions for evaluation
|
||||
eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/',
|
||||
log_path='./logs/', eval_freq=500,
|
||||
eval_callback = EvalCallback(eval_env, best_model_save_path="./logs/",
|
||||
log_path="./logs/", eval_freq=500,
|
||||
deterministic=True, render=False)
|
||||
|
||||
model = SAC('MlpPolicy', 'Pendulum-v1')
|
||||
model = SAC("MlpPolicy", "Pendulum-v1")
|
||||
model.learn(5000, callback=eval_callback)
|
||||
|
||||
|
||||
|
|
@ -242,15 +242,15 @@ Alternatively, you can pass directly a list of callbacks to the ``learn()`` meth
|
|||
from stable_baselines3 import SAC
|
||||
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback
|
||||
|
||||
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs/')
|
||||
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path="./logs/")
|
||||
# Separate evaluation env
|
||||
eval_env = gym.make('Pendulum-v1')
|
||||
eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/best_model',
|
||||
log_path='./logs/results', eval_freq=500)
|
||||
eval_env = gym.make("Pendulum-v1")
|
||||
eval_callback = EvalCallback(eval_env, best_model_save_path="./logs/best_model",
|
||||
log_path="./logs/results", eval_freq=500)
|
||||
# Create the callback list
|
||||
callback = CallbackList([checkpoint_callback, eval_callback])
|
||||
|
||||
model = SAC('MlpPolicy', 'Pendulum-v1')
|
||||
model = SAC("MlpPolicy", "Pendulum-v1")
|
||||
# Equivalent to:
|
||||
# model.learn(5000, callback=[checkpoint_callback, eval_callback])
|
||||
model.learn(5000, callback=callback)
|
||||
|
|
@ -273,12 +273,12 @@ It must be used with the :ref:`EvalCallback` and use the event triggered by a ne
|
|||
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
|
||||
|
||||
# Separate evaluation env
|
||||
eval_env = gym.make('Pendulum-v1')
|
||||
eval_env = gym.make("Pendulum-v1")
|
||||
# Stop training when the model reaches the reward threshold
|
||||
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1)
|
||||
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, verbose=1)
|
||||
|
||||
model = SAC('MlpPolicy', 'Pendulum-v1', verbose=1)
|
||||
model = SAC("MlpPolicy", "Pendulum-v1", verbose=1)
|
||||
# Almost infinite number of timesteps, but the training will stop
|
||||
# early as soon as the reward threshold is reached
|
||||
model.learn(int(1e10), callback=eval_callback)
|
||||
|
|
@ -306,10 +306,10 @@ An :ref:`EventCallback` that will trigger its child callback every ``n_steps`` t
|
|||
|
||||
# this is equivalent to defining CheckpointCallback(save_freq=500)
|
||||
# checkpoint_callback will be triggered every 500 steps
|
||||
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path='./logs/')
|
||||
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path="./logs/")
|
||||
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
|
||||
|
||||
model = PPO('MlpPolicy', 'Pendulum-v1', verbose=1)
|
||||
model = PPO("MlpPolicy", "Pendulum-v1", verbose=1)
|
||||
|
||||
model.learn(int(2e4), callback=event_callback)
|
||||
|
||||
|
|
@ -338,7 +338,7 @@ and in total for ``max_episodes * n_envs`` episodes.
|
|||
# Stops training when the model reaches the maximum number of episodes
|
||||
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=5, verbose=1)
|
||||
|
||||
model = A2C('MlpPolicy', 'Pendulum-v1', verbose=1)
|
||||
model = A2C("MlpPolicy", "Pendulum-v1", verbose=1)
|
||||
# Almost infinite number of timesteps, but the training will stop
|
||||
# early as soon as the max number of episodes is reached
|
||||
model.learn(int(1e10), callback=callback_max_episodes)
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ which defines for the python process, how it should handle floating point error.
|
|||
|
||||
import numpy as np
|
||||
|
||||
np.seterr(all='raise') # define before your code.
|
||||
np.seterr(all="raise") # define before your code.
|
||||
|
||||
print("numpy test:")
|
||||
|
||||
|
|
@ -66,7 +66,7 @@ but this will also avoid overflow issues on floating point numbers:
|
|||
|
||||
import numpy as np
|
||||
|
||||
np.seterr(all='raise') # define before your code.
|
||||
np.seterr(all="raise") # define before your code.
|
||||
|
||||
print("numpy overflow test:")
|
||||
|
||||
|
|
@ -81,11 +81,11 @@ but will not avoid the propagation issues:
|
|||
|
||||
import numpy as np
|
||||
|
||||
np.seterr(all='raise') # define before your code.
|
||||
np.seterr(all="raise") # define before your code.
|
||||
|
||||
print("numpy propagation test:")
|
||||
|
||||
a = np.float64('NaN')
|
||||
a = np.float64("NaN")
|
||||
b = np.float64(1.0)
|
||||
val = a + b # this will neither warn nor raise anything
|
||||
print(val)
|
||||
|
|
@ -109,7 +109,7 @@ It will monitor the actions, observations, and rewards, indicating what action o
|
|||
|
||||
class NanAndInfEnv(gym.Env):
|
||||
"""Custom Environment that raised NaNs and Infs"""
|
||||
metadata = {'render.modes': ['human']}
|
||||
metadata = {"render.modes": ["human"]}
|
||||
|
||||
def __init__(self):
|
||||
super(NanAndInfEnv, self).__init__()
|
||||
|
|
@ -119,9 +119,9 @@ It will monitor the actions, observations, and rewards, indicating what action o
|
|||
def step(self, _action):
|
||||
randf = np.random.rand()
|
||||
if randf > 0.99:
|
||||
obs = float('NaN')
|
||||
obs = float("NaN")
|
||||
elif randf > 0.98:
|
||||
obs = float('inf')
|
||||
obs = float("inf")
|
||||
else:
|
||||
obs = randf
|
||||
return [obs], 0.0, False, {}
|
||||
|
|
@ -129,7 +129,7 @@ It will monitor the actions, observations, and rewards, indicating what action o
|
|||
def reset(self):
|
||||
return [0.0]
|
||||
|
||||
def render(self, mode='human', close=False):
|
||||
def render(self, mode="human", close=False):
|
||||
pass
|
||||
|
||||
# Create environment
|
||||
|
|
@ -137,7 +137,7 @@ It will monitor the actions, observations, and rewards, indicating what action o
|
|||
env = VecCheckNan(env, raise_exception=True)
|
||||
|
||||
# Instantiate the agent
|
||||
model = PPO('MlpPolicy', env)
|
||||
model = PPO("MlpPolicy", env)
|
||||
|
||||
# Train the agent
|
||||
model.learn(total_timesteps=int(2e5)) # this will crash explaining that the invalid value originated from the environment.
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ That is to say, your environment must implement the following methods (and inher
|
|||
|
||||
class CustomEnv(gym.Env):
|
||||
"""Custom Environment that follows gym interface"""
|
||||
metadata = {'render.modes': ['human']}
|
||||
metadata = {"render.modes": ["human"]}
|
||||
|
||||
def __init__(self, arg1, arg2, ...):
|
||||
super(CustomEnv, self).__init__()
|
||||
|
|
@ -45,7 +45,7 @@ That is to say, your environment must implement the following methods (and inher
|
|||
def reset(self):
|
||||
...
|
||||
return observation # reward, done, info can't be included
|
||||
def render(self, mode='human'):
|
||||
def render(self, mode="human"):
|
||||
...
|
||||
def close (self):
|
||||
...
|
||||
|
|
@ -58,7 +58,7 @@ Then you can define and train a RL agent with:
|
|||
# Instantiate the env
|
||||
env = CustomEnv(arg1, ...)
|
||||
# Define and Train the agent
|
||||
model = A2C('CnnPolicy', env).learn(total_timesteps=1000)
|
||||
model = A2C("CnnPolicy", env).learn(total_timesteps=1000)
|
||||
|
||||
|
||||
To check that your environment follows the Gym interface that SB3 supports, please use:
|
||||
|
|
|
|||
|
|
@ -71,10 +71,10 @@ In the following example, we will train, save and load a DQN model on the Lunar
|
|||
|
||||
|
||||
# Create environment
|
||||
env = gym.make('LunarLander-v2')
|
||||
env = gym.make("LunarLander-v2")
|
||||
|
||||
# Instantiate the agent
|
||||
model = DQN('MlpPolicy', env, verbose=1)
|
||||
model = DQN("MlpPolicy", env, verbose=1)
|
||||
# Train the agent
|
||||
model.learn(total_timesteps=int(2e5))
|
||||
# Save the agent
|
||||
|
|
@ -138,7 +138,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
|
|||
set_random_seed(seed)
|
||||
return _init
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
env_id = "CartPole-v1"
|
||||
num_cpu = 4 # Number of processes to use
|
||||
# Create the vectorized environment
|
||||
|
|
@ -149,7 +149,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
|
|||
# You can choose between `DummyVecEnv` (usually faster) and `SubprocVecEnv`
|
||||
# env = make_vec_env(env_id, n_envs=num_cpu, seed=0, vec_env_cls=SubprocVecEnv)
|
||||
|
||||
model = PPO('MlpPolicy', env, verbose=1)
|
||||
model = PPO("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=25_000)
|
||||
|
||||
obs = env.reset()
|
||||
|
|
@ -182,7 +182,7 @@ Multiprocessing with off-policy algorithms
|
|||
# We collect 4 transitions per call to `ènv.step()`
|
||||
# and performs 2 gradient steps per call to `ènv.step()`
|
||||
# if gradient_steps=-1, then we would do 4 gradients steps per call to `ènv.step()`
|
||||
model = SAC('MlpPolicy', env, train_freq=1, gradient_steps=2, verbose=1)
|
||||
model = SAC("MlpPolicy", env, train_freq=1, gradient_steps=2, verbose=1)
|
||||
model.learn(total_timesteps=10_000)
|
||||
|
||||
|
||||
|
|
@ -254,7 +254,7 @@ If your callback returns False, training is aborted early.
|
|||
super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)
|
||||
self.check_freq = check_freq
|
||||
self.log_dir = log_dir
|
||||
self.save_path = os.path.join(log_dir, 'best_model')
|
||||
self.save_path = os.path.join(log_dir, "best_model")
|
||||
self.best_mean_reward = -np.inf
|
||||
|
||||
def _init_callback(self) -> None:
|
||||
|
|
@ -266,7 +266,7 @@ If your callback returns False, training is aborted early.
|
|||
if self.n_calls % self.check_freq == 0:
|
||||
|
||||
# Retrieve training reward
|
||||
x, y = ts2xy(load_results(self.log_dir), 'timesteps')
|
||||
x, y = ts2xy(load_results(self.log_dir), "timesteps")
|
||||
if len(x) > 0:
|
||||
# Mean training reward over the last 100 episodes
|
||||
mean_reward = np.mean(y[-100:])
|
||||
|
|
@ -289,14 +289,14 @@ If your callback returns False, training is aborted early.
|
|||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# Create and wrap the environment
|
||||
env = gym.make('LunarLanderContinuous-v2')
|
||||
env = gym.make("LunarLanderContinuous-v2")
|
||||
env = Monitor(env, log_dir)
|
||||
|
||||
# Add some action noise for exploration
|
||||
n_actions = env.action_space.shape[-1]
|
||||
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
|
||||
# Because we use parameter noise, we should use a MlpPolicy with layer normalization
|
||||
model = TD3('MlpPolicy', env, action_noise=action_noise, verbose=0)
|
||||
model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=0)
|
||||
# Create the callback: check every 1000 steps
|
||||
callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir)
|
||||
# Train the agent
|
||||
|
|
@ -336,11 +336,11 @@ and multiprocessing for you. To install the Atari environments, run the command
|
|||
# There already exists an environment generator
|
||||
# that will make and wrap atari environments correctly.
|
||||
# Here we are also multi-worker training (n_envs=4 => 4 environments)
|
||||
env = make_atari_env('PongNoFrameskip-v4', n_envs=4, seed=0)
|
||||
env = make_atari_env("PongNoFrameskip-v4", n_envs=4, seed=0)
|
||||
# Frame-stacking with 4 frames
|
||||
env = VecFrameStack(env, n_stack=4)
|
||||
|
||||
model = A2C('CnnPolicy', env, verbose=1)
|
||||
model = A2C("CnnPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=25_000)
|
||||
|
||||
obs = env.reset()
|
||||
|
|
@ -382,7 +382,7 @@ will compute a running average and standard deviation of input features (it can
|
|||
env = VecNormalize(env, norm_obs=True, norm_reward=True,
|
||||
clip_obs=10.)
|
||||
|
||||
model = PPO('MlpPolicy', env)
|
||||
model = PPO("MlpPolicy", env)
|
||||
model.learn(total_timesteps=2000)
|
||||
|
||||
# Don't forget to save the VecNormalize statistics when saving the agent
|
||||
|
|
@ -564,7 +564,7 @@ Behind the scene, SB3 uses an :ref:`EvalCallback <callbacks>`.
|
|||
|
||||
# Create the model, the training environment
|
||||
# and the test environment (for evaluation)
|
||||
model = SAC('MlpPolicy', 'Pendulum-v1', verbose=1,
|
||||
model = SAC("MlpPolicy", "Pendulum-v1", verbose=1,
|
||||
learning_rate=1e-3, create_eval_env=True)
|
||||
|
||||
# Evaluate the model every 1000 steps on 5 test episodes
|
||||
|
|
@ -717,7 +717,7 @@ to keep track of the agent progress.
|
|||
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
|
||||
|
||||
# ProcgenEnv is already vectorized
|
||||
venv = ProcgenEnv(num_envs=2, env_name='starpilot')
|
||||
venv = ProcgenEnv(num_envs=2, env_name="starpilot")
|
||||
|
||||
# To use only part of the observation:
|
||||
# venv = VecExtractDictObs(venv, "rgb")
|
||||
|
|
@ -753,8 +753,8 @@ Record a mp4 video (here using a random agent).
|
|||
import gym
|
||||
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
|
||||
|
||||
env_id = 'CartPole-v1'
|
||||
video_folder = 'logs/videos/'
|
||||
env_id = "CartPole-v1"
|
||||
video_folder = "logs/videos/"
|
||||
video_length = 100
|
||||
|
||||
env = DummyVecEnv([lambda: gym.make(env_id)])
|
||||
|
|
@ -792,11 +792,11 @@ Bonus: Make a GIF of a Trained Agent
|
|||
|
||||
images = []
|
||||
obs = model.env.reset()
|
||||
img = model.env.render(mode='rgb_array')
|
||||
img = model.env.render(mode="rgb_array")
|
||||
for i in range(350):
|
||||
images.append(img)
|
||||
action, _ = model.predict(obs)
|
||||
obs, _, _ ,_ = model.env.step(action)
|
||||
img = model.env.render(mode='rgb_array')
|
||||
img = model.env.render(mode="rgb_array")
|
||||
|
||||
imageio.mimsave('lander_a2c.gif', [np.array(img) for i, img in enumerate(images) if i%2 == 0], fps=29)
|
||||
imageio.mimsave("lander_a2c.gif", [np.array(img) for i, img in enumerate(images) if i%2 == 0], fps=29)
|
||||
|
|
|
|||
|
|
@ -14,9 +14,9 @@ Here is a quick example of how to train and run A2C on a CartPole environment:
|
|||
|
||||
from stable_baselines3 import A2C
|
||||
|
||||
env = gym.make('CartPole-v1')
|
||||
env = gym.make("CartPole-v1")
|
||||
|
||||
model = A2C('MlpPolicy', env, verbose=1)
|
||||
model = A2C("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=10000)
|
||||
|
||||
obs = env.reset()
|
||||
|
|
@ -40,4 +40,4 @@ the policy is registered:
|
|||
|
||||
from stable_baselines3 import A2C
|
||||
|
||||
model = A2C('MlpPolicy', 'CartPole-v1').learn(10000)
|
||||
model = A2C("MlpPolicy", "CartPole-v1").learn(10000)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ To use Tensorboard with stable baselines3, you simply need to pass the location
|
|||
|
||||
from stable_baselines3 import A2C
|
||||
|
||||
model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
|
||||
model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
|
||||
model.learn(total_timesteps=10_000)
|
||||
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ You can also define custom logging name when training (by default it is the algo
|
|||
|
||||
from stable_baselines3 import A2C
|
||||
|
||||
model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
|
||||
model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
|
||||
model.learn(total_timesteps=10_000, tb_log_name="first_run")
|
||||
# Pass reset_num_timesteps=False to continue the training curve in tensorboard
|
||||
# By default, it will create a new curve
|
||||
|
|
@ -91,7 +91,7 @@ Here is a simple example on how to log both additional tensor or arbitrary scala
|
|||
def _on_step(self) -> bool:
|
||||
# Log scalar value (here a random variable)
|
||||
value = np.random.random()
|
||||
self.logger.record('random_value', value)
|
||||
self.logger.record("random_value", value)
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ Documentation:
|
|||
- Added link to a GitHub issue in the custom policy documentation (@AlexPasqua)
|
||||
- Update doc on exporting models (fixes and added torch jit)
|
||||
- Fixed typos (@Akhilez)
|
||||
- Standardized the use of ``"`` for string representation in documentation
|
||||
|
||||
Release 1.6.0 (2022-07-11)
|
||||
---------------------------
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ This example is only to demonstrate the use of the library and its functions, an
|
|||
env = BitFlippingEnv(n_bits=N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
|
||||
|
||||
# Available strategies (cf paper): future, final, episode
|
||||
goal_selection_strategy = 'future' # equivalent to GoalSelectionStrategy.FUTURE
|
||||
goal_selection_strategy = "future" # equivalent to GoalSelectionStrategy.FUTURE
|
||||
|
||||
# If True the HER transitions will get sampled online
|
||||
online_sampling = True
|
||||
|
|
@ -101,7 +101,7 @@ This example is only to demonstrate the use of the library and its functions, an
|
|||
model.save("./her_bit_env")
|
||||
# Because it needs access to `env.compute_reward()`
|
||||
# HER must be loaded with the env
|
||||
model = model_class.load('./her_bit_env', env=env)
|
||||
model = model_class.load("./her_bit_env", env=env)
|
||||
|
||||
obs = env.reset()
|
||||
for _ in range(100):
|
||||
|
|
|
|||
Loading…
Reference in a new issue