Merge branch 'master' into feat/mps-support

This commit is contained in:
Quentin Gallouédec 2022-10-04 11:01:49 +02:00 committed by GitHub
commit 6d868c02bb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 164 additions and 96 deletions

View file

@ -216,13 +216,13 @@ It will save the best model if ``best_model_save_path`` folder is specified and
from stable_baselines3.common.callbacks import EvalCallback
# Separate evaluation env
eval_env = gym.make('Pendulum-v1')
eval_env = gym.make("Pendulum-v1")
# Use deterministic actions for evaluation
eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/',
log_path='./logs/', eval_freq=500,
eval_callback = EvalCallback(eval_env, best_model_save_path="./logs/",
log_path="./logs/", eval_freq=500,
deterministic=True, render=False)
model = SAC('MlpPolicy', 'Pendulum-v1')
model = SAC("MlpPolicy", "Pendulum-v1")
model.learn(5000, callback=eval_callback)
@ -242,15 +242,15 @@ Alternatively, you can pass directly a list of callbacks to the ``learn()`` meth
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs/')
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path="./logs/")
# Separate evaluation env
eval_env = gym.make('Pendulum-v1')
eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/best_model',
log_path='./logs/results', eval_freq=500)
eval_env = gym.make("Pendulum-v1")
eval_callback = EvalCallback(eval_env, best_model_save_path="./logs/best_model",
log_path="./logs/results", eval_freq=500)
# Create the callback list
callback = CallbackList([checkpoint_callback, eval_callback])
model = SAC('MlpPolicy', 'Pendulum-v1')
model = SAC("MlpPolicy", "Pendulum-v1")
# Equivalent to:
# model.learn(5000, callback=[checkpoint_callback, eval_callback])
model.learn(5000, callback=callback)
@ -273,12 +273,12 @@ It must be used with the :ref:`EvalCallback` and use the event triggered by a ne
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
# Separate evaluation env
eval_env = gym.make('Pendulum-v1')
eval_env = gym.make("Pendulum-v1")
# Stop training when the model reaches the reward threshold
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1)
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, verbose=1)
model = SAC('MlpPolicy', 'Pendulum-v1', verbose=1)
model = SAC("MlpPolicy", "Pendulum-v1", verbose=1)
# Almost infinite number of timesteps, but the training will stop
# early as soon as the reward threshold is reached
model.learn(int(1e10), callback=eval_callback)
@ -306,10 +306,10 @@ An :ref:`EventCallback` that will trigger its child callback every ``n_steps`` t
# this is equivalent to defining CheckpointCallback(save_freq=500)
# checkpoint_callback will be triggered every 500 steps
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path='./logs/')
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path="./logs/")
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
model = PPO('MlpPolicy', 'Pendulum-v1', verbose=1)
model = PPO("MlpPolicy", "Pendulum-v1", verbose=1)
model.learn(int(2e4), callback=event_callback)
@ -338,7 +338,7 @@ and in total for ``max_episodes * n_envs`` episodes.
# Stops training when the model reaches the maximum number of episodes
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=5, verbose=1)
model = A2C('MlpPolicy', 'Pendulum-v1', verbose=1)
model = A2C("MlpPolicy", "Pendulum-v1", verbose=1)
# Almost infinite number of timesteps, but the training will stop
# early as soon as the max number of episodes is reached
model.learn(int(1e10), callback=callback_max_episodes)

View file

