Add timeout handling for on-policy algorithms (#658)

* Add timeout handling for on-policy algorithms

* Fixes

* Fix infinite loop in eval

* Skip type check for python 3.9

* Fix for discrete obs + add docstring

* Fix A2C test

* Removed unused helper

* Add test for infinite horizon

* typed ast should be fixed

* Apply suggestions from code review

Co-authored-by: Anssi <kaneran21@hotmail.com>

Co-authored-by: Anssi <kaneran21@hotmail.com>
This commit is contained in:
Antonin RAFFIN 2021-11-16 17:19:16 +01:00 committed by GitHub
parent e75e1de4c1
commit d228364ccf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 82 additions and 10 deletions

View file

@ -40,8 +40,6 @@ jobs:
- name: Type check
run: |
make type
# Skip Type Check for python 3.9, see https://github.com/python/typed_ast/issues/169
if: ${{ matrix.python-version != 3.9 }}
- name: Check codestyle
run: |
make check-codestyle

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.3.1a1 (WIP)
Release 1.3.1a2 (WIP)
---------------------------
Breaking Changes:
@ -13,6 +13,8 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Added ``norm_obs_keys`` param for ``VecNormalize`` wrapper to configure which observation keys to normalize (@kachayev)
- Handle timeout termination properly for on-policy algorithms (when using ``TimeLimit``)
Bug Fixes:
^^^^^^^^^^

View file

@ -133,7 +133,6 @@ class A2C(OnPolicyAlgorithm):
# Convert discrete action from float to long
actions = actions.long().flatten()
# TODO: avoid second computation of everything because of the gradient
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
values = values.flatten()

View file

@ -190,14 +190,27 @@ class OnPolicyAlgorithm(BaseAlgorithm):
if isinstance(self.action_space, gym.spaces.Discrete):
# Reshape in case of discrete action
actions = actions.reshape(-1, 1)
# Handle timeout by bootstraping with value function
# see GitHub issue #633
for idx, done_ in enumerate(dones):
if (
done_
and infos[idx].get("terminal_observation") is not None
and infos[idx].get("TimeLimit.truncated", False)
):
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
with th.no_grad():
terminal_value = self.policy.predict_values(terminal_obs)[0]
rewards[idx] += self.gamma * terminal_value
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs)
self._last_obs = new_obs
self._last_episode_starts = dones
with th.no_grad():
# Compute value for the last timestep
obs_tensor = obs_as_tensor(new_obs, self.device)
_, values, _ = self.policy.forward(obs_tensor)
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)

View file

@ -1 +1 @@
1.3.1a1
1.3.1a2

View file

@ -66,7 +66,7 @@ class DummyDictEnv(gym.Env):
def step(self, action):
reward = 0.0
done = np.random.rand() > 0.8
done = False
return self.observation_space.sample(), reward, done, {}
def compute_reward(self, achieved_goal, desired_goal, info):
@ -266,7 +266,7 @@ def test_vec_normalize(model_class):
Additional tests for PPO/A2C/SAC/DDPG/TD3/DQN to check observation space support
for GoalEnv and VecNormalize using MultiInputPolicy.
"""
env = DummyVecEnv([lambda: DummyDictEnv(use_discrete_actions=model_class == DQN)])
env = DummyVecEnv([lambda: gym.wrappers.TimeLimit(DummyDictEnv(use_discrete_actions=model_class == DQN), 100)])
env = VecNormalize(env, norm_obs_keys=["vec"])
kwargs = {}

View file

@ -3,7 +3,7 @@ import numpy as np
import pytest
import torch as th
from stable_baselines3 import A2C, PPO
from stable_baselines3 import A2C, PPO, SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy
@ -35,6 +35,23 @@ class CustomEnv(gym.Env):
return self.observation_space.sample(), reward, done, {}
class InfiniteHorizonEnv(gym.Env):
def __init__(self, n_states=4):
super().__init__()
self.n_states = n_states
self.observation_space = gym.spaces.Discrete(n_states)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.current_state = 0
def reset(self):
self.current_state = 0
return self.current_state
def step(self, action):
self.current_state = (self.current_state + 1) % self.n_states
return self.current_state, 1.0, False, {}
class CheckGAECallback(BaseCallback):
def __init__(self):
super(CheckGAECallback, self).__init__(verbose=0)
@ -112,3 +129,43 @@ def test_gae_computation(model_class, gae_lambda, gamma, num_episodes):
# Change constant value so advantage != returns
model.policy.constant_value = 1.0
model.learn(rollout_size, callback=CheckGAECallback())
@pytest.mark.parametrize("model_class", [A2C, SAC])
@pytest.mark.parametrize("handle_timeout_termination", [False, True])
def test_infinite_horizon(model_class, handle_timeout_termination):
max_steps = 8
gamma = 0.98
env = gym.wrappers.TimeLimit(InfiniteHorizonEnv(n_states=4), max_steps)
kwargs = {}
if model_class == SAC:
policy_kwargs = dict(net_arch=[64], n_critics=1)
kwargs = dict(
replay_buffer_kwargs=dict(handle_timeout_termination=handle_timeout_termination),
tau=0.5,
learning_rate=0.005,
)
else:
policy_kwargs = dict(net_arch=[64])
kwargs = dict(learning_rate=0.002)
# A2C always handle timeouts
if not handle_timeout_termination:
return
model = model_class("MlpPolicy", env, gamma=gamma, seed=1, policy_kwargs=policy_kwargs, **kwargs)
model.learn(1500)
# Value of the initial state
obs_tensor = model.policy.obs_to_tensor(0)[0]
if model_class == A2C:
value = model.policy.predict_values(obs_tensor).item()
else:
value = model.critic(obs_tensor, model.actor(obs_tensor))[0].item()
# True value (geometric series with a reward of one at each step)
infinite_horizon_value = 1 / (1 - gamma)
if handle_timeout_termination:
# true value +/- 1
assert abs(infinite_horizon_value - value) < 1.0
else:
# wrong estimation
assert abs(infinite_horizon_value - value) > 1.0

View file

@ -22,6 +22,9 @@ def test_discrete(model_class, env):
# DQN only support discrete actions
if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
return
elif model_class == A2C:
# slightly higher budget
n_steps = 3500
model = model_class("MlpPolicy", env_, gamma=0.4, seed=1, **kwargs).learn(n_steps)