mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
* Add rollout_buffer_class and rollout_buffer_kwargs parameters to OnPolicyAlgorithm * Add rollout_buffer_class and rollout_buffer_kwargs to PPO. * Add rollout_buffer_class and rollout_buffer_kwargs to A2C. * Make use of the rollout buffer kwargs. * Update version * Add test and update doc --------- Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
166 lines
6.4 KiB
Python
166 lines
6.4 KiB
Python
import gymnasium as gym
|
|
import numpy as np
|
|
import pytest
|
|
import torch as th
|
|
from gymnasium import spaces
|
|
|
|
from stable_baselines3 import A2C
|
|
from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer
|
|
from stable_baselines3.common.env_checker import check_env
|
|
from stable_baselines3.common.env_util import make_vec_env
|
|
from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples
|
|
from stable_baselines3.common.utils import get_device
|
|
from stable_baselines3.common.vec_env import VecNormalize
|
|
|
|
|
|
class DummyEnv(gym.Env):
|
|
"""
|
|
Custom gym environment for testing purposes
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.action_space = spaces.Box(1, 5, (1,))
|
|
self.observation_space = spaces.Box(1, 5, (1,))
|
|
self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32)
|
|
self._rewards = [1, 2, 3, 4, 5]
|
|
self._t = 0
|
|
self._ep_length = 100
|
|
|
|
def reset(self, *, seed=None, options=None):
|
|
self._t = 0
|
|
obs = self._observations[0]
|
|
return obs, {}
|
|
|
|
def step(self, action):
|
|
self._t += 1
|
|
index = self._t % len(self._observations)
|
|
obs = self._observations[index]
|
|
terminated = False
|
|
truncated = self._t >= self._ep_length
|
|
reward = self._rewards[index]
|
|
return obs, reward, terminated, truncated, {}
|
|
|
|
|
|
class DummyDictEnv(gym.Env):
|
|
"""
|
|
Custom gym environment for testing purposes
|
|
"""
|
|
|
|
def __init__(self):
|
|
# Test for multi-dim action space
|
|
self.action_space = spaces.Box(1, 5, shape=(10, 7))
|
|
space = spaces.Box(1, 5, (1,))
|
|
self.observation_space = spaces.Dict({"observation": space, "achieved_goal": space, "desired_goal": space})
|
|
self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32)
|
|
self._rewards = [1, 2, 3, 4, 5]
|
|
self._t = 0
|
|
self._ep_length = 100
|
|
|
|
def reset(self, seed=None, options=None):
|
|
self._t = 0
|
|
obs = {key: self._observations[0] for key in self.observation_space.spaces.keys()}
|
|
return obs, {}
|
|
|
|
def step(self, action):
|
|
self._t += 1
|
|
index = self._t % len(self._observations)
|
|
obs = {key: self._observations[index] for key in self.observation_space.spaces.keys()}
|
|
terminated = False
|
|
truncated = self._t >= self._ep_length
|
|
reward = self._rewards[index]
|
|
return obs, reward, terminated, truncated, {}
|
|
|
|
|
|
@pytest.mark.parametrize("env_cls", [DummyEnv, DummyDictEnv])
|
|
def test_env(env_cls):
|
|
# Check the env used for testing
|
|
# Do not warn for assymetric space
|
|
check_env(env_cls(), warn=False, skip_render_check=True)
|
|
|
|
|
|
@pytest.mark.parametrize("replay_buffer_cls", [ReplayBuffer, DictReplayBuffer])
|
|
def test_replay_buffer_normalization(replay_buffer_cls):
|
|
env = {ReplayBuffer: DummyEnv, DictReplayBuffer: DummyDictEnv}[replay_buffer_cls]
|
|
env = make_vec_env(env)
|
|
env = VecNormalize(env)
|
|
|
|
buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device="cpu")
|
|
|
|
# Interract and store transitions
|
|
env.reset()
|
|
obs = env.get_original_obs()
|
|
for _ in range(100):
|
|
action = env.action_space.sample()
|
|
_, _, done, info = env.step(action)
|
|
next_obs = env.get_original_obs()
|
|
reward = env.get_original_reward()
|
|
buffer.add(obs, next_obs, action, reward, done, info)
|
|
obs = next_obs
|
|
|
|
sample = buffer.sample(50, env)
|
|
# Test observation normalization
|
|
for observations in [sample.observations, sample.next_observations]:
|
|
if isinstance(sample, DictReplayBufferSamples):
|
|
for key in observations.keys():
|
|
assert th.allclose(observations[key].mean(0), th.zeros(1), atol=1)
|
|
elif isinstance(sample, ReplayBufferSamples):
|
|
assert th.allclose(observations.mean(0), th.zeros(1), atol=1)
|
|
# Test reward normalization
|
|
assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1)
|
|
|
|
|
|
@pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer])
|
|
@pytest.mark.parametrize("device", ["cpu", "cuda", "auto"])
|
|
def test_device_buffer(replay_buffer_cls, device):
|
|
if device == "cuda" and not th.cuda.is_available():
|
|
pytest.skip("CUDA not available")
|
|
|
|
env = {
|
|
RolloutBuffer: DummyEnv,
|
|
DictRolloutBuffer: DummyDictEnv,
|
|
ReplayBuffer: DummyEnv,
|
|
DictReplayBuffer: DummyDictEnv,
|
|
}[replay_buffer_cls]
|
|
env = make_vec_env(env)
|
|
|
|
buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device)
|
|
|
|
# Interract and store transitions
|
|
obs = env.reset()
|
|
for _ in range(100):
|
|
action = env.action_space.sample()
|
|
next_obs, reward, done, info = env.step(action)
|
|
if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
|
|
episode_start, values, log_prob = np.zeros(1), th.zeros(1), th.ones(1)
|
|
buffer.add(obs, action, reward, episode_start, values, log_prob)
|
|
else:
|
|
buffer.add(obs, next_obs, action, reward, done, info)
|
|
obs = next_obs
|
|
|
|
# Get data from the buffer
|
|
if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
|
|
data = buffer.get(50)
|
|
elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]:
|
|
data = buffer.sample(50)
|
|
|
|
# Check that all data are on the desired device
|
|
desired_device = get_device(device).type
|
|
for value in list(data):
|
|
if isinstance(value, dict):
|
|
for key in value.keys():
|
|
assert value[key].device.type == desired_device
|
|
elif isinstance(value, th.Tensor):
|
|
assert value.device.type == desired_device
|
|
|
|
|
|
def test_custom_rollout_buffer():
|
|
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict())
|
|
|
|
with pytest.raises(TypeError, match="unexpected keyword argument 'wrong_keyword'"):
|
|
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict(wrong_keyword=1))
|
|
|
|
with pytest.raises(TypeError, match="got multiple values for keyword argument 'gamma'"):
|
|
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict(gamma=1))
|
|
|
|
with pytest.raises(AssertionError, match="DictRolloutBuffer must be used with Dict obs space only"):
|
|
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=DictRolloutBuffer)
|