From d228364ccffefef9632b4352ca6f7a7a6071cfa5 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 16 Nov 2021 17:19:16 +0100 Subject: [PATCH] Add timeout handling for on-policy algorithms (#658) * Add timeout handling for on-policy algorithms * Fixes * Fix infinite loop in eval * Skip type check for python 3.9 * Fix for discrete obs + add docstring * Fix A2C test * Removed unused helper * Add test for infinite horizon * typed ast should be fixed * Apply suggestions from code review Co-authored-by: Anssi Co-authored-by: Anssi --- .github/workflows/ci.yml | 2 - docs/misc/changelog.rst | 4 +- stable_baselines3/a2c/a2c.py | 1 - .../common/on_policy_algorithm.py | 17 +++++- stable_baselines3/version.txt | 2 +- tests/test_dict_env.py | 4 +- tests/test_gae.py | 59 ++++++++++++++++++- tests/test_identity.py | 3 + 8 files changed, 82 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 75ce35a..6626122 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -40,8 +40,6 @@ jobs: - name: Type check run: | make type - # Skip Type Check for python 3.9, see https://github.com/python/typed_ast/issues/169 - if: ${{ matrix.python-version != 3.9 }} - name: Check codestyle run: | make check-codestyle diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index adb669c..8be5910 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.3.1a1 (WIP) +Release 1.3.1a2 (WIP) --------------------------- Breaking Changes: @@ -13,6 +13,8 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Added ``norm_obs_keys`` param for ``VecNormalize`` wrapper to configure which observation keys to normalize (@kachayev) +- Handle timeout termination properly for on-policy algorithms (when using ``TimeLimit``) + Bug Fixes: ^^^^^^^^^^ diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 6641177..837ec42 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -133,7 +133,6 @@ class A2C(OnPolicyAlgorithm): # Convert discrete action from float to long actions = actions.long().flatten() - # TODO: avoid second computation of everything because of the gradient values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) values = values.flatten() diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 8e783dd..0aff9bb 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -190,14 +190,27 @@ class OnPolicyAlgorithm(BaseAlgorithm): if isinstance(self.action_space, gym.spaces.Discrete): # Reshape in case of discrete action actions = actions.reshape(-1, 1) + + # Handle timeout by bootstraping with value function + # see GitHub issue #633 + for idx, done_ in enumerate(dones): + if ( + done_ + and infos[idx].get("terminal_observation") is not None + and infos[idx].get("TimeLimit.truncated", False) + ): + terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] + with th.no_grad(): + terminal_value = self.policy.predict_values(terminal_obs)[0] + rewards[idx] += self.gamma * terminal_value + rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs) self._last_obs = new_obs self._last_episode_starts = dones with th.no_grad(): # Compute value for the last timestep - obs_tensor = obs_as_tensor(new_obs, self.device) - _, values, _ = self.policy.forward(obs_tensor) + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 690d925..c5813fb 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.3.1a1 +1.3.1a2 diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 34d9dbd..f781999 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -66,7 +66,7 @@ class DummyDictEnv(gym.Env): def step(self, action): reward = 0.0 - done = np.random.rand() > 0.8 + done = False return self.observation_space.sample(), reward, done, {} def compute_reward(self, achieved_goal, desired_goal, info): @@ -266,7 +266,7 @@ def test_vec_normalize(model_class): Additional tests for PPO/A2C/SAC/DDPG/TD3/DQN to check observation space support for GoalEnv and VecNormalize using MultiInputPolicy. """ - env = DummyVecEnv([lambda: DummyDictEnv(use_discrete_actions=model_class == DQN)]) + env = DummyVecEnv([lambda: gym.wrappers.TimeLimit(DummyDictEnv(use_discrete_actions=model_class == DQN), 100)]) env = VecNormalize(env, norm_obs_keys=["vec"]) kwargs = {} diff --git a/tests/test_gae.py b/tests/test_gae.py index 7f095c0..54e03b8 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -3,7 +3,7 @@ import numpy as np import pytest import torch as th -from stable_baselines3 import A2C, PPO +from stable_baselines3 import A2C, PPO, SAC from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.policies import ActorCriticPolicy @@ -35,6 +35,23 @@ class CustomEnv(gym.Env): return self.observation_space.sample(), reward, done, {} +class InfiniteHorizonEnv(gym.Env): + def __init__(self, n_states=4): + super().__init__() + self.n_states = n_states + self.observation_space = gym.spaces.Discrete(n_states) + self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + self.current_state = 0 + + def reset(self): + self.current_state = 0 + return self.current_state + + def step(self, action): + self.current_state = (self.current_state + 1) % self.n_states + return self.current_state, 1.0, False, {} + + class CheckGAECallback(BaseCallback): def __init__(self): super(CheckGAECallback, self).__init__(verbose=0) @@ -112,3 +129,43 @@ def test_gae_computation(model_class, gae_lambda, gamma, num_episodes): # Change constant value so advantage != returns model.policy.constant_value = 1.0 model.learn(rollout_size, callback=CheckGAECallback()) + + +@pytest.mark.parametrize("model_class", [A2C, SAC]) +@pytest.mark.parametrize("handle_timeout_termination", [False, True]) +def test_infinite_horizon(model_class, handle_timeout_termination): + max_steps = 8 + gamma = 0.98 + env = gym.wrappers.TimeLimit(InfiniteHorizonEnv(n_states=4), max_steps) + kwargs = {} + if model_class == SAC: + policy_kwargs = dict(net_arch=[64], n_critics=1) + kwargs = dict( + replay_buffer_kwargs=dict(handle_timeout_termination=handle_timeout_termination), + tau=0.5, + learning_rate=0.005, + ) + else: + policy_kwargs = dict(net_arch=[64]) + kwargs = dict(learning_rate=0.002) + # A2C always handle timeouts + if not handle_timeout_termination: + return + + model = model_class("MlpPolicy", env, gamma=gamma, seed=1, policy_kwargs=policy_kwargs, **kwargs) + model.learn(1500) + # Value of the initial state + obs_tensor = model.policy.obs_to_tensor(0)[0] + if model_class == A2C: + value = model.policy.predict_values(obs_tensor).item() + else: + value = model.critic(obs_tensor, model.actor(obs_tensor))[0].item() + # True value (geometric series with a reward of one at each step) + infinite_horizon_value = 1 / (1 - gamma) + + if handle_timeout_termination: + # true value +/- 1 + assert abs(infinite_horizon_value - value) < 1.0 + else: + # wrong estimation + assert abs(infinite_horizon_value - value) > 1.0 diff --git a/tests/test_identity.py b/tests/test_identity.py index 6226580..f5bbc49 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -22,6 +22,9 @@ def test_discrete(model_class, env): # DQN only support discrete actions if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)): return + elif model_class == A2C: + # slightly higher budget + n_steps = 3500 model = model_class("MlpPolicy", env_, gamma=0.4, seed=1, **kwargs).learn(n_steps)