@ -51,7 +51,7 @@ which defines for the python process, how it should handle floating point error.
import numpy as np
np.seterr(all='raise') # define before your code.
np.seterr(all="raise") # define before your code.
print("numpy test:")
@ -66,7 +66,7 @@ but this will also avoid overflow issues on floating point numbers:
import numpy as np
np.seterr(all='raise') # define before your code.
np.seterr(all="raise") # define before your code.
print("numpy overflow test:")
@ -81,11 +81,11 @@ but will not avoid the propagation issues:
import numpy as np
np.seterr(all='raise') # define before your code.
np.seterr(all="raise") # define before your code.
print("numpy propagation test:")
a = np.float64('NaN')
a = np.float64("NaN")
b = np.float64(1.0)
val = a + b # this will neither warn nor raise anything
print(val)
@ -109,7 +109,7 @@ It will monitor the actions, observations, and rewards, indicating what action o
class NanAndInfEnv(gym.Env):
"""Custom Environment that raised NaNs and Infs"""
metadata = {'render.modes': ['human']}
metadata = {"render.modes": ["human"]}
def __init__(self):
super(NanAndInfEnv, self).__init__()
@ -119,9 +119,9 @@ It will monitor the actions, observations, and rewards, indicating what action o
def step(self, _action):
randf = np.random.rand()
if randf > 0.99:
obs = float('NaN')
obs = float("NaN")
elif randf > 0.98:
obs = float('inf')
obs = float("inf")
else:
obs = randf
return [obs], 0.0, False, {}
@ -129,7 +129,7 @@ It will monitor the actions, observations, and rewards, indicating what action o
def reset(self):
return [0.0]
def render(self, mode='human', close=False):
def render(self, mode="human", close=False):
pass
# Create environment
@ -137,7 +137,7 @@ It will monitor the actions, observations, and rewards, indicating what action o
env = VecCheckNan(env, raise_exception=True)
# Instantiate the agent
model = PPO('MlpPolicy', env)
model = PPO("MlpPolicy", env)
# Train the agent
model.learn(total_timesteps=int(2e5)) # this will crash explaining that the invalid value originated from the environment.

View file

