mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-26 03:01:19 +00:00
Add support for Discrete observation spaces
This commit is contained in:
parent
52d2cd6a1b
commit
fa599c65a6
7 changed files with 62 additions and 19 deletions
|
|
@ -13,6 +13,7 @@ Breaking Changes:
|
|||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Add support for Discrete observation spaces
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
|
|||
41
tests/test_identity.py
Normal file
41
tests/test_identity.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from torchy_baselines import A2C, PPO, SAC, TD3
|
||||
from torchy_baselines.common.identity_env import IdentityEnvBox, IdentityEnv
|
||||
from torchy_baselines.common.evaluation import evaluate_policy
|
||||
from torchy_baselines.common.noise import NormalActionNoise
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
||||
def test_discrete(model_class):
|
||||
env = IdentityEnv(10)
|
||||
model = model_class('MlpPolicy', env, gamma=0.5, seed=0).learn(3000)
|
||||
|
||||
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])
|
||||
def test_continuous(model_class):
|
||||
env = IdentityEnvBox(eps=0.5)
|
||||
|
||||
n_steps = {
|
||||
A2C: 3000,
|
||||
PPO: 3000,
|
||||
SAC: 500,
|
||||
TD3: 500
|
||||
}[model_class]
|
||||
|
||||
kwargs = dict(
|
||||
policy_kwargs=dict(net_arch=[64, 64]),
|
||||
seed=0,
|
||||
gamma=0.95
|
||||
)
|
||||
if model_class in [TD3]:
|
||||
n_actions = 1
|
||||
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
|
||||
kwargs['action_noise'] = action_noise
|
||||
|
||||
model = model_class('MlpPolicy', env, **kwargs).learn(n_steps)
|
||||
|
||||
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)
|
||||
|
|
@ -5,8 +5,8 @@ import pytest
|
|||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.logger import (make_output_format, read_csv, read_json, DEBUG, ScopedConfigure,
|
||||
info, debug, set_level, configure, logkv, logkvs, dumpkvs, logkv_mean, warn,
|
||||
error, reset)
|
||||
info, debug, set_level, configure, logkv, logkvs,
|
||||
dumpkvs, logkv_mean, warn, error, reset)
|
||||
|
||||
KEY_VALUES = {
|
||||
"test": 1,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import numpy as np
|
||||
import os
|
||||
import pytest
|
||||
import torch as th
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
||||
from torchy_baselines import A2C, PPO, SAC, TD3
|
||||
from torchy_baselines.common.identity_env import IdentityEnvBox
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv
|
||||
|
|
|
|||
|
|
@ -338,15 +338,15 @@ class BaseRLModel(ABC):
|
|||
"Box environment, please use {} ".format(observation_space.shape) +
|
||||
"or (n_env, {}) for the observation shape."
|
||||
.format(", ".join(map(str, observation_space.shape))))
|
||||
# TODO: add support for Discrete, MultiDiscrete and MultiBinary observation spaces
|
||||
# elif isinstance(observation_space, gym.spaces.Discrete):
|
||||
# if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
|
||||
# return False
|
||||
# elif len(observation.shape) == 1:
|
||||
# return True
|
||||
# else:
|
||||
# raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
|
||||
# "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
|
||||
elif isinstance(observation_space, gym.spaces.Discrete):
|
||||
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
|
||||
return False
|
||||
elif len(observation.shape) == 1:
|
||||
return True
|
||||
else:
|
||||
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
|
||||
"Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
|
||||
# TODO: add support for MultiDiscrete and MultiBinary observation spaces
|
||||
# elif isinstance(observation_space, gym.spaces.MultiDiscrete):
|
||||
# if observation.shape == (len(observation_space.nvec),):
|
||||
# return False
|
||||
|
|
|
|||
|
|
@ -49,8 +49,8 @@ def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True,
|
|||
mean_reward = np.mean(episode_rewards)
|
||||
std_reward = np.std(episode_rewards)
|
||||
if reward_threshold is not None:
|
||||
assert mean_reward > reward_threshold, (f'Mean reward below threshold: '
|
||||
'{mean_reward:.2f} < {reward_threshold:.2f}')
|
||||
assert mean_reward > reward_threshold, ('Mean reward below threshold: '
|
||||
f'{mean_reward:.2f} < {reward_threshold:.2f}')
|
||||
if return_episode_rewards:
|
||||
return episode_rewards, episode_lengths
|
||||
return mean_reward, std_reward
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space,
|
|||
return obs.float()
|
||||
elif isinstance(observation_space, spaces.Discrete):
|
||||
# One hot encoding and convert to float to avoid errors
|
||||
return F.one_hot(obs, num_classes=observation_space.n).float()
|
||||
return F.one_hot(obs.long(), num_classes=observation_space.n).float()
|
||||
else:
|
||||
# TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict
|
||||
raise NotImplementedError()
|
||||
|
|
@ -88,8 +88,8 @@ def get_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]:
|
|||
# raise NotImplementedError()
|
||||
return np.prod(observation_space.shape)
|
||||
elif isinstance(observation_space, spaces.Discrete):
|
||||
# Observation is an int
|
||||
return 1
|
||||
# Observation is a one hot vector
|
||||
return observation_space.n
|
||||
else:
|
||||
# TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
Loading…
Reference in a new issue