mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
Add timeout handling for on-policy algorithms (#658)
* Add timeout handling for on-policy algorithms * Fixes * Fix infinite loop in eval * Skip type check for python 3.9 * Fix for discrete obs + add docstring * Fix A2C test * Removed unused helper * Add test for infinite horizon * typed ast should be fixed * Apply suggestions from code review Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Anssi <kaneran21@hotmail.com>
This commit is contained in:
parent
e75e1de4c1
commit
d228364ccf
8 changed files with 82 additions and 10 deletions
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
|
|
@ -40,8 +40,6 @@ jobs:
|
|||
- name: Type check
|
||||
run: |
|
||||
make type
|
||||
# Skip Type Check for python 3.9, see https://github.com/python/typed_ast/issues/169
|
||||
if: ${{ matrix.python-version != 3.9 }}
|
||||
- name: Check codestyle
|
||||
run: |
|
||||
make check-codestyle
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.3.1a1 (WIP)
|
||||
Release 1.3.1a2 (WIP)
|
||||
---------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -13,6 +13,8 @@ Breaking Changes:
|
|||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``norm_obs_keys`` param for ``VecNormalize`` wrapper to configure which observation keys to normalize (@kachayev)
|
||||
- Handle timeout termination properly for on-policy algorithms (when using ``TimeLimit``)
|
||||
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -133,7 +133,6 @@ class A2C(OnPolicyAlgorithm):
|
|||
# Convert discrete action from float to long
|
||||
actions = actions.long().flatten()
|
||||
|
||||
# TODO: avoid second computation of everything because of the gradient
|
||||
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
|
||||
values = values.flatten()
|
||||
|
||||
|
|
|
|||
|
|
@ -190,14 +190,27 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
if isinstance(self.action_space, gym.spaces.Discrete):
|
||||
# Reshape in case of discrete action
|
||||
actions = actions.reshape(-1, 1)
|
||||
|
||||
# Handle timeout by bootstraping with value function
|
||||
# see GitHub issue #633
|
||||
for idx, done_ in enumerate(dones):
|
||||
if (
|
||||
done_
|
||||
and infos[idx].get("terminal_observation") is not None
|
||||
and infos[idx].get("TimeLimit.truncated", False)
|
||||
):
|
||||
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
|
||||
with th.no_grad():
|
||||
terminal_value = self.policy.predict_values(terminal_obs)[0]
|
||||
rewards[idx] += self.gamma * terminal_value
|
||||
|
||||
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs)
|
||||
self._last_obs = new_obs
|
||||
self._last_episode_starts = dones
|
||||
|
||||
with th.no_grad():
|
||||
# Compute value for the last timestep
|
||||
obs_tensor = obs_as_tensor(new_obs, self.device)
|
||||
_, values, _ = self.policy.forward(obs_tensor)
|
||||
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))
|
||||
|
||||
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.3.1a1
|
||||
1.3.1a2
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class DummyDictEnv(gym.Env):
|
|||
|
||||
def step(self, action):
|
||||
reward = 0.0
|
||||
done = np.random.rand() > 0.8
|
||||
done = False
|
||||
return self.observation_space.sample(), reward, done, {}
|
||||
|
||||
def compute_reward(self, achieved_goal, desired_goal, info):
|
||||
|
|
@ -266,7 +266,7 @@ def test_vec_normalize(model_class):
|
|||
Additional tests for PPO/A2C/SAC/DDPG/TD3/DQN to check observation space support
|
||||
for GoalEnv and VecNormalize using MultiInputPolicy.
|
||||
"""
|
||||
env = DummyVecEnv([lambda: DummyDictEnv(use_discrete_actions=model_class == DQN)])
|
||||
env = DummyVecEnv([lambda: gym.wrappers.TimeLimit(DummyDictEnv(use_discrete_actions=model_class == DQN), 100)])
|
||||
env = VecNormalize(env, norm_obs_keys=["vec"])
|
||||
|
||||
kwargs = {}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch as th
|
||||
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3 import A2C, PPO, SAC
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
|
||||
|
|
@ -35,6 +35,23 @@ class CustomEnv(gym.Env):
|
|||
return self.observation_space.sample(), reward, done, {}
|
||||
|
||||
|
||||
class InfiniteHorizonEnv(gym.Env):
|
||||
def __init__(self, n_states=4):
|
||||
super().__init__()
|
||||
self.n_states = n_states
|
||||
self.observation_space = gym.spaces.Discrete(n_states)
|
||||
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.current_state = 0
|
||||
|
||||
def reset(self):
|
||||
self.current_state = 0
|
||||
return self.current_state
|
||||
|
||||
def step(self, action):
|
||||
self.current_state = (self.current_state + 1) % self.n_states
|
||||
return self.current_state, 1.0, False, {}
|
||||
|
||||
|
||||
class CheckGAECallback(BaseCallback):
|
||||
def __init__(self):
|
||||
super(CheckGAECallback, self).__init__(verbose=0)
|
||||
|
|
@ -112,3 +129,43 @@ def test_gae_computation(model_class, gae_lambda, gamma, num_episodes):
|
|||
# Change constant value so advantage != returns
|
||||
model.policy.constant_value = 1.0
|
||||
model.learn(rollout_size, callback=CheckGAECallback())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, SAC])
|
||||
@pytest.mark.parametrize("handle_timeout_termination", [False, True])
|
||||
def test_infinite_horizon(model_class, handle_timeout_termination):
|
||||
max_steps = 8
|
||||
gamma = 0.98
|
||||
env = gym.wrappers.TimeLimit(InfiniteHorizonEnv(n_states=4), max_steps)
|
||||
kwargs = {}
|
||||
if model_class == SAC:
|
||||
policy_kwargs = dict(net_arch=[64], n_critics=1)
|
||||
kwargs = dict(
|
||||
replay_buffer_kwargs=dict(handle_timeout_termination=handle_timeout_termination),
|
||||
tau=0.5,
|
||||
learning_rate=0.005,
|
||||
)
|
||||
else:
|
||||
policy_kwargs = dict(net_arch=[64])
|
||||
kwargs = dict(learning_rate=0.002)
|
||||
# A2C always handle timeouts
|
||||
if not handle_timeout_termination:
|
||||
return
|
||||
|
||||
model = model_class("MlpPolicy", env, gamma=gamma, seed=1, policy_kwargs=policy_kwargs, **kwargs)
|
||||
model.learn(1500)
|
||||
# Value of the initial state
|
||||
obs_tensor = model.policy.obs_to_tensor(0)[0]
|
||||
if model_class == A2C:
|
||||
value = model.policy.predict_values(obs_tensor).item()
|
||||
else:
|
||||
value = model.critic(obs_tensor, model.actor(obs_tensor))[0].item()
|
||||
# True value (geometric series with a reward of one at each step)
|
||||
infinite_horizon_value = 1 / (1 - gamma)
|
||||
|
||||
if handle_timeout_termination:
|
||||
# true value +/- 1
|
||||
assert abs(infinite_horizon_value - value) < 1.0
|
||||
else:
|
||||
# wrong estimation
|
||||
assert abs(infinite_horizon_value - value) > 1.0
|
||||
|
|
|
|||
|
|
@ -22,6 +22,9 @@ def test_discrete(model_class, env):
|
|||
# DQN only support discrete actions
|
||||
if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
|
||||
return
|
||||
elif model_class == A2C:
|
||||
# slightly higher budget
|
||||
n_steps = 3500
|
||||
|
||||
model = model_class("MlpPolicy", env_, gamma=0.4, seed=1, **kwargs).learn(n_steps)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue