mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-04 04:07:27 +00:00
Merge branch 'master' into feat/mps-support
This commit is contained in:
commit
6d868c02bb
13 changed files with 164 additions and 96 deletions
|
|
@ -216,13 +216,13 @@ It will save the best model if ``best_model_save_path`` folder is specified and
|
|||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
|
||||
# Separate evaluation env
|
||||
eval_env = gym.make('Pendulum-v1')
|
||||
eval_env = gym.make("Pendulum-v1")
|
||||
# Use deterministic actions for evaluation
|
||||
eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/',
|
||||
log_path='./logs/', eval_freq=500,
|
||||
eval_callback = EvalCallback(eval_env, best_model_save_path="./logs/",
|
||||
log_path="./logs/", eval_freq=500,
|
||||
deterministic=True, render=False)
|
||||
|
||||
model = SAC('MlpPolicy', 'Pendulum-v1')
|
||||
model = SAC("MlpPolicy", "Pendulum-v1")
|
||||
model.learn(5000, callback=eval_callback)
|
||||
|
||||
|
||||
|
|
@ -242,15 +242,15 @@ Alternatively, you can pass directly a list of callbacks to the ``learn()`` meth
|
|||
from stable_baselines3 import SAC
|
||||
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback
|
||||
|
||||
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs/')
|
||||
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path="./logs/")
|
||||
# Separate evaluation env
|
||||
eval_env = gym.make('Pendulum-v1')
|
||||
eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/best_model',
|
||||
log_path='./logs/results', eval_freq=500)
|
||||
eval_env = gym.make("Pendulum-v1")
|
||||
eval_callback = EvalCallback(eval_env, best_model_save_path="./logs/best_model",
|
||||
log_path="./logs/results", eval_freq=500)
|
||||
# Create the callback list
|
||||
callback = CallbackList([checkpoint_callback, eval_callback])
|
||||
|
||||
model = SAC('MlpPolicy', 'Pendulum-v1')
|
||||
model = SAC("MlpPolicy", "Pendulum-v1")
|
||||
# Equivalent to:
|
||||
# model.learn(5000, callback=[checkpoint_callback, eval_callback])
|
||||
model.learn(5000, callback=callback)
|
||||
|
|
@ -273,12 +273,12 @@ It must be used with the :ref:`EvalCallback` and use the event triggered by a ne
|
|||
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
|
||||
|
||||
# Separate evaluation env
|
||||
eval_env = gym.make('Pendulum-v1')
|
||||
eval_env = gym.make("Pendulum-v1")
|
||||
# Stop training when the model reaches the reward threshold
|
||||
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1)
|
||||
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, verbose=1)
|
||||
|
||||
model = SAC('MlpPolicy', 'Pendulum-v1', verbose=1)
|
||||
model = SAC("MlpPolicy", "Pendulum-v1", verbose=1)
|
||||
# Almost infinite number of timesteps, but the training will stop
|
||||
# early as soon as the reward threshold is reached
|
||||
model.learn(int(1e10), callback=eval_callback)
|
||||
|
|
@ -306,10 +306,10 @@ An :ref:`EventCallback` that will trigger its child callback every ``n_steps`` t
|
|||
|
||||
# this is equivalent to defining CheckpointCallback(save_freq=500)
|
||||
# checkpoint_callback will be triggered every 500 steps
|
||||
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path='./logs/')
|
||||
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path="./logs/")
|
||||
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
|
||||
|
||||
model = PPO('MlpPolicy', 'Pendulum-v1', verbose=1)
|
||||
model = PPO("MlpPolicy", "Pendulum-v1", verbose=1)
|
||||
|
||||
model.learn(int(2e4), callback=event_callback)
|
||||
|
||||
|
|
@ -338,7 +338,7 @@ and in total for ``max_episodes * n_envs`` episodes.
|
|||
# Stops training when the model reaches the maximum number of episodes
|
||||
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=5, verbose=1)
|
||||
|
||||
model = A2C('MlpPolicy', 'Pendulum-v1', verbose=1)
|
||||
model = A2C("MlpPolicy", "Pendulum-v1", verbose=1)
|
||||
# Almost infinite number of timesteps, but the training will stop
|
||||
# early as soon as the max number of episodes is reached
|
||||
model.learn(int(1e10), callback=callback_max_episodes)
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ which defines for the python process, how it should handle floating point error.
|
|||
|
||||
import numpy as np
|
||||
|
||||
np.seterr(all='raise') # define before your code.
|
||||
np.seterr(all="raise") # define before your code.
|
||||
|
||||
print("numpy test:")
|
||||
|
||||
|
|
@ -66,7 +66,7 @@ but this will also avoid overflow issues on floating point numbers:
|
|||
|
||||
import numpy as np
|
||||
|
||||
np.seterr(all='raise') # define before your code.
|
||||
np.seterr(all="raise") # define before your code.
|
||||
|
||||
print("numpy overflow test:")
|
||||
|
||||
|
|
@ -81,11 +81,11 @@ but will not avoid the propagation issues:
|
|||
|
||||
import numpy as np
|
||||
|
||||
np.seterr(all='raise') # define before your code.
|
||||
np.seterr(all="raise") # define before your code.
|
||||
|
||||
print("numpy propagation test:")
|
||||
|
||||
a = np.float64('NaN')
|
||||
a = np.float64("NaN")
|
||||
b = np.float64(1.0)
|
||||
val = a + b # this will neither warn nor raise anything
|
||||
print(val)
|
||||
|
|
@ -109,7 +109,7 @@ It will monitor the actions, observations, and rewards, indicating what action o
|
|||
|
||||
class NanAndInfEnv(gym.Env):
|
||||
"""Custom Environment that raised NaNs and Infs"""
|
||||
metadata = {'render.modes': ['human']}
|
||||
metadata = {"render.modes": ["human"]}
|
||||
|
||||
def __init__(self):
|
||||
super(NanAndInfEnv, self).__init__()
|
||||
|
|
@ -119,9 +119,9 @@ It will monitor the actions, observations, and rewards, indicating what action o
|
|||
def step(self, _action):
|
||||
randf = np.random.rand()
|
||||
if randf > 0.99:
|
||||
obs = float('NaN')
|
||||
obs = float("NaN")
|
||||
elif randf > 0.98:
|
||||
obs = float('inf')
|
||||
obs = float("inf")
|
||||
else:
|
||||
obs = randf
|
||||
return [obs], 0.0, False, {}
|
||||
|
|
@ -129,7 +129,7 @@ It will monitor the actions, observations, and rewards, indicating what action o
|
|||
def reset(self):
|
||||
return [0.0]
|
||||
|
||||
def render(self, mode='human', close=False):
|
||||
def render(self, mode="human", close=False):
|
||||
pass
|
||||
|
||||
# Create environment
|
||||
|
|
@ -137,7 +137,7 @@ It will monitor the actions, observations, and rewards, indicating what action o
|
|||
env = VecCheckNan(env, raise_exception=True)
|
||||
|
||||
# Instantiate the agent
|
||||
model = PPO('MlpPolicy', env)
|
||||
model = PPO("MlpPolicy", env)
|
||||
|
||||
# Train the agent
|
||||
model.learn(total_timesteps=int(2e5)) # this will crash explaining that the invalid value originated from the environment.
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ That is to say, your environment must implement the following methods (and inher
|
|||
|
||||
class CustomEnv(gym.Env):
|
||||
"""Custom Environment that follows gym interface"""
|
||||
metadata = {'render.modes': ['human']}
|
||||
metadata = {"render.modes": ["human"]}
|
||||
|
||||
def __init__(self, arg1, arg2, ...):
|
||||
super(CustomEnv, self).__init__()
|
||||
|
|
@ -45,7 +45,7 @@ That is to say, your environment must implement the following methods (and inher
|
|||
def reset(self):
|
||||
...
|
||||
return observation # reward, done, info can't be included
|
||||
def render(self, mode='human'):
|
||||
def render(self, mode="human"):
|
||||
...
|
||||
def close (self):
|
||||
...
|
||||
|
|
@ -58,7 +58,7 @@ Then you can define and train a RL agent with:
|
|||
# Instantiate the env
|
||||
env = CustomEnv(arg1, ...)
|
||||
# Define and Train the agent
|
||||
model = A2C('CnnPolicy', env).learn(total_timesteps=1000)
|
||||
model = A2C("CnnPolicy", env).learn(total_timesteps=1000)
|
||||
|
||||
|
||||
To check that your environment follows the Gym interface that SB3 supports, please use:
|
||||
|
|
|
|||
|
|
@ -71,10 +71,10 @@ In the following example, we will train, save and load a DQN model on the Lunar
|
|||
|
||||
|
||||
# Create environment
|
||||
env = gym.make('LunarLander-v2')
|
||||
env = gym.make("LunarLander-v2")
|
||||
|
||||
# Instantiate the agent
|
||||
model = DQN('MlpPolicy', env, verbose=1)
|
||||
model = DQN("MlpPolicy", env, verbose=1)
|
||||
# Train the agent
|
||||
model.learn(total_timesteps=int(2e5))
|
||||
# Save the agent
|
||||
|
|
@ -138,7 +138,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
|
|||
set_random_seed(seed)
|
||||
return _init
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
env_id = "CartPole-v1"
|
||||
num_cpu = 4 # Number of processes to use
|
||||
# Create the vectorized environment
|
||||
|
|
@ -149,7 +149,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
|
|||
# You can choose between `DummyVecEnv` (usually faster) and `SubprocVecEnv`
|
||||
# env = make_vec_env(env_id, n_envs=num_cpu, seed=0, vec_env_cls=SubprocVecEnv)
|
||||
|
||||
model = PPO('MlpPolicy', env, verbose=1)
|
||||
model = PPO("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=25_000)
|
||||
|
||||
obs = env.reset()
|
||||
|
|
@ -182,7 +182,7 @@ Multiprocessing with off-policy algorithms
|
|||
# We collect 4 transitions per call to `ènv.step()`
|
||||
# and performs 2 gradient steps per call to `ènv.step()`
|
||||
# if gradient_steps=-1, then we would do 4 gradients steps per call to `ènv.step()`
|
||||
model = SAC('MlpPolicy', env, train_freq=1, gradient_steps=2, verbose=1)
|
||||
model = SAC("MlpPolicy", env, train_freq=1, gradient_steps=2, verbose=1)
|
||||
model.learn(total_timesteps=10_000)
|
||||
|
||||
|
||||
|
|
@ -254,7 +254,7 @@ If your callback returns False, training is aborted early.
|
|||
super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)
|
||||
self.check_freq = check_freq
|
||||
self.log_dir = log_dir
|
||||
self.save_path = os.path.join(log_dir, 'best_model')
|
||||
self.save_path = os.path.join(log_dir, "best_model")
|
||||
self.best_mean_reward = -np.inf
|
||||
|
||||
def _init_callback(self) -> None:
|
||||
|
|
@ -266,7 +266,7 @@ If your callback returns False, training is aborted early.
|
|||
if self.n_calls % self.check_freq == 0:
|
||||
|
||||
# Retrieve training reward
|
||||
x, y = ts2xy(load_results(self.log_dir), 'timesteps')
|
||||
x, y = ts2xy(load_results(self.log_dir), "timesteps")
|
||||
if len(x) > 0:
|
||||
# Mean training reward over the last 100 episodes
|
||||
mean_reward = np.mean(y[-100:])
|
||||
|
|
@ -289,14 +289,14 @@ If your callback returns False, training is aborted early.
|
|||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# Create and wrap the environment
|
||||
env = gym.make('LunarLanderContinuous-v2')
|
||||
env = gym.make("LunarLanderContinuous-v2")
|
||||
env = Monitor(env, log_dir)
|
||||
|
||||
# Add some action noise for exploration
|
||||
n_actions = env.action_space.shape[-1]
|
||||
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
|
||||
# Because we use parameter noise, we should use a MlpPolicy with layer normalization
|
||||
model = TD3('MlpPolicy', env, action_noise=action_noise, verbose=0)
|
||||
model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=0)
|
||||
# Create the callback: check every 1000 steps
|
||||
callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir)
|
||||
# Train the agent
|
||||
|
|
@ -336,11 +336,11 @@ and multiprocessing for you. To install the Atari environments, run the command
|
|||
# There already exists an environment generator
|
||||
# that will make and wrap atari environments correctly.
|
||||
# Here we are also multi-worker training (n_envs=4 => 4 environments)
|
||||
env = make_atari_env('PongNoFrameskip-v4', n_envs=4, seed=0)
|
||||
env = make_atari_env("PongNoFrameskip-v4", n_envs=4, seed=0)
|
||||
# Frame-stacking with 4 frames
|
||||
env = VecFrameStack(env, n_stack=4)
|
||||
|
||||
model = A2C('CnnPolicy', env, verbose=1)
|
||||
model = A2C("CnnPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=25_000)
|
||||
|
||||
obs = env.reset()
|
||||
|
|
@ -382,7 +382,7 @@ will compute a running average and standard deviation of input features (it can
|
|||
env = VecNormalize(env, norm_obs=True, norm_reward=True,
|
||||
clip_obs=10.)
|
||||
|
||||
model = PPO('MlpPolicy', env)
|
||||
model = PPO("MlpPolicy", env)
|
||||
model.learn(total_timesteps=2000)
|
||||
|
||||
# Don't forget to save the VecNormalize statistics when saving the agent
|
||||
|
|
@ -564,7 +564,7 @@ Behind the scene, SB3 uses an :ref:`EvalCallback <callbacks>`.
|
|||
|
||||
# Create the model, the training environment
|
||||
# and the test environment (for evaluation)
|
||||
model = SAC('MlpPolicy', 'Pendulum-v1', verbose=1,
|
||||
model = SAC("MlpPolicy", "Pendulum-v1", verbose=1,
|
||||
learning_rate=1e-3, create_eval_env=True)
|
||||
|
||||
# Evaluate the model every 1000 steps on 5 test episodes
|
||||
|
|
@ -717,7 +717,7 @@ to keep track of the agent progress.
|
|||
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
|
||||
|
||||
# ProcgenEnv is already vectorized
|
||||
venv = ProcgenEnv(num_envs=2, env_name='starpilot')
|
||||
venv = ProcgenEnv(num_envs=2, env_name="starpilot")
|
||||
|
||||
# To use only part of the observation:
|
||||
# venv = VecExtractDictObs(venv, "rgb")
|
||||
|
|
@ -753,8 +753,8 @@ Record a mp4 video (here using a random agent).
|
|||
import gym
|
||||
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
|
||||
|
||||
env_id = 'CartPole-v1'
|
||||
video_folder = 'logs/videos/'
|
||||
env_id = "CartPole-v1"
|
||||
video_folder = "logs/videos/"
|
||||
video_length = 100
|
||||
|
||||
env = DummyVecEnv([lambda: gym.make(env_id)])
|
||||
|
|
@ -792,11 +792,11 @@ Bonus: Make a GIF of a Trained Agent
|
|||
|
||||
images = []
|
||||
obs = model.env.reset()
|
||||
img = model.env.render(mode='rgb_array')
|
||||
img = model.env.render(mode="rgb_array")
|
||||
for i in range(350):
|
||||
images.append(img)
|
||||
action, _ = model.predict(obs)
|
||||
obs, _, _ ,_ = model.env.step(action)
|
||||
img = model.env.render(mode='rgb_array')
|
||||
img = model.env.render(mode="rgb_array")
|
||||
|
||||
imageio.mimsave('lander_a2c.gif', [np.array(img) for i, img in enumerate(images) if i%2 == 0], fps=29)
|
||||
imageio.mimsave("lander_a2c.gif", [np.array(img) for i, img in enumerate(images) if i%2 == 0], fps=29)
|
||||
|
|
|
|||
|
|
@ -46,29 +46,40 @@ For PPO, assuming a shared feature extactor.
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
import torch as th
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
import torch
|
||||
|
||||
class OnnxablePolicy(torch.nn.Module):
|
||||
def __init__(self, extractor, action_net, value_net):
|
||||
super(OnnxablePolicy, self).__init__()
|
||||
self.extractor = extractor
|
||||
self.action_net = action_net
|
||||
self.value_net = value_net
|
||||
|
||||
def forward(self, observation):
|
||||
# NOTE: You may have to process (normalize) observation in the correct
|
||||
# way before using this. See `common.preprocessing.preprocess_obs`
|
||||
action_hidden, value_hidden = self.extractor(observation)
|
||||
return self.action_net(action_hidden), self.value_net(value_hidden)
|
||||
class OnnxablePolicy(th.nn.Module):
|
||||
def __init__(self, extractor, action_net, value_net):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.action_net = action_net
|
||||
self.value_net = value_net
|
||||
|
||||
def forward(self, observation):
|
||||
# NOTE: You may have to process (normalize) observation in the correct
|
||||
# way before using this. See `common.preprocessing.preprocess_obs`
|
||||
action_hidden, value_hidden = self.extractor(observation)
|
||||
return self.action_net(action_hidden), self.value_net(value_hidden)
|
||||
|
||||
|
||||
# Example: model = PPO("MlpPolicy", "Pendulum-v1")
|
||||
model = PPO.load("PathToTrainedModel.zip")
|
||||
model.policy.to("cpu")
|
||||
onnxable_model = OnnxablePolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net)
|
||||
model = PPO.load("PathToTrainedModel.zip", device="cpu")
|
||||
onnxable_model = OnnxablePolicy(
|
||||
model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net
|
||||
)
|
||||
|
||||
dummy_input = torch.randn(1, observation_size)
|
||||
torch.onnx.export(onnxable_model, dummy_input, "my_ppo_model.onnx", opset_version=9)
|
||||
observation_size = model.observation_space.shape
|
||||
dummy_input = th.randn(1, *observation_size)
|
||||
th.onnx.export(
|
||||
onnxable_model,
|
||||
dummy_input,
|
||||
"my_ppo_model.onnx",
|
||||
opset_version=9,
|
||||
input_names=["input"],
|
||||
)
|
||||
|
||||
##### Load and test with onnx
|
||||
|
||||
|
|
@ -76,48 +87,97 @@ For PPO, assuming a shared feature extactor.
|
|||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
|
||||
onnx_path = "my_ppo_model.onnx"
|
||||
onnx_model = onnx.load(onnx_path)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
|
||||
observation = np.zeros((1, observation_size)).astype(np.float32)
|
||||
observation = np.zeros((1, *observation_size)).astype(np.float32)
|
||||
ort_sess = ort.InferenceSession(onnx_path)
|
||||
action, value = ort_sess.run(None, {'input.1': observation})
|
||||
action, value = ort_sess.run(None, {"input": observation})
|
||||
|
||||
|
||||
For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch as th
|
||||
|
||||
from stable_baselines3 import SAC
|
||||
import torch
|
||||
|
||||
class OnnxablePolicy(torch.nn.Module):
|
||||
def __init__(self, actor):
|
||||
super(OnnxablePolicy, self).__init__()
|
||||
|
||||
# Removing the flatten layer because it can't be onnxed
|
||||
self.actor = torch.nn.Sequential(actor.latent_pi, actor.mu)
|
||||
class OnnxablePolicy(th.nn.Module):
|
||||
def __init__(self, actor: th.nn.Module):
|
||||
super().__init__()
|
||||
# Removing the flatten layer because it can't be onnxed
|
||||
self.actor = th.nn.Sequential(
|
||||
actor.latent_pi,
|
||||
actor.mu,
|
||||
# For gSDE
|
||||
# th.nn.Hardtanh(min_val=-actor.clip_mean, max_val=actor.clip_mean),
|
||||
# Squash the output
|
||||
th.nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, observation):
|
||||
# NOTE: You may have to process (normalize) observation in the correct
|
||||
# way before using this. See `common.preprocessing.preprocess_obs`
|
||||
return self.actor(observation)
|
||||
def forward(self, observation: th.Tensor) -> th.Tensor:
|
||||
# NOTE: You may have to process (normalize) observation in the correct
|
||||
# way before using this. See `common.preprocessing.preprocess_obs`
|
||||
return self.actor(observation)
|
||||
|
||||
model = SAC.load("PathToTrainedModel.zip")
|
||||
|
||||
# Example: model = SAC("MlpPolicy", "Pendulum-v1")
|
||||
model = SAC.load("PathToTrainedModel.zip", device="cpu")
|
||||
onnxable_model = OnnxablePolicy(model.policy.actor)
|
||||
|
||||
dummy_input = torch.randn(1, observation_size)
|
||||
onnxable_model.policy.to("cpu")
|
||||
torch.onnx.export(onnxable_model, dummy_input, "my_sac_actor.onnx", opset_version=9)
|
||||
observation_size = model.observation_space.shape
|
||||
dummy_input = th.randn(1, *observation_size)
|
||||
th.onnx.export(
|
||||
onnxable_model,
|
||||
dummy_input,
|
||||
"my_sac_actor.onnx",
|
||||
opset_version=9,
|
||||
input_names=["input"],
|
||||
)
|
||||
|
||||
##### Load and test with onnx
|
||||
|
||||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
|
||||
onnx_path = "my_sac_actor.onnx"
|
||||
|
||||
observation = np.zeros((1, *observation_size)).astype(np.float32)
|
||||
ort_sess = ort.InferenceSession(onnx_path)
|
||||
action = ort_sess.run(None, {"input": observation})
|
||||
|
||||
|
||||
For more discussion around the topic refer to this `issue. <https://github.com/DLR-RM/stable-baselines3/issues/383>`_
|
||||
|
||||
Export to C++
|
||||
-----------------
|
||||
Trace/Export to C++
|
||||
-------------------
|
||||
|
||||
(using PyTorch JIT)
|
||||
TODO: help is welcomed!
|
||||
You can use PyTorch JIT to trace and save a trained model that can be re-used in other applications
|
||||
(for instance inference code written in C++).
|
||||
|
||||
There is a draft PR in the RL Zoo about C++ export: https://github.com/DLR-RM/rl-baselines3-zoo/pull/228
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# See "ONNX export" for imports and OnnxablePolicy
|
||||
jit_path = "sac_traced.pt"
|
||||
|
||||
# Trace and optimize the module
|
||||
traced_module = th.jit.trace(onnxable_model.eval(), dummy_input)
|
||||
frozen_module = th.jit.freeze(traced_module)
|
||||
frozen_module = th.jit.optimize_for_inference(frozen_module)
|
||||
th.jit.save(frozen_module, jit_path)
|
||||
|
||||
##### Load and test with torch
|
||||
|
||||
import torch as th
|
||||
|
||||
dummy_input = th.randn(1, *observation_size)
|
||||
loaded_module = th.jit.load(jit_path)
|
||||
action_jit = loaded_module(dummy_input)
|
||||
|
||||
|
||||
Export to tensorflowjs / ONNX-JS
|
||||
|
|
|
|||
|
|
@ -14,9 +14,9 @@ Here is a quick example of how to train and run A2C on a CartPole environment:
|
|||
|
||||
from stable_baselines3 import A2C
|
||||
|
||||
env = gym.make('CartPole-v1')
|
||||
env = gym.make("CartPole-v1")
|
||||
|
||||
model = A2C('MlpPolicy', env, verbose=1)
|
||||
model = A2C("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=10000)
|
||||
|
||||
obs = env.reset()
|
||||
|
|
@ -40,4 +40,4 @@ the policy is registered:
|
|||
|
||||
from stable_baselines3 import A2C
|
||||
|
||||
model = A2C('MlpPolicy', 'CartPole-v1').learn(10000)
|
||||
model = A2C("MlpPolicy", "CartPole-v1").learn(10000)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ To use Tensorboard with stable baselines3, you simply need to pass the location
|
|||
|
||||
from stable_baselines3 import A2C
|
||||
|
||||
model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
|
||||
model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
|
||||
model.learn(total_timesteps=10_000)
|
||||
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ You can also define custom logging name when training (by default it is the algo
|
|||
|
||||
from stable_baselines3 import A2C
|
||||
|
||||
model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
|
||||
model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
|
||||
model.learn(total_timesteps=10_000, tb_log_name="first_run")
|
||||
# Pass reset_num_timesteps=False to continue the training curve in tensorboard
|
||||
# By default, it will create a new curve
|
||||
|
|
@ -91,7 +91,7 @@ Here is a simple example on how to log both additional tensor or arbitrary scala
|
|||
def _on_step(self) -> bool:
|
||||
# Log scalar value (here a random variable)
|
||||
value = np.random.random()
|
||||
self.logger.record('random_value', value)
|
||||
self.logger.record("random_value", value)
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.6.1a4 (WIP)
|
||||
Release 1.6.1 (2022-09-29)
|
||||
---------------------------
|
||||
|
||||
**Bug fix release**
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Switched minimum tensorboard version to 2.9.1
|
||||
|
|
@ -20,6 +22,7 @@ New Features:
|
|||
|
||||
SB3-Contrib
|
||||
^^^^^^^^^^^
|
||||
- Fixed the issue of wrongly passing policy arguments when using ``CnnLstmPolicy`` or ``MultiInputLstmPolicy`` with ``RecurrentPPO`` (@mlodel)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -34,6 +37,7 @@ Bug Fixes:
|
|||
- Removed ``forward()`` abstract method declaration from ``common.policies.BaseModel`` (already defined in ``torch.nn.Module``) to fix type errors in subclasses (@Rocamonde)
|
||||
- Fixed the return type of ``.load()`` and ``.learn()`` methods in ``BaseAlgorithm`` so that they now use ``TypeVar`` (@Rocamonde)
|
||||
- Fixed an issue where keys with different tags but the same key raised an error in ``common.logger.HumanOutputFormat`` (@Rocamonde and @AdamGleave)
|
||||
- Set importlib-metadata version to `~=4.13`
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -53,7 +57,9 @@ Documentation:
|
|||
- Fixed typo in install doc(@jlp-ue)
|
||||
- Clarified and standardized verbosity documentation
|
||||
- Added link to a GitHub issue in the custom policy documentation (@AlexPasqua)
|
||||
|
||||
- Update doc on exporting models (fixes and added torch jit)
|
||||
- Fixed typos (@Akhilez)
|
||||
- Standardized the use of ``"`` for string representation in documentation
|
||||
|
||||
Release 1.6.0 (2022-07-11)
|
||||
---------------------------
|
||||
|
|
@ -1042,4 +1048,4 @@ And all the contributors:
|
|||
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
|
||||
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
|
||||
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
|
||||
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr
|
||||
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ This example is only to demonstrate the use of the library and its functions, an
|
|||
env = BitFlippingEnv(n_bits=N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
|
||||
|
||||
# Available strategies (cf paper): future, final, episode
|
||||
goal_selection_strategy = 'future' # equivalent to GoalSelectionStrategy.FUTURE
|
||||
goal_selection_strategy = "future" # equivalent to GoalSelectionStrategy.FUTURE
|
||||
|
||||
# If True the HER transitions will get sampled online
|
||||
online_sampling = True
|
||||
|
|
@ -101,7 +101,7 @@ This example is only to demonstrate the use of the library and its functions, an
|
|||
model.save("./her_bit_env")
|
||||
# Because it needs access to `env.compute_reward()`
|
||||
# HER must be loaded with the env
|
||||
model = model_class.load('./her_bit_env', env=env)
|
||||
model = model_class.load("./her_bit_env", env=env)
|
||||
|
||||
obs = env.reset()
|
||||
for _ in range(100):
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -94,6 +94,8 @@ setup(
|
|||
"pytype",
|
||||
# Lint code
|
||||
"flake8>=3.8",
|
||||
# flake8 not compatible with importlib-metadata>5.0
|
||||
"importlib-metadata~=4.13",
|
||||
# Find likely bugs
|
||||
"flake8-bugbear",
|
||||
# Sort imports
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
:param callback: Callback that will be called at each step
|
||||
(and at the beginning and end of the rollout)
|
||||
:param rollout_buffer: Buffer to fill with rollouts
|
||||
:param n_steps: Number of experiences to collect per environment
|
||||
:param n_rollout_steps: Number of experiences to collect per environment
|
||||
:return: True if function returned with at least `n_rollout_steps`
|
||||
collected, False if callback terminated rollout prematurely.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -234,7 +234,7 @@ class PPO(OnPolicyAlgorithm):
|
|||
# No clipping
|
||||
values_pred = values
|
||||
else:
|
||||
# Clip the different between old and new value
|
||||
# Clip the difference between old and new value
|
||||
# NOTE: this depends on the reward scaling
|
||||
values_pred = rollout_data.old_values + th.clamp(
|
||||
values - rollout_data.old_values, -clip_range_vf, clip_range_vf
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.6.1a4
|
||||
1.6.1
|
||||
|
|
|
|||
Loading…
Reference in a new issue