@ -27,7 +27,7 @@ That is to say, your environment must implement the following methods (and inher
class CustomEnv(gym.Env):
"""Custom Environment that follows gym interface"""
metadata = {'render.modes': ['human']}
metadata = {"render.modes": ["human"]}
def __init__(self, arg1, arg2, ...):
super(CustomEnv, self).__init__()
@ -45,7 +45,7 @@ That is to say, your environment must implement the following methods (and inher
def reset(self):
...
return observation # reward, done, info can't be included
def render(self, mode='human'):
def render(self, mode="human"):
...
def close (self):
...
@ -58,7 +58,7 @@ Then you can define and train a RL agent with:
# Instantiate the env
env = CustomEnv(arg1, ...)
# Define and Train the agent
model = A2C('CnnPolicy', env).learn(total_timesteps=1000)
model = A2C("CnnPolicy", env).learn(total_timesteps=1000)
To check that your environment follows the Gym interface that SB3 supports, please use:

View file

@ -71,10 +71,10 @@ In the following example, we will train, save and load a DQN model on the Lunar
# Create environment
env = gym.make('LunarLander-v2')
env = gym.make("LunarLander-v2")
# Instantiate the agent
model = DQN('MlpPolicy', env, verbose=1)
model = DQN("MlpPolicy", env, verbose=1)
# Train the agent
model.learn(total_timesteps=int(2e5))
# Save the agent
@ -138,7 +138,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
set_random_seed(seed)
return _init
if __name__ == '__main__':
if __name__ == "__main__":
env_id = "CartPole-v1"
num_cpu = 4 # Number of processes to use
# Create the vectorized environment
@ -149,7 +149,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
# You can choose between `DummyVecEnv` (usually faster) and `SubprocVecEnv`
# env = make_vec_env(env_id, n_envs=num_cpu, seed=0, vec_env_cls=SubprocVecEnv)
model = PPO('MlpPolicy', env, verbose=1)
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=25_000)
obs = env.reset()
@ -182,7 +182,7 @@ Multiprocessing with off-policy algorithms
# We collect 4 transitions per call to `ènv.step()`
# and performs 2 gradient steps per call to `ènv.step()`
# if gradient_steps=-1, then we would do 4 gradients steps per call to `ènv.step()`
model = SAC('MlpPolicy', env, train_freq=1, gradient_steps=2, verbose=1)
model = SAC("MlpPolicy", env, train_freq=1, gradient_steps=2, verbose=1)
model.learn(total_timesteps=10_000)
@ -254,7 +254,7 @@ If your callback returns False, training is aborted early.
super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)
self.check_freq = check_freq
self.log_dir = log_dir
self.save_path = os.path.join(log_dir, 'best_model')
self.save_path = os.path.join(log_dir, "best_model")
self.best_mean_reward = -np.inf
def _init_callback(self) -> None:
@ -266,7 +266,7 @@ If your callback returns False, training is aborted early.
if self.n_calls % self.check_freq == 0:
# Retrieve training reward
x, y = ts2xy(load_results(self.log_dir), 'timesteps')
x, y = ts2xy(load_results(self.log_dir), "timesteps")
if len(x) > 0:
# Mean training reward over the last 100 episodes
mean_reward = np.mean(y[-100:])
@ -289,14 +289,14 @@ If your callback returns False, training is aborted early.
os.makedirs(log_dir, exist_ok=True)
# Create and wrap the environment
env = gym.make('LunarLanderContinuous-v2')
env = gym.make("LunarLanderContinuous-v2")
env = Monitor(env, log_dir)
# Add some action noise for exploration
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
# Because we use parameter noise, we should use a MlpPolicy with layer normalization
model = TD3('MlpPolicy', env, action_noise=action_noise, verbose=0)
model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=0)
# Create the callback: check every 1000 steps
callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir)
# Train the agent
@ -336,11 +336,11 @@ and multiprocessing for you. To install the Atari environments, run the command
# There already exists an environment generator
# that will make and wrap atari environments correctly.
# Here we are also multi-worker training (n_envs=4 => 4 environments)
env = make_atari_env('PongNoFrameskip-v4', n_envs=4, seed=0)
env = make_atari_env("PongNoFrameskip-v4", n_envs=4, seed=0)
# Frame-stacking with 4 frames
env = VecFrameStack(env, n_stack=4)
model = A2C('CnnPolicy', env, verbose=1)
model = A2C("CnnPolicy", env, verbose=1)
model.learn(total_timesteps=25_000)
obs = env.reset()
@ -382,7 +382,7 @@ will compute a running average and standard deviation of input features (it can
env = VecNormalize(env, norm_obs=True, norm_reward=True,
clip_obs=10.)
model = PPO('MlpPolicy', env)
model = PPO("MlpPolicy", env)
model.learn(total_timesteps=2000)
# Don't forget to save the VecNormalize statistics when saving the agent
@ -564,7 +564,7 @@ Behind the scene, SB3 uses an :ref:`EvalCallback <callbacks>`.
# Create the model, the training environment
# and the test environment (for evaluation)
model = SAC('MlpPolicy', 'Pendulum-v1', verbose=1,
model = SAC("MlpPolicy", "Pendulum-v1", verbose=1,
learning_rate=1e-3, create_eval_env=True)
# Evaluate the model every 1000 steps on 5 test episodes
@ -717,7 +717,7 @@ to keep track of the agent progress.
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
# ProcgenEnv is already vectorized
venv = ProcgenEnv(num_envs=2, env_name='starpilot')
venv = ProcgenEnv(num_envs=2, env_name="starpilot")
# To use only part of the observation:
# venv = VecExtractDictObs(venv, "rgb")
@ -753,8 +753,8 @@ Record a mp4 video (here using a random agent).
import gym
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
env_id = 'CartPole-v1'
video_folder = 'logs/videos/'
env_id = "CartPole-v1"
video_folder = "logs/videos/"
video_length = 100
env = DummyVecEnv([lambda: gym.make(env_id)])
@ -792,11 +792,11 @@ Bonus: Make a GIF of a Trained Agent
images = []
obs = model.env.reset()
img = model.env.render(mode='rgb_array')
img = model.env.render(mode="rgb_array")
for i in range(350):
images.append(img)
action, _ = model.predict(obs)
obs, _, _ ,_ = model.env.step(action)
img = model.env.render(mode='rgb_array')
img = model.env.render(mode="rgb_array")
imageio.mimsave('lander_a2c.gif', [np.array(img) for i, img in enumerate(images) if i%2 == 0], fps=29)
imageio.mimsave("lander_a2c.gif", [np.array(img) for i, img in enumerate(images) if i%2 == 0], fps=29)

View file

@ -46,29 +46,40 @@ For PPO, assuming a shared feature extactor.
.. code-block:: python
import torch as th
from stable_baselines3 import PPO
import torch
class OnnxablePolicy(torch.nn.Module):
def __init__(self, extractor, action_net, value_net):
super(OnnxablePolicy, self).__init__()
self.extractor = extractor
self.action_net = action_net
self.value_net = value_net
def forward(self, observation):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
action_hidden, value_hidden = self.extractor(observation)
return self.action_net(action_hidden), self.value_net(value_hidden)
class OnnxablePolicy(th.nn.Module):
def __init__(self, extractor, action_net, value_net):
super().__init__()
self.extractor = extractor
self.action_net = action_net
self.value_net = value_net
def forward(self, observation):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
action_hidden, value_hidden = self.extractor(observation)
return self.action_net(action_hidden), self.value_net(value_hidden)
# Example: model = PPO("MlpPolicy", "Pendulum-v1")
model = PPO.load("PathToTrainedModel.zip")
model.policy.to("cpu")
onnxable_model = OnnxablePolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net)
model = PPO.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(
model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net
)
dummy_input = torch.randn(1, observation_size)
torch.onnx.export(onnxable_model, dummy_input, "my_ppo_model.onnx", opset_version=9)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnxable_model,
dummy_input,
"my_ppo_model.onnx",
opset_version=9,
input_names=["input"],
)
##### Load and test with onnx
@ -76,48 +87,97 @@ For PPO, assuming a shared feature extactor.
import onnxruntime as ort
import numpy as np
onnx_path = "my_ppo_model.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
observation = np.zeros((1, observation_size)).astype(np.float32)
observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
action, value = ort_sess.run(None, {'input.1': observation})
action, value = ort_sess.run(None, {"input": observation})
For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies.
.. code-block:: python
import torch as th
from stable_baselines3 import SAC
import torch
class OnnxablePolicy(torch.nn.Module):
def __init__(self, actor):
super(OnnxablePolicy, self).__init__()
# Removing the flatten layer because it can't be onnxed
self.actor = torch.nn.Sequential(actor.latent_pi, actor.mu)
class OnnxablePolicy(th.nn.Module):
def __init__(self, actor: th.nn.Module):
super().__init__()
# Removing the flatten layer because it can't be onnxed
self.actor = th.nn.Sequential(
actor.latent_pi,
actor.mu,
# For gSDE
# th.nn.Hardtanh(min_val=-actor.clip_mean, max_val=actor.clip_mean),
# Squash the output
th.nn.Tanh(),
)
def forward(self, observation):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
return self.actor(observation)
def forward(self, observation: th.Tensor) -> th.Tensor:
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
return self.actor(observation)
model = SAC.load("PathToTrainedModel.zip")
# Example: model = SAC("MlpPolicy", "Pendulum-v1")
model = SAC.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(model.policy.actor)
dummy_input = torch.randn(1, observation_size)
onnxable_model.policy.to("cpu")
torch.onnx.export(onnxable_model, dummy_input, "my_sac_actor.onnx", opset_version=9)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnxable_model,
dummy_input,
"my_sac_actor.onnx",
opset_version=9,
input_names=["input"],
)
##### Load and test with onnx
import onnxruntime as ort
import numpy as np
onnx_path = "my_sac_actor.onnx"
observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
action = ort_sess.run(None, {"input": observation})
For more discussion around the topic refer to this `issue. <https://github.com/DLR-RM/stable-baselines3/issues/383>`_
Export to C++
-----------------
Trace/Export to C++
-------------------
(using PyTorch JIT)
TODO: help is welcomed!
You can use PyTorch JIT to trace and save a trained model that can be re-used in other applications
(for instance inference code written in C++).
There is a draft PR in the RL Zoo about C++ export: https://github.com/DLR-RM/rl-baselines3-zoo/pull/228
.. code-block:: python
# See "ONNX export" for imports and OnnxablePolicy
jit_path = "sac_traced.pt"
# Trace and optimize the module
traced_module = th.jit.trace(onnxable_model.eval(), dummy_input)
frozen_module = th.jit.freeze(traced_module)
frozen_module = th.jit.optimize_for_inference(frozen_module)
th.jit.save(frozen_module, jit_path)
##### Load and test with torch
import torch as th
dummy_input = th.randn(1, *observation_size)
loaded_module = th.jit.load(jit_path)
action_jit = loaded_module(dummy_input)
Export to tensorflowjs / ONNX-JS

