diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 4813df4..6a31f4f 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,4 +1,4 @@ -image: stablebaselines/stable-baselines3-cpu:0.9.0a2 +image: stablebaselines/stable-baselines3-cpu:0.11.0a4 type-check: script: diff --git a/README.md b/README.md index f3baf41..600c5da 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ # Stable Baselines3 -Stable Baselines3 is a set of improved implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines). +Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines). You can read a detailed presentation of Stable Baselines in the [Medium article](https://medium.com/@araffin/stable-baselines-a-fork-of-openai-baselines-reinforcement-learning-made-easy-df87c4b2fc82). @@ -50,7 +50,6 @@ A migration guide from SB2 to SB3 can be found in the [documentation](https://st Documentation is available online: [https://stable-baselines3.readthedocs.io/](https://stable-baselines3.readthedocs.io/) - ## RL Baselines3 Zoo: A Collection of Trained RL Agents [RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo). is a collection of pre-trained Reinforcement Learning agents using Stable-Baselines3. @@ -68,6 +67,15 @@ Github repo: https://github.com/DLR-RM/rl-baselines3-zoo Documentation: https://stable-baselines3.readthedocs.io/en/master/guide/rl_zoo.html +## SB3-Contrib: Experimental RL Features + +We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) + +This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Truncated Quantile Critics (TQC) or Quantile Regression DQN (QR-DQN). + +Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/) + + ## Installation **Note:** Stable-Baselines3 supports PyTorch 1.4+. diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 2ca362d..887bfb9 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -31,6 +31,10 @@ Actions ``gym.spaces``: - ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination. +.. note:: + + More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo `. + .. note:: Some logging values (like ``ep_rew_mean``, ``ep_len_mean``) are only available when using a ``Monitor`` wrapper diff --git a/docs/guide/rl_tips.rst b/docs/guide/rl_tips.rst index 29199a1..c207c21 100644 --- a/docs/guide/rl_tips.rst +++ b/docs/guide/rl_tips.rst @@ -87,8 +87,6 @@ Looking at the training curve (episode reward function of the timesteps) is a go - - We suggest you reading `Deep Reinforcement Learning that Matters `_ for a good discussion about RL evaluation. You can also take a look at this `blog post `_ @@ -122,6 +120,7 @@ Discrete Actions - Single Process ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ DQN with extensions (double DQN, prioritized replay, ...) are the recommended algorithms. +We notably provide QR-DQN in our :ref:`contrib repo `. DQN is usually slower to train (regarding wall clock time) but is the most sample efficient (because of its replay buffer). Discrete Actions - Multiprocessed @@ -136,7 +135,7 @@ Continuous Actions Continuous Actions - Single Process ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Current State Of The Art (SOTA) algorithms are ``SAC`` and ``TD3``. +Current State Of The Art (SOTA) algorithms are ``SAC``, ``TD3`` and ``TQC`` (available in our :ref:`contrib repo `). Please use the hyperparameters in the `RL zoo `_ for best results. @@ -156,7 +155,7 @@ Goal Environment ----------------- If your environment follows the ``GoalEnv`` interface (cf :ref:`HER `), then you should use -HER + (SAC/TD3/DDPG/DQN) depending on the action space. +HER + (SAC/TD3/DDPG/DQN/TQC) depending on the action space. .. note:: diff --git a/docs/guide/sb3_contrib.rst b/docs/guide/sb3_contrib.rst new file mode 100644 index 0000000..3d2d15e --- /dev/null +++ b/docs/guide/sb3_contrib.rst @@ -0,0 +1,97 @@ +.. _sb3_contrib: + +================== +SB3 Contrib +================== + +We implement experimental features in a separate contrib repository: +`SB3-Contrib`_ + +This allows Stable-Baselines3 (SB3) to maintain a stable and compact core, while still +providing the latest features, like Truncated Quantile Critics (TQC) or +Quantile Regression DQN (QR-DQN). + +Why create this repository? +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Over the span of stable-baselines and stable-baselines3, the community +has been eager to contribute in form of better logging utilities, +environment wrappers, extended support (e.g. different action spaces) +and learning algorithms. + +However sometimes these utilities were too niche to be considered for +stable-baselines or proved to be too difficult to integrate well into +the existing code without creating a mess. sb3-contrib aims to fix this by not +requiring the neatest code integration with existing code and not +setting limits on what is too niche: almost everything remotely useful +goes! +We hope this allows us to provide reliable implementations +following stable-baselines usual standards (consistent style, documentation, etc) +beyond the relatively small scope of utilities in the main repository. + +Features +-------- + +See documentation for the full list of included features. + +**RL Algorithms**: + +- `Truncated Quantile Critics (TQC)`_ +- `Quantile Regression DQN (QR-DQN)`_ + +**Gym Wrappers**: + +- `Time Feature Wrapper`_ + +Documentation +------------- + +Documentation is available online: https://sb3-contrib.readthedocs.io/ + +Installation +------------ + +To install Stable-Baselines3 contrib with pip, execute: + +:: + + pip install sb3-contrib + +We recommend to use the ``master`` version of Stable Baselines3 and SB3-Contrib. + +To install Stable Baselines3 ``master`` version: + +:: + + pip install git+https://github.com/DLR-RM/stable-baselines3 + +To install Stable Baselines3 contrib ``master`` version: + +:: + + pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib + + +Example +------- + +SB3-Contrib follows the SB3 API and folder structure. So, if you are familiar with SB3, +using SB3-Contrib should be easy too. + +Here is an example of training a Quantile Regression DQN (QR-DQN) agent on the CartPole environment. + +.. code-block:: python + + from sb3_contrib import QRDQN + + policy_kwargs = dict(n_quantiles=50) + model = QRDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1) + model.learn(total_timesteps=10000, log_interval=4) + model.save("qrdqn_cartpole") + + + +.. _SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib +.. _Truncated Quantile Critics (TQC): https://arxiv.org/abs/2005.04269 +.. _Quantile Regression DQN (QR-DQN): https://arxiv.org/abs/1710.10044 +.. _Time Feature Wrapper: https://arxiv.org/abs/1712.00378 diff --git a/docs/index.rst b/docs/index.rst index c60e5f3..61ac1d5 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,10 +3,10 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -Welcome to Stable Baselines3 docs! - RL Baselines Made Easy -=========================================================== +Stable-Baselines3 Docs - Reliable Reinforcement Learning Implementations +======================================================================== -`Stable Baselines3 `_ is a set of improved implementations of reinforcement learning algorithms in PyTorch. +`Stable Baselines3 (SB3) `_ is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of `Stable Baselines `_. @@ -16,6 +16,8 @@ RL Baselines3 Zoo (collection of pre-trained agents): https://github.com/DLR-RM/ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and do hyperparameter tuning. +SB3 Contrib (experimental RL code, latest algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib + Main Features -------------- @@ -45,6 +47,7 @@ Main Features guide/callbacks guide/tensorboard guide/rl_zoo + guide/sb3_contrib guide/imitation guide/migration guide/checking_nan diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 10ff71d..d3b7b3f 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Pre-Release 0.11.0a2 (WIP) +Pre-Release 0.11.0a4 (WIP) ------------------------------- Breaking Changes: @@ -44,6 +44,9 @@ Others: - Add signatures to callable type annotations (@ernestum) - Improve error message in ``NatureCNN`` - Added checks for supported action spaces to improve clarity of error messages for the user +- Renamed variables in the ``train()`` method of ``SAC``, ``TD3`` and ``DQN`` to match SB3-Contrib. +- Updated docker base image to Ubuntu 18.04 +- Set tensorboard min version to 2.2.0 (earlier version are apparently not working with PyTorch) Documentation: ^^^^^^^^^^^^^^ @@ -55,6 +58,7 @@ Documentation: - Added example of learning rate schedule - Added SUMO-RL as example project (@LucasAlegre) - Fix docstring of classes in atari_wrappers.py which were inside the constructor (@LucasAlegre) +- Added SB3-Contrib page Pre-Release 0.10.0 (2020-10-28) ------------------------------- diff --git a/scripts/build_docker.sh b/scripts/build_docker.sh index 0c599a6..13ac86b 100755 --- a/scripts/build_docker.sh +++ b/scripts/build_docker.sh @@ -1,7 +1,7 @@ #!/bin/bash -CPU_PARENT=ubuntu:16.04 -GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu16.04 +CPU_PARENT=ubuntu:18.04 +GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 TAG=stablebaselines/stable-baselines3 VERSION=$(cat ./stable_baselines3/version.txt) diff --git a/setup.py b/setup.py index 72146ad..0ef4e9b 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ long_description = """ # Stable Baselines3 -Stable Baselines3 is a set of improved implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines). +Stable Baselines3 is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines). These algorithms will make it easier for the research community and industry to replicate, refine, and identify new ideas, and will create good baselines to build projects on top of. We expect these tools will be used as a base around which new ideas can be added, and as a tool for comparing a new approach against existing ones. We also hope that the simplicity of these tools will allow beginners to experiment with a more advanced toolset, without being buried in implementation details. @@ -29,6 +29,9 @@ https://stable-baselines3.readthedocs.io/en/master/ RL Baselines3 Zoo: https://github.com/DLR-RM/rl-baselines3-zoo +SB3 Contrib: +https://github.com/Stable-Baselines-Team/stable-baselines3-contrib + ## Quick example Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms using Gym. @@ -112,7 +115,7 @@ setup( "atari_py~=0.2.0", "pillow", # Tensorboard support - "tensorboard", + "tensorboard>=2.2.0", # Checking memory taken by replay buffer "psutil", ], diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 045c377..8377202 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -155,23 +155,23 @@ class DQN(OffPolicyAlgorithm): replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) with th.no_grad(): - # Compute the target Q values - target_q = self.q_net_target(replay_data.next_observations) + # Compute the next Q-values using the target network + next_q_values = self.q_net_target(replay_data.next_observations) # Follow greedy policy: use the one with the highest value - target_q, _ = target_q.max(dim=1) + next_q_values, _ = next_q_values.max(dim=1) # Avoid potential broadcast issue - target_q = target_q.reshape(-1, 1) + next_q_values = next_q_values.reshape(-1, 1) # 1-step TD target - target_q = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q + target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values - # Get current Q estimates - current_q = self.q_net(replay_data.observations) + # Get current Q-values estimates + current_q_values = self.q_net(replay_data.observations) # Retrieve the q-values for the actions from the replay buffer - current_q = th.gather(current_q, dim=1, index=replay_data.actions.long()) + current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long()) # Compute Huber loss (less sensitive to outliers) - loss = F.smooth_l1_loss(current_q, target_q) + loss = F.smooth_l1_loss(current_q_values, target_q_values) losses.append(loss.item()) # Optimize the policy diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index f72424e..cd0d17e 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -74,7 +74,6 @@ class QNetwork(BasePolicy): features_dim=self.features_dim, activation_fn=self.activation_fn, features_extractor=self.features_extractor, - epsilon=self.epsilon, ) ) return data diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index a0c299a..cd7a413 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -223,20 +223,20 @@ class SAC(OffPolicyAlgorithm): with th.no_grad(): # Select action according to policy next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) - # Compute the target Q value: min over all critics targets - targets = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1) - target_q, _ = th.min(targets, dim=1, keepdim=True) + # Compute the next Q values: min over all critics targets + next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1) + next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True) # add entropy term - target_q = target_q - ent_coef * next_log_prob.reshape(-1, 1) + next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1) # td error + entropy term - q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q + target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values - # Get current Q estimates for each critic network + # Get current Q-values estimates for each critic network # using action from the replay buffer - current_q_estimates = self.critic(replay_data.observations, replay_data.actions) + current_q_values = self.critic(replay_data.observations, replay_data.actions) # Compute critic loss - critic_loss = 0.5 * sum([F.mse_loss(current_q, q_backup) for current_q in current_q_estimates]) + critic_loss = 0.5 * sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values]) critic_losses.append(critic_loss.item()) # Optimize the critic diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index ed74830..1a3d059 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -146,16 +146,16 @@ class TD3(OffPolicyAlgorithm): noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip) next_actions = (self.actor_target(replay_data.next_observations) + noise).clamp(-1, 1) - # Compute the target Q value: min over all critics targets - targets = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1) - target_q, _ = th.min(targets, dim=1, keepdim=True) - target_q = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q + # Compute the next Q-values: min over all critics targets + next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1) + next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True) + target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values - # Get current Q estimates for each critic network - current_q_estimates = self.critic(replay_data.observations, replay_data.actions) + # Get current Q-values estimates for each critic network + current_q_values = self.critic(replay_data.observations, replay_data.actions) # Compute critic loss - critic_loss = sum([F.mse_loss(current_q, target_q) for current_q in current_q_estimates]) + critic_loss = sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values]) critic_losses.append(critic_loss.item()) # Optimize the critics diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index a09c7eb..1b742ef 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -0.11.0a2 +0.11.0a4 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index f7c5521..b5b733f 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -408,6 +408,82 @@ def test_save_load_policy(tmp_path, model_class, policy_str): os.remove(tmp_path / "actor.pkl") +@pytest.mark.parametrize("model_class", [DQN]) +@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"]) +def test_save_load_q_net(tmp_path, model_class, policy_str): + """ + Test saving and loading q-network/quantile net only. + + :param model_class: (BaseAlgorithm) A RL model + :param policy_str: (str) Name of the policy. + """ + kwargs = dict(policy_kwargs=dict(net_arch=[16])) + if policy_str == "MlpPolicy": + env = select_env(model_class) + else: + if model_class in [DQN]: + # Avoid memory error when using replay buffer + # Reduce the size of the features + kwargs = dict( + buffer_size=250, + learning_starts=100, + policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), + ) + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN) + + env = DummyVecEnv([lambda: env]) + + # create model + model = model_class(policy_str, env, verbose=1, **kwargs) + model.learn(total_timesteps=300) + + env.reset() + observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) + + q_net = model.q_net + q_net_class = q_net.__class__ + + # Get dictionary of current parameters + params = deepcopy(q_net.state_dict()) + + # Modify all parameters to be random values + random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) + + # Update model parameters with the new random values + q_net.load_state_dict(random_params) + + new_params = q_net.state_dict() + # Check that all params are different now + for k in params: + assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected." + + params = new_params + + # get selected actions + selected_actions, _ = q_net.predict(observations, deterministic=True) + + # Save and load q_net + q_net.save(tmp_path / "q_net.pkl") + + del q_net + + q_net = q_net_class.load(tmp_path / "q_net.pkl") + + # check if params are still the same after load + new_params = q_net.state_dict() + + # Check that all params are the same as before save load procedure now + for key in params: + assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load." + + # check if model still selects the same actions + new_selected_actions, _ = q_net.predict(observations, deterministic=True) + assert np.allclose(selected_actions, new_selected_actions, 1e-4) + + # clear file from os + os.remove(tmp_path / "q_net.pkl") + + @pytest.mark.parametrize("pathtype", [str, pathlib.Path]) def test_open_file_str_pathlib(tmp_path, pathtype): # check that suffix isn't added because we used open_path first