Add support for Discrete observation spaces

This commit is contained in:
Antonin RAFFIN 2020-03-25 16:42:05 +01:00
parent 52d2cd6a1b
commit fa599c65a6
7 changed files with 62 additions and 19 deletions

View file

@ -13,6 +13,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Add support for Discrete observation spaces
Bug Fixes:
^^^^^^^^^^

41
tests/test_identity.py Normal file
View file

@ -0,0 +1,41 @@
import numpy as np
import pytest
from torchy_baselines import A2C, PPO, SAC, TD3
from torchy_baselines.common.identity_env import IdentityEnvBox, IdentityEnv
from torchy_baselines.common.evaluation import evaluate_policy
from torchy_baselines.common.noise import NormalActionNoise
@pytest.mark.parametrize("model_class", [A2C, PPO])
def test_discrete(model_class):
env = IdentityEnv(10)
model = model_class('MlpPolicy', env, gamma=0.5, seed=0).learn(3000)
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])
def test_continuous(model_class):
env = IdentityEnvBox(eps=0.5)
n_steps = {
A2C: 3000,
PPO: 3000,
SAC: 500,
TD3: 500
}[model_class]
kwargs = dict(
policy_kwargs=dict(net_arch=[64, 64]),
seed=0,
gamma=0.95
)
if model_class in [TD3]:
n_actions = 1
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
kwargs['action_noise'] = action_noise
model = model_class('MlpPolicy', env, **kwargs).learn(n_steps)
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)

View file

@ -5,8 +5,8 @@ import pytest
import numpy as np
from torchy_baselines.common.logger import (make_output_format, read_csv, read_json, DEBUG, ScopedConfigure,
info, debug, set_level, configure, logkv, logkvs, dumpkvs, logkv_mean, warn,
error, reset)
info, debug, set_level, configure, logkv, logkvs,
dumpkvs, logkv_mean, warn, error, reset)
KEY_VALUES = {
"test": 1,

View file

@ -1,9 +1,10 @@
import numpy as np
import os
import pytest
import torch as th
from copy import deepcopy
import pytest
import numpy as np
import torch as th
from torchy_baselines import A2C, PPO, SAC, TD3
from torchy_baselines.common.identity_env import IdentityEnvBox
from torchy_baselines.common.vec_env import DummyVecEnv

View file

@ -338,15 +338,15 @@ class BaseRLModel(ABC):
"Box environment, please use {} ".format(observation_space.shape) +
"or (n_env, {}) for the observation shape."
.format(", ".join(map(str, observation_space.shape))))
# TODO: add support for Discrete, MultiDiscrete and MultiBinary observation spaces
# elif isinstance(observation_space, gym.spaces.Discrete):
# if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
# return False
# elif len(observation.shape) == 1:
# return True
# else:
# raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
# "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
elif isinstance(observation_space, gym.spaces.Discrete):
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
return False
elif len(observation.shape) == 1:
return True
else:
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
"Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
# TODO: add support for MultiDiscrete and MultiBinary observation spaces
# elif isinstance(observation_space, gym.spaces.MultiDiscrete):
# if observation.shape == (len(observation_space.nvec),):
# return False

View file

@ -49,8 +49,8 @@ def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True,
mean_reward = np.mean(episode_rewards)
std_reward = np.std(episode_rewards)
if reward_threshold is not None:
assert mean_reward > reward_threshold, (f'Mean reward below threshold: '
'{mean_reward:.2f} < {reward_threshold:.2f}')
assert mean_reward > reward_threshold, ('Mean reward below threshold: '
f'{mean_reward:.2f} < {reward_threshold:.2f}')
if return_episode_rewards:
return episode_rewards, episode_lengths
return mean_reward, std_reward

View file

@ -52,7 +52,7 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space,
return obs.float()
elif isinstance(observation_space, spaces.Discrete):
# One hot encoding and convert to float to avoid errors
return F.one_hot(obs, num_classes=observation_space.n).float()
return F.one_hot(obs.long(), num_classes=observation_space.n).float()
else:
# TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict
raise NotImplementedError()
@ -88,8 +88,8 @@ def get_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]:
# raise NotImplementedError()
return np.prod(observation_space.shape)
elif isinstance(observation_space, spaces.Discrete):
# Observation is an int
return 1
# Observation is a one hot vector
return observation_space.n
else:
# TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict
raise NotImplementedError()