View file

@ -14,9 +14,9 @@ Here is a quick example of how to train and run A2C on a CartPole environment:
from stable_baselines3 import A2C
env = gym.make('CartPole-v1')
env = gym.make("CartPole-v1")
model = A2C('MlpPolicy', env, verbose=1)
model = A2C("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000)
obs = env.reset()
@ -40,4 +40,4 @@ the policy is registered:
from stable_baselines3 import A2C
model = A2C('MlpPolicy', 'CartPole-v1').learn(10000)
model = A2C("MlpPolicy", "CartPole-v1").learn(10000)

View file

@ -12,7 +12,7 @@ To use Tensorboard with stable baselines3, you simply need to pass the location
from stable_baselines3 import A2C
model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10_000)
@ -22,7 +22,7 @@ You can also define custom logging name when training (by default it is the algo
from stable_baselines3 import A2C
model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10_000, tb_log_name="first_run")
# Pass reset_num_timesteps=False to continue the training curve in tensorboard
# By default, it will create a new curve
@ -91,7 +91,7 @@ Here is a simple example on how to log both additional tensor or arbitrary scala
def _on_step(self) -> bool:
# Log scalar value (here a random variable)
value = np.random.random()
self.logger.record('random_value', value)
self.logger.record("random_value", value)
return True

