mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-03 23:49:57 +00:00
Fix set_env to keep the number of timesteps (#615)
* Fix for `set_env` * Add test and update changelog * Use underscores and f-strings * Add PyPi info * Update comments
This commit is contained in:
parent
1564a85081
commit
e907eca18e
7 changed files with 93 additions and 30 deletions
|
|
@ -149,7 +149,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
|
|||
# env = make_vec_env(env_id, n_envs=num_cpu, seed=0, vec_env_cls=SubprocVecEnv)
|
||||
|
||||
model = PPO('MlpPolicy', env, verbose=1)
|
||||
model.learn(total_timesteps=25000)
|
||||
model.learn(total_timesteps=25_000)
|
||||
|
||||
obs = env.reset()
|
||||
for _ in range(1000):
|
||||
|
|
@ -177,7 +177,7 @@ These dictionaries are randomly initilaized on the creation of the environment a
|
|||
env = SimpleMultiObsEnv(random_start=False)
|
||||
|
||||
model = PPO("MultiInputPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=1e5)
|
||||
model.learn(total_timesteps=100_000)
|
||||
|
||||
|
||||
Using Callback: Monitoring Training
|
||||
|
|
@ -217,12 +217,12 @@ If your callback returns False, training is aborted early.
|
|||
Callback for saving a model (the check is done every ``check_freq`` steps)
|
||||
based on the training reward (in practice, we recommend using ``EvalCallback``).
|
||||
|
||||
:param check_freq: (int)
|
||||
:param log_dir: (str) Path to the folder where the model will be saved.
|
||||
:param check_freq:
|
||||
:param log_dir: Path to the folder where the model will be saved.
|
||||
It must contains the file created by the ``Monitor`` wrapper.
|
||||
:param verbose: (int)
|
||||
:param verbose: Verbosity level.
|
||||
"""
|
||||
def __init__(self, check_freq: int, log_dir: str, verbose=1):
|
||||
def __init__(self, check_freq: int, log_dir: str, verbose: int = 1):
|
||||
super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)
|
||||
self.check_freq = check_freq
|
||||
self.log_dir = log_dir
|
||||
|
|
@ -243,15 +243,15 @@ If your callback returns False, training is aborted early.
|
|||
# Mean training reward over the last 100 episodes
|
||||
mean_reward = np.mean(y[-100:])
|
||||
if self.verbose > 0:
|
||||
print("Num timesteps: {}".format(self.num_timesteps))
|
||||
print("Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}".format(self.best_mean_reward, mean_reward))
|
||||
print(f"Num timesteps: {self.num_timesteps}")
|
||||
print(f"Best mean reward: {self.best_mean_reward:.2f} - Last mean reward per episode: {mean_reward:.2f}")
|
||||
|
||||
# New best model, you could save the agent here
|
||||
if mean_reward > self.best_mean_reward:
|
||||
self.best_mean_reward = mean_reward
|
||||
# Example for saving best model
|
||||
if self.verbose > 0:
|
||||
print("Saving new best model to {}".format(self.save_path))
|
||||
print(f"Saving new best model to {self.save_path}")
|
||||
self.model.save(self.save_path)
|
||||
|
||||
return True
|
||||
|
|
@ -313,7 +313,7 @@ and multiprocessing for you.
|
|||
env = VecFrameStack(env, n_stack=4)
|
||||
|
||||
model = A2C('CnnPolicy', env, verbose=1)
|
||||
model.learn(total_timesteps=25000)
|
||||
model.learn(total_timesteps=25_000)
|
||||
|
||||
obs = env.reset()
|
||||
while True:
|
||||
|
|
@ -495,10 +495,10 @@ linear and constant schedules.
|
|||
|
||||
# Initial learning rate of 0.001
|
||||
model = PPO("MlpPolicy", "CartPole-v1", learning_rate=linear_schedule(0.001), verbose=1)
|
||||
model.learn(total_timesteps=20000)
|
||||
model.learn(total_timesteps=20_000)
|
||||
# By default, `reset_num_timesteps` is True, in which case the learning rate schedule resets.
|
||||
# progress_remaining = 1.0 - (num_timesteps / total_timesteps)
|
||||
model.learn(total_timesteps=10000, reset_num_timesteps=True)
|
||||
model.learn(total_timesteps=10_000, reset_num_timesteps=True)
|
||||
|
||||
|
||||
Advanced Saving and Loading
|
||||
|
|
@ -630,7 +630,7 @@ A2C policy gradient updates on the model.
|
|||
|
||||
# Use traditional actor-critic policy gradient updates to
|
||||
# find good initial parameters
|
||||
model.learn(total_timesteps=10000)
|
||||
model.learn(total_timesteps=10_000)
|
||||
|
||||
# Include only variables with "policy", "action" (policy) or "shared_net" (shared layers)
|
||||
# in their name: only these ones affect the action.
|
||||
|
|
@ -698,7 +698,7 @@ to keep track of the agent progress.
|
|||
venv = VecMonitor(venv=venv)
|
||||
|
||||
model = PPO("MultiInputPolicy", venv, verbose=1)
|
||||
model.learn(10000)
|
||||
model.learn(10_000)
|
||||
|
||||
|
||||
Record a Video
|
||||
|
|
@ -726,7 +726,7 @@ Record a mp4 video (here using a random agent).
|
|||
# Record the video starting at the first step
|
||||
env = VecVideoRecorder(env, video_folder,
|
||||
record_video_trigger=lambda x: x == 0, video_length=video_length,
|
||||
name_prefix="random-agent-{}".format(env_id))
|
||||
name_prefix=f"random-agent-{env_id}")
|
||||
|
||||
env.reset()
|
||||
for _ in range(video_length + 1):
|
||||
|
|
@ -750,7 +750,7 @@ Bonus: Make a GIF of a Trained Agent
|
|||
|
||||
from stable_baselines3 import A2C
|
||||
|
||||
model = A2C("MlpPolicy", "LunarLander-v2").learn(100000)
|
||||
model = A2C("MlpPolicy", "LunarLander-v2").learn(100_000)
|
||||
|
||||
images = []
|
||||
obs = model.env.reset()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ To use Tensorboard with stable baselines3, you simply need to pass the location
|
|||
from stable_baselines3 import A2C
|
||||
|
||||
model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
|
||||
model.learn(total_timesteps=10000)
|
||||
model.learn(total_timesteps=10_000)
|
||||
|
||||
|
||||
You can also define custom logging name when training (by default it is the algorithm name)
|
||||
|
|
@ -23,11 +23,11 @@ You can also define custom logging name when training (by default it is the algo
|
|||
from stable_baselines3 import A2C
|
||||
|
||||
model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
|
||||
model.learn(total_timesteps=10000, tb_log_name="first_run")
|
||||
model.learn(total_timesteps=10_000, tb_log_name="first_run")
|
||||
# Pass reset_num_timesteps=False to continue the training curve in tensorboard
|
||||
# By default, it will create a new curve
|
||||
model.learn(total_timesteps=10000, tb_log_name="second_run", reset_num_timesteps=False)
|
||||
model.learn(total_timesteps=10000, tb_log_name="third_run", reset_num_timesteps=False)
|
||||
model.learn(total_timesteps=10_000, tb_log_name="second_run", reset_num_timesteps=False)
|
||||
model.learn(total_timesteps=10_000, tb_log_name="third_run", reset_num_timesteps=False)
|
||||
|
||||
|
||||
Once the learn function is called, you can monitor the RL agent during or after the training, with the following bash command:
|
||||
|
|
|
|||
|
|
@ -4,9 +4,15 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.2.1a4 (WIP)
|
||||
Release 1.2.1a5 (WIP)
|
||||
---------------------------
|
||||
|
||||
.. warning::
|
||||
|
||||
This version will be the last one supporting Python 3.6 (end of life in Dec 2021).
|
||||
We highly recommended you to upgrade to Python >= 3.7.
|
||||
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- ``sde_net_arch`` argument in policies is deprecated and will be removed in a future version.
|
||||
|
|
@ -31,6 +37,7 @@ Bug Fixes:
|
|||
when observation normalization is disabled.
|
||||
- Fixed a bug where ``DQN`` would throw an error when using ``Discrete`` observation and stochastic actions
|
||||
- Fixed a bug where sub-classed observation spaces could not be used
|
||||
- Added ``force_reset`` argument to ``load()`` and ``set_env()`` in order to be able to call ``learn(reset_num_timesteps=False)`` with a new environment
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -40,6 +47,7 @@ Others:
|
|||
- Cap gym max version to 0.19 to avoid issues with atari-py and other breaking changes
|
||||
- Improved error message when using dict observation with the wrong policy
|
||||
- Improved error message when using ``EvalCallback`` with two envs not wrapped the same way.
|
||||
- Added additional infos about supported python version for PyPi in ``setup.py``
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
@ -51,7 +59,7 @@ Documentation:
|
|||
- Fix PPO environment name (@IljaAvadiev)
|
||||
- Fix custom env doc and add env registration example
|
||||
- Update algorithms from SB3 Contrib
|
||||
|
||||
- Use underscores for numeric literals in examples to improve clarity
|
||||
|
||||
Release 1.2.0 (2021-09-03)
|
||||
---------------------------
|
||||
|
|
|
|||
9
setup.py
9
setup.py
|
|
@ -134,6 +134,15 @@ setup(
|
|||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
version=__version__,
|
||||
python_requires=">=3.6",
|
||||
# PyPI package information.
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
],
|
||||
)
|
||||
|
||||
# python setup.py sdist
|
||||
|
|
|
|||
|
|
@ -478,7 +478,7 @@ class BaseAlgorithm(ABC):
|
|||
"""
|
||||
return self._vec_normalize_env
|
||||
|
||||
def set_env(self, env: GymEnv) -> None:
|
||||
def set_env(self, env: GymEnv, force_reset: bool = True) -> None:
|
||||
"""
|
||||
Checks the validity of the environment, and if it is coherent, set it as the current environment.
|
||||
Furthermore wrap any non vectorized env into a vectorized
|
||||
|
|
@ -487,12 +487,19 @@ class BaseAlgorithm(ABC):
|
|||
- action_space
|
||||
|
||||
:param env: The environment for learning a policy
|
||||
:param force_reset: Force call to ``reset()`` before training
|
||||
to avoid unexpected behavior.
|
||||
See issue https://github.com/DLR-RM/stable-baselines3/issues/597
|
||||
"""
|
||||
# if it is not a VecEnv, make it a VecEnv
|
||||
# and do other transformations (dict obs, image transpose) if needed
|
||||
env = self._wrap_env(env, self.verbose)
|
||||
# Check that the observation spaces match
|
||||
check_for_correct_spaces(env, self.observation_space, self.action_space)
|
||||
# Discard `_last_obs`, this will force the env to reset before training
|
||||
# See issue https://github.com/DLR-RM/stable-baselines3/issues/597
|
||||
if force_reset:
|
||||
self._last_obs = None
|
||||
|
||||
self.n_envs = env.num_envs
|
||||
self.env = env
|
||||
|
|
@ -636,6 +643,7 @@ class BaseAlgorithm(ABC):
|
|||
device: Union[th.device, str] = "auto",
|
||||
custom_objects: Optional[Dict[str, Any]] = None,
|
||||
print_system_info: bool = False,
|
||||
force_reset: bool = True,
|
||||
**kwargs,
|
||||
) -> "BaseAlgorithm":
|
||||
"""
|
||||
|
|
@ -654,6 +662,9 @@ class BaseAlgorithm(ABC):
|
|||
file that can not be deserialized.
|
||||
:param print_system_info: Whether to print system info from the saved model
|
||||
and the current system info (useful to debug loading issues)
|
||||
:param force_reset: Force call to ``reset()`` before training
|
||||
to avoid unexpected behavior.
|
||||
See https://github.com/DLR-RM/stable-baselines3/issues/597
|
||||
:param kwargs: extra arguments to change the model when loading
|
||||
"""
|
||||
if print_system_info:
|
||||
|
|
@ -683,6 +694,10 @@ class BaseAlgorithm(ABC):
|
|||
env = cls._wrap_env(env, data["verbose"])
|
||||
# Check if given env is valid
|
||||
check_for_correct_spaces(env, data["observation_space"], data["action_space"])
|
||||
# Discard `_last_obs`, this will force the env to reset before training
|
||||
# See issue https://github.com/DLR-RM/stable-baselines3/issues/597
|
||||
if force_reset and data is not None:
|
||||
data["_last_obs"] = None
|
||||
else:
|
||||
# Use stored env, if one exists. If not, continue as is (can be used for predict)
|
||||
if "env" in data:
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.2.1a4
|
||||
1.2.1a5
|
||||
|
|
|
|||
|
|
@ -163,9 +163,10 @@ def test_save_load(tmp_path, model_class):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
def test_set_env(model_class):
|
||||
def test_set_env(tmp_path, model_class):
|
||||
"""
|
||||
Test if set_env function does work correct
|
||||
|
||||
:param model_class: (BaseAlgorithm) A RL model
|
||||
"""
|
||||
|
||||
|
|
@ -176,24 +177,54 @@ def test_set_env(model_class):
|
|||
|
||||
kwargs = {}
|
||||
if model_class in {DQN, DDPG, SAC, TD3}:
|
||||
kwargs = dict(learning_starts=100, train_freq=4)
|
||||
kwargs = dict(learning_starts=50, train_freq=4)
|
||||
elif model_class in {A2C, PPO}:
|
||||
kwargs = dict(n_steps=64)
|
||||
|
||||
# create model
|
||||
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), **kwargs)
|
||||
# learn
|
||||
model.learn(total_timesteps=128)
|
||||
model.learn(total_timesteps=64)
|
||||
|
||||
# change env
|
||||
model.set_env(env2)
|
||||
model.set_env(env2, force_reset=True)
|
||||
# Check that last obs was discarded
|
||||
assert model._last_obs is None
|
||||
# learn again
|
||||
model.learn(total_timesteps=128)
|
||||
model.learn(total_timesteps=64, reset_num_timesteps=True)
|
||||
assert model.num_timesteps == 64
|
||||
|
||||
# change env test wrapping
|
||||
model.set_env(env3)
|
||||
# learn again
|
||||
model.learn(total_timesteps=128)
|
||||
model.learn(total_timesteps=64)
|
||||
|
||||
# Keep the same env, disable reset
|
||||
model.set_env(model.get_env(), force_reset=False)
|
||||
assert model._last_obs is not None
|
||||
# learn again
|
||||
model.learn(total_timesteps=64, reset_num_timesteps=False)
|
||||
assert model.num_timesteps == 2 * 64
|
||||
|
||||
current_env = model.get_env()
|
||||
model.save(tmp_path / "test_save.zip")
|
||||
del model
|
||||
# Check that we can keep the number of timesteps after loading
|
||||
# Here the env kept its state so we don't have to reset
|
||||
model = model_class.load(tmp_path / "test_save.zip", env=current_env, force_reset=False)
|
||||
assert model._last_obs is not None
|
||||
model.learn(total_timesteps=64, reset_num_timesteps=False)
|
||||
assert model.num_timesteps == 3 * 64
|
||||
|
||||
del model
|
||||
# We are changing the env, the env must reset but we should keep the number of timesteps
|
||||
model = model_class.load(tmp_path / "test_save.zip", env=env3, force_reset=True)
|
||||
assert model._last_obs is None
|
||||
model.learn(total_timesteps=64, reset_num_timesteps=False)
|
||||
assert model.num_timesteps == 3 * 64
|
||||
|
||||
# Clear saved file
|
||||
os.remove(tmp_path / "test_save.zip")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
|
|
|
|||
Loading…
Reference in a new issue