View file

@ -3,9 +3,11 @@
Changelog
==========
Release 1.6.1a4 (WIP)
Release 1.6.1 (2022-09-29)
---------------------------
**Bug fix release**
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Switched minimum tensorboard version to 2.9.1
@ -20,6 +22,7 @@ New Features:
SB3-Contrib
^^^^^^^^^^^
- Fixed the issue of wrongly passing policy arguments when using ``CnnLstmPolicy`` or ``MultiInputLstmPolicy`` with ``RecurrentPPO`` (@mlodel)
Bug Fixes:
^^^^^^^^^^
@ -34,6 +37,7 @@ Bug Fixes:
- Removed ``forward()`` abstract method declaration from ``common.policies.BaseModel`` (already defined in ``torch.nn.Module``) to fix type errors in subclasses (@Rocamonde)
- Fixed the return type of ``.load()`` and ``.learn()`` methods in ``BaseAlgorithm`` so that they now use ``TypeVar`` (@Rocamonde)
- Fixed an issue where keys with different tags but the same key raised an error in ``common.logger.HumanOutputFormat`` (@Rocamonde and @AdamGleave)
- Set importlib-metadata version to `~=4.13`
Deprecations:
^^^^^^^^^^^^^
@ -53,7 +57,9 @@ Documentation:
- Fixed typo in install doc(@jlp-ue)
- Clarified and standardized verbosity documentation
- Added link to a GitHub issue in the custom policy documentation (@AlexPasqua)
- Update doc on exporting models (fixes and added torch jit)
- Fixed typos (@Akhilez)
- Standardized the use of ``"`` for string representation in documentation
Release 1.6.0 (2022-07-11)
---------------------------
@ -1042,4 +1048,4 @@ And all the contributors:
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde

View file

@ -73,7 +73,7 @@ This example is only to demonstrate the use of the library and its functions, an
env = BitFlippingEnv(n_bits=N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
# Available strategies (cf paper): future, final, episode
goal_selection_strategy = 'future' # equivalent to GoalSelectionStrategy.FUTURE
goal_selection_strategy = "future" # equivalent to GoalSelectionStrategy.FUTURE
# If True the HER transitions will get sampled online
online_sampling = True
@ -101,7 +101,7 @@ This example is only to demonstrate the use of the library and its functions, an
model.save("./her_bit_env")
# Because it needs access to `env.compute_reward()`
# HER must be loaded with the env
model = model_class.load('./her_bit_env', env=env)
model = model_class.load("./her_bit_env", env=env)
obs = env.reset()
for _ in range(100):

View file

@ -94,6 +94,8 @@ setup(
"pytype",
# Lint code
"flake8>=3.8",
# flake8 not compatible with importlib-metadata>5.0
"importlib-metadata~=4.13",
# Find likely bugs
"flake8-bugbear",
# Sort imports

View file

@ -143,7 +143,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
:param callback: Callback that will be called at each step
(and at the beginning and end of the rollout)
:param rollout_buffer: Buffer to fill with rollouts
:param n_steps: Number of experiences to collect per environment
:param n_rollout_steps: Number of experiences to collect per environment
:return: True if function returned with at least `n_rollout_steps`
collected, False if callback terminated rollout prematurely.
"""

View file

@ -234,7 +234,7 @@ class PPO(OnPolicyAlgorithm):
# No clipping
values_pred = values
else:
# Clip the different between old and new value
# Clip the difference between old and new value
# NOTE: this depends on the reward scaling
values_pred = rollout_data.old_values + th.clamp(
values - rollout_data.old_values, -clip_range_vf, clip_range_vf

View file

@ -1 +1 @@
1.6.1a4
1